diff --git a/.bazelrc b/.bazelrc index c70c57136102b483a4332ca22f775d7a2c5b849e..1a9c46362e530ab8345d40845b7c2a954133823e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,6 +10,9 @@ build:android_arm64 --config=android build:android_arm64 --cpu=arm64-v8a build:android_arm64 --fat_apk_cpu=arm64-v8a +# Sets the default Apple platform to macOS. +build --apple_platform_type=macos + # Config to use a mostly-static build and disable modular op registration # support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). # By default, TensorFlow will build with a dependence on @@ -67,6 +70,7 @@ build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true build:gdr --define=with_gdr_support=true build:ngraph --define=with_ngraph_support=true build:verbs --define=with_verbs_support=true +build:numa --define=with_numa_support=true # Options to disable default on features build:noaws --define=no_aws_support=true @@ -90,11 +94,20 @@ build --define=grpc_no_ares=true build:dynamic_kernels --define=dynamic_loaded_kernels=true build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS +# Build TF with C++ 17 features. +build:c++17 --cxxopt=-std=c++1z +build:c++17 --cxxopt=-stdlib=libc++ +build:c++1z --cxxopt=-std=c++1z +build:c++1z --cxxopt=-stdlib=libc++ + # Default paths for TF_SYSTEM_LIBS build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include +# Disable MKL-DNN contraction kernels by default. +build --define=tensorflow_mkldnn_contraction_kernel=0 + # Default options should come above this line # Options from ./configure diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4a296f265f7b9521c46d350cec26ff199f43eb6c..73782143a3d4b1742f33bb96845ed300eedb6f50 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,24 +55,28 @@ TensorFlow coding style. #### General guidelines and philosophy for contribution -* Include unit tests when you contribute new features, as they help to - a) prove that your code works correctly, and b) guard against future breaking - changes to lower the maintenance cost. -* Bug fixes also generally require unit tests, because the presence of bugs - usually indicates insufficient test coverage. -* Keep API compatibility in mind when you change code in core TensorFlow, - e.g., code in [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python). - TensorFlow has reached version 1 and hence cannot make - non-backward-compatible API changes without a major release. Reviewers of your - pull request will comment on any API compatibility issues. -* When you contribute a new feature to TensorFlow, the maintenance burden is (by - default) transferred to the TensorFlow team. This means that benefit of the - contribution must be compared against the cost of maintaining the feature. -* Full new features (e.g., a new op implementing a cutting-edge algorithm) - typically will live in - [tensorflow/contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib) - to get some airtime before decision is made regarding whether they are to be - migrated to the core. +* Include unit tests when you contribute new features, as they help to a) + prove that your code works correctly, and b) guard against future breaking + changes to lower the maintenance cost. +* Bug fixes also generally require unit tests, because the presence of bugs + usually indicates insufficient test coverage. +* Keep API compatibility in mind when you change code in core TensorFlow, + e.g., code in + [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) + and + [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python). + TensorFlow has reached version 1 and hence cannot make + non-backward-compatible API changes without a major release. Reviewers of + your pull request will comment on any API compatibility issues. +* When you contribute a new feature to TensorFlow, the maintenance burden is + (by default) transferred to the TensorFlow team. This means that benefit of + the contribution must be compared against the cost of maintaining the + feature. +* Full new features (e.g., a new op implementing a cutting-edge algorithm) + typically will live in + [tensorflow/addons](https://github.com/tensorflow/addons) to get some + airtime before decision is made regarding whether they are to be migrated to + the core. #### License @@ -150,41 +154,45 @@ may exist in your changes. There are two ways to run TensorFlow unit tests. -1. Using tools and libraries installed directly on your system. +1. Using tools and libraries installed directly on your system. - Refer to the - [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and - [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) - for the required packages. Alternatively, use the said - [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., - `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` - for development to avoid installing the packages directly on your system. + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) + and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and + `tensorflow/tensorflow:nightly-devel-gpu` for development to avoid + installing the packages directly on your system (in which case remember to + change directory from `/root` to `/tensorflow` once you get into the running + container so `bazel` can find the `tensorflow` workspace). - Once you have the packages installed, you can run a specific unit test in - bazel by doing as follows: + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: - If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add - the `cuda` option flag + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag - ```bash - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" - ``` + export flags="--config=opt --config=cuda -k" + ``` - For example, to run all tests under tensorflow/python, do: + For example, to run all tests under tensorflow/python, do: - ```bash - bazel test ${flags} //tensorflow/python/... - ``` + ```bash + bazel test ${flags} //tensorflow/python/... + ``` -2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. +2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. - ```bash - # Install Docker first, then this will build and run cpu tests - tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... - ``` - - See - [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + ```bash + # Install Docker first, then this will build and run cpu tests + tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... + ``` + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) + for details. diff --git a/METADATA b/METADATA index d452effd8df8eda2b118bdb05cdba5a5676a83b4..7c290b77762322a54842e44fdd5d25089f7e82e2 100644 --- a/METADATA +++ b/METADATA @@ -23,7 +23,7 @@ third_party { type: GIT value: "https://github.com/tensorflow/tensorflow" } - version: "v1.4.0" - last_upgrade_date { year: 2017 month: 11 day: 13 } + version: "v1.13.0" + last_upgrade_date { year: 2019 month: 3 day: 15 } license_type: NOTICE } diff --git a/README.md b/README.md index 4e37b239b16e6eeefc587aeb242a03e1f88eddbd..96a8ecf4f693d5634da63f4ecc6f4e9c35751f5b 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,8 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift. +TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards +compatible API's for C++, Go, Java, JavaScript and Swift. Keep up to date with release announcements and security updates by subscribing to diff --git a/RELEASE.md b/RELEASE.md index 0a56e6909870e398c9d6349576cd2f8e6734f072..48d7e3140aacf526412dfc7999e3e0c0931df89e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,118 @@ +# Release 1.13.0 + +## Major Features and Improvements + +* TensorFlow Lite has moved from contrib to core. This means that Python modules are under `tf.lite` and source code is now under `tensorflow/lite` rather than `tensorflow/contrib/lite`. +* TensorFlow GPU binaries are now built against CUDA 10 and TensorRT 5.0. +* Support for Python3.7 on all operating systems. +* Moved NCCL to core. + +## Behavioral changes + +* Disallow conversion of python floating types to uint32/64 (matching behavior of other integer types) in `tf.constant`. +* Make the `gain` argument of convolutional orthogonal initializers (`convolutional_delta_orthogonal`, `convolutional_orthogonal_1D`, `convolutional_orthogonal_2D`, `convolutional_orthogonal_3D`) have consistent behavior with the `tf.initializers.orthogonal` initializer, i.e. scale the output l2-norm by `gain` and NOT by `sqrt(gain)`. (Note that these functions are currently in `tf.contrib` which is not guaranteed backward compatible). + +## Bug Fixes and Other Changes + +* Documentation + * Update the doc with the details about the rounding mode used in quantize_and_dequantize_v2. + * Clarify that tensorflow::port::InitMain() _should_ be called before using the TensorFlow library. Programs failing to do this are not portable to all platforms. +* Deprecations and Symbol renames. + * Removing deprecations for the following endpoints: `tf.acos`, `tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`, `tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`, `tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`, `tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`, `tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`, `tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan` + * Deprecate `tf.data.Dataset.shard`. + * Deprecate `saved_model.loader.load` which is replaced by `saved_model.load` and `saved_model.main_op`, which will be replaced by `saved_model.main_op` in V2. + * Deprecate tf.QUANTIZED_DTYPES. The official new symbol is tf.dtypes.QUANTIZED_DTYPES. + * Update sklearn imports for deprecated packages. + * Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of `Dataset.range`. + * Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of `tf.train.confusion_matrix`. + * Add `tf.dtypes.` endpoint for every constant in dtypes.py; moving endpoints in versions.py to corresponding endpoints in `tf.sysconfig.` and `tf.version.`; moving all constants under `tf.saved_model` submodules to `tf.saved_model` module. New endpoints are added in V1 and V2 but existing endpoint removals are only applied in V2. + * Deprecates behavior where device assignment overrides collocation constraints inside a collocation context manager. +* Keras & Python API + * Add to Keras functionality analogous to `tf.register_tensor_conversion_function`. + * Subclassed Keras models can now be saved through `tf.contrib.saved_model.save_keras_model`. + * `LinearOperator.matmul` now returns a new `LinearOperator`. +* New ops and improved op functionality + * Add a Nearest Neighbor Resize op. + * Add an `ignore_unknown` argument to `parse_values` which suppresses ValueError for unknown hyperparameter types. Such * Add `tf.linalg.matvec` convenience function. + * `tf.einsum()`raises `ValueError` for unsupported equations like `"ii->"`. + * Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`. + * Add LU decomposition op. + * Add quantile loss to gradient boosted trees in estimator. + * Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding algorithm. + * Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`, `unicode_split`, `unicode_split_with_offset`, and `unicode_transcode` ops. Amongst other things, this Op adds the ability to encode, decode, and transcode a variety of input text encoding formats into the main Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE) + * Add "unit" attribute to the substr op, which allows obtaining the substring of a string containing unicode characters. + * Broadcasting support for Ragged Tensors. + * `SpaceToDepth` supports uint8 data type. + * Support multi-label quantile regression in estimator. + * We now use "div" as the default partition_strategy in `tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and `tf.nn.nce_loss`. + hyperparameter are ignored. +* Performance + * Improve performance of GPU cumsum/cumprod by up to 300x. + * Added support for weight decay in most TPU embedding optimizers, including AdamW and MomentumW. +* TensorFlow 2.0 Development + * Add a command line tool to convert to TF2.0, tf_upgrade_v2 + * Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0. + * Change the default recurrent activation function for LSTM from 'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is 'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we change the default for CPU mode to sigmoid as well. With that, the default LSTM will be compatible with both CPU and GPU kernel. This will enable user with GPU to use CuDNN kernel by default and get a 10x performance boost in training. Note that this is checkpoint breaking change. If user want to use their 1.x pre-trained checkpoint, please construct the layer with LSTM(recurrent_activation='hard_sigmoid') to fallback to 1.x behavior. +* TensorFlow Lite + * Move from `tensorflow/contrib/lite` to `tensorflow/lite`. + * Add experimental Java API for injecting TensorFlow Lite delegates + * Add support for strings in TensorFlow Lite Java API. +* `tf.contrib`: + * Add Apache Ignite Filesystem plugin to support accessing Apache IGFS. + * Dropout now takes `rate` argument, `keep_prob` is deprecated. + * Estimator occurrences references `tf.contrib.estimator` were changed to `tf.estimator`: + * `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator` + * `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator` + * `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator` + * `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator` + * `tf.contrib.estimator.InMemoryEvaluatorHook` and tf.estimator.experimental.InMemoryEvaluatorHook`. + * `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`. + * Expose `tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy. + * Migrate linear optimizer from contrib to core. + * Move `tf.contrib.signal` to `tf.signal` (preserving aliases in tf.contrib.signal). + * Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`. +* tf.data: + * Add `tf.data.experimental.StatsOptions()`, to configure options to collect statistics from `tf.data.Dataset` pipeline using `StatsAggregator`. Add nested option, `experimental_stats` (which takes a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`. Deprecates `tf.data.experimental.set_stats_agregator`. + * Performance optimizations: + * Add `tf.data.experimental.OptimizationOptions()`, to configure options to enable `tf.data` performance optimizations. Add nested option, `experimental_optimization` (which takes a `tf.data.experimental.OptimizationOptions` object), to `tf.data.Options`. Remove performance optimization options from `tf.data.Options`, and add them under `tf.data.experimental.OptimizationOptions` instead. + * Enable `map_and_batch_fusion` and `noop_elimination` optimizations by default. They can be disabled by configuring `tf.data.experimental.OptimizationOptions` to set `map_and_batch = False` or `noop_elimination = False` respectively. To disable all default optimizations, set `apply_default_optimizations = False`. + * Support parallel map in `map_and_filter_fusion`. + * Disable static optimizations for input pipelines that use non-resource `tf.Variable`s. + * Add NUMA-aware MapAndBatch dataset. + * Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it from V2, and added tf.compat.v1.data.make_one_shot_iterator()`. + * Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`. + * Enable nested dataset support in core `tf.data` transformations. + * For `tf.data.Dataset` implementers: Added `tf.data.Dataset._element_structured property` to replace `Dataset.output_{types,shapes,classes}`. + * Make `num_parallel_calls` of `tf.data.Dataset.interleave` and `tf.data.Dataset.map` work in Eager mode. +* Toolchains + * Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`. + * Added bounds checking to printing deprecation warnings. + * Upgraded CUDA dependency to 10.0 + * To build with Android NDK r14b, add "#include " to android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h + * Removed `:android_tensorflow_lib_selective_registration*` targets, use `:android_tensorflow_lib_lite*` targets instead. +* XLA + * Move `RoundToEven` function to xla/client/lib/math.h. + * A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1" or "true" allows the debug options passed within an XRTCompile op to be passed directly to the XLA compilation backend. If such variable is not set (service side), only a restricted set will be passed through. + * Allow the XRTCompile op to return the ProgramShape resulted form the XLA compilation as a second return argument. + * XLA HLO graphs can now be rendered as SVG/HTML. +* Estimator + * Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with `tf.estimator.BaselineEstimator` + * Replace all occurences of `tf.contrib.estimator.DNNLinearCombinedEstimator` with `tf.estimator.DNNLinearCombinedEstimator` + * Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator` + * Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with `tf.estimator.LinearEstimator` + * Users of `tf.contrib.estimator.export_all_saved_models` and related should switch to `tf.estimator.Estimator.experimental_export_all_saved_models`. + * Update `regression_head` to the new Head API for Canned Estimator V2. + * Switch `multi_class_head` to Head API for Canned Estimator V2. + * Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook` and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with `tf.estimator.experimental.InMemoryEvaluatorHook` and `tf.estimator.experimental.make_stop_at_checkpoint_step_hook` + * Migrate linear optimizer from contrib to core. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abhinav Upadhyay, Ag Ramesh, akikaaa, Alexis Louis, Anders Huss, Andreas Madsen, Andrew Banchich, Andy Craze, Anton Dmitriev, Artem Malykh, Avijit-Nervana, Balint Cristian, Benjamin Tan Wei Hao, Bhavani Subramanian, Brendan Finan, Brian Nemsick, Bryan Cutler, By Shen, Cao Zongyan, Castiel, Chris Antaki, Christian Goll, Cibifang, Clayne Robison, Codrut Grosu, Cong Xu, Dalmo Cirne, Daniel Hunter, Dougal J. Sutherland, Edvard Fagerholm, EFanZh, Erik Smistad, Evgeniy Polyakov, Feiyang Chen, franklin5, Fred Reiss, Gautam, gehring, Geoffrey Irving, George Sterpu, Gitea, Grzegorz George Pawelczak, Guozhong Zhuang, himkt, Hoeseong Kim, Huan Li (李卓桓), HuiyangFei, hyunyoung, Isaac Burbank, jackonan, Jacky Ko, Jason Furmanek, Jason Zaman, Javier Luraschi, Jiang,Zhoulong, joaak, John Lin, Jonathan Wyatt Hoech, josephyearsley, Josh Gordon, Julian Niedermeier, Karl Lessard, Keno Fischer, lanhin, Leon Graser, leondgarse, Li, Guizi, Li, Yiqiang, lxl910915, Mahmoud Abuzaina, manhyuk, Marcela Morales Quispe, margaretmz, Matt Conley, Max Pumperla, mbhuiyan, mdfaijul, Meng, Peng, Michael, Michael Gielda, mrTsjolder, Muhammad Wildan, neargye, Nehal J Wani, NEWPLAN, Niranjan Hasabnis, Nutti, olicht, Pan Daoxin, Pedro Monreal, Peng Yu, pillarpond, Pooya Davoodi, qiezi, Rholais Lii, Richard Yu, Rin Arakaki, Roger Iyengar, sahilbadyal, Sami Kama, Sandip Giri, Scott Leishman, Serge Panev, Seunghoon Park, Shafi Dayatar, shengfuintel, Shimin Guo, Siju, silent567, Stefan Dyulgerov, steven, Tao Wei, Thor Johnsen, Tingbo Lu, tomguluson92, Tongxuan Liu, Trevor Morris, Ubuntu, Vadim Borisov, vanderliang, wangsiyu, Wen Yun, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, Xiaoming (Jason) Cui, Yan Facai (颜发才), Yanbo Liang, Yaniv Blumenfeld, Yash Gaurkar, Yicheng Fan, Yong Tang, Yongjoon Lee, Yuan (Terry) Tang, Yuxin Wu, zldrobit + # Release 1.12.0 ## Major Features and Improvements diff --git a/WORKSPACE b/WORKSPACE index 957b8d8528dc9b5e2ea134921b28601aa6fed2d1..ef44c25553cc9ea00a6d73d89a7b9c39481efbdc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -29,7 +29,7 @@ load( bazel_toolchains_repositories() load( - "@io_bazel_rules_docker//container:container.bzl", + "@io_bazel_rules_docker//repositories:repositories.bzl", container_repositories = "repositories", ) @@ -43,29 +43,47 @@ remote_config_workspace() # Apple and Swift rules. http_archive( name = "build_bazel_rules_apple", - sha256 = "4fe4ee824200b48821730f89ff260984332dc3551db587c24691235d1d96a8a7", - strip_prefix = "rules_apple-0.10.0", - urls = ["https://github.com/bazelbuild/rules_apple/archive/0.10.0.tar.gz"], + sha256 = "4b90786009fa8df25230442244bad2832ba8d6bc4987f68150a7de59c8827e90", + strip_prefix = "rules_apple-0.14.0", + urls = ["https://github.com/bazelbuild/rules_apple/archive/0.14.0.tar.gz"], ) -http_archive( - name = "build_bazel_rules_swift", - sha256 = "6544ff5615febec0342de1127144d2f3e43ea80fb7f9b1ade65e6a184e39e618", - strip_prefix = "rules_swift-0.5.0", - urls = ["https://github.com/bazelbuild/rules_swift/archive/0.5.0.tar.gz"], +http_file( + name = "xctestrunner", + executable = 1, + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"], ) + http_archive( name = "bazel_skylib", - sha256 = "eb5c57e4c12e68c0c20bc774bfbc60a568e800d025557bc4ea022c6479acc867", - strip_prefix = "bazel-skylib-0.6.0", - urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.6.0.tar.gz"], + sha256 = "2c62d8cd4ab1e65c08647eb4afe38f51591f43f7f0885e7769832fa137633dcb", + strip_prefix = "bazel-skylib-0.7.0", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.7.0.tar.gz"], ) -http_file( - name = "xctestrunner", - executable = 1, - urls = ["https://github.com/google/xctestrunner/releases/download/0.2.5/ios_test_runner.par"], + +http_archive( + name = "build_bazel_apple_support", + sha256 = "835663c4bb02f4bf01dce8a2a176df7fa682dbb867d3698ae12258c1628bb8f0", + strip_prefix = "apple_support-0.5.0", + urls = ["https://github.com/bazelbuild/apple_support/archive/0.5.0.tar.gz"], +) + +http_archive( + name = "build_bazel_rules_swift", + sha256 = "32d124878cd49775d84f59ba90440c8b23b7c775aec8fec1978f751c76ddee8a", + strip_prefix = "rules_swift-0.7.0", + urls = ["https://github.com/bazelbuild/rules_swift/archive/0.7.0.tar.gz"], ) -load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") -apple_rules_dependencies(ignore_version_differences = True) + +http_archive( + name = "com_github_apple_swift_swift_protobuf", + type = "zip", + strip_prefix = "swift-protobuf-1.2.0/", + urls = ["https://github.com/apple/swift-protobuf/archive/1.2.0.zip"], +) + +# Use swift_rules_dependencies to fetch the tolchains. +# Since we defined all the "git_repository" rules above, the following call will +# skip redefining them. load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") swift_rules_dependencies() @@ -134,4 +152,3 @@ http_archive( "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) - diff --git a/configure.py b/configure.py index adc9ef9caca8c0128c63896fdebbbadf7f86da81..fe724c9e6acf4eccda275799b4c23a6b58bfed3b 100644 --- a/configure.py +++ b/configure.py @@ -50,11 +50,18 @@ _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' +_TF_CURRENT_BAZEL_VERSION = None NCCL_LIB_PATHS = [ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' ] +# List of files to be configured for using Bazel on Apple platforms. +APPLE_BAZEL_FILES = [ + 'tensorflow/lite/experimental/objc/BUILD', + 'tensorflow/lite/experimental/swift/BUILD' +] + if platform.machine() == 'ppc64le': _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' else: @@ -256,6 +263,7 @@ def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() + def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -330,8 +338,8 @@ def get_var(environ_cp, 'Environment variable %s must be set as a boolean indicator.\n' 'The following are accepted as TRUE : %s.\n' 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % (var_name, ', '.join(true_strings), - ', '.join(false_strings), var)) + 'Current value is %s.' % + (var_name, ', '.join(true_strings), ', '.join(false_strings), var)) while var is None: user_input_origin = get_input(question) @@ -764,11 +772,12 @@ def check_ndk_level(android_ndk_home_path): else: raise Exception('Unable to parse NDK revision.') if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) + print( + 'WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % + (android_ndk_home_path, ndk_api_level, _SUPPORTED_ANDROID_NDK_VERSIONS)) return ndk_api_level @@ -785,8 +794,7 @@ def set_gcc_host_compiler_path(environ_cp): environ_cp, var_name='GCC_HOST_COMPILER_PATH', var_default=default_gcc_host_compiler_path, - ask_for_var= - 'Please specify which gcc should be used by nvcc as the host compiler.', + ask_for_var='Please specify which gcc should be used by nvcc as the host compiler.', check_success=os.path.exists, error_msg='Invalid gcc path. %s cannot be found.', ) @@ -1224,8 +1232,8 @@ def set_tf_nccl_install_path(environ_cp): # Reset and Retry print( 'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' - 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, - nccl_hdr_path)) + 'O/S agnostic package of NCCL 2' % + (tf_nccl_version, nccl_lib_path, nccl_hdr_path)) environ_cp['TF_NCCL_VERSION'] = '' else: @@ -1237,6 +1245,7 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1273,13 +1282,15 @@ def set_tf_cuda_compute_capabilities(environ_cp): ask_cuda_compute_capabilities = ( 'Please specify a list of comma-separated ' - 'Cuda compute capabilities you want to ' + 'CUDA compute capabilities you want to ' 'build with.\nYou can find the compute ' 'capability of your device at: ' 'https://developer.nvidia.com/cuda-gpus.\nPlease' ' note that each additional compute ' 'capability significantly increases your ' - 'build time and binary size. [Default is: %s]: ' % + 'build time and binary size, and that ' + 'TensorFlow only supports compute ' + 'capabilities >= 3.5 [Default is: %s]: ' % default_cuda_compute_capabilities) tf_cuda_compute_capabilities = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', @@ -1292,13 +1303,18 @@ def set_tf_cuda_compute_capabilities(environ_cp): for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: - print('Invalid compute capability: ' % compute_capability) + print('Invalid compute capability: %s' % compute_capability) all_valid = False else: - ver = int(m.group(0).split('.')[0]) - if ver < 3: - print('Only compute capabilities 3.0 or higher are supported.') + ver = float(m.group(0)) + if ver < 3.0: + print('ERROR: TensorFlow only supports CUDA compute capabilities 3.0 ' + 'and higher. Please re-specify the list of compute ' + 'capabilities excluding version %s.' % ver) all_valid = False + if ver < 3.5: + print('WARNING: XLA does not support CUDA compute capabilities ' + 'lower than 3.5. Disable XLA when running on older GPUs.') if all_valid: break @@ -1482,7 +1498,36 @@ def set_other_mpi_vars(environ_cp): else: raise ValueError( 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % - mpi_home, mpi_home, mpi_home) + (mpi_home, mpi_home, mpi_home)) + + +def system_specific_test_config(env): + """Add default test flags required for TF tests to bazelrc.""" + write_to_bazelrc('test --flaky_test_attempts=3') + write_to_bazelrc('test --test_size_filters=small,medium') + write_to_bazelrc( + 'test --test_tag_filters=-benchmark-test,-no_oss,-oss_serial') + write_to_bazelrc('test --build_tag_filters=-benchmark-test,-no_oss') + if is_windows(): + if env.get('TF_NEED_CUDA', None) == '1': + write_to_bazelrc( + 'test --test_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') + write_to_bazelrc( + 'test --build_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') + else: + write_to_bazelrc('test --test_tag_filters=-no_windows,-gpu') + write_to_bazelrc('test --build_tag_filters=-no_windows,-gpu') + elif is_macos(): + write_to_bazelrc('test --test_tag_filters=-gpu,-nomac,-no_mac') + write_to_bazelrc('test --build_tag_filters=-gpu,-nomac,-no_mac') + elif is_linux(): + if env.get('TF_NEED_CUDA', None) == '1': + write_to_bazelrc('test --test_tag_filters=-no_gpu') + write_to_bazelrc('test --build_tag_filters=-no_gpu') + write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') + else: + write_to_bazelrc('test --test_tag_filters=-gpu') + write_to_bazelrc('test --build_tag_filters=-gpu') def set_system_libs_flag(environ_cp): @@ -1508,15 +1553,14 @@ def set_windows_build_flags(environ_cp): write_to_bazelrc('build --config monolithic') # Suppress warning messages write_to_bazelrc('build --copt=-w --host_copt=-w') + # Fix winsock2.h conflicts + write_to_bazelrc( + 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN') # Output more verbose information when something goes wrong write_to_bazelrc('build --verbose_failures') # The host and target platforms are the same in Windows build. So we don't # have to distinct them. This avoids building the same targets twice. write_to_bazelrc('build --distinct_host_configuration=false') - # Enable short object file path to avoid long path issue on Windows. - # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 - # Short object file path will be enabled by default. - write_to_bazelrc('build --experimental_shortened_obj_file_path=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', @@ -1537,9 +1581,30 @@ def config_info_line(name, help_text): print('\t--config=%-12s\t# %s' % (name, help_text)) +def configure_apple_bazel_rules(): + """Configures Bazel rules for building on Apple platforms. + + Enables analyzing and building Apple Bazel rules on Apple platforms. This + function will only be executed if `is_macos()` is true. + """ + if not is_macos(): + return + for filepath in APPLE_BAZEL_FILES: + print( + 'Configuring %s file to analyze and build Bazel rules on Apple platforms.' + % filepath) + existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') + renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) + os.rename(existing_filepath, renamed_filepath) + if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000: + print( + 'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.') + + def main(): global _TF_WORKSPACE_ROOT global _TF_BAZELRC + global _TF_CURRENT_BAZEL_VERSION parser = argparse.ArgumentParser() parser.add_argument( @@ -1556,7 +1621,8 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.19.0', '0.21.0') + current_bazel_version = check_bazel_version('0.19.0', '0.23.2') + _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) reset_tf_configure_bazelrc() @@ -1577,6 +1643,8 @@ def main(): if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' + else: + environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1679,6 +1747,16 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) + system_specific_test_config(os.environ) + + if get_var( + environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES', + 'Configure Bazel rules for Apple platforms', False, + ('Would you like to configure Bazel rules for building on Apple platforms?' + ), 'Configuring Bazel rules for Apple platforms.', + 'Not configuring Bazel rules for Apple platforms.'): + configure_apple_bazel_rules() + print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') @@ -1687,8 +1765,10 @@ def main(): config_info_line('gdr', 'Build with GDR support.') config_info_line('verbs', 'Build with libverbs support.') config_info_line('ngraph', 'Build with Intel nGraph support.') - config_info_line('dynamic_kernels', - '(Experimental) Build kernels into separate shared objects.') + config_info_line('numa', 'Build with NUMA support.') + config_info_line( + 'dynamic_kernels', + '(Experimental) Build kernels into separate shared objects.') print('Preconfigured Bazel build configs to DISABLE default on features:') config_info_line('noaws', 'Disable AWS S3 filesystem support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 29d71c323ab5ee860ebf48c332cfd7f607f3f0c3..24d34cf9c154c2eec8f840118e704018106c9848 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -16,6 +16,8 @@ exports_files([ ]) load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl") +load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", @@ -40,12 +42,16 @@ load( # @unused TENSORFLOW_API_INIT_FILES_V2 = ( - TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) + TENSORFLOW_API_INIT_FILES + + get_compat_files(TENSORFLOW_API_INIT_FILES, 2) + + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) # @unused -TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( - TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +TENSORFLOW_API_INIT_FILES_V1 = ( + TENSORFLOW_API_INIT_FILES_V1 + + get_compat_files(TENSORFLOW_API_INIT_FILES, 2) + + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) # Config setting used when building for products @@ -90,6 +96,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "emscripten", + values = {"crosstool_top": "//external:android/emscripten"}, + visibility = ["//visibility:public"], +) + config_setting( name = "raspberry_pi_armeabi", values = { @@ -135,12 +147,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "darwin", - values = {"cpu": "darwin"}, - visibility = ["//visibility:public"], -) - config_setting( name = "windows", values = {"cpu": "x64_windows"}, @@ -153,9 +159,18 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "macos", + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "ios", - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + values = {"apple_platform_type": "ios"}, visibility = ["//visibility:public"], ) @@ -294,6 +309,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_numa_support", + define_values = {"with_numa_support": "true"}, + visibility = ["//visibility:public"], +) + # Crosses between framework_shared_object and a bunch of other configurations # due to limitations in nested select() statements. config_setting( @@ -343,6 +364,13 @@ config_setting( }, ) +config_setting( + name = "using_rocm_hipcc", + define_values = { + "using_rocm_hipcc": "true", + }, +) + config_setting( name = "with_mpi_support", values = {"define": "with_mpi_support=true"}, @@ -381,16 +409,7 @@ config_setting( package_group( name = "internal", - packages = [ - "-//third_party/tensorflow/python/estimator", - "//learning/deepmind/...", - "//learning/meta_rank/...", - "//tensorflow/...", - "//tensorflow_estimator/contrib/...", - "//tensorflow_fold/llgtm/...", - "//tensorflow_text/...", - "//third_party/py/tensor2tensor/...", - ], + packages = ["//tensorflow/..."], ) load( @@ -451,11 +470,10 @@ tf_cc_shared_object( name = "libtensorflow_framework.so", framework_so = [], linkopts = select({ - "//tensorflow:darwin": [], + "//tensorflow:macos": [], "//tensorflow:windows": [], "//conditions:default": [ - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_framework_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)", ], }), linkstatic = 1, @@ -486,21 +504,27 @@ tf_cc_shared_object( # symbols in object files. tf_cc_shared_object( - name = "libtensorflow.so", + name = "tensorflow", linkopts = select({ - "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow/c:exported_symbols.lds)", + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], - "//tensorflow:windows": [], + "//tensorflow:windows": [ + ], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/c:version_script.lds)", + "-Wl,--version-script,$(location //tensorflow/c:version_script.lds)", ], }), + per_os_targets = True, visibility = ["//visibility:public"], + # add win_def_file for tensorflow + win_def_file = select({ + # We need this DEF file to properly export symbols on Windows + "//tensorflow:windows": ":tensorflow_filtered_def_file", + "//conditions:default": None, + }), deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -512,20 +536,25 @@ tf_cc_shared_object( ) tf_cc_shared_object( - name = "libtensorflow_cc.so", + name = "tensorflow_cc", linkopts = select({ - "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow:tf_exported_symbols.lds)", + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], }), + per_os_targets = True, visibility = ["//visibility:public"], + # add win_def_file for tensorflow_cc + win_def_file = select({ + # We need this DEF file to properly export symbols on Windows + "//tensorflow:windows": ":tensorflow_filtered_def_file", + "//conditions:default": None, + }), deps = [ "//tensorflow:tf_exported_symbols.lds", "//tensorflow:tf_version_script.lds", @@ -539,6 +568,92 @@ tf_cc_shared_object( ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]), ) +# ** Targets for Windows build (start) ** + +# Build a shared library (DLL) by cc_binary from tf_custom_op_library_additional_deps_impl, +# it contains all object code from its dependencies. +# This target is only used for parsing the symbols to be exported in tensorflow.dll. +# Do NOT depend on it. +tf_native_cc_binary( + name = "tf_custom_op_library_additional_deps.dll", + linkshared = 1, + linkstatic = 1, + deps = tf_custom_op_library_additional_deps_impl(), +) + +# Get a DEF file generated by parsing all object files +# of tf_custom_op_library_additional_deps.so +filegroup( + name = "tensorflow_def_file", + srcs = [":tf_custom_op_library_additional_deps.dll"], + output_group = "def_file", +) + +# Filter the DEF file to reduce the number of symbols to 64K or less. +# Note that we also write the name of the pyd file into DEF file so that +# the dynamic libraries of custom ops can find it at runtime. +genrule( + name = "tensorflow_filtered_def_file", + srcs = [":tensorflow_def_file"], + outs = ["tensorflow_filtered_def_file.def"], + cmd = select({ + "//tensorflow:windows": """ + $(location @local_config_def_file_filter//:def_file_filter) \\ + --input $(location :tensorflow_def_file) \\ + --output $@ + """, + "//conditions:default": "touch $@", # Just a placeholder for Unix platforms + }), + tools = ["@local_config_def_file_filter//:def_file_filter"], + visibility = ["//visibility:public"], +) + +# The interface library (tensorflow.dll.if.lib) for linking tensorflow DLL library (tensorflow.dll) on Windows. +# To learn more about import library (called interface library in Bazel): +# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly +filegroup( + name = "get_tensorflow_dll_import_lib", + srcs = ["//tensorflow:tensorflow.dll"], + output_group = "interface_library", + visibility = ["//visibility:public"], +) + +# Rename the import library for tensorflow.dll from tensorflow.dll.if.lib to tensorflow.lib +genrule( + name = "tensorflow_dll_import_lib", + srcs = [":get_tensorflow_dll_import_lib"], + outs = ["tensorflow.lib"], + cmd = select({ + "//tensorflow:windows": "cp -f $< $@", + "//conditions:default": "touch $@", # Just a placeholder for Unix platforms + }), + visibility = ["//visibility:public"], +) + +# The interface library (tensorflow_cc.dll.if.lib) for linking tensorflow DLL library (tensorflow_cc.dll) on Windows. +# To learn more about import library (called interface library in Bazel): +# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly +filegroup( + name = "get_tensorflow_cc_dll_import_lib", + srcs = ["//tensorflow:tensorflow_cc.dll"], + output_group = "interface_library", + visibility = ["//visibility:public"], +) + +# Rename the import library for tensorflow.dll from tensorflow_cc.dll.if.lib to tensorflow.lib +genrule( + name = "tensorflow_cc_dll_import_lib", + srcs = [":get_tensorflow_cc_dll_import_lib"], + outs = ["tensorflow_cc.lib"], + cmd = select({ + "//tensorflow:windows": "cp -f $< $@", + "//conditions:default": "touch $@", # Just a placeholder for Unix platforms + }), + visibility = ["//visibility:public"], +) + +# ** Targets for Windows build (end) ** + exports_files( [ "tf_version_script.lds", @@ -599,13 +714,20 @@ gen_api_init_files( name = "tf_python_api_gen_v1", srcs = [ "api_template_v1.__init__.py", + "compat_template.__init__.py", "compat_template_v1.__init__.py", ], api_version = 1, - compat_api_versions = [1], - compat_init_templates = ["compat_template_v1.__init__.py"], + compat_api_versions = [ + 1, + 2, + ], + compat_init_templates = [ + "compat_template_v1.__init__.py", + "compat_template.__init__.py", + ], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, + output_files = TENSORFLOW_API_INIT_FILES_V1, output_package = "tensorflow._api.v1", root_file_name = "v1.py", root_init_template = "api_template_v1.__init__.py", @@ -615,11 +737,18 @@ gen_api_init_files( name = "tf_python_api_gen_v2", srcs = [ "api_template.__init__.py", + "compat_template.__init__.py", "compat_template_v1.__init__.py", ], api_version = 2, - compat_api_versions = [1], - compat_init_templates = ["compat_template_v1.__init__.py"], + compat_api_versions = [ + 1, + 2, + ], + compat_init_templates = [ + "compat_template_v1.__init__.py", + "compat_template.__init__.py", + ], output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index a93799bfe84b0f9c4743e1ad0effd6e69ad7f3f2..7bd6b7223989cddfea935f0ed2bcf7536015feea 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -26,14 +26,28 @@ import sys as _sys # API IMPORTS PLACEHOLDER +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +# We're using bitwise, but there's nothing special about that. +_API_MODULE = bitwise # pylint: disable=undefined-variable +_current_module = _sys.modules[__name__] +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg="Limited tf.summary API due to missing TensorBoard installation") _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( 'tensorflow_estimator.python.estimator.api._v2.estimator')) -_current_module = _sys.modules[__name__] if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, @@ -42,14 +56,6 @@ if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v2.keras')) -# Make sure directory containing top level submodules is in -# the __path__ so that "from tensorflow.foo import bar" works. -# We're using bitwise, but there's nothing special about that. -_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if not hasattr(_current_module, '__path__'): - __path__ = [_tf_api_dir] -elif _tf_api_dir not in __path__: - __path__.append(_tf_api_dir) # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top @@ -111,5 +117,11 @@ try: except NameError: pass +# Add module aliases +if hasattr(_current_module, 'keras'): + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers + initializers = keras.initializers # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index eeca8f0d566a6401cb64e4fe3f0ee3c5aeb4ece2..5eb25a81b7f765f551bc4f1b7ba99b35dbc6b7bb 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -70,7 +70,7 @@ _API_MODULE = app # pylint: disable=undefined-variable # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) if not hasattr(_current_module, '__path__'): __path__ = [_tf_api_dir] elif _tf_api_dir not in __path__: diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 4e2fe34d28c9363ebba690c5491b258a4dba11b3..00fea495fba80cef49e71f724985a14abb3932da 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -39,14 +39,19 @@ filegroup( "python_api.h", "*test*", ], - ), + ) + [ + "//tensorflow/cc:srcs", + "//tensorflow/core/distributed_runtime:server_lib.h", + ], visibility = ["//visibility:public"], ) tf_cuda_library( name = "c_api_internal", - srcs = ["c_api.h"], - hdrs = ["c_api_internal.h"], + hdrs = [ + "c_api.h", + "c_api_internal.h", + ], visibility = [ "//tensorflow:internal", "//tensorflow/c:__subpackages__", @@ -59,6 +64,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_platform", "//tensorflow/core:op_gen_lib", "//tensorflow/core/distributed_runtime:server_lib", ], @@ -67,22 +73,37 @@ tf_cuda_library( tf_cuda_library( name = "c_api", - srcs = [ - "c_api.cc", - "c_api_function.cc", - ], hdrs = [ "c_api.h", ], copts = tf_copts(), visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":c_api_no_xla", + ":c_api_internal", + ] + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], + }), +) + +tf_cuda_library( + name = "c_api_no_xla", + srcs = [ + "c_api.cc", + "c_api_function.cc", + ], + hdrs = ["c_api.h"], + copts = tf_copts(), + visibility = ["//tensorflow/c:__subpackages__"], + deps = [":c_api_internal"] + select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api_internal", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", @@ -97,13 +118,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/kernels:logging_ops", ], - }) + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - ], - "//conditions:default": [], }), ) @@ -129,6 +145,7 @@ tf_cuda_library( "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", + "@com_google_absl//absl/strings", ], ) @@ -155,8 +172,8 @@ tf_cuda_library( hdrs = ["tf_status_helper.h"], visibility = ["//visibility:public"], deps = [ - ":c_api", ":c_api_internal", + ":c_api_no_xla", "//tensorflow/core:lib", ], ) @@ -212,13 +229,13 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - ":c_api", + ":c_api_no_xla", ":c_api_internal", ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api", + ":c_api_no_xla", ":c_api_internal", ":tf_status_helper", "//tensorflow/core:framework", @@ -270,7 +287,7 @@ tf_cuda_cc_test( ], kernels = [":test_op_kernel"], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), tags = [ @@ -288,13 +305,23 @@ tf_cuda_cc_test( "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/compiler/jit", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:spectral_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/kernels:array", @@ -309,7 +336,7 @@ tf_cc_test( srcs = ["c_api_experimental_test.cc"], data = ["testdata/tf_record"], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), # We must ensure that the dependencies can be dynamically linked since @@ -318,6 +345,7 @@ tf_cc_test( deps = [ ":c_api", ":c_api_experimental", + ":c_api_internal", ":c_test_util", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_test_util", @@ -334,6 +362,7 @@ tf_cc_test( srcs = ["c_api_function_test.cc"], deps = [ ":c_api", + ":c_api_internal", ":c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -376,7 +405,7 @@ tf_cuda_cc_test( size = "small", srcs = ["env_test.cc"], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), tags = ["noasan"], @@ -397,7 +426,7 @@ tf_cuda_cc_test( size = "small", srcs = ["kernels_test.cc"], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), tags = ["noasan"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 9f2f83920cc73028fd2372afaf303e8b1c1c64f9..bbc16b85429ebaa38e7992878330c04c0bdb7f99 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -20,14 +20,19 @@ limitations under the License. #include #include -#ifndef __ANDROID__ +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" // NOLINT + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/op_gen_lib.h" -#endif +#include "tensorflow/core/kernels/logging_ops.h" +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eval_const_tensor.h" @@ -257,6 +262,74 @@ int64_t TF_Dim(const TF_Tensor* t, int dim_index) { size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } +int64_t TF_TensorElementCount(const TF_Tensor* t) { + int64_t result = 1; + int rank = TF_NumDims(t); + for (int dim = 0; dim < rank; ++dim) { + result *= TF_Dim(t, dim); + } + return result; +} + +// Returns the number of elements that would be present in a tensor with the +// given shape. +static int64_t ShapeNumElements(const int64_t* dims, int num_dims) { + int64_t result = 1; + for (int dim = 0; dim < num_dims; ++dim) { + result *= dims[dim]; + } + return result; +} + +static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Unref(); + } +} + +static void RefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Ref(); + } +} + +void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type, + TF_Tensor* to, const int64_t* new_dims, + int num_new_dims, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + size_t in_size = TF_DataTypeSize(TF_TensorType(from)); + if (in_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor has a zero-sized data type"); + return; + } + size_t out_size = TF_DataTypeSize(type); + if (out_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "output tensor has a zero-sized data type"); + return; + } + + if (ShapeNumElements(new_dims, num_new_dims) * out_size != + TF_TensorElementCount(from) * in_size) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor is not compatible with output shape"); + return; + } + + tensorflow::TensorShapeProto p; + for (int i = 0; i < num_new_dims; ++i) { + p.add_dim()->set_size(new_dims[i]); + } + to->shape = tensorflow::TensorShape(p); + to->dtype = type; + if (to->buffer != from->buffer) { + UnrefIfNonNull(to->buffer); + to->buffer = from->buffer; + RefIfNonNull(to->buffer); + } +} + // -------------------------------------------------------------------------- size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status) { @@ -295,7 +368,7 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len, size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, size_t* dst_len, TF_Status* status) { status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); - if (!status->status.ok()) return 0; + if (TF_GetCode(status) != TF_OK) return 0; return static_cast(*dst - src) + *dst_len; } @@ -350,7 +423,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, TF_Status* status) { Session* session; status->status = NewSession(opt->options, &session); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { return new TF_DeprecatedSession({session}); } else { DCHECK_EQ(nullptr, session); @@ -542,7 +615,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, offsets++; const string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); @@ -572,7 +645,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } -Status MessageToBuffer(const tensorflow::protobuf::Message& in, +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out) { if (out->data != nullptr) { return InvalidArgument("Passing non-empty TF_Buffer is invalid."); @@ -702,7 +775,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { // TODO(nolivia): check this on a subset of the graph instead of all of // it. status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { session->graph->mu.unlock(); return false; } @@ -722,7 +795,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { *graph_def.mutable_library() = graph.flib_def().ToProto(); session->graph->mu.unlock(); status->status = session->session->Extend(graph_def); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { // Contract is we always delete input_values[i]. return false; } @@ -752,7 +825,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); - if (!status->status.ok()) return false; + if (TF_GetCode(status) != TF_OK) return false; } return true; } @@ -790,7 +863,7 @@ static void TF_Run_Helper( // Serialize back to upstream client, who now owns the new buffer if (run_metadata != nullptr) { status->status = MessageToBuffer(run_metadata_proto, run_metadata); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; } } else { // NOTE(zongheng): PRun does not support RunOptions yet. @@ -810,7 +883,7 @@ static void TF_Run_Helper( continue; } c_outputs[i] = TF_TensorFromTensor(src, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; } } @@ -867,7 +940,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, string new_handle; status->status = s->session->PRunSetup(input_names, output_names, target_oper_names, &new_handle); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; @@ -906,7 +979,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { status->status = tensorflow::LoadLibrary( library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, &lib_handle->op_list.length); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { delete lib_handle; return nullptr; } @@ -1010,7 +1083,7 @@ TensorId ToTensorId(const TF_Output& output) { return TensorId(output.oper->node.name(), output.index); } -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) std::vector OutputsFromTFOutputs(TF_Output* tf_outputs, int n) { std::vector outputs(n); @@ -1028,7 +1101,7 @@ void TFOutputsFromOutputs(const std::vector& outputs, tf_outputs[i].index = outputs[i].index(); } } -#endif // __ANDROID__ +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) } // namespace @@ -1242,6 +1315,13 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, reinterpret_cast(values), num_values)); } +void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, + const char* placeholder) { + tensorflow::AttrValue attr_value; + attr_value.set_placeholder(placeholder); + desc->node_builder.Attr(attr_name, attr_value); +} + void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; @@ -1327,7 +1407,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, TF_Tensor* value, TF_Status* status) { Tensor t; status->status = TF_TensorToTensor(value, &t); - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); + if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, @@ -1337,13 +1417,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, std::vector t; t.reserve(num_values); - for (int i = 0; i < num_values && status->status.ok(); ++i) { + for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) { Tensor v; status->status = TF_TensorToTensor(values[i], &v); t.emplace_back(v); } - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); + if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, @@ -1391,11 +1471,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // Run shape inference function for newly added node. status->status = desc->graph->refiner.AddNode(ret); } - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; } else if (ret != nullptr) { @@ -1444,7 +1524,7 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, NameRangeMap name_ranges; status->status = NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); - if (!status->status.ok()) return -1; + if (TF_GetCode(status) != TF_OK) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { status->status = InvalidArgument("Input arg '", arg_name, "' not found"); @@ -1466,7 +1546,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, NameRangeMap name_ranges; status->status = NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); - if (!status->status.ok()) return -1; + if (TF_GetCode(status) != TF_OK) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { status->status = InvalidArgument("Input arg '", arg_name, "' not found"); @@ -1564,7 +1644,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, TF_Status* status) { TF_AttrMetadata metadata; const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return metadata; + if (TF_GetCode(status) != TF_OK) return metadata; switch (attr->value_case()) { #define SINGLE_CASE(kK, attr_type, size_expr) \ case tensorflow::AttrValue::kK: \ @@ -1671,7 +1751,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, void* value, size_t max_length, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kS) { status->status = InvalidArgument("Attribute '", attr_name, "' is not a string"); @@ -1689,7 +1769,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, int max_values, void* storage, size_t storage_size, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kList) { status->status = InvalidArgument("Value for '", attr_name, "' is not a list"); @@ -1722,7 +1802,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ int max_values, TF_Status* status) { \ const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ + if (TF_GetCode(status) != TF_OK) return; \ if (attr->value_case() != tensorflow::AttrValue::kList) { \ status->status = \ InvalidArgument("Value for '", attr_name, "' is not a list."); \ @@ -1744,7 +1824,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, PartialTensorShape shape; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { value[i] = shape.dim_size(i); @@ -1758,7 +1838,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, std::vector shapes; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; auto len = std::min(static_cast(shapes.size()), max_values); int64_t* p = storage; int storage_left = storage_size; @@ -1786,7 +1866,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, const char* attr_name, TF_Buffer* value, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kShape) { status->status = InvalidArgument("Value for '", attr_name, "' is not a shape."); @@ -1800,7 +1880,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, TF_Buffer** values, int max_values, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kList) { status->status = InvalidArgument("Value for '", attr_name, "' is not a list"); @@ -1810,7 +1890,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, for (int i = 0; i < len; ++i) { values[i] = TF_NewBuffer(); status->status = MessageToBuffer(attr->list().shape(i), values[i]); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { // Delete everything allocated to far, the operation has failed. for (int j = 0; j <= i; ++j) { TF_DeleteBuffer(values[j]); @@ -1825,7 +1905,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, *value = nullptr; Tensor t; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; *value = TF_TensorFromTensor(t, status); } @@ -1834,7 +1914,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Status* status) { std::vector ts; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { values[i] = TF_TensorFromTensor(ts[i], status); @@ -1845,7 +1925,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; status->status = MessageToBuffer(*attr, output_attr_value); } @@ -1923,7 +2003,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, { mutex_lock l(graph->mu); status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; } status->status = MessageToBuffer(*op_def, output_op_def); } @@ -2041,7 +2121,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tensorflow::ImportGraphDefResults results; status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, &graph->refiner, &results); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; // Add new nodes to name_map for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { @@ -2095,7 +2175,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( auto results = new TF_ImportGraphDefResults(); mutex_lock l(graph->mu); GraphImportGraphDefLocked(graph, def, options, results, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { delete results; return nullptr; } @@ -2143,7 +2223,7 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, namespace { -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) // Creates a placeholder representing an input to the cond or body graph. // TODO(skyewm): remove these from final graph @@ -2153,7 +2233,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input)); // TODO(skyewm): set placeholder shape TF_Operation* oper = TF_FinishOperation(desc, status); - if (!status->status.ok()) return false; + if (TF_GetCode(status) != TF_OK) return false; *input = {oper, 0}; return true; } @@ -2237,7 +2317,7 @@ bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) { return true; } -#endif // __ANDROID__ +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) void FreeWhileResources(const TF_WhileParams* params) { TF_DeleteGraph(params->cond_graph); @@ -2256,9 +2336,9 @@ TF_WhileParams EmptyWhileParams() { TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Creating while loops is not supported in Android. File a bug at " + "Creating while loops is not supported on mobile. File a bug at " "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); return EmptyWhileParams(); @@ -2298,15 +2378,15 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output, body_graph, body_inputs, body_outputs, name}; - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { FreeWhileResources(¶ms); return EmptyWhileParams(); } return params; -#endif // __ANDROID__ +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) namespace { // TODO(skyewm): make nodes in while loop unfetchable like in Python version @@ -2381,13 +2461,13 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, } } // namespace -#endif // __ANDROID__ +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, TF_Output* outputs) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Creating while loops is not supported in Android. File a bug at " + "Creating while loops is not supported on mobile. File a bug at " "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); #else @@ -2395,7 +2475,7 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, if (!ValidateConstWhileParams(*params, status)) return; TF_FinishWhileHelper(params, status, outputs); FreeWhileResources(params); -#endif // __ANDROID__ +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } @@ -2408,9 +2488,9 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Adding gradients is not supported in Android. File a bug at " + "Adding gradients is not supported on mobile. File a bug at " "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); #else @@ -2490,7 +2570,7 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, // Unpack the results from grad_outputs_arg. TFOutputsFromOutputs(dy_arg, dy); -#endif // __ANDROID__ +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } // TF_Session functions ---------------------------------------------- @@ -2502,7 +2582,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; status->status = NewSession(opt->options, &session); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); @@ -2519,11 +2599,11 @@ TF_Session* TF_LoadSessionFromSavedModel( const TF_SessionOptions* session_options, const TF_Buffer* run_options, const char* export_dir, const char* const* tags, int tags_len, TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { -// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that -// the tensorflow/cc/saved_model:loader build target is Android friendly. -#ifdef __ANDROID__ +// TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring +// that the tensorflow/cc/saved_model:loader build target is mobile friendly. +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Loading a SavedModel is not supported in Android. File a bug at " + "Loading a SavedModel is not supported on mobile. File a bug at " "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); return nullptr; @@ -2550,7 +2630,7 @@ TF_Session* TF_LoadSessionFromSavedModel( status->status = tensorflow::LoadSavedModel(session_options->options, run_options_proto, export_dir, tag_set, &bundle); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session // extends using GraphDefs. The Graph instance is different, but equivalent @@ -2567,7 +2647,7 @@ TF_Session* TF_LoadSessionFromSavedModel( if (meta_graph_def != nullptr) { status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; } TF_Session* session = new TF_Session(bundle.session.release(), graph); @@ -2575,7 +2655,7 @@ TF_Session* TF_LoadSessionFromSavedModel( graph->sessions[session] = ""; session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; -#endif // __ANDROID__ +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } void TF_CloseSession(TF_Session* s, TF_Status* status) { @@ -2667,7 +2747,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; @@ -2729,9 +2809,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, tensor, graph->refiner, *graph->graph.op_registry(), graph->graph.versions().producer(), &evaluated, &result_tensor); if (evaluated) { - DCHECK(status->status.ok()); + DCHECK(TF_GetCode(status) == TF_OK); *result = TF_TensorFromTensor(result_tensor, status); - if (!status->status.ok()) evaluated = false; + if (TF_GetCode(status) != TF_OK) evaluated = false; } return evaluated; } @@ -2750,9 +2830,9 @@ void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, size_t text_len, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "ApiDefMap is not supported in Android."); + "ApiDefMap is not supported on mobile."); #else mutex_lock l(api_def_map->lock); if (api_def_map->update_docs_called) { @@ -2763,14 +2843,14 @@ void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, } string api_def_text(text, text_len); status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); -#endif // __ANDROID__ +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, size_t name_len, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "ApiDefMap is not supported in Android."); + "ApiDefMap is not supported on mobile."); return nullptr; #else mutex_lock l(api_def_map->lock); @@ -2786,19 +2866,19 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteBuffer(ret); return nullptr; } return ret; -#endif // __ANDROID__ +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteBuffer(ret); return nullptr; } @@ -2810,7 +2890,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { tensorflow::GetRegisteredKernelsForOp(name); TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteBuffer(ret); return nullptr; } @@ -2819,16 +2899,16 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { // TF_Server functions ---------------------------------------------- -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) TF_Server::TF_Server(std::unique_ptr server) : target(server->target()), server(std::move(server)) {} -#endif // __ANDROID__ +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) TF_Server* TF_NewServer(const void* proto, size_t proto_len, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported in Android"); + "Server functionality is not supported on mobile"); return nullptr; #else tensorflow::ServerDef server_def; @@ -2840,41 +2920,41 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len, std::unique_ptr out_server; status->status = tensorflow::NewServer(server_def, &out_server); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; return new TF_Server(std::move(out_server)); -#endif +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } void TF_ServerStart(TF_Server* server, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported in Android"); + "Server functionality is not supported on mobile"); #else status->status = server->server->Start(); -#endif +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } void TF_ServerStop(TF_Server* server, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported in Android"); + "Server functionality is not supported on mobile"); #else status->status = server->server->Stop(); -#endif +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } void TF_ServerJoin(TF_Server* server, TF_Status* status) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) status->status = tensorflow::errors::Unimplemented( - "Server functionality is not supported in Android"); + "Server functionality is not supported on mobile"); #else status->status = server->server->Join(); -#endif +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } const char* TF_ServerTarget(TF_Server* server) { -#ifdef __ANDROID__ +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) return nullptr; #else return server->target.c_str(); @@ -2882,8 +2962,15 @@ const char* TF_ServerTarget(TF_Server* server) { } void TF_DeleteServer(TF_Server* server) { -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) delete server; -#endif +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) } + +void TF_RegisterLogListener(void (*listener)(const char*)) { +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + tensorflow::logging::RegisterListener(listener); +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +} + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c7abba85521fccec07983cd5ab4f94a8368d6181..051de3a7dc0f8c630b6c81d2cfa960e5279c93c0 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -272,6 +272,39 @@ TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); // Return a pointer to the underlying data buffer. TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); +// Returns the number of elements in the tensor. +TF_CAPI_EXPORT extern int64_t TF_TensorElementCount(const TF_Tensor* tensor); + +// Copy the internal data representation of `from` to `to`. `new_dims` and +// `num_new_dims` specify the new shape of the `to` tensor, `type` specifies its +// data type. On success, *status is set to TF_OK and the two tensors share the +// same data buffer. +// +// This call requires that the `from` tensor and the given type and shape (dims +// and num_dims) are "compatible" (i.e. they occupy the same number of bytes). +// Specifically, given from_type_size = TF_DataTypeSize(TF_TensorType(from)): +// +// ShapeElementCount(dims, num_dims) * TF_DataTypeSize(type) +// +// must equal +// +// TF_TensorElementCount(from) * from_type_size +// +// where TF_ShapeElementCount would be the number of elements in a tensor with +// the given shape. +// +// In addition, this function requires: +// * TF_DataTypeSize(TF_TensorType(from)) != 0 +// * TF_DataTypeSize(type) != 0 +// +// If any of the requirements are not met, *status is set to +// TF_INVALID_ARGUMENT. +TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from, + TF_DataType type, TF_Tensor* to, + const int64_t* new_dims, + int num_new_dims, + TF_Status* status); + // -------------------------------------------------------------------------- // Encode the string `src` (`src_len` bytes long) into `dst` in the format // required by TF_STRING tensors. Does not write to memory more than `dst_len` @@ -516,6 +549,10 @@ TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, const TF_DataType* values, int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, + const char* attr_name, + const char* placeholder); + // Set a 'func' attribute to the specified name. // `value` must point to a string of length `length` bytes. TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, @@ -1277,6 +1314,28 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + // Returns the name of the graph function. // The return value points to memory that is only usable until the next // mutation to *func. @@ -1710,6 +1769,14 @@ TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); // it will be stopped and joined. TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); +// Register a listener method that processes printed messages. +// +// If any listeners are registered, the print operator will call all listeners +// with the printed messages and immediately return without writing to the +// logs. +TF_CAPI_EXPORT extern void TF_RegisterLogListener( + void (*listener)(const char*)); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index f04b285037dff403428ed74fe90eac60339fe36b..7ff4084decc686b067226ecaecf2af29d51d42f2 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "absl/strings/substitute.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" @@ -128,6 +129,14 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { return ret; } +char* TF_FunctionDebugString(TF_Function* func, size_t* len) { + const auto& debug_str = func->fdef.DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + // On success, returns a set of TF_Function instances from `text_proto` of // GraphDef type. These functions must be deleted by calling TF_DeleteFunction. // @@ -8737,6 +8746,12 @@ static void CheckOk(TF_Status* status) { void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { auto* status = TF_NewStatus(); + if (!TFE_TensorHandleIsConcrete(handle)) { + VLOG(1) << "Symbolic tensor: " << handle; + TF_DeleteStatus(status); + return; + } + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -8748,6 +8763,11 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } +void TFE_OpPrintDebugString(TFE_Op* op) { + VLOG(1) << "TFE_OpPrintDebugString() over " << op; + LOG(INFO) << op->operation.DebugString(); +} + struct TFE_ExecuteOpNotification { TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} tensorflow::Notification n; @@ -8941,3 +8961,189 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, } status->status = EnableCollectiveOps(server_def, ctx); } + +std::string tensorflow::getTF_OutputDebugString(TF_Output node) { + return absl::Substitute("TF_Output($0, $1)", node.oper, node.index); +} + +using tensorflow::getTF_OutputDebugString; + +TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput(TF_Output t, + TF_DataType dtype) { + auto ret = new TFE_TensorHandle(t, dtype); + VLOG(1) << "Storing TFOutput " << getTF_OutputDebugString(t) + << " into tensor handle " << ret << " with internal handle " + << ret->handle; + return ret; +} + +unsigned char TFE_TensorHandleIsConcrete(TFE_TensorHandle* handle) { + assert(handle->handle != nullptr); + return handle->handle->getSymbolicTensor() == nullptr; +} + +TF_Output TFE_GetTFOutputFromTensorHandle(TFE_TensorHandle* handle, + TF_Status* status) { + if (TFE_TensorHandleIsConcrete(handle)) { + status->status = + tensorflow::errors::Internal("Not a symbolic tensor: ", handle); + return TF_Output{nullptr, -1}; + } + + auto* sym_tensor = handle->handle->getSymbolicTensor(); + CHECK(sym_tensor != nullptr); + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "Retrieving " << getTF_OutputDebugString(ret) + << " from tensor handle " << handle; + CHECK_GE(sym_tensor->index, 0); + return ret; +} + +TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph) { + return new TFE_TraceContext(graph); +} + +void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx) { delete trace_ctx; } + +// If `handle` is already symbolic, return it. Otherwise map it to a new +// symbolic tensor (a PlaceHolder op) and return that. +static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx, + tensorflow::TensorHandle* handle, + TF_Status* status) { + VLOG(1) << "Getting symbolic tensor for input tensor handle " << handle + << ": " << handle->DebugString(); + + auto* sym_tensor = handle->getSymbolicTensor(); + if (sym_tensor != nullptr) { + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "This handle is a symbolic tensor " << sym_tensor << ": " + << getTF_OutputDebugString(ret); + return ret; + } + + auto find_it = trace_ctx->input_tensor_map.find(handle); + if (find_it != trace_ctx->input_tensor_map.end()) { + VLOG(1) << "There exists a map entry from this concrete tensor to: " + << getTF_OutputDebugString(find_it->second); + return find_it->second; + } + + auto node_name = tensorflow::strings::StrCat("additional_input_", + trace_ctx->node_counter++); + VLOG(1) << "Adding a place holder node named " << node_name; + auto* desc = + TF_NewOperation(trace_ctx->graph, "Placeholder", node_name.c_str()); + TF_SetAttrType(desc, "dtype", + static_cast(handle->dtype) /*TF_FLOAT*/); + auto* result = TF_FinishOperation(desc, status); + if (!status->status.ok()) { + return TF_Output{nullptr, -1}; + } + + auto ret = TF_Output{result, 0}; + VLOG(1) << "Creating a new map entry to map to: " + << getTF_OutputDebugString(ret); + trace_ctx->input_tensor_map[handle] = ret; + // `handle` could be destroyed before it's read from `input_tensor_map` (say + // during a subsequent TFE_FinalizeInputTensorsFromTraceContext() call), so we + // increment its ref count to extend its life span to that of `trace_ctx`. + handle->Ref(); + VLOG(1) << "Ref count for handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + return ret; +} + +TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, + TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": " + << op->operation.DebugString(); + + const auto& op_type = op->operation.Name(); + auto op_name = + tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); + auto* desc = + TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); + + VLOG(1) << "Adding attrs."; + tensorflow::AttrValueMap attrs; + op->operation.Attrs().FillAttrValueMap(&attrs); + for (const auto& attr : attrs) { + desc->node_builder.Attr(attr.first, attr.second); + } + + VLOG(1) << "Adding inputs."; + const auto& inputs = op->operation.Inputs(); + size_t inputIndex = 0; + const tensorflow::OpDef& op_def = desc->node_builder.op_def(); + for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) { + // TODO(bgogul): Add support for number attributes. + DCHECK(input_arg.number_attr().empty()) + << "Number attributes is not implemented yet."; + if (input_arg.type_list_attr().empty()) { + auto symbolic_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + TF_AddInput(desc, symbolic_input); + continue; + } + const std::string& type_list_attr = input_arg.type_list_attr(); + const auto& attr_value = attrs[type_list_attr]; + DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList) + << "Type list attribute should be a list!"; + std::vector list_inputs(attr_value.list().type_size()); + for (TF_Output& list_input : list_inputs) { + list_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + } + TF_AddInputList(desc, list_inputs.data(), list_inputs.size()); + } + + auto* graph_op = TF_FinishOperation(desc, status); + if (!status->status.ok()) return nullptr; + + VLOG(1) << "Op finalized; setting return tensors."; + *num_retvals = TF_OperationNumOutputs(graph_op); + VLOG(1) << "This op has " << *num_retvals << " outputs."; + for (int i = 0; i < *num_retvals; ++i) { + auto output = TF_Output{graph_op, i}; + auto dtype = TF_OperationOutputType(output); + retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype); + } + return graph_op; +} + +int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) { + if (trace_ctx->input_tensors == nullptr) { + trace_ctx->input_tensors = + new std::vector>(); + trace_ctx->input_tensors->reserve(trace_ctx->input_tensor_map.size()); + + for (auto input : trace_ctx->input_tensor_map) { + trace_ctx->input_tensors->emplace_back(input.first, input.second); + } + } + return trace_ctx->input_tensor_map.size(); +} + +TF_Output TFE_GetInputGraphNodeFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + return trace_ctx->input_tensors->at(idx).second; +} + +TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + auto* handle = trace_ctx->input_tensors->at(idx).first; + VLOG(1) << "Ref count for internal handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + handle->Ref(); + auto* ret = new TFE_TensorHandle(handle); + VLOG(1) << "Returning a new tensor handle " << ret << ": " + << handle->DebugString(); + return ret; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index e6d04d0c2b25a3f7b1ebf50c58268f003595a520..8d1a8b82fbaf9901b6d9aecf6d092ae298c8dba3 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -84,6 +84,15 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, size_t* len); +// Returns the function content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +// +// Do not return const char*, because some foreign language binding +// (e.g. swift) cannot then call free() on the returned pointer. +TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func, + size_t* len); + // Creates a stack of data set + iterator nodes, currently hard-coded to return // a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, // returns the IteratorGetNext node, which caller can run or feed into an node. @@ -181,6 +190,8 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op); + typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; // Allows invoking a kernel asynchronously, and explicitly returns a @@ -255,6 +266,54 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, const void* proto, size_t proto_len, TF_Status* status); + +// Create a symbolic tensor from the input graph node. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput( + TF_Output t, TF_DataType data_type); + +// Returns 0 if the input tensor handle represents a symbolic tensor (i.e., a +// graph node). Otherwise returns non-0. +TF_CAPI_EXPORT extern unsigned char TFE_TensorHandleIsConcrete( + TFE_TensorHandle* handle); + +// If `handle` is a symbolic tensor, return the corresponding graph node +// represented by TF_Output. Otherwise, return an error status. +TF_CAPI_EXPORT extern TF_Output TFE_GetTFOutputFromTensorHandle( + TFE_TensorHandle* handle, TF_Status* status); + +typedef struct TFE_TraceContext TFE_TraceContext; + +// A trace context contains a trace graph, to which TFE_AddEagerOpToGraph() +// calls add graph nodes as a way to symbolically execute the eager ops. +// +// It also contains a hash map from concrete input tensors to symbolic +// tensors. That map will be used to create input tensors to the trace graph. +TF_CAPI_EXPORT extern TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph); + +TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx); + +// Symbolically executes `op`, by adding a corresponding node to the graph +// associated with `trace_ctx`. This graph node outputs a set of symbolic +// tensors in `retvals` and `num_retvals`. Returns the corresponding graph +// operation on success, otherwise returns nullptr. +TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph( + TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); + +// Finalizes the trace graph and its inputs, and returns the number of inputs. +// After this call, the next two APIs can be called to iterate over the input +// tensors. +TF_CAPI_EXPORT extern int TFE_FinalizeInputTensorsFromTraceContext( + TFE_TraceContext* trace_ctx); + +TF_CAPI_EXPORT extern TF_Output TFE_GetInputGraphNodeFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx); + +// Each input tensor should be consumed at most once. +TF_CAPI_EXPORT extern TFE_TensorHandle* +TFE_ConsumeInputConcreteTensorFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index daa7701b7fe7e8ce757b6504329cf6434ad39778..2c92e38f03a9d01d285f475b1a8996c44475c5c2 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" @@ -296,5 +297,178 @@ TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, SymbolicTensor) { + TF_Status* status = TF_NewStatus(); + auto node = TF_Output{nullptr, 1}; + auto* sym_handle = TFE_NewTensorHandleFromTFOutput(node, TF_FLOAT); + TFE_TensorHandlePrintDebugString(sym_handle); + CHECK_EQ(TFE_TensorHandleDataType(sym_handle), TF_FLOAT); + ASSERT_FALSE(TFE_TensorHandleIsConcrete(sym_handle)); + auto same_node = TFE_GetTFOutputFromTensorHandle(sym_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(same_node.oper, node.oper); + ASSERT_EQ(same_node.index, node.index); + TFE_DeleteTensorHandle(sym_handle); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + ASSERT_TRUE(TFE_TensorHandleIsConcrete(m)); + (void)TFE_GetTFOutputFromTensorHandle(m, status); + CHECK_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(m); + + TF_DeleteStatus(status); +} + +class AddEagerOpToGraphTest : public ::testing::Test { + protected: + AddEagerOpToGraphTest() + : status_(TF_NewStatus()), + eager_ctx_(nullptr), + graph_(TF_NewGraph()), + trace_ctx_(TFE_NewTraceContext(graph_)) { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + eager_ctx_ = TFE_NewContext(opts, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_DeleteContextOptions(opts); + } + + ~AddEagerOpToGraphTest() override { + TFE_DeleteTraceContext(trace_ctx_); + TF_DeleteGraph(graph_); + TFE_DeleteContext(eager_ctx_); + TF_DeleteStatus(status_); + } + + template + void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) { + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx_`. + TF_Operation* graph_op = + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_NOTNULL(graph_op); + // Check the expectations. + checker(graph_op); + for (int i = 0; i < num_retvals; ++i) { + TFE_DeleteTensorHandle(retvals[i]); + } + } + + TF_Status* status_; + TFE_Context* eager_ctx_; + TF_Graph* graph_; + TFE_TraceContext* trace_ctx_; +}; + +TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) { + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* op = MatMulOp(eager_ctx_, m, m); + + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpPrintDebugString(op); + + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx`. + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_); + CHECK_EQ(num_inputs, 1); + auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_, + /*idx*/ 0); + + LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor); + auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_, + /*idx*/ 0); + TFE_TensorHandlePrintDebugString(handle); + TFE_DeleteTensorHandle(handle); + + CHECK_EQ(num_retvals, 1); + CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(m); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) { + // Create MinOp + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* op = MinOp(eager_ctx_, axis, axis); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + // Check the attributes set by the call to MinOp above. + AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) { + unsigned char value; + TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(value, 1); + TF_DataType dtype; + TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + TF_OperationGetAttrType(graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TFE_TensorHandleDataType(axis)); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) { + // Create a "Squeeze" operator with list attributes. + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpAddInput(squeeze, axis, status_); + TFE_OpSetAttrType(squeeze, "T", TF_INT32); + std::vector boundaries = {1, 2, 3, 4}; + TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(), + boundaries.size()); + // Check attributes are preserved. + AddEagerOpToGraphAndCheck( + squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) { + TF_DataType dtype; + TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + std::unique_ptr list(new int64_t[boundaries.size()]); + TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(), + boundaries.size(), status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(), + boundaries.begin())); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(squeeze); +} + +TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) { + TFE_TensorHandle* scalar = TestScalarTensorHandle(static_cast(1)); + TFE_Op* identityn = TFE_NewOp(eager_ctx_, "IdentityN", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + constexpr size_t kNumInputs = 3; + for (size_t i = 0; i < kNumInputs; ++i) { + TFE_OpAddInput(identityn, scalar, status_); + } + TF_DataType types[kNumInputs] = {TF_FLOAT, TF_FLOAT, TF_FLOAT}; + TFE_OpSetAttrTypeList(identityn, "T", types, kNumInputs); + AddEagerOpToGraphAndCheck( + identityn, [this, kNumInputs](TF_Operation* graph_op) { + EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs); + EXPECT_EQ(TF_OperationInputListLength(graph_op, "input", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_EQ(TF_OperationOutputListLength(graph_op, "output", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + }); + TFE_DeleteTensorHandle(scalar); + TFE_DeleteOp(identityn); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 28b9f8df9c873ee394eb6a241dd9ac06ba6c8796..68b530b718bd50adc21cbd50d0d4b92d075fd013 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -162,6 +162,11 @@ Status FillFunctionBody( const std::vector& body_nodes, const std::unordered_map& tensor_renaming, FunctionDef* fdef) { + std::unordered_set func_attr_names; + for (const auto& func_attr : fdef->signature().attr()) { + func_attr_names.insert(func_attr.name()); + } + std::vector in_edges; std::vector control_edges; for (const Node* node : body_nodes) { @@ -243,6 +248,48 @@ Status FillFunctionBody( if (node->op_def().is_stateful()) { fdef->mutable_signature()->set_is_stateful(true); } + + // If this node has any attributes with placeholder value, add the + // attribute to FunctionDef signature. + for (const auto& iter : node->attrs()) { + if (iter.second.placeholder().empty()) { + continue; + } + + // If we already added the attribute, skip it. + string func_attr_name = iter.second.placeholder(); + if (func_attr_names.find(func_attr_name) != func_attr_names.end()) { + continue; + } + + // This node's attribute is a placeholder value, so it does not have type + // information. We check node's OpDef for attribute type. + string node_attr_name = iter.first; + const OpDef::AttrDef* node_attr_def = nullptr; + for (const auto& node_attr : node->op_def().attr()) { + if (node_attr.name() == node_attr_name) { + node_attr_def = &node_attr; + } + } + if (!node_attr_def) { +#ifdef TENSORFLOW_LITE_PROTOS + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name); +#else + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name, ", OpDef: ", node->op_def().DebugString()); +#endif + } + OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr(); + attr_def->set_name(func_attr_name); + attr_def->set_type(node_attr_def->type()); + + func_attr_names.insert(func_attr_name); + } } return Status::OK(); } @@ -255,6 +302,8 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, const std::vector& inputs, const std::vector& outputs, const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, const char* description, FunctionDef* fdef) { if (!output_names.empty()) { DCHECK_EQ(output_names.size(), outputs.size()); @@ -378,6 +427,29 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, fdef->mutable_signature()->set_name(fn_name); } + if (!control_output_names.empty() && + (control_outputs.size() != control_output_names.size())) { + return InvalidArgument( + "Expected number of control outputs (", control_outputs.size(), + ") and the number of control output names (", + control_output_names.size(), ") to match but they do not."); + } + std::unordered_set control_output_names_set; + for (int i = 0; i < control_outputs.size(); ++i) { + string signature_name; + if (!control_output_names.empty()) { + signature_name = control_output_names[i]; + } else { + signature_name = control_outputs[i]->name(); + } + if (!control_output_names_set.insert(signature_name).second) { + return errors::InvalidArgument("Repeated control output name: ", + signature_name); + } + fdef->mutable_signature()->add_control_output(signature_name); + (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name(); + } + return Status::OK(); } @@ -485,14 +557,14 @@ Status ComputeBodyNodes( using tensorflow::Node; using tensorflow::string; -TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, - int num_opers, const TF_Operation* const* opers, - int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, - const char* const* output_names, - const TF_FunctionOptions* opts, - const char* description, TF_Status* status) { +TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { tensorflow::mutex_lock l(*const_cast(&fn_body->mu)); // Process inputs. @@ -500,13 +572,13 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, std::unordered_map> input_nodes; status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs, &input_tensors, &input_nodes); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Process outputs. std::vector output_tensors; status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs, outputs, &output_tensors); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Process output names. std::vector output_names_vec; @@ -517,26 +589,55 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, } } + // Process control output names. + std::vector control_output_names_vec; + if (control_output_names) { + control_output_names_vec.reserve(ncontrol_outputs); + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_names_vec.push_back(string(output_names[i])); + } + } + // Compute body nodes. std::vector body_nodes; status->status = tensorflow::ComputeBodyNodes( fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; + + // Compute body nodes. + std::vector control_output_nodes; + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_nodes.push_back(&control_outputs[i]->node); + } // Do the actual function creation. TF_Function* tf_function = new TF_Function(); DCHECK(append_hash_to_fn_name <= 1); status->status = tensorflow::GraphToFunctionDef( fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, - input_tensors, output_tensors, output_names_vec, description, - &tf_function->fdef); - if (!status->status.ok()) { + input_tensors, output_tensors, output_names_vec, control_output_nodes, + control_output_names_vec, description, &tf_function->fdef); + if (TF_GetCode(status) != TF_OK) { TF_DeleteFunction(tf_function); return nullptr; } return tf_function; } +TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, + int num_opers, const TF_Operation* const* opers, + int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, + const char* const* output_names, + const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { + return TF_GraphToFunctionWithControlOutputs( + fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs, + inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts, + description, status); +} + const char* TF_FunctionName(TF_Function* func) { return func->fdef.signature().name().c_str(); } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 73fe73769bc1219ce865149d67d333c53371ccc5..946f8c4a2c3fb25f908d809e00bf579b40a8668b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -1230,6 +1231,53 @@ void DefineFunction(const char* name, TF_Function** func, ASSERT_NE(*func, nullptr); } +REGISTER_OP("CustomOp") + .Output("output: float32") + .Attr("index: int") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); + +void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s, + const char* name, const char* placeholder, + TF_Operation** op) { + TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name); + TF_SetAttrPlaceholder(desc, "index", placeholder); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Operation *node1, *node2, *node3; + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1", + &node1); + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1", + &node2); + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2", + &node3); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}}; + func_ = TF_GraphToFunction( + func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, + /*opers=*/nullptr, 0, inputs, 3, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, /*description=*/nullptr, s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(func_, nullptr); + + // Verify that FunctionDef has 2 attributes, "v1" and "v2". + ASSERT_EQ(func_->fdef.signature().attr().size(), 2); + EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1"); + EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int"); + EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2"); + EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int"); +} + TEST_F(CApiFunctionTest, SetGradientAndRun) { // Define the function and its grad DefineFunction(func_name_, &func_); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 5ba26d3c585350aa510f9970cbfc246a9a108543..9a69c58718b3514287256124629f59443f38fd39 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -24,10 +24,12 @@ limitations under the License. #include #include -#ifndef __ANDROID__ -#include "tensorflow/core/distributed_runtime/server_lib.h" +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" // NO_LINT + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/core/framework/op_gen_lib.h" -#endif +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -42,6 +44,7 @@ limitations under the License. namespace tensorflow { class Device; class DeviceMgr; +class ServerInterface; } // namespace tensorflow // Internal structures used by the C API. These are likely to change and should @@ -167,27 +170,27 @@ struct TF_Function { struct TF_ApiDefMap { explicit TF_ApiDefMap(const tensorflow::OpList& op_list) : -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) api_def_map(op_list), -#endif +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) update_docs_called(false) { } -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); -#endif +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) bool update_docs_called GUARDED_BY(lock); tensorflow::mutex lock; }; -#ifndef __ANDROID__ +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) struct TF_Server { TF_Server(std::unique_ptr server); const tensorflow::string target; std::unique_ptr server; }; -#endif +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) namespace tensorflow { @@ -204,7 +207,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); -Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out); // Set the shapes and types of the output's handle. // @@ -228,6 +232,8 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op, bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) LOCKS_EXCLUDED(session->graph->mu, session->mu); +std::string getTF_OutputDebugString(TF_Output node); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d5934a10395ae094f65d3bc8b6cd7b94dbd32410..2be03bf0de6277fc63c353ad6dc63bec096a6993 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -163,6 +163,7 @@ TEST(CAPI, AllocateTensor) { EXPECT_EQ(dims[0], TF_Dim(t, 0)); EXPECT_EQ(dims[1], TF_Dim(t, 1)); EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); + EXPECT_EQ(6, TF_TensorElementCount(t)); TF_DeleteTensor(t); } @@ -1467,6 +1468,41 @@ TEST(CAPI, DeletingNullPointerIsSafe) { TF_DeleteStatus(status); } +TEST(CAPI, TestBitcastFrom_Reshape) { + int64_t dims[] = {2, 3}; + TF_Tensor* a = + TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64)); + TF_Tensor* b = + TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64)); + EXPECT_NE(a, nullptr); + EXPECT_NE(b, nullptr); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(1, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + int64_t new_dims[] = {3, 2}; + TF_Status* status = TF_NewStatus(); + TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(6, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + // Check that a write to one tensor shows up in the other. + *(static_cast(TF_TensorData(a))) = 4; + EXPECT_EQ(4, *(static_cast(TF_TensorData(b)))); + *(static_cast(TF_TensorData(b))) = 6; + EXPECT_EQ(6, *(static_cast(TF_TensorData(a)))); + + TF_DeleteTensor(a); + TF_DeleteTensor(b); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c index b86d8eb8e300e02a3871ecd5f424a82c521b18fc..7468122cd567270c8454f886e478be34c2c15cbf 100644 --- a/tensorflow/c/c_test.c +++ b/tensorflow/c/c_test.c @@ -25,6 +25,16 @@ limitations under the License. #include "tensorflow/c/env.h" #include "tensorflow/c/kernels.h" +// A create function. This will never actually get called in this test, it's +// just nice to know that it compiles. +void* create(TF_OpKernelConstruction* ctx) { + TF_DataType type; + TF_Status* s = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "foobar", &type, s); + TF_DeleteStatus(s); + return NULL; +} + // A compute function. This will never actually get called in this test, it's // just nice to know that it compiles. void compute(void* kernel, TF_OpKernelContext* ctx) { @@ -32,12 +42,7 @@ void compute(void* kernel, TF_OpKernelContext* ctx) { TF_Status* s = TF_NewStatus(); TF_GetInput(ctx, 0, &input, s); TF_DeleteTensor(input); - - TF_DataType type; - TF_OpKernelContext_GetAttrType(ctx, "foobar", &type, s); - TF_DeleteStatus(s); - } // Exercises tensorflow's C API. @@ -80,7 +85,7 @@ int main(int argc, char** argv) { TF_StringStreamDone(s); TF_KernelBuilder* b = - TF_NewKernelBuilder("SomeOp", "SomeDevice", NULL, &compute, NULL); + TF_NewKernelBuilder("SomeOp", "SomeDevice", &create, &compute, NULL); TF_RegisterKernelBuilder("someKernel", b, status); TF_DeleteStatus(status); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 51665fb9db948e165129e5cbdf0dc3fb28f90f91..445b2cd25812e1d73fdd85b61f2a234150b880a6 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -70,7 +70,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core/profiler/lib:eager_profiler", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core:gpu_runtime", ], ) @@ -110,7 +110,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", - "//tensorflow/core/profiler/lib:eager_profiler", + "//tensorflow/core/profiler/lib:profiler_session", ], ) @@ -147,6 +147,7 @@ tf_cuda_cc_test( ], deps = [ ":c_api", + ":c_api_internal", ":c_api_test_util", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", @@ -210,6 +211,8 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/core/profiler/rpc/client:capture_profile", "//tensorflow/core:gpu_runtime", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index f0408677d403affaee66ebe4c7592f6bd6c74d09..1002f474fed20759c67eb765e8d4c81a0f529ecc 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -226,14 +226,84 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - ctx->context.InitializeRemote(std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, - r, device_mgr, keep_alive_secs); + return ctx->context.InitializeRemote( + std::move(server), std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, r, device_mgr, + keep_alive_secs); +#undef LOG_AND_RETURN_IF_ERROR +} +tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op, + TFE_TensorHandle* input) { + TFE_OpInferenceContext* ictx = op->inference_ctx.get(); + const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++); + if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) { + // Some clients that are still setting their input attributes manually are + // adding input list to their op by calling `TFE_OpAddInput` for each of + // its elements instead of calling `TFE_OpAddInputList`. When this happens, + // we cannot detect the end of such list, thus lose track of the input + // arguments in the op definition. To guarantee backward compatibility with + // those clients, disable automatic inference in this case. + op->inference_ctx.reset(nullptr); + return tensorflow::Status::OK(); + } + const std::string& type_attr = input_def.type_attr(); + if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) { + op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype); + ictx->attrs.insert(type_attr); + } + return tensorflow::Status::OK(); +} + +void OpInferSingleTypeInputListAttrs(TFE_Op* op, + const tensorflow::OpDef::ArgDef& input_def, + TFE_TensorHandle** inputs, + int num_inputs) { + TFE_OpInferenceContext* ictx = op->inference_ctx.get(); + if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) { + op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs); + ictx->attrs.insert(input_def.number_attr()); + } + if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) { + op->operation.MutableAttrs()->Set(input_def.type_attr(), + inputs[0]->handle->dtype); + ictx->attrs.insert(input_def.type_attr()); + } +} + +void OpInferMixedTypeInputListAttrs(TFE_Op* op, + const tensorflow::OpDef::ArgDef& input_def, + TFE_TensorHandle** inputs, int num_inputs) { + TFE_OpInferenceContext* ictx = op->inference_ctx.get(); + if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) { + std::unique_ptr dtypes( + new tensorflow::DataType[num_inputs]); + for (int i = 0; i < num_inputs; ++i) { + dtypes[i] = inputs[i]->handle->dtype; + } + op->operation.MutableAttrs()->Set( + input_def.type_list_attr(), + tensorflow::gtl::ArraySlice(dtypes.get(), + num_inputs)); + ictx->attrs.insert(input_def.type_list_attr()); + } +} + +tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs, + int num_inputs) { + TFE_OpInferenceContext* ictx = op->inference_ctx.get(); + const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++); + if (!input_def.type_list_attr().empty()) { + OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs); + } else if (!input_def.type_attr().empty() && + !input_def.number_attr().empty()) { + OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs); + } else { + return tensorflow::errors::InvalidArgument("Invalid input list definition"); + } return tensorflow::Status::OK(); -#undef LOG_AND_RETURN_IF_ERROR } + } // namespace extern "C" { @@ -249,6 +319,7 @@ void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, unsigned char enable) { options->async = enable; } + void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { options->policy = policy; @@ -302,7 +373,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.ClearCaches(); +} // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, @@ -356,6 +429,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { if (h == nullptr) return; + VLOG(1) << "Deleting tensor handle " << h << " with internal handle " + << h->handle; if (h->handle) { h->handle->Unref(); } @@ -490,20 +565,29 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, const tensorflow::AttrTypeMap* types; bool is_function = false; status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); - if (status->status.ok()) { - if (is_function && !ctx->context.FindFunctionByName(name)) { - status->status = tensorflow::errors::NotFound( - "'", name, - "' is neither a type of a primitive operation nor a name " - "of a function registered in binary running on ", - tensorflow::port::Hostname(), - ". Make sure the operation or function is " - "registered in the binary running in this process."); + if (!status->status.ok()) { + return nullptr; + } + if (!is_function) { + const tensorflow::OpDef* op_def; + status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); + if (!status->status.ok()) { return nullptr; } - return new TFE_Op(ctx, name, is_function, types); + return new TFE_Op(ctx, name, false, types, + new TFE_OpInferenceContext(op_def)); } - return nullptr; + if (!ctx->context.FindFunctionByName(name)) { + status->status = tensorflow::errors::NotFound( + "'", name, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + tensorflow::port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process."); + return nullptr; + } + return new TFE_Op(ctx, name, true, types, nullptr); } void TFE_DeleteOp(TFE_Op* op) { delete op; } @@ -527,8 +611,21 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { #endif // TENSORFLOW_EAGER_USE_XLA } -void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - op->operation.AddInput(h->handle); +void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { + op->operation.AddInput(input->handle); + if (op->inference_ctx) { + status->status = OpInferSingleInputAttrs(op, input); + } +} + +void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status) { + for (int i = 0; i < num_inputs; ++i) { + op->operation.AddInput(inputs[i]->handle); + } + if (op->inference_ctx) { + status->status = OpInferInputListAttrs(op, inputs, num_inputs); + } } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, @@ -712,6 +809,7 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_Execute() on op " << op; tensorflow::gtl::InlinedVector handle_retvals( *num_retvals); status->status = @@ -754,12 +852,18 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, status->status = ctx->context.AddFunctionDef(function->fdef); } +unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { + return ctx->context.FindFunctionDef(name) != nullptr; +} + void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(true); + ctx->context.SetShouldStoreGraphs(true); + ctx->context.SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(false); + ctx->context.SetShouldStoreGraphs(false); + ctx->context.SetShouldStoreStepStats(false); } } // extern "C" diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 120748ab763a3358b6e38e64bb3b6fd2ea32f7c3..442593e8475be8ecac427fc63ce131d648437305 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -98,7 +98,8 @@ TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, // Clears the internal caches in the TFE context. Useful when reseeding random // ops. -TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx); +TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx, + TF_Status* status); // Sets a thread-local device placement policy. After this call, other calls to // TFE_Execute in the same thread will use the device policy specified here @@ -282,9 +283,14 @@ TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable); -TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, +TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op, + TFE_TensorHandle** inputs, + int num_inputs, + TF_Status* status); + TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, @@ -393,6 +399,10 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Checks whether a function is registered under `name`. +TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, + const char* name); + // Enables tracing of RunMetadata on the ops executed from this context. TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index ffcd5ace0b98597363abe63201bf6c328a03212f..b4192716c4fee66da5133fc592e39e26a66a98e8 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -32,13 +32,13 @@ std::vector TensorShapeAsVector(TFE_TensorHandle* handle, TF_Status* status) { std::vector shape; int rank = TFE_TensorHandleNumDims(handle, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return shape; } shape.reserve(rank); for (int i = 0; i < rank; ++i) { shape.push_back(TFE_TensorHandleDim(handle, i, status)); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return shape; } } @@ -53,7 +53,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TFE_TensorHandle* handle, TF_Status* status) { const tensorflow::Tensor* tensor; status->status = handle->handle->Tensor(&tensor); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return nullptr; } @@ -139,7 +139,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( // If the tensor is not an XLA tensor, the device shape is // the same as regular tensor shape. std::vector dev_dims = TensorShapeAsVector(handle, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return nullptr; } return new TFE_TensorDebugInfo(dev_dims); diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc index cddb9f6e00e9d639026f4bbe061d58f76771c0a9..4e987c745ecabd85c89d039468eb94ed51b4d00f 100644 --- a/tensorflow/c/eager/c_api_debug_test.cc +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" TEST(CApiDebug, ScalarCPU) { - TFE_TensorHandle* h = TestScalarTensorHandle(); + TFE_TensorHandle* h = TestScalarTensorHandle(1.0f); TF_Status* status = TF_NewStatus(); TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 1ce03fb22693960627c27cd4aec58106a9ff3218..2e825341dfeae7cb2276add9006cc58ecdcdfe2b 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" +#include "tensorflow/core/profiler/rpc/profiler_server.h" using tensorflow::string; @@ -24,16 +26,20 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { op->operation.ConsumeInput(h->handle); } -TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx) { +TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx) { return new TFE_Profiler(ctx); } +bool TFE_ProfilerIsOk(TFE_Profiler* profiler) { + return profiler->profiler->Status().ok(); +} + void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; string content; status->status = profiler->profiler->SerializeToString(&content); void* data = tensorflow::port::Malloc(content.length()); @@ -44,3 +50,45 @@ void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, tensorflow::port::Free(data); }; } + +TFE_ProfilerContext* TFE_NewProfilerContext() { + return new TFE_ProfilerContext; +} + +void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, + TFE_Context* eager_context) { + profiler_context->profiler_context.eager_context = &eager_context->context; +} + +void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { + delete profiler_context; +} + +void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { + // Release child thread intentionally. The child thread can be terminate by + // terminating the main thread. + tensorflow::StartProfilerServer(&context->profiler_context, port).release(); +} + +void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(true); +} + +void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(false); +} + +bool TFE_ProfilerClientStartTracing(const char* service_addr, + const char* logdir, const char* worker_list, + bool include_dataset_ops, int duration_ms, + int num_tracing_attempts) { + tensorflow::Status s = + tensorflow::profiler::client::ValidateHostPortPair(service_addr); + if (!s.ok()) { + return false; + } + s = tensorflow::profiler::client::StartTracing( + service_addr, logdir, worker_list, include_dataset_ops, duration_ms, + num_tracing_attempts); + return s.ok(); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 9eb80f521624e0116dd8ea5e4dbbf7e3d350a09c..219b9f40720a4fc212bd6e191b5bb441cf2abeb8 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -25,6 +25,8 @@ extern "C" { TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); +typedef struct TFE_ProfilerContext TFE_ProfilerContext; + // A profiler which will start profiling when creating the object and will stop // when the object is destroyed. It will profile all operations run under the // given TFE_Context. Multiple instance of it can be created, but at most one @@ -32,17 +34,59 @@ TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, // Thread-safety: TFE_Profiler is thread-safe. typedef struct TFE_Profiler TFE_Profiler; -TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx); +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx); +TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler); TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); -// The output string is a binary string of tensorflow.tfprof.ProfileProto. -// User can write the string to file for offline analysis by tfprof command-line -// tools or graphical user interface. +// The output string is a binary string of tensorflow.tpu.Trace. User can write +// the string to file for offline analysis by tensorboard. TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, TF_Buffer* buf, TF_Status* status); +// Return a new profiler context object. +TF_CAPI_EXPORT extern TFE_ProfilerContext* TFE_NewProfilerContext(void); + +// Set the eager context in TFE_ProfilerServerOptions +TF_CAPI_EXPORT extern void TFE_ProfilerContextSetEagerContext( + TFE_ProfilerContext* profiler_context, TFE_Context* eager_context); + +// Destroy a profiler context object. +TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext( + TFE_ProfilerContext* profiler_context); + +// Start a profiler grpc server which listens to specified port. It will start +// the server on its own thread. It can be shutdown by terminating tensorflow. +// It can be used in both Eager mode and graph mode. Creating multiple profiler +// server is allowed. The service defined in +// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable +// file following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context, + int port); + +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); + +// Send a grpc request to profiler server (service_addr) to perform on-demand +// profiling and save the result into logdir which can be visualized by +// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set +// include_dataset_opts to false to profile longer traces. It will block the +// caller thread until receives tracing result. +// This API is designed for TensorBoard, for end user, please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing( + const char* service_addr, const char* logdir, const char* worker_list, + bool include_dataset_ops, int duration_ms, int num_tracing_attempts); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index c3d29bd2e094c1dca50e3e132bc9002503cb056b..d85048caa7c7f727271352883cb834a2575bd251 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/profiler/tfprof_log.pb.h" +#include "tensorflow/core/profiler/trace_events.pb.h" using tensorflow::string; @@ -41,9 +41,12 @@ void ExecuteWithProfiling(bool async) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); - TFE_Profiler* profiler = TFE_NewProfiler(ctx); + TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); + TFE_ProfilerContextSetEagerContext(profiler_context, ctx); + TFE_Profiler* profiler = TFE_NewProfiler(profiler_context); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_DeleteProfilerContext(profiler_context); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -70,17 +73,17 @@ void ExecuteWithProfiling(bool async) { TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status); TFE_DeleteProfiler(profiler); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - tensorflow::tfprof::ProfileProto profile_proto; + profiler::Trace profile_proto; EXPECT_TRUE(profile_proto.ParseFromString( {reinterpret_cast(profiler_result->data), profiler_result->length})); string profile_proto_str = profile_proto.DebugString(); if (!gpu_device_name.empty()) { - EXPECT_TRUE(HasSubstr(profile_proto_str, "gpu:0")); + EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0")); // device name with "stream:all" is collected by Device Tracer. EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); } - EXPECT_TRUE(HasSubstr(profile_proto_str, "cpu:0")); + EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0")); TF_DeleteBuffer(profiler_result); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); @@ -100,5 +103,27 @@ void ExecuteWithProfiling(bool async) { TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } +TEST(CAPI, MultipleProfilerSession) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(false)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); + TFE_ProfilerContextSetEagerContext(profiler_context, ctx); + + TFE_Profiler* profiler1 = TFE_NewProfiler(profiler_context); + EXPECT_TRUE(TFE_ProfilerIsOk(profiler1)); + + TFE_Profiler* profiler2 = TFE_NewProfiler(profiler_context); + EXPECT_FALSE(TFE_ProfilerIsOk(profiler2)); + + TFE_DeleteProfiler(profiler1); + TFE_DeleteProfiler(profiler2); + TFE_DeleteProfilerContext(profiler_context); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index d330aa4290ea7ddf0461a2277bb3d6eb215df241..35dafb9a7f14bfe1fad21bda35685598164f3895 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -52,7 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/profiler/lib/eager_profiler.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/public/version.h" struct TFE_ContextOptions { @@ -83,6 +83,12 @@ struct TFE_TensorHandle { TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; + + // Create a symbolic tensor. + TFE_TensorHandle(TF_Output t, TF_DataType dtype) + : handle(new tensorflow::TensorHandle( + tensorflow::OutputGraphNode{t.oper, t.index}, + static_cast(dtype))) {} }; struct TFE_TensorDebugInfo { @@ -93,19 +99,36 @@ struct TFE_TensorDebugInfo { std::vector dev_dims; }; +struct TFE_OpInferenceContext { + explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def) + : op_def(op_def) {} + + const tensorflow::OpDef* op_def; // op definition from protobuf + int input_arg_idx = 0; // arg definition index for the next input to be added + tensorflow::gtl::FlatSet attrs; // attributes inferred so far +}; + struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, bool is_function, - const tensorflow::AttrTypeMap* t) - : operation(&ctx->context, op, is_function, t) {} + const tensorflow::AttrTypeMap* t, + TFE_OpInferenceContext* inference_ctx) + : operation(&ctx->context, op, is_function, t), + inference_ctx(inference_ctx) {} tensorflow::EagerOperation operation; + std::unique_ptr inference_ctx; +}; + +struct TFE_ProfilerContext { + tensorflow::ProfilerContext profiler_context; }; struct TFE_Profiler { - TFE_Profiler(TFE_Context* ctx) - : profiler(tensorflow::EagerProfiler::Create(&ctx->context)) {} + TFE_Profiler(TFE_ProfilerContext* ctx) { + profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context); + } - std::unique_ptr profiler; + std::unique_ptr profiler; }; namespace tensorflow { @@ -115,4 +138,24 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const char* attr_name, TF_Status* status); } // namespace tensorflow +struct TFE_TraceContext { + TF_Graph* const graph; + + unsigned int node_counter = 0; + // Each tensor handle will have its ref count incremented when it's added as a + // map key, and decremented when this object is destroyed. + std::map input_tensor_map; + std::vector>* input_tensors = + nullptr; + + TFE_TraceContext(TF_Graph* graph) : graph(graph) {} + + ~TFE_TraceContext() { + delete input_tensors; + for (auto input : input_tensor_map) { + input.first->Unref(); + } + } +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3d1ca4fb4b561a03ea9d879b1876fb1fd08a3139..b5e55420016bc9015ab71a515299838be953f5f4 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" #include #include "absl/strings/match.h" @@ -1626,4 +1627,158 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { TFE_DeleteTensorHandle(h); TFE_DeleteTensorHandle(h_shares_tensor); } + +TEST(CAPI, TestTFE_OpInferSingleInputAttrs) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(minOp, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(minOp, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::AttrValueMap attr_values; + minOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T"); + EXPECT_NE(attr_found, attr_values.cend()); + EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT); + attr_found = attr_values.find("Tidx"); + EXPECT_NE(attr_found, attr_values.cend()); + EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_INT32); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input1 = TestMatrixTensorHandle(); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(); + TFE_TensorHandle* dim = TestScalarTensorHandle(0); + TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInput(concatOp, dim, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInputList(concatOp, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::AttrValueMap attr_values; + concatOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T"); + EXPECT_NE(attr_found, attr_values.cend()); + EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT); + attr_found = attr_values.find("N"); + EXPECT_NE(attr_found, attr_values.cend()); + EXPECT_EQ(attr_found->second.i(), 2); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(concatOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + TFE_DeleteOp(concatOp); + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(dim); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* condition = TestScalarTensorHandle(true); + TFE_TensorHandle* t1 = TestMatrixTensorHandle(); + TFE_TensorHandle* t2 = TestAxisTensorHandle(); + TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(assertOp, condition, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* data[] = {condition, t1, t2}; + TFE_OpAddInputList(assertOp, data, 3, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::AttrValueMap attr_values; + assertOp->operation.Attrs().FillAttrValueMap(&attr_values); + tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T"); + EXPECT_NE(attr_found, attr_values.cend()); + EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL); + EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); + EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(assertOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + TFE_DeleteOp(assertOp); + TFE_DeleteTensorHandle(condition); + TFE_DeleteTensorHandle(t1); + TFE_DeleteTensorHandle(t2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input1 = TestMatrixTensorHandle(); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(); + TFE_TensorHandle* dim = TestScalarTensorHandle(0); + TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInput(concatOp, dim, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CHECK(concatOp->inference_ctx); + TFE_OpAddInput(concatOp, inputs[0], status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present"; + TFE_OpAddInput(concatOp, inputs[1], status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::AttrValueMap attr_values; + concatOp->operation.Attrs().FillAttrValueMap(&attr_values); + EXPECT_EQ(attr_values.find("T"), attr_values.end()); + EXPECT_EQ(attr_values.find("N"), attr_values.end()); + + TF_DeleteStatus(status); + TFE_DeleteOp(concatOp); + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(dim); + TFE_DeleteContext(ctx); +} } // namespace diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index bd38127d50c171af801dd1b937acefdba491b4a6..17d17c0b7f7909e8dc1aaea61ade2cce1c466a3f 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -21,8 +21,8 @@ limitations under the License. using tensorflow::string; -TFE_TensorHandle* TestScalarTensorHandle() { - float data[] = {1.0f}; +TFE_TensorHandle* TestScalarTensorHandle(float value) { + float data[] = {value}; TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); @@ -33,6 +33,30 @@ TFE_TensorHandle* TestScalarTensorHandle() { return th; } +TFE_TensorHandle* TestScalarTensorHandle(int value) { + int data[] = {value}; + TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestScalarTensorHandle(bool value) { + bool data[] = {value}; + TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* DoubleTestMatrixTensorHandle() { int64_t dims[] = {2, 2}; double data[] = {1.0, 2.0, 3.0, 4.0}; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 75ef9459e93b4f2ed471c423a34565594efc1714..4ff3ff4301f63c001bec1eb23fb9e098b78c6a5e 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -20,7 +20,13 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // Return a tensor handle containing a float scalar -TFE_TensorHandle* TestScalarTensorHandle(); +TFE_TensorHandle* TestScalarTensorHandle(float value); + +// Return a tensor handle containing a int scalar +TFE_TensorHandle* TestScalarTensorHandle(int value); + +// Return a tensor handle containing a bool scalar +TFE_TensorHandle* TestScalarTensorHandle(bool value); // Return a tensor handle containing a 2x2 matrix of doubles TFE_TensorHandle* DoubleTestMatrixTensorHandle(); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 5c11f51e8749de84547ae873f5f55ebd42bc4b3d..29896b0d2bf6860775bed00284e7d6ff7992f474 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -639,7 +639,9 @@ Status GradientTape::ComputeGradient( } } } - CHECK(state.op_tape.empty()); + if (!state.op_tape.empty()) { + return tensorflow::errors::Internal("Invalid tape state."); + } result->reserve(source_tensor_ids.size()); gtl::FlatSet used_gradient_ids(source_tensor_ids.size()); for (auto is : source_tensor_ids) { diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index c45efd7fae252dc7c6c98354faa66c4d690bd5e7..71181ae430ab64106e2a75937bd54fbf2efc61ac 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -48,9 +48,10 @@ TF_KernelBuilder* TF_NewKernelBuilder( } void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { - DCHECK_NE(builder, nullptr); - delete builder->cc_builder; - delete builder; + if (builder != nullptr) { + delete builder->cc_builder; + delete builder; + } } namespace tensorflow { @@ -172,9 +173,10 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { cc_ctx->CtxFailure(s); } -#define DEFINE_TF_GETATTR_(struct_name, func, c_type, cc_type) \ - void struct_name##_GetAttr##func(struct_name* ctx, const char* attr_name, \ - c_type* val, TF_Status* status) { \ +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ + void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ + const char* attr_name, \ + c_type* val, TF_Status* status) { \ TF_SetStatus(status, TF_OK, ""); \ cc_type v; \ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \ @@ -185,10 +187,6 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { } \ } -#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ - DEFINE_TF_GETATTR_(TF_OpKernelConstruction, func, c_type, cc_type) \ - DEFINE_TF_GETATTR_(TF_OpKernelContext, func, c_type, cc_type) - DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index b015d0103969355e8566242bfcc007f697c6ae18..c47bfa8aa3a721d422a0a1536b924f3e53793193 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -137,15 +137,6 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, TF_Status* status); -// Interprets the named kernel context attribute as a TF_DataType and places it -// into *val. *status is set to TF_OK. -// -// If the attribute could not be found or could not be interpreted as -// TF_DataType, *status is populated with an error. -TF_CAPI_EXPORT extern void TF_OpKernelContext_GetAttrType( - TF_OpKernelContext* ctx, const char* attr_name, TF_DataType* val, - TF_Status* status); - #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..597182ab016a1b9564579ac0abf9667cf6d1dce9 --- /dev/null +++ b/tensorflow/c/kernels/BUILD @@ -0,0 +1,44 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_kernel_library", +) + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +tf_kernel_library( + name = "bitcast_op", + prefix = "bitcast_op", + deps = [ + "//tensorflow/c:kernels", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + ], +) + +tf_cc_test( + name = "bitcast_op_test", + srcs = ["bitcast_op_test.cc"], + deps = [ + ":bitcast_op", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Changes to the Android srcs here should be replicated in +# tensorflow/contrib/makefile/tf_op_files.txt +# LINT.IfChange +filegroup( + name = "android_all_ops", + srcs = [ + "bitcast_op.cc", + ], +) +# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt) diff --git a/tensorflow/c/kernels/bitcast_op.cc b/tensorflow/c/kernels/bitcast_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2f313af38677a3b93580fab1730363b43b32615 --- /dev/null +++ b/tensorflow/c/kernels/bitcast_op.cc @@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/kernels.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" + +// BitcastOp implements a bitcast kernel, creating an output tensor that shares +// the same data buffer as the input but with a different shape and/or data +// type. Its inputs are: +// +// * the input tensor +// * an attribute named "T" containing the TF_DataType of the input tensor +// * an attribute named "type" containing the TF_DataType of the output tensor +// +// Given an input tensor of shape [...], if the input DataType "T" is larger +// than the output DataType "type", then the shape changes from [...] +// to [..., sizeof(T)/sizeof(type)]. +// +// If "T" is smaller than "type", the operator requires that the rightmost +// dimension be equal to sizeof(type)/sizeof(T). The shape then goes from +// [..., sizeof(type)/sizeof(T)] to [...]. +// +// Bitcast is implemented as a low-level cast, so machines with different endian +// orderings will give different results. +typedef struct BitcastOp { + TF_DataType input_data_type; + TF_DataType output_data_type; + size_t in_size; + size_t out_size; +} BitcastOp; + +static void* BitcastOp_Create(TF_OpKernelConstruction* ctx) { + auto* kernel = new BitcastOp; + + TF_Status* s = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "T", &kernel->input_data_type, s); + + if (TF_GetCode(s) == TF_OK) { + TF_OpKernelConstruction_GetAttrType(ctx, "type", &kernel->output_data_type, + s); + } + + if (TF_GetCode(s) == TF_OK) { + kernel->in_size = TF_DataTypeSize(kernel->input_data_type); + kernel->out_size = TF_DataTypeSize(kernel->output_data_type); + + size_t check_size = std::max(kernel->in_size, kernel->out_size) % + std::min(kernel->in_size, kernel->out_size); + if (check_size != 0) { + std::ostringstream err; + err << "cannot convert between datatype " << kernel->input_data_type + << " and " << kernel->output_data_type; + TF_SetStatus(s, TF_INVALID_ARGUMENT, err.str().c_str()); + } + } + + if (TF_GetCode(s) != TF_OK) { + TF_OpKernelConstruction_Failure(ctx, s); + delete kernel; + kernel = nullptr; + } + + TF_DeleteStatus(s); + return kernel; +} + +static void BitcastOp_Delete(void* kernel) { + delete static_cast(kernel); +} + +static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) { + auto* k = static_cast(kernel); + int dim_count = 0; + + TF_Tensor* tensor; + TF_Status* status = TF_NewStatus(); + TF_GetInput(ctx, 0, &tensor, status); + if (TF_GetCode(status) == TF_OK) { + dim_count = TF_NumDims(tensor); + if (!(k->in_size >= k->out_size || + (dim_count > 0 && + TF_Dim(tensor, dim_count - 1) == k->out_size / k->in_size))) { + std::ostringstream err; + err << "Cannot bitcast from " << k->input_data_type << " to " + << k->output_data_type; + TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); + } + } + + if (TF_GetCode(status) == TF_OK) { + auto* dims = new int64_t[dim_count + 1]; + int new_dim_count = dim_count; + for (int dim = 0; dim < dim_count; ++dim) { + dims[dim] = TF_Dim(tensor, dim); + } + if (k->out_size < k->in_size) { + dims[new_dim_count++] = static_cast(k->in_size / k->out_size); + } else if (k->out_size > k->in_size) { + --new_dim_count; + } + + TF_Tensor* output = TF_AllocateTensor(k->output_data_type, dims, 0, + TF_DataTypeSize(k->output_data_type)); + TF_TensorBitcastFrom(tensor, k->output_data_type, output, dims, + new_dim_count, status); + if (TF_GetCode(status) == TF_OK) { + TF_SetOutput(ctx, 0, output, status); + } + delete[] dims; + TF_DeleteTensor(output); + } + + if (TF_GetCode(status) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status); + } + TF_DeleteStatus(status); + TF_DeleteTensor(tensor); +} + +static void RegisterBitcastOp() { + TF_Status* status = TF_NewStatus(); + + { + auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU, + &BitcastOp_Create, &BitcastOp_Compute, + &BitcastOp_Delete); + TF_RegisterKernelBuilder("BitcastOp", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering bitcast kernel"; + } + +#if GOOGLE_CUDA + { + auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU, + &BitcastOp_Create, &BitcastOp_Compute, + &BitcastOp_Delete); + TF_RegisterKernelBuilder("BitcastOp", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering CUDA bitcast kernel"; + } +#endif + + TF_DeleteStatus(status); +} + +// A dummy static variable initialized by a lambda whose side-effect is to +// register the bitcast kernel. +static bool BitcastOpIsRegistered = []() { + if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) { + RegisterBitcastOp(); + } + return true; +}(); diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..06ffcca19dad13dc2fcae15125e25fc546562566 --- /dev/null +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +void TestBitcastOp(Tensor* input_tensor, DataType out_type, + TensorShape expected_shape, error::Code expected_code) { + Status status; + NodeDef def; + def.set_op("Bitcast"); + def.set_device(DEVICE_CPU); + + AttrValue typeAttr; + SetAttrValue(input_tensor->dtype(), &typeAttr); + + AttrValue outTypeAttr; + SetAttrValue(out_type, &outTypeAttr); + + (*def.mutable_attr())["T"] = typeAttr; + (*def.mutable_attr())["type"] = outTypeAttr; + + def.add_input( + strings::StrCat("input1: ", DataTypeString(input_tensor->dtype()))); + + std::unique_ptr kernel = + CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status); + ASSERT_TRUE(status.ok()) << status.ToString(); + + OpKernelContext::Params params; + DummyDevice dummy_device(nullptr, false); + params.device = &dummy_device; + params.op_kernel = kernel.get(); + gtl::InlinedVector inputs; + inputs.emplace_back(input_tensor); + params.inputs = &inputs; + + OpKernelContext ctx(¶ms); + kernel->Compute(&ctx); + ASSERT_EQ(expected_code, ctx.status().code()); + if (expected_code == error::OK) { + ASSERT_EQ(expected_shape, ctx.mutable_output(0)->shape()) + << ctx.mutable_output(0)->shape().DebugString(); + } +} + +TEST(BitcastOpTest, TestUpcast) { + Tensor int8_input(DT_UINT8, {8}); + for (int i = 0; i < 8; i++) { + int8_input.vec()(i) = static_cast(1); + } + TestBitcastOp(&int8_input, DT_UINT64, TensorShape(), error::OK); +} + +TEST(BitcastOpTest, TestDowncast) { + Tensor int64_input(static_cast(1)); + TestBitcastOp(&int64_input, DT_UINT8, TensorShape({8}), error::OK); +} + +TEST(BitcastOpTest, TestCastToSameSize) { + Tensor int32_input(DT_UINT32, {4, 6}); + TestBitcastOp(&int32_input, DT_UINT8, TensorShape({4, 6, 4}), error::OK); +} + +TEST(BitcastOpTest, TestImpossibleCast) { + Tensor int8_input(DT_UINT8, {1}); + TestBitcastOp(&int8_input, DT_UINT32, TensorShape(), error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 531eb598fa8b6ff3f7c87641091f8d5a11752728..608887722f7bca44c884a3426d5e378e9387a530 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -36,6 +36,15 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { struct MyCustomKernel* s = new struct MyCustomKernel; s->created = true; s->compute_called = false; + + // Exercise attribute reads. + TF_DataType type; + TF_Status* status = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + EXPECT_EQ(TF_FLOAT, type); + TF_DeleteStatus(status); + return s; } @@ -43,17 +52,7 @@ static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { struct MyCustomKernel* s = static_cast(kernel); s->compute_called = true; if (ctx != nullptr) { - TF_Status* status = TF_NewStatus(); - EXPECT_EQ(43, TF_StepId(ctx)); - - // Exercise attribute reads. - TF_DataType type; - TF_OpKernelContext_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - EXPECT_EQ(TF_FLOAT, type); - - TF_DeleteStatus(status); } } @@ -224,4 +223,8 @@ TEST(TestKernel, TestInputAndOutputCount) { } } +TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) { + TF_DeleteKernelBuilder(nullptr); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 98d8393332269ae349cf8aa5c0b612c6f17172e6..2c9d9f3a15b4dfec9d8fe511c8005cafc1d71ef7 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -41,6 +41,15 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, RecordMutation(graph, *op, "setting attribute"); } +void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status) { + AttrValue attr_val; + + mutex_lock l(graph->mu); + op->node.ClearAttr(attr_name); + RecordMutation(graph, *op, "clearing attribute"); +} + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); @@ -80,7 +89,7 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // This modification only updates the destination node for // the purposes of running this graph in a session. Thus, we don't // record the source node as being modified. @@ -154,7 +163,7 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, tensorflow::shape_inference::ShapeHandle shape; status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); @@ -165,7 +174,7 @@ void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, mutex_lock l(graph->mu); status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, new_src.index, &dst->node); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // This modification only updates the destination node for // the purposes of running this graph in a session. Thus, we don't // record the source node as being modified. diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 44779ca656165dd65590cb5e9ea3ccf71165ed63..f26c0cb2ae2f6e00a247660a02525901e87920cd 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -32,6 +32,11 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, TF_Buffer* attr_value_proto, TF_Status* status); +// Clears the attr in the node_def Protocol Buffer and sets a status upon +// completion. +void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status); + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); // Updates 'dst' to consume 'new_src'. diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a09becc49b10d2c58f98fbcc11df5190f794c1d4..545e472115ad1a3f001754cbec37448696076c35 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -8,6 +8,19 @@ package( licenses(["notice"]) # Apache 2.0 +filegroup( + name = "srcs", + srcs = [ + "framework/gradients.h", + "framework/ops.h", + "framework/scope.h", + "framework/scope_internal.h", + "ops/array_ops.h", + "ops/while_loop.h", + "//tensorflow/cc/saved_model:loader.h", + ], +) + load( "//tensorflow:tensorflow.bzl", "cc_library_with_android_deps", @@ -150,6 +163,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ], ) @@ -586,16 +600,32 @@ tf_gen_op_wrappers_cc( pkg = "//tensorflow/core", ) -cc_library_with_android_deps( +tf_gen_op_wrappers_cc( + name = "tpu_ops", + include_internal_ops = 1, + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + +cc_library( name = "cc_op_gen_main", srcs = [ "framework/cc_op_gen.cc", "framework/cc_op_gen.h", "framework/cc_op_gen_main.cc", ], - android_deps = [ - "//tensorflow/core:android_tensorflow_lib", - ], copts = tf_copts(), data = [ "//tensorflow/core/api_def:base_api_def", @@ -666,7 +696,7 @@ tf_cc_binary( copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:darwin": [ + "//tensorflow:macos": [ "-lm", "-lpthread", ], diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 39593370d1c243e84dc5b6091724d1d404c102b0..43a33cbea6e1e4a50f61cc7d6d8d70cac6a603d2 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -321,6 +321,7 @@ std::pair AttrTypeName(StringPiece attr_type) { {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice", true}}, {"func", {"NameAttrList", true}}, + {"list(func)", {"gtl::ArraySlice", true}}, }; auto entry = attr_type_map->find(attr_type); diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index affd90b1bcc7cb4a8b3ffed6aeeb4bd480f5e314..303fdf64ec723864848096009a57dabda2fc93e4 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -96,7 +96,7 @@ class SymbolicGradientBuilder { // Used to identify nodes at which to stop backprop. std::unordered_set GetStopBackpropNodes( const std::vector& reachable_nodes, - std::unordered_set output_nodes); + const std::unordered_set& output_nodes); const Scope& scope_; const ops::GradOpRegistry* registry_; @@ -167,7 +167,6 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, std::vector SymbolicGradientBuilder::GetReachableNodes() { std::vector reachable_nodes(scope_.graph()->num_node_ids(), false); std::deque queue; - std::vector visited(scope_.graph()->num_node_ids(), false); for (const Output& out : outputs_) { if (!reachable_nodes[out.node()->id()]) { queue.push_back(out.node()); @@ -180,10 +179,10 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { queue.pop_front(); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; - if (visited[e->src()->id()]) continue; - queue.push_back(e->src()); - reachable_nodes[e->src()->id()] = true; - visited[e->src()->id()] = true; + if (!reachable_nodes[e->src()->id()]) { + queue.push_back(e->src()); + reachable_nodes[e->src()->id()] = true; + } } } return reachable_nodes; @@ -191,7 +190,7 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( const std::vector& reachable_nodes, - std::unordered_set output_nodes) { + const std::unordered_set& output_nodes) { // Output nodes that get transitively consumed by other `outputs_` are stored // in `internal_outputs`. std::unordered_set internal_outputs; @@ -201,9 +200,9 @@ std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( // `output_` node was encountered, pair.second will be nullptr. std::deque> queue; for (const Output& nout : inputs_) { - if (visited.find(nout.node()) == visited.end()) { + auto const& pair = visited.insert(nout.node()); + if (pair.second) { queue.push_back(std::make_pair(nout.node(), static_cast(nullptr))); - visited.insert(nout.node()); } } // BFS from nodes in 'inputs_' along out edges for the entire graph. Internal @@ -217,22 +216,23 @@ std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( for (const Edge* e : n->out_edges()) { // If a node is not reachable from outputs_, we can stop. if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; - if (visited.find(e->dst()) != visited.end()) continue; - - int node_id = e->dst()->id(); - Node* last_output_node = p.second; - if (output_nodes.find(node_id) != output_nodes.end()) { - // We reached an output node. - if (last_output_node != nullptr) { - // If we had already found an output node on this path so we mark - // it as an internal output. - internal_outputs.insert(last_output_node->id()); + + auto const& pair = visited.insert(e->dst()); + if (pair.second) { + int node_id = e->dst()->id(); + Node* last_output_node = p.second; + if (output_nodes.find(node_id) != output_nodes.end()) { + // We reached an output node. + if (last_output_node != nullptr) { + // If we had already found an output node on this path so we mark + // it as an internal output. + internal_outputs.insert(last_output_node->id()); + } + // Mark this newly found output node to insert in the queue. + last_output_node = e->dst(); } - // Mark this newly found output node to insert in the queue. - last_output_node = e->dst(); + queue.push_back(std::make_pair(e->dst(), last_output_node)); } - queue.push_back(std::make_pair(e->dst(), last_output_node)); - visited.insert(e->dst()); } } // Finally, we set stop_backprop_nodes to all output_nodes that aren't also @@ -286,9 +286,9 @@ Status SymbolicGradientBuilder::Initialize() { std::unordered_set visited; std::deque queue; for (const Output& nout : inputs_) { - if (visited.find(nout.node()) == visited.end()) { + auto const& pair = visited.insert(nout.node()); + if (pair.second) { queue.push_back(nout.node()); - visited.insert(nout.node()); } } @@ -309,9 +309,9 @@ Status SymbolicGradientBuilder::Initialize() { // we don't expect it to receive a backpropagated gradient. // It will not be counted in num_expected_backprops. if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; - if (visited.find(e->dst()) == visited.end()) { + auto const& pair = visited.insert(e->dst()); + if (pair.second) { queue.push_back(e->dst()); - visited.insert(e->dst()); } ++num_expected_backprops; } diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 81785b2d89b3d36b46992b7ae376b5175a806027..134d64af140f394825bb75477e31639be1aa8d50 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -153,6 +152,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(device), + assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -313,11 +314,10 @@ Status Scope::ToGraphDef(GraphDef* gdef) const { return Status::OK(); } -Status Scope::ToGraph(Graph* g) const { +Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const { if (ok()) { GraphDef graph_def; graph()->ToGraphDef(&graph_def); - GraphConstructorOptions opts; UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g)); } return *impl()->status_; diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 0a75f23725c143e6b22ee6dffae1428ed8209fe8..1e17b74bc8f05869e50aa4ec645e57f7bcebc9f6 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -205,14 +206,15 @@ class Scope { // START_SKIP_DOXYGEN - /// If status() is Status::OK(), construct a Graph object using the default + /// If status() is Status::OK(), construct a Graph object using `opts` as the /// GraphConstructorOptions, and return Status::OK if graph construction was /// successful. Otherwise, return the error status. // TODO(josh11b, keveman): Make this faster; right now it converts // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds // edges from the source and to the sink node, resolves back edges // by name), and makes sure the resulting graph is valid. - Status ToGraph(Graph* g) const; + Status ToGraph( + Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const; // Calls AddNode() using this scope's ShapeRefiner. This exists in the public // API to prevent custom op wrappers from needing access to shape_refiner.h or diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 05c287bdc62cdb8be7208ce3975f280aaa816766..056eea7eb5a5c796da01cd7353662da35aecddf9 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -29,13 +29,17 @@ Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op, bool align_corners; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + bool half_pixel_centers; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "half_pixel_centers", + &half_pixel_centers)); // The internal gradient implementation needs the shape of the input image. // x_shape = shape(x)[1:3] // = slice(shape(x), {1}, {3 - 1}) auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2}); grad_outputs->push_back(internal::ResizeNearestNeighborGrad( scope, grad_inputs[0], x_shape, - internal::ResizeNearestNeighborGrad::AlignCorners(align_corners))); + internal::ResizeNearestNeighborGrad::AlignCorners(align_corners) + .HalfPixelCenters(half_pixel_centers))); grad_outputs->push_back(NoGradient()); return scope.status(); } @@ -47,9 +51,13 @@ Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op, bool align_corners; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + bool half_pixel_centers; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "half_pixel_centers", + &half_pixel_centers)); grad_outputs->push_back(internal::ResizeBilinearGrad( scope, grad_inputs[0], op.input(0), - internal::ResizeBilinearGrad::AlignCorners(align_corners))); + internal::ResizeBilinearGrad::AlignCorners(align_corners) + .HalfPixelCenters(half_pixel_centers))); grad_outputs->push_back(NoGradient()); return scope.status(); } @@ -61,9 +69,14 @@ Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, bool align_corners; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + bool half_pixel_centers; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "half_pixel_centers", + &half_pixel_centers)); + grad_outputs->push_back(internal::ResizeBicubicGrad( scope, grad_inputs[0], op.input(0), - internal::ResizeBicubicGrad::AlignCorners(align_corners))); + internal::ResizeBicubicGrad::AlignCorners(align_corners) + .HalfPixelCenters(half_pixel_centers))); grad_outputs->push_back(NoGradient()); return scope.status(); } @@ -75,17 +88,40 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, string kernel_type; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); + bool antialias; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias)); grad_outputs->push_back(internal::ScaleAndTranslateGrad( scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), - internal::ScaleAndTranslateGrad::KernelType(kernel_type))); + internal::ScaleAndTranslateGrad::KernelType(kernel_type) + .Antialias(antialias))); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); return scope.status(); } + REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); +Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + DataType input_type; + string method; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type)); + auto image_shape = Shape(scope, op.input(0)); + grad_outputs->push_back(CropAndResizeGradImage( + scope, grad_inputs[0], op.input(1), op.input(2), image_shape, input_type, + CropAndResizeGradImage::Method(method))); + grad_outputs->push_back(CropAndResizeGradBoxes( + scope, grad_inputs[0], op.input(0), op.input(1), op.input(2))); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} + +REGISTER_GRADIENT_OP("CropAndResize", CropAndResizeGradHelper); } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 1d150226538093467e092e02f38090a327f9c9b6..d50f4f5750a680ed2ac20451b5522ede28cc474d 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -27,6 +27,7 @@ namespace tensorflow { namespace { using ops::Const; +using ops::CropAndResize; using ops::ResizeBicubic; using ops::ResizeBilinear; using ops::ResizeNearestNeighbor; @@ -51,7 +52,8 @@ class ImageGradTest : public ::testing::Test { template void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape, - const bool align_corners, Output* x, Output* y) { + const bool align_corners, const bool half_pixel_centers, + Output* x, Output* y) { *x = Const(scope_, x_data); switch (op_type) { case RESIZE_NEAREST: @@ -61,22 +63,26 @@ class ImageGradTest : public ::testing::Test { return; case RESIZE_BILINEAR: *y = ResizeBilinear(scope_, *x, y_shape, - ResizeBilinear::AlignCorners(align_corners)); + ResizeBilinear::AlignCorners(align_corners) + .HalfPixelCenters(half_pixel_centers)); return; case RESIZE_BICUBIC: *y = ResizeBicubic(scope_, *x, y_shape, - ResizeBicubic::AlignCorners(align_corners)); + ResizeBicubic::AlignCorners(align_corners) + .HalfPixelCenters(half_pixel_centers)); return; } assert(false); } template - void TestResizedShapeForType(const OpType op_type, const bool align_corners) { + void TestResizedShapeForType(const OpType op_type, const bool align_corners, + const bool half_pixel_centers) { TensorShape x_shape({1, 2, 2, 1}); Tensor x_data = MakeData(x_shape); Output x, y; - MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + MakeOp(op_type, x_data, {4, 6}, align_corners, half_pixel_centers, &x, + &y); ClientSession session(scope_); std::vector outputs; @@ -86,44 +92,64 @@ class ImageGradTest : public ::testing::Test { } void TestResizedShape(OpType op_type) { - for (const bool align_corners : {true, false}) { - TestResizedShapeForType(op_type, align_corners); - TestResizedShapeForType(op_type, align_corners); - TestResizedShapeForType(op_type, align_corners); + for (const bool half_pixel_centers : {true, false}) { + for (const bool align_corners : {true, false}) { + if (half_pixel_centers && align_corners) { + continue; + } + TestResizedShapeForType(op_type, align_corners, + half_pixel_centers); + TestResizedShapeForType(op_type, align_corners, + half_pixel_centers); + TestResizedShapeForType(op_type, align_corners, + half_pixel_centers); + } } } template void TestResizeToSmallerAndAlign(const OpType op_type, - const bool align_corners) { + const bool align_corners, + const bool half_pixel_centers) { TensorShape x_shape({1, 4, 6, 1}); Tensor x_data = MakeData(x_shape); Output x, y; - MakeOp(op_type, x_data, {2, 3}, align_corners, &x, &y); + MakeOp(op_type, x_data, {2, 3}, align_corners, half_pixel_centers, &x, + &y); JAC_T max_error; TF_ASSERT_OK((ComputeGradientError( scope_, x, x_data, y, {1, 2, 3, 1}, &max_error))); - EXPECT_LT(max_error, 1e-3); + EXPECT_LT(max_error, 1.5e-3); } template void TestResizeToLargerAndAlign(const OpType op_type, - const bool align_corners) { + const bool align_corners, + const bool half_pixel_centers) { TensorShape x_shape({1, 2, 3, 1}); Tensor x_data = MakeData(x_shape); Output x, y; - MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + MakeOp(op_type, x_data, {4, 6}, align_corners, half_pixel_centers, &x, + &y); JAC_T max_error; TF_ASSERT_OK((ComputeGradientError( scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); - EXPECT_LT(max_error, 1e-3); + EXPECT_LT(max_error, 1.5e-3); } template void TestResize(OpType op_type) { - for (const bool align_corners : {true, false}) { - TestResizeToSmallerAndAlign(op_type, align_corners); - TestResizeToLargerAndAlign(op_type, align_corners); + for (const bool half_pixel_centers : {true, false}) { + for (const bool align_corners : {true, false}) { + // if (!half_pixel_centers) continue; + if (half_pixel_centers && align_corners) { + continue; + } + TestResizeToSmallerAndAlign(op_type, align_corners, + half_pixel_centers); + TestResizeToLargerAndAlign(op_type, align_corners, + half_pixel_centers); + } } } @@ -170,29 +196,151 @@ class ScaleAndTranslateGradTest : public ::testing::Test { } template - void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x, - Output* y) { + void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale, + Input translation, const string& kernel_type, bool antialias, + Output* x, Output* y) { *x = Const(scope_, x_data); - *y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f}); + *y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation, + ScaleAndTranslate::KernelType(kernel_type) + .Antialias(antialias) + .Antialias(antialias)); TF_ASSERT_OK(scope_.status()); } template - void TestResize() { - TensorShape x_shape({1, 2, 3, 1}); + void TestScaleAndTranslate(const TensorShape x_shape, const int out_height, + const int out_width, Input scale, + Input translation, const string& kernel_type, + bool antialias) { Tensor x_data = MakeData(x_shape); Output x, y; - MakeOp(x_data, {4, 6}, &x, &y); + MakeOp(x_data, {out_height, out_width}, scale, translation, + kernel_type, antialias, &x, &y); JAC_T max_error; TF_ASSERT_OK((ComputeGradientError( - scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error))); + EXPECT_LT(max_error, 2e-3); + } + + const std::vector kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f}, + Input{2.1f, 2.1f}}; + const std::vector kTranslations = { + Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f}, + Input{100.0f, 200.0f}}; + Scope scope_; +}; + +TEST_F(ScaleAndTranslateGradTest, TestGrads) { + const std::vector kKernelTypes = {"lanczos1", "lanczos3", + "lanczos5", "gaussian"}; + constexpr int kOutHeight = 4; + constexpr int kOutWidth = 6; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) { + constexpr int kOutHeight = 4; + constexpr int kOutWidth = 6; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + TestScaleAndTranslate(kXShape, kOutHeight, kOutWidth, + scale, translation, "lanczos3", + false); + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) { + const std::vector kKernelTypes = {"lanczos3", "gaussian"}; + + constexpr int kOutHeight = 2; + constexpr int kOutWidth = 3; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) { + const std::vector kKernelTypes = {"lanczos3", "gaussian"}; + constexpr int kOutHeight = 2; + constexpr int kOutWidth = 3; + + const TensorShape kXShape = TensorShape({1, 4, 6, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +class CropAndResizeGradTest : public ::testing::Test { + protected: + CropAndResizeGradTest() : scope_(Scope::NewRootScope()) {} + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind, + const Input& crop_size, Output* x, Output* y) { + *x = Const(scope_, x_data); + *y = CropAndResize(scope_, *x, boxes, box_ind, crop_size, + CropAndResize::Method("bilinear")); + TF_ASSERT_OK(scope_.status()); + } + + template + void TestCropAndResize() { + TensorShape x_shape({1, 4, 2, 1}); + Tensor x_data = MakeData(x_shape); + TensorShape box_shape({1, 4}); + Tensor boxes = MakeData(box_shape); + Output x, y; + MakeOp(x_data, boxes, {0}, {1, 1}, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 1, 1, 1}, &max_error))); EXPECT_LT(max_error, 1e-3); } Scope scope_; }; -TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } +TEST_F(CropAndResizeGradTest, TestCrop) { + TestCropAndResize(); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index cf65fe1ab99b49207a64e86310178141b30d07d7..e9838d9aba6554b40082187057851e9c896f8352 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -10,7 +10,7 @@ tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], tags = [ - "noguitar", # b/77649654 + "nogpu", # b/77649654 ], deps = [ ":profiler", diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 52345a376cc29ee47ccb9888c9bb26292468b5a9..dedd55f16afb879ea966dc89d14d88ee15d9e83e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -81,6 +81,7 @@ cc_library( ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", ]) + if_android([ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 10f7abf09e925c0c31cfd595ecee4605f189476f..66260fcf4a9b24f78d45010c6e86d4ee398b6d3d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" -#include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 23e9dc40d23899b9cef168c9128b6d8ed1be3ee9..eeb910178902ca883ed211379ba3f188c139f92e 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -124,7 +124,9 @@ Status GetVariableNameToTensorMap( return Status::OK(); } std::vector variable_names; + variable_names.reserve(variable_names_set.size()); std::vector tensor_names; + tensor_names.reserve(variable_names_set.size()); for (const string& node_name : variable_names_set) { variable_names.push_back(node_name); NodeDef* node_def = name_to_node_map.at(node_name); diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49cb74f19ef325c5861b124e458dd7e3b7f436e9 --- /dev/null +++ b/tensorflow/compat_template.__init__.py @@ -0,0 +1,57 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os +import sys as _sys + +# pylint: disable=g-bad-import-order + +# API IMPORTS PLACEHOLDER + +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg=( + "Limited tf.compat.v2.summary API due to missing TensorBoard " + "installation")) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v2.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v2.keras')) + +# We would like the following to work for fully enabling 2.0 in a 1.0 install: +# +# import tensorflow.compat.v2 as tf +# tf.enable_v2_behavior() +# +# This make this one symbol available directly. +from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top + +# Add module aliases +_current_module = _sys.modules[__name__] +if hasattr(_current_module, 'keras'): + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers + initializers = keras.initializers diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index b966c22b2319aef3b87ef54a283911718d37cf84..9549a71c41a0ba2aac58abd8cfb182aa4eaf3b4f 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -28,7 +28,8 @@ from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v1.keras')) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 16151e77737429f4fbf690fc34b12a70bacebdc4..af016bf80e7a10d8729a1eb385466af48b5810cd 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -30,6 +30,7 @@ cc_library( "flags.h", ], deps = [ + ":aot_only_var_handle_op", ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:cpu_function_runtime", @@ -71,6 +72,7 @@ tf_cc_test( ":tfcompile_lib", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", @@ -205,6 +207,15 @@ cc_library( ], ) +cc_library( + name = "aot_only_var_handle_op", + srcs = ["aot_only_var_handle_op.cc"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + ], + alwayslink = 1, +) + tf_cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], diff --git a/tensorflow/compiler/aot/aot_only_var_handle_op.cc b/tensorflow/compiler/aot/aot_only_var_handle_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce36a979f424610a5aa952afa8db2245ed971a9 --- /dev/null +++ b/tensorflow/compiler/aot/aot_only_var_handle_op.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Implementation of varhandle that binds a VarHandleOp to an XlaResource of the +// same name. It is not safe to use this op in a JIT context. +class XlaAotOnlyVarHandleOp : public XlaOpKernel { + public: + explicit XlaAotOnlyVarHandleOp(OpKernelConstruction* c); + void Compile(XlaOpKernelContext* context) override; + + private: + string name_; +}; + +XlaAotOnlyVarHandleOp::XlaAotOnlyVarHandleOp(OpKernelConstruction* c) + : XlaOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("shared_name", &name_)); +} + +void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) { + // Look for a resource of the same name. TF also keys that on the container + // and type attributes, but that doesn't seem necessary. + for (const auto& resource : context->xla_context()->resources()) { + if (resource->kind() == XlaResource::kVariable && + resource->name() == name_) { + context->SetResourceOutput(0, resource.get()); + return; + } + } + context->SetStatus( + errors::InvalidArgument("Variable: ", name_, " not configured")); +} +} // namespace + +REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index d016632da2a9d7c2c2f81c02dd573787a0502923..2355fad8802a490fafb702f53d88312611f9ebf4 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -168,12 +168,12 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (config.feed_size() != num_args) { - return errors::InvalidArgument("mismatch between feed_size(", - config.feed_size(), ") and num_args(", - num_args, ")"); + if (config.feed_size() + config.variable_size() != num_args) { + return errors::InvalidArgument( + "mismatch between feed_size(", config.feed_size(), ")+variable_size(", + config.variable_size(), ") and num_args(", num_args, ")"); } - for (int i = 0; i < num_args; ++i) { + for (int i = 0; i < config.feed_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); @@ -212,12 +212,14 @@ Status GenResultMethods(const tf2xla::Config& config, // tuple result, and we rely on this to simplify code generation. return errors::Internal("codegen requires the XLA result to be a tuple"); } - if (config.fetch_size() != ps.result().tuple_shapes_size()) { + size_t num_results = ps.result().tuple_shapes_size(); + if (config.fetch_size() + config.variable_size() != num_results) { return errors::InvalidArgument("mismatch between fetch_size(", - config.feed_size(), ") and tuple_size(", + config.fetch_size(), ")+variable_size(", + config.variable_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { + for (int i = 0; i < config.fetch_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR(AddRewritesForShape( i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); @@ -245,6 +247,51 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generate methods for variables. +Status GenVariableMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, string* methods) { + size_t num_args = ps.parameters_size(); + for (int i = config.feed_size(); i < num_args; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); + const string code = R"( + void set_var_{{NAME}}_data({{TYPE}}* data) { + set_arg_data({{I}}, data); + } +)"; + const tf2xla::Variable& var = config.variable(i - config.feed_size()); + *methods += RewriteWithName( + var.name().empty() ? var.node_name() : var.name(), code, rewrites); + } + size_t num_results = ps.result().tuple_shapes_size(); + for (int i = config.fetch_size(); i < num_results; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR(AddRewritesForShape( + i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); + string code = R"( + {{TYPE}}* var_{{NAME}}_data() { + return static_cast<{{TYPE}}*>(result_data({{I}})); + } + {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) { + return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( + result_data({{I}}))){{INDICES}}; + } + const {{TYPE}}* var_{{NAME}}_data() const { + return static_cast(result_data({{I}})); + } + const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const { + return (*static_cast( + result_data({{I}}))){{INDICES}}; + } +)"; + const tf2xla::Variable& var = config.variable(i - config.fetch_size()); + *methods += RewriteWithName( + var.name().empty() ? var.node_name() : var.name(), code, rewrites); + } + return Status::OK(); +} + // Generates code implementing {Arg,Result}Names(), where T is one of // tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string // literal in the array, with nullptr terminating the array. @@ -291,6 +338,14 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name")); } } + for (const tf2xla::Variable& variable : config.variable()) { + if (!variable.name().empty()) { + TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name")); + } else { + TF_RETURN_IF_ERROR( + ValidateCppIdent(variable.node_name(), "variable name")); + } + } return Status::OK(); } @@ -339,9 +394,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); const xla::ProgramShapeProto& ps = compile_result.program_shape; - string methods_arg, methods_result; + string methods_arg, methods_result, methods_variable; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); + TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( buffer_infos_for_args.data(), buffer_infos_for_args.size(), /*allocate_entry_params=*/true); @@ -523,6 +579,21 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} + // Methods for managing variable buffers. Buffers are in row-major order. The + // input and output buffers may or may not be identical. + // + // void set_var_X_data(T* data) + // Sets the buffer for variable X. + // + // T* var_X_data() + // Returns the buffer of type T for variable X. + // + // T& var_X(...dim indices...) + // Returns a reference to the value of type T for variable X, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. +{{METHODS_VARIABLE}} + private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; @@ -589,6 +660,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { include_hlo_profile_printer_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, + {"{{METHODS_VARIABLE}}\n", methods_variable}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index c1788ca32a1d099284eeb870f9513891051fd29e..5580e55b691bd10698b63d86bc0194b25da743b9 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -172,6 +174,15 @@ TEST(CodegenTest, Golden) { tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); + tf2xla::Variable* variable = config.add_variable(); + variable->set_node_name("myvar"); + variable->mutable_shape()->add_dim()->set_size(1); + variable->set_type(DT_FLOAT); + tf2xla::Variable* variable2 = config.add_variable(); + variable2->set_node_name("my/var"); + variable2->set_name("myvar2"); + variable2->mutable_shape()->add_dim()->set_size(5); + variable2->set_type(DT_INT32); CompileResult compile_result; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, @@ -186,9 +197,14 @@ TEST(CodegenTest, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::S32, {5}), }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(xla::U32, {5, 6}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::S32, {5}), + })) .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 35994fc785d3e1d5e883c49bec96de315e189d2e..8591df538779e3bc0f6e55607180a6d49009735e 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -52,7 +52,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: // arg bytes total: 104 @@ -214,6 +214,58 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { result_data(0)))[dim0][dim1]; } + // Methods for managing variable buffers. Buffers are in row-major order. The + // input and output buffers may or may not be identical. + // + // void set_var_X_data(T* data) + // Sets the buffer for variable X. + // + // T* var_X_data() + // Returns the buffer of type T for variable X. + // + // T& var_X(...dim indices...) + // Returns a reference to the value of type T for variable X, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + + void set_var_myvar_data(float* data) { + set_arg_data(2, data); + } + + void set_var_myvar2_data(tensorflow::int32* data) { + set_arg_data(3, data); + } + + float* var_myvar_data() { + return static_cast(result_data(1)); + } + float& var_myvar() { + return (*static_cast( + result_data(1)))[0]; + } + const float* var_myvar_data() const { + return static_cast(result_data(1)); + } + const float& var_myvar() const { + return (*static_cast( + result_data(1)))[0]; + } + + tensorflow::int32* var_myvar2_data() { + return static_cast(result_data(2)); + } + tensorflow::int32& var_myvar2(size_t dim0) { + return (*static_cast( + result_data(2)))[dim0]; + } + const tensorflow::int32* var_myvar2_data() const { + return static_cast(result_data(2)); + } + const tensorflow::int32& var_myvar2(size_t dim0) const { + return (*static_cast( + result_data(2)))[dim0]; + } + private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = 6; @@ -257,7 +309,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 64); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index 7f7b96428572705f30144e6c95cd4cf9c44ce2a3..2884597abcf29583e6192296b0e4ce6825d7c01a 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 9fc223bdc7c0e207ce2005cb86250aa77e709df8..0e46a9f5e9d68fa2174f7bd9b9fa7c3a82dfb715 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -108,10 +108,13 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, computation.Snapshot()); // Serialize the HloSnapshot deterministically so that all the outputs of a // tf_library genrule are deterministic. - string proto; - TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); + const size_t size = module->ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK( + SerializeToBufferDeterministic(*module, serialized.get(), size)); TF_RETURN_IF_ERROR( - WriteStringToFile(Env::Default(), flags.out_session_module, proto)); + WriteStringToFile(Env::Default(), flags.out_session_module, + absl::string_view(serialized.get(), size))); } xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 10fa33ab5e84dcbc1629bee6214e8969046f19c2..ce8dae4262913c975ca69dedd0420f1457e11ee9 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -26,6 +26,8 @@ test_suite( ":test_graph_tfmatmulandadd_test", ":test_graph_tfsplits_test", ":test_graph_tftop_k_test", + ":test_graph_tfvariable_sequential_updates_test", + ":test_graph_tfvariable_test", ":tfcompile_test", ], ) @@ -69,6 +71,8 @@ genrule( "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", "test_graph_tftop_k.pb", + "test_graph_tfvariable.pb", + "test_graph_tfvariable_sequential_updates.pb", ], # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # GPUs which might be present. This is important because builds may run @@ -222,6 +226,28 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfvariable", + testonly = 1, + config = "test_graph_tfvariable.config.pbtxt", + cpp_class = "VariableComp", + graph = "test_graph_tfvariable.pb", + tags = [ + "manual", + ], +) + +tf_library( + name = "test_graph_tfvariable_sequential_updates", + testonly = 1, + config = "test_graph_tfvariable_sequential_updates.config.pbtxt", + cpp_class = "VariableSequentialUpdatesComp", + graph = "test_graph_tfvariable_sequential_updates.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -241,6 +267,8 @@ tf_cc_test( ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", ":test_graph_tftop_k", + ":test_graph_tfvariable", + ":test_graph_tfvariable_sequential_updates", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 64b861a73091642b03573543a5c55618bf33915d..7f5e907e26365c0d9ec65e6f00d410a87f452241 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -50,7 +50,7 @@ def tfadd_with_ckpt(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) @@ -65,7 +65,7 @@ def tfadd_with_ckpt_saver(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) @@ -149,6 +149,25 @@ def tftop_k(_): array_ops.identity(output[1], name='indices') +def tfvariable(_): + x = variables.Variable(1000.0, name='x') + old_x = x.value() + with ops.control_dependencies([old_x]): + new_x = x.assign_add(42.0) + array_ops.stack([old_x, new_x], name='result') + + +def tfvariable_sequential_updates(_): + x = variables.Variable(1.0, name='x') + updates = control_flow_ops.no_op() + for _ in range(3): + with ops.control_dependencies([updates]): + x_val = x.read_value() + 1.0 + updates = x.assign_sub(0.1 * x_val) + + array_ops.identity(updates, name='result') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -171,6 +190,8 @@ def main(_): write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) write_graph(tftop_k, FLAGS.out_dir) + write_graph(tfvariable, FLAGS.out_dir) + write_graph(tfvariable_sequential_updates, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9b4c4215a330b014f595edde001aba73ad7d8263 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt @@ -0,0 +1,12 @@ +# Text form of tensorflow.tf2xla.Config proto. +fetch { + id { node_name: "result" } +} + +variable { + node_name: "x" + shape { + dim { size: 1 } + } + type: DT_FLOAT +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7312c40baf6957c273fc389efa11d08ed9f7a0dd --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt @@ -0,0 +1,9 @@ +# Text form of tensorflow.tf2xla.Config proto. +fetch { + id { node_name: "result" } +} + +variable { + node_name: "x" + type: DT_FLOAT +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 4dd79e5882d7da61be029735ef2b165908c599f9..5bee7f2540a4177a9c4e726bb739d7b92a4dacfc 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -30,6 +30,8 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -473,6 +475,49 @@ TEST(TFCompileTest, TopK) { EXPECT_EQ(expected_indices[1], fn.result1(1)); } +TEST(TFCompileTest, Variable) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + VariableComp fn; + float x = 23; + fn.set_var_x_data(&x); + + fn.set_thread_pool(&device); + fn.Run(); + EXPECT_EQ(fn.result0(0, 0), 23); + EXPECT_EQ(fn.result0(1, 0), 65); + EXPECT_EQ(fn.var_x(), 65); + + EXPECT_EQ(fn.var_x_data(), &x); + EXPECT_EQ(x, 65); + fn.Run(); + EXPECT_EQ(fn.result0(0, 0), 65); + EXPECT_EQ(fn.result0(1, 0), 107); + EXPECT_EQ(fn.var_x(), 107); +} + +TEST(TFCompileTest, VariableSequentialUpdates) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + // This implements the recursion: + // x[0] = 1.0 + // x[n+1] = x[n] - 0.1*(x[n-1] + 1.0) + VariableSequentialUpdatesComp fn; + float x = 1; + fn.set_var_x_data(&x); + + fn.set_thread_pool(&device); + // First calculate x[3] + fn.Run(); + EXPECT_NEAR(x, 0.458f, 1e-6); + + // Then calculate x[6] + fn.Run(); + EXPECT_NEAR(x, 0.062882f, 1e-6); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 4051664c24cacad4a2d151ad3ac9009015900609..fd701ab7166eb6520ad9050abb5285c9d0e0b6bd 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -207,7 +207,7 @@ def tf_library( # # Note that setting the local=1 attribute on a *test target* causes the # test infrastructure to skip that test. However this is a genrule, not - # a test target, and runs with --genrule_strategy=forced_forge, meaning + # a test target, and runs with --strategy=Genrule=forced_forge, meaning # the local=1 attribute is ignored, and the genrule is still run. # # https://www.bazel.io/versions/master/docs/be/general.html#genrule @@ -392,6 +392,6 @@ def target_llvm_triple(): "//tensorflow:android_x86": "i686-none-android", "//tensorflow:ios": "arm64-none-ios", "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:darwin": "x86_64-none-darwin", + "//tensorflow:macos": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index d548de8c44285f6d21dd778db464a31e1b19645b..0b6ab7e723d6e3a55da2f1c30b75f44cbdaa75bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -136,6 +136,10 @@ int main(int argc, char** argv) { tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); + if (argc > 1 && absl::string_view(argv[1]) == "--help") { + std::cerr << usage << "\n"; + return 0; + } bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); QCHECK(parsed_flags_ok) << "\n" << usage; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 1d72d8c2d88c1d256341c46b04e620f2b198e7ea..4424c29e395adf550943c3da99606ad20ba1ad49 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -20,6 +20,8 @@ package( ], ) +# NB! Removing the cc_header_only_library import breaks the OSS build since +# copybara injects some build rules that use it. load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") @@ -166,7 +168,6 @@ cc_library( ":xla_tensor", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -175,18 +176,29 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:host_constant_op", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", @@ -198,7 +210,9 @@ cc_library( "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", + "//tensorflow/stream_executor/platform", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", @@ -253,11 +267,11 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -270,9 +284,7 @@ cc_library( hdrs = ["xla_compilation_cache.h"], deps = [ "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -282,7 +294,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -406,7 +417,6 @@ cc_library( hdrs = ["shape_inference.h"], deps = [ ":shape_inference_helpers", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -455,7 +465,6 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -481,6 +490,7 @@ cc_library( name = "compilation_passes", srcs = [ "build_xla_ops_pass.cc", + "clone_constants_for_better_clustering.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", @@ -493,6 +503,7 @@ cc_library( ], hdrs = [ "build_xla_ops_pass.h", + "clone_constants_for_better_clustering.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "encapsulate_xla_computations_pass.h", @@ -510,11 +521,11 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -526,11 +537,12 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -546,12 +558,20 @@ cc_library( srcs = ["xla_cluster_util.cc"], hdrs = ["xla_cluster_util.h"], deps = [ + ":flags", ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -592,11 +612,25 @@ tf_cc_test( ], ) +cc_library( + name = "compilation_passes_test_main", + testonly = True, + srcs = ["compilation_passes_test_main.cc"], + visibility = ["//visibility:public"], + deps = [ + ":flags", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ "build_xla_ops_pass_test.cc", + "clone_constants_for_better_clustering_test.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "extract_outside_compilation_pass_test.cc", @@ -607,7 +641,9 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", + ":compilation_passes_test_main", ":encapsulate_util", + ":flags", ":node_matchers", ":xla_cluster_util", ":xla_cpu_device", @@ -636,7 +672,6 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", "//tensorflow/core:test", - "//tensorflow/core:test_main", "//tensorflow/core:testlib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -660,6 +695,7 @@ tf_cc_test( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -667,6 +703,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -681,6 +719,7 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 9f4042630edaec1b9519b6434d859a48372e8b15..6058e4195425fadd8c101c6d41303ed0a6ea69e8 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -16,18 +16,20 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -39,6 +41,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -115,6 +118,13 @@ void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) { return; } + if (ctrl_edges.size() == 1 && ctrl_edges.front()->dst()->IsSink()) { + // Avoid creating a Merge node if we can just add an edge to _SINK + // instead. + s.graph()->AddControlEdge(new_node, s.graph()->sink_node()); + return; + } + // We can't merge control edges directly so we instead first "convert" them to // normal values that can be merged, merge the values and then "convert" the // merged value back into control. @@ -204,11 +214,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) { } } -// Returns true (into `result`) if `node` must be compiled. -Status NodeRequiresCompilation(Node* n, bool* result) { +// Returns true (into `result`) if a node placed on `device` must be compiled. +Status DeviceRequiresCompilation(const string& device, bool* result) { DeviceType device_type(""); - TF_RETURN_IF_ERROR( - DeviceToDeviceType(n->assigned_device_name(), &device_type)); + TF_RETURN_IF_ERROR(DeviceToDeviceType(device, &device_type)); const XlaOpRegistry::DeviceRegistration* registration = nullptr; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { return errors::Internal("Could not find compilation device ", @@ -219,11 +228,97 @@ Status NodeRequiresCompilation(Node* n, bool* result) { return Status::OK(); } +// Replaces `n` with a `PartionedCall` op that calls the same function. +Status ReplaceFunctionCallWithPartionedCall( + const GraphOptimizationPassOptions& options, + const FunctionLibraryDefinition& flib_def, Node* n, Graph* g, + const NameAttrList& func, const Scope& root) { + string config_string = options.session_options->config.SerializeAsString(); + + int input_count = absl::c_count_if( + n->in_edges(), [](const Edge* e) { return !e->IsControlEdge(); }); + + std::vector args(input_count); + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge()) { + args[e->dst_input()] = Output(e->src(), e->src_output()); + } + } + + ops::PartitionedCall call( + root.WithOpName("partitioned_call"), args, n->output_types(), func, + ops::PartitionedCall::Attrs{}.ConfigProto(config_string)); + + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), call.operation.node()); + } + } + + std::vector edges_to_delete; + + for (const Edge* e : n->out_edges()) { + edges_to_delete.push_back(e); + if (e->IsControlEdge()) { + g->AddControlEdge(call.operation.node(), e->dst()); + } else { + g->AddEdge(call.operation.node(), e->src_output(), e->dst(), + e->dst_input()); + } + } + + for (const Edge* e : edges_to_delete) { + g->RemoveEdge(e); + } + + g->RemoveNode(n); + return Status::OK(); +} + +Status InferDeviceForCluster(Node* n, const string& function_name, + const FunctionLibraryDefinition& flib_def, + string* result) { + const FunctionDef* func_def = flib_def.Find(function_name); + TF_RET_CHECK(func_def) << "Could not find " << function_name; + + std::set device_names; + for (const NodeDef& ndef : func_def->node_def()) { + VLOG(3) << ndef.DebugString(); + if (!ndef.device().empty()) { + device_names.insert(ndef.device()); + } + } + + if (!n->assigned_device_name().empty()) { + // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device + // assignment when constant folding. We should fix EncapsulateSubgraphsPass + // instead. + device_names.insert(n->assigned_device_name()); + } + + std::vector device_names_vector; + absl::c_copy(device_names, std::back_inserter(device_names_vector)); + + Status s = PickDeviceForXla(device_names_vector, true, result); + if (s.ok()) { + VLOG(2) << "For " << function_name << " PickDeviceForXla(" + << absl::StrJoin(device_names_vector, ", ") << ") -> " << *result; + } + return s; +} + Status ReplaceNodeWithXlaCompileAndXlaRun( + const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, Graph* g, Node* n) { + XlaClusterInfo cluster_info; + TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); + + string device; + TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(), + flib_def, &device)); bool requires_compilation; - TF_RETURN_IF_ERROR(NodeRequiresCompilation(n, &requires_compilation)); + TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation)); if (!lazy_compilation_enabled) { requires_compilation = true; } @@ -232,10 +327,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) .NewSubScope(n->name()) .WithDevice(n->requested_device()) - .WithAssignedDevice(n->assigned_device_name()); - - XlaClusterInfo cluster_info; - TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); + .WithAssignedDevice(device); ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), /*constants=*/cluster_info.constant_inputs, @@ -297,6 +389,9 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( g->AddControlEdge( DataToControl(root, inverse_predicated_compilation_key).node(), n); n->ClearAttr(kXlaCompiledKernelAttr); + + TF_RETURN_IF_ERROR(ReplaceFunctionCallWithPartionedCall( + options, flib_def, n, g, cluster_info.function, root)); } return Status::OK(); @@ -327,11 +422,11 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( - *options.flib_def, lazy_compilation_enabled, graph, n)); + options, *options.flib_def, lazy_compilation_enabled, graph, n)); } if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); + DumpGraphToFile("build_xla_ops", *graph, options.flib_def); } return Status::OK(); diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 390ffa694b6f127544d92f3024a02d877556aacd..902ec7f182b8f70dbd7f5b7f3138710845f4e3e1 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -54,9 +54,11 @@ using ::tensorflow::testing::matchers::Op; using ::tensorflow::testing::matchers::Out; using ::testing::_; -Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { +Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, + std::unique_ptr* result) { auto graph = absl::make_unique(OpRegistry::Global()); TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + FunctionLibraryDefinition flib_def(graph->op_registry(), fdef_lib); // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; @@ -68,7 +70,12 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { } } + FixupSourceAndSinkEdges(graph.get()); + + SessionOptions session_options; GraphOptimizationPassOptions opt_options; + opt_options.session_options = &session_options; + opt_options.flib_def = &flib_def; opt_options.graph = &graph; BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); TF_RETURN_IF_ERROR(pass.Run(opt_options)); @@ -112,23 +119,23 @@ Node* MakeWrite(const Scope& scope, const string& id) { } FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { - FunctionDefLibrary flib_def; + FunctionDefLibrary fdef_lib; FunctionDef func = FunctionDefHelper::Create( /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, /*attr_def*/ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, /*ret_def=*/{{"out", "out:output:0"}}); - *flib_def.add_function() = std::move(func); - return flib_def; + *fdef_lib.add_function() = std::move(func); + return fdef_lib; } TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); - FunctionDefLibrary flib_def = + FunctionDefLibrary fdef_lib = CreateFunctionDefLibWithConstFunction("cluster_0"); - TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); call->set_requested_device(kXlaDeviceName); @@ -136,7 +143,7 @@ TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { root.graph()->AddControlEdge(call, write_op); std::unique_ptr graph; - TF_ASSERT_OK(BuildXlaOps(root, &graph)); + TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); ASSERT_NE(write_op_new, nullptr); @@ -146,9 +153,9 @@ TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { Scope root = Scope::NewRootScope().ExitOnError(); - FunctionDefLibrary flib_def = + FunctionDefLibrary fdef_lib = CreateFunctionDefLibWithConstFunction("cluster_0"); - TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); Node* call; TF_ASSERT_OK( @@ -158,7 +165,7 @@ TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { root.graph()->AddControlEdge(call, write_op); std::unique_ptr graph; - Status failure_status = BuildXlaOps(root, &graph); + Status failure_status = BuildXlaOps(root, fdef_lib, &graph); ASSERT_FALSE(failure_status.ok()); EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); } @@ -166,9 +173,9 @@ TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { Scope root = Scope::NewRootScope().ExitOnError(); - FunctionDefLibrary flib_def = + FunctionDefLibrary fdef_lib = CreateFunctionDefLibWithConstFunction("cluster_0"); - TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); @@ -182,14 +189,14 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { auto xla_run = NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key))); auto tf_call = - NodeWith(Op("cluster_0"), + NodeWith(Op("PartitionedCall"), CtrlDeps(NodeWith(Op("Identity"), Inputs(Out(0, predicated_compilation_key))))); auto merge = NodeWith(Op("Merge"), Inputs(Out(tf_call), Out(xla_run))); auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge))); std::unique_ptr graph; - TF_ASSERT_OK(BuildXlaOps(root, &graph)); + TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); ASSERT_NE(write_op_new, nullptr); @@ -200,9 +207,9 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) { const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); - FunctionDefLibrary flib_def = + FunctionDefLibrary fdef_lib = CreateFunctionDefLibWithConstFunction("cluster_0"); - TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); @@ -212,7 +219,7 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) { Node* write_op = MakeWrite(root, Output(call), "write_result"); std::unique_ptr graph; - TF_ASSERT_OK(BuildXlaOps(root, &graph)); + TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); auto xla_op = NodeWith(Op("_XlaRun"), Inputs(Out(NodeWith(Op("_XlaCompile"))))); @@ -223,5 +230,23 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) { ASSERT_NE(write_op_new, nullptr); EXPECT_THAT(write_op_new, assign_var); } + +TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary fdef_lib = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); + + Node* sink_node = graph->sink_node(); + EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")), + NodeWith(Op("PartitionedCall")), + NodeWith(Op("NoOp"))))); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc new file mode 100644 index 0000000000000000000000000000000000000000..848a6362a4a8f506d233b126461911067f26d9f2 --- /dev/null +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -0,0 +1,187 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" + +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { + +using se::port::StatusOr; + +string CloneConstantsForBetterClusteringPass::GenerateUniqueName( + const absl::flat_hash_set& name_set, absl::string_view prefix) { + string candidate; + do { + candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++); + } while (name_set.contains(candidate)); + return candidate; +} + +StatusOr CloneConstantsForBetterClusteringPass::CloneNode( + Graph* g, const absl::flat_hash_set& name_set, Node* n) { + NodeDef new_in_def = n->def(); + new_in_def.clear_input(); + new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name())); + Status s; + Node* new_in = g->AddNode(new_in_def, &s); + TF_RETURN_IF_ERROR(s); + + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), new_in); + } else { + g->AddEdge(e->src(), e->src_output(), new_in, e->dst_input()); + } + } + + new_in->set_assigned_device_name(n->assigned_device_name()); + return new_in; +} + +namespace { +// We only clone host constants for now since we want to avoid increasing memory +// pressure on GPUs. +StatusOr IsSmallHostConstant(Node* n) { + if (!n->IsConstant()) { + return false; + } + + DeviceNameUtils::ParsedName parsed; + TF_RET_CHECK( + DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed)); + if (parsed.type != DEVICE_CPU) { + return false; + } + + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); + + // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large + // constant" threshold, if there is one. + const int kSmallTensorThreshold = 16; + int64 total_elements = 1; + for (const auto& dim : proto->tensor_shape().dim()) { + if (dim.size() < 0) { + return errors::Internal("Unknown dimension size in constant tensor ", + n->name()); + } + total_elements *= dim.size(); + } + return total_elements < kSmallTensorThreshold; +} + +bool IsInPlaceOp(absl::string_view op_name) { + return op_name == "InplaceUpdate" || op_name == "InplaceAdd" || + op_name == "InplaceSub"; +} +} // namespace + +Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs( + Graph* g, const absl::flat_hash_set& name_set, Node* n) { + std::vector in_edges; + absl::c_copy(n->in_edges(), std::back_inserter(in_edges)); + for (const Edge* e : in_edges) { + Node* input = e->src(); + TF_ASSIGN_OR_RETURN(bool is_small_host_constant, + IsSmallHostConstant(input)); + if (is_small_host_constant && input->out_edges().size() != 1) { + VLOG(2) << "Cloning small host constant " << input->name(); + TF_ASSIGN_OR_RETURN(Node* const input_cloned, + CloneNode(g, name_set, input)); + if (e->IsControlEdge()) { + g->AddControlEdge(input_cloned, e->dst()); + } else { + int dst_input = e->dst_input(); + TF_RET_CHECK(e->src_output() == 0) + << "expected constant to have exactly one non-control output, but " + "found output index = " + << e->src_output(); + g->RemoveEdge(e); + g->AddEdge(input_cloned, 0, n, dst_input); + } + } + } + return Status::OK(); +} + +Status CloneConstantsForBetterClusteringPass::Run( + const GraphOptimizationPassOptions& options) { + if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) { + return Status::OK(); + } + + Graph* g = options.graph->get(); + absl::flat_hash_set name_set; + absl::c_transform(g->nodes(), std::inserter(name_set, name_set.begin()), + [](Node* n) { return n->name(); }); + std::vector nodes; + for (Node* n : g->nodes()) { + // We rely on the immutability of Tensors to safely clone Const operations. + // However, "in place" ops do not respect the immutability of Tensors so we + // avoid this transformation when such ops are present in the graph. + // + // In-place operations are problematic because they break the semantic + // illusion that tensorflow::Tensor instances are immutable. For instance + // if we have the following graph: + // + // digraph { + // SRC -> Const + // SRC -> I + // SRC -> V + // Const -> Identity + // Const -> InplaceAdd [label="x"] + // I -> InplaceAdd [label="i"] + // V -> InplaceAdd [label="v"] + // InplaceAdd -> Identity [style=dotted] + // } + // + // then the value produced by `Identity` is Const+I*V since InplaceAdd + // modifies the tensor in place. However, if we clone `Const` and turn the + // graph into: + // + // digraph { + // SRC -> "Const/clone_1" + // SRC -> "Const/clone_2" + // SRC -> I + // SRC -> V + // "Const/clone_1" -> Identity + // "Const/clone_2" -> InplaceAdd [label="x"] + // I -> InplaceAdd [label="i"] + // V -> InplaceAdd [label="v"] + // InplaceAdd -> Identity [style=dotted] + // } + // + // then `Identity` no longer produces Const+I*V because the InplaceAdd + // operation only modifies Const/clone_2 in place. + + if (IsInPlaceOp(n->type_string())) { + return Status::OK(); + } + nodes.push_back(n); + } + + // Iterate over a copy of the nodes to avoid iterating over g->nodes() while + // creating more nodes. + for (Node* n : nodes) { + TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(g, name_set, n)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.h b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h new file mode 100644 index 0000000000000000000000000000000000000000..f67da75b34fb2e3eb9af9c16837fcc9a69bb94a8 --- /dev/null +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ +#define TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +// Clones small host constants in the graph to make it easier to form larger +// clusters. +// +// This helps us in two ways: +// +// - It reduces dependencies between clusters. Let's say a constant C is used +// by nodes X and Y. If X and Y are put in different clusters (for whatever +// reason) Y's cluster now has to wait for all the operations in X's cluster +// to finish before it starts running. +// +// - It lets us create bigger clusters in multi-GPU benchmarks. Consider the +// following graph: +// +// digraph { +// Const -> GPU_1 +// Const -> GPU_0_Y +// GPU_0_X -> GPU_0_Y +// } +// +// We'd cluster Const and GPU_1 together (and place it on GPU_1), and this +// will block us from clustering GPU_0_X and GPU_0_Y together since that +// would increase the amount of work on GPU 0 waiting on work on GPU 1. +// However, cloning Const into two copies, one for GPU_0_Y and one for GPU_1 +// will let us create one cluster containing {Const/copy_0, GPU_1} and +// another containing {Const/copy_1, GPU_0_X, GPU_0_Y}. +// +// We only clone small host constants now to avoid increasing memory consumption +// too much. Moreover, in practice the constants we have to duplicate are +// things like the `perm` input to `Transpose` and the `size` input to `Slice` +// which tend to be small anyway. + +class CloneConstantsForBetterClusteringPass : public GraphOptimizationPass { + public: + CloneConstantsForBetterClusteringPass() = default; + + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + Status CloneSmallHostConstantInputs( + Graph* g, const absl::flat_hash_set& name_set, Node* n); + string GenerateUniqueName(const absl::flat_hash_set& name_set, + absl::string_view prefix); + se::port::StatusOr CloneNode( + Graph* g, const absl::flat_hash_set& name_set, Node* n); + + int unique_name_counter_ = 0; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31543d1c3f8571be946868aa53ebad3c95ba9a5a --- /dev/null +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +using ::tensorflow::testing::FindNodeByName; + +Status CloneConstantsForBetterClustering(const Scope& s, + std::unique_ptr* result) { + auto graph = absl::make_unique(OpRegistry::Global()); + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + GraphOptimizationPassOptions options; + options.graph = &graph; + options.session_options = &session_options; + + // Scope::ToGraph seems to drop assigned devices, probably because it goes + // through a GraphDef. So explicitly maintain the device assignment. + // std::unordered_map assigned_device_names; + // for (Node* n : s.graph()->nodes()) { + // assigned_device_names[n->name()] = n->assigned_device_name(); + // } + GraphConstructorOptions opts; + opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(s.ToGraph(graph.get(), opts)); + // for (Node* n : graph->nodes()) { + // n->set_assigned_device_name(assigned_device_names[n->name()]); + // } + + CloneConstantsForBetterClusteringPass rewriter; + TF_RETURN_IF_ERROR(rewriter.Run(options)); + *result = std::move(graph); + return Status::OK(); +} + +const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0"; +const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0"; + +TEST(CloneConstantsForBetterClusteringTest, Basic) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_NE(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const(on_gpu.WithOpName("perm"), {3, 1, 2, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_EQ(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, DontCloneLargeConstants) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const( + on_cpu.WithOpName("perm"), + {17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_EQ(tr0_perm.node, tr1_perm.node); +} + +TEST(CloneConstantsForBetterClusteringTest, InplaceOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + Scope on_cpu = root.WithAssignedDevice(kCPU).WithDevice(kCPU); + + Output in0 = ops::Placeholder(on_gpu.WithOpName("in0"), DT_FLOAT); + Output in1 = ops::Placeholder(on_gpu.WithOpName("in1"), DT_FLOAT); + + Output perm = ops::Const(on_cpu.WithOpName("perm"), {3, 1, 2, 0}); + + { + Output tr0 = ops::Transpose(on_gpu.WithOpName("tr0"), in0, perm); + Output tr1 = ops::Transpose(on_gpu.WithOpName("tr1"), in1, perm); + } + + Output in_place_add = + ops::InplaceAdd(on_cpu.WithOpName("tr0"), perm, + ops::Placeholder(on_cpu.WithOpName("i"), DT_INT32), perm); + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor tr0_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr0")->input_tensor(1, &tr0_perm)); + + OutputTensor tr1_perm; + TF_ASSERT_OK(FindNodeByName(result.get(), "tr1")->input_tensor(1, &tr1_perm)); + + EXPECT_EQ(tr0_perm.node, tr1_perm.node); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilation_passes_test_main.cc b/tensorflow/compiler/jit/compilation_passes_test_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..c73702fa6428de59262455db51ed3a3192ee2dc1 --- /dev/null +++ b/tensorflow/compiler/jit/compilation_passes_test_main.cc @@ -0,0 +1,58 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +GTEST_API_ int main(int real_argc, char** real_argv) { + std::vector flag_list; + tensorflow::AppendMarkForCompilationPassFlags(&flag_list); + auto usage = tensorflow::Flags::Usage(real_argv[0], flag_list); + + std::vector args; + + args.reserve(real_argc + 1); + for (int i = 0; i < real_argc; i++) { + args.push_back(real_argv[i]); + } + + struct FreeDeleter { + void operator()(char* ptr) { free(ptr); } + }; + + std::unique_ptr enable_global_jit_arg( + strdup("--tf_xla_cpu_global_jit=true")); + args.push_back(enable_global_jit_arg.get()); + + std::unique_ptr reduce_min_cluster_size_arg( + strdup("--tf_xla_min_cluster_size=2")); + args.push_back(reduce_min_cluster_size_arg.get()); + + int argc = args.size(); + + if (!tensorflow::Flags::Parse(&argc, &args.front(), flag_list)) { + LOG(ERROR) << "\n" << usage; + return 2; + } + + testing::InitGoogleTest(&argc, &args.front()); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 6f1ff85f24a4c1fd3e6d54fcff9f8868aee6f750..7021985affa494ed40c64825c2bd1d221db4e3bb 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -126,8 +126,9 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, const DataTypeVector& arg_types = (*fbody)->arg_types; std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr)); + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(*((*fbody)->graph), &const_args, + /*compile_time_const_nodes=*/nullptr, flr)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { @@ -153,11 +154,14 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, std::unique_ptr* kernel) { TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); - VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + VLOG(3) << "Attemping to create XlaLaunchOp for " << node_def.DebugString(); // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); if (!IsCompilable(flr, node_def)) { + VLOG(1) << "Not creating XlaLaunchOp because function invoked by the " + "following node is not compilable: " + << node_def.DebugString(); // node_def is calling a function that XLA can't compile. return errors::InvalidArgument("Not compilable: ", node_def.ShortDebugString()); diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0562838f628c66b1eb03af9d2a5139c01dca31c5..4856301cef4fd9426c04e1ff557e25ae37980575 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -20,7 +20,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/hash/hash.h" @@ -110,7 +113,11 @@ class Predicate { enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol }; virtual string ToString() const = 0; - int64 hash() const { return hash_; } + + // An ID assigned to the Predicate at construction time. Conceptually like a + // pointer, except that it is stable across runs. + int64 id() const { return id_; } + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; @@ -123,29 +130,19 @@ class Predicate { static void Visit(Predicate* p, const FunctionTy& func); protected: - explicit Predicate(int64 hash) : hash_(hash) {} + explicit Predicate(int64 id) : id_(id) {} private: - const int64 hash_; + const int64 id_; TF_DISALLOW_COPY_AND_ASSIGN(Predicate); }; -int64 HashPredicateSequence(Predicate::Kind kind, - absl::Span preds) { - int64 hash = ::tensorflow::hash()(kind); - for (Predicate* pred : preds) { - hash = Hash64Combine(hash, pred->hash()); - } - return hash; -} - // Represents a logical conjunction of a set of predicates. class AndPredicate : public Predicate { public: - explicit AndPredicate(std::vector operands) - : Predicate(HashPredicateSequence(Kind::kAnd, operands)), - operands_(std::move(operands)) {} + explicit AndPredicate(int64 id, std::vector operands) + : Predicate(id), operands_(std::move(operands)) {} string ToString() const override { if (operands().empty()) { @@ -174,9 +171,8 @@ class AndPredicate : public Predicate { // Represents a logical disjunction of a set of predicates. class OrPredicate : public Predicate { public: - explicit OrPredicate(std::vector operands) - : Predicate(HashPredicateSequence(Kind::kOr, operands)), - operands_(std::move(operands)) {} + explicit OrPredicate(int64 id, std::vector operands) + : Predicate(id), operands_(std::move(operands)) {} string ToString() const override { if (operands().empty()) { @@ -204,9 +200,8 @@ class OrPredicate : public Predicate { // Represents a logical negation of a set of predicates. class NotPredicate : public Predicate { public: - explicit NotPredicate(Predicate* operand) - : Predicate(HashPredicateSequence(Kind::kNot, {operand})), - operands_({operand}) {} + explicit NotPredicate(int64 id, Predicate* operand) + : Predicate(id), operands_({operand}) {} string ToString() const override { return absl::StrCat("~", operand()->ToString()); @@ -222,29 +217,38 @@ class NotPredicate : public Predicate { std::array operands_; }; -// Represents an infinite list of predicates. +// Represents the liveness of an induction variable. For users inside the loop +// this represents the "current" liveness of the induction variable. For users +// outside the loop it represents the "last" liveness of the induction variable. +// +// More concretely, an and recurrence {S,&,X} represents the liveness of V +// in the following graph: // -// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands -// for the list of predicates: +// V = Merge(S', V_NextIt) +// V = Op(V, X') +// V_NextIt = NextIteration(V) // -// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// where Predicate(S') = S and Predicate(X') = X. // -// where GenSym(, ) renames every SymbolPredicate in -// by appending to it, in effect creating a "fresh" symbol. -// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on -// subsequent iterations". +// `X` may contain symbolic predicates and the operations corresponding to these +// symbolic predicates are either in frame `loop` or outside it. The symbols +// that are inside frame `loop` are loop variant (i.e. can have different +// liveness in each loop iteration) and the symbols that are outside frame +// `loop` are loop invariant (i.e. have the same liveness across all +// iterations). class AndRecurrencePredicate : public Predicate { public: - explicit AndRecurrencePredicate(Predicate* start, Predicate* step) - : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), - operands_({start, step}) {} + explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step, + std::vector frame) + : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } + absl::Span frame() const { return frame_; } string ToString() const override { return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + "}<", absl::StrJoin(frame(), ";"), ">"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -255,6 +259,7 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; + std::vector frame_; }; // Represents an uninterpreted symbol in a logical predicate. @@ -264,8 +269,8 @@ class AndRecurrencePredicate : public Predicate { // symbols. class SymbolPredicate : public Predicate { public: - explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) - : Predicate(Hash(tensor_id, must_be_true)), + explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true) + : Predicate(id), tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} @@ -281,20 +286,13 @@ class SymbolPredicate : public Predicate { // "tensor_id() is live and evaluates to true". // // If `must_be_true()` is false then this SymbolPredicate represents the - // proposition "tensor_id() is live (and may evalutate to any value)" + // proposition "tensor_id() is live (and may evaluate to any value)" TensorId tensor_id() const { return tensor_id_; } bool must_be_true() const { return must_be_true_; } private: TensorId tensor_id_; bool must_be_true_; - - static int64 Hash(const TensorId tensor_id, bool must_be_true) { - return Hash64Combine( - ::tensorflow::hash()(must_be_true), - Hash64Combine(::tensorflow::hash()(Kind::kSymbol), - TensorId::Hasher{}(tensor_id))); - } }; template @@ -333,34 +331,58 @@ class PredicateFactory { } Predicate* MakeNotPredicate(Predicate* pred) { - SignatureForNot signature = pred; - auto it = interned_not_instances_.find(signature); - if (it == interned_not_instances_.end()) { - std::unique_ptr new_pred = Make(pred); - Predicate* new_pred_ptr = new_pred.get(); - interned_not_instances_.emplace(signature, std::move(new_pred)); - return new_pred_ptr; - } else { - return it->second.get(); + auto it = make_not_predicate_cache_.find(pred); + if (it != make_not_predicate_cache_.end()) { + return it->second; } + + Predicate* result = MakeNotPredicateImpl(pred); + + bool insert_successful = + make_not_predicate_cache_.insert({pred, result}).second; + (void)insert_successful; + DCHECK(insert_successful); + + return result; } - Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { - auto it = interned_and_rec_instances_.find({start, step}); + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step, + std::vector frame) { + SignatureForAndRec signature(start, step, std::move(frame)); + auto it = interned_and_rec_instances_.find(signature); if (it != interned_and_rec_instances_.end()) { return it->second.get(); } - std::unique_ptr new_pred = - Make(start, step); + std::unique_ptr new_pred = Make( + std::get<0>(signature), std::get<1>(signature), std::get<2>(signature)); Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_rec_instances_ - .emplace(SignatureForAndRec(start, step), std::move(new_pred)) - .second); + bool inserted = + interned_and_rec_instances_.emplace(signature, std::move(new_pred)) + .second; + (void)inserted; + DCHECK(inserted); return new_pred_ptr; } - Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true, + Predicate** predicate) { + TensorId tensor_id(node->name(), output_idx); + + bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL; + TF_RET_CHECK(!must_be_true || is_boolean_tensor); + + if (node->type_string() == "Const" && must_be_true) { + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto)); + + Tensor tensor(proto->dtype()); + TF_RET_CHECK(tensor.FromProto(*proto)); + + *predicate = tensor.scalar()() ? MakeTrue() : MakeFalse(); + return Status::OK(); + } + SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); if (it == interned_symbol_instances_.end()) { @@ -369,20 +391,70 @@ class PredicateFactory { Predicate* new_pred_ptr = new_pred.get(); interned_symbol_instances_.emplace(std::move(signature), std::move(new_pred)); - return new_pred_ptr; + *predicate = new_pred_ptr; } else { - return it->second.get(); + *predicate = it->second.get(); } + + return Status::OK(); } Predicate* MakeTrue() { return MakeAndPredicate({}); } Predicate* MakeFalse() { return MakeOrPredicate({}); } + ~PredicateFactory() { + DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?"; + } + private: + Predicate* MakeNotPredicateImpl(Predicate* pred) { + IncrementStackDepth stack_frame(this); + if (!stack_frame.HasOverflowed()) { + if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) { + return simplified; + } + + // ~~A => A + if (auto* not_pred = dynamic_cast(pred)) { + return not_pred->operand(); + } + } + + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* SimplifyUsingDeMorgan(Predicate* pred) { + // ~(A & B & C & ...) => ~A | ~B | ~C | ~... + // ~(A | B | C | ...) -> ~A & ~B & ~C & ~... + Predicate::Kind kind = pred->kind(); + + if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) { + std::vector new_operands; + absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands), + [&](Predicate* p) { return MakeNotPredicate(p); }); + return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands) + : MakeOrPredicate(new_operands); + } + + return nullptr; + } + template std::unique_ptr Make(Args&&... args) { + // If we ever expose the Predicate class outside this .cc file then we may + // want to make this hard to misuse (by accidentally passing in an arbitrary + // integer to the Predicate constructor for instance). return std::unique_ptr( - new PredicateT(std::forward(args)...)); + new PredicateT(id_counter_++, std::forward(args)...)); } Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); @@ -402,7 +474,8 @@ class PredicateFactory { using SignatureForAndOr = std::pair>; using SignatureForNot = Predicate*; - using SignatureForAndRec = std::pair; + using SignatureForAndRec = + std::tuple>; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -422,6 +495,36 @@ class PredicateFactory { } }; + // Used to limit recursion to avoid blowing up the stack and cap compile time. + class IncrementStackDepth { + public: + explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) { + parent_->stack_depth_++; + } + + bool HasOverflowed() const { + const int kMaxStackDepth = 8; + return parent_->stack_depth_ >= kMaxStackDepth; + } + + ~IncrementStackDepth() { parent_->stack_depth_--; } + + private: + PredicateFactory* parent_; + }; + + // A cache for the MakeNotPredicate function. + // + // NB! This is *not* the same as `interned_not_instances_`. + // `interned_not_instances_` maps ensures pointer identity for `NotPredicate` + // instances, i.e., it ensures there at most one instance of Not(predicate) + // for any given predicate whereas `make_not_predicate_cache_` simply caches + // the result of the `MakeNotPredicate` function. The values in + // `interned_not_instances_` are always instance of `NotPredicate` whereas the + // values in `make_not_predicate_cache_` may not be (for instance it will map + // Not(Not(A)) to A). + absl::flat_hash_map make_not_predicate_cache_; + absl::flat_hash_map, HashSignatureForAndOr> interned_and_or_instances_; @@ -432,13 +535,15 @@ class PredicateFactory { absl::flat_hash_map, HashSignatureForSymbol> interned_symbol_instances_; + int64 id_counter_ = 0; + int stack_depth_ = 0; }; Predicate* PredicateFactory::MakeInternedAndOr( std::vector simplified_ops, Predicate::Kind pred_kind) { std::stable_sort( simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + [](Predicate* a, Predicate* b) { return a->id() < b->id(); }); auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); if (it != interned_and_or_instances_.end()) { @@ -466,6 +571,13 @@ Predicate* PredicateFactory::MakeAndOrImpl( absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + + IncrementStackDepth stack_frame(this); + if (stack_frame.HasOverflowed()) { + return MakeInternedAndOr( + std::vector(operands.begin(), operands.end()), pred_kind); + } + Predicate::Kind other_pred_kind = is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; absl::flat_hash_set simplified_ops_set; @@ -494,16 +606,31 @@ Predicate* PredicateFactory::MakeAndOrImpl( // Simplify "A&~A=>False" and "A|~A=>True". absl::flat_hash_set negated_ops; - for (Predicate* op : simplified_ops) { - if (op->kind() == Predicate::Kind::kNot) { - negated_ops.insert(dynamic_cast(*op).operand()); - } - } - for (Predicate* op : simplified_ops) { if (negated_ops.count(op)) { + // Simple case: + // + // A & ~A & ... == False + // A | ~A | ... == True return is_and ? MakeFalse() : MakeTrue(); } + + Predicate* negated_op = MakeNotPredicate(op); + if (negated_op->kind() == pred_kind) { + // Slightly more complicated case: + // + // (~A | ~B | ~C) & A & B & C & ... == + // ~(A & B & C) & (A & B & C) & ... == False + // + // (~A & ~B & ~C) | A | B | C | ... == + // ~(A | B | C) | (A | B | C) | ... == True + if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) { + return simplified_ops_set.contains(p); + })) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + negated_ops.insert(negated_op); } // If all ops contain the same subop, then factor it out thanks to the @@ -619,6 +746,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; + std::vector control_flow_info_; bool vlog_; }; @@ -640,7 +768,8 @@ Status DeadnessAnalysisImpl::GetInputPreds( auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); if (it == predicate_map_.end()) { GraphCycles graph_cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + TF_RETURN_IF_ERROR( + CreateCycleDetectionGraph(&graph_, &graph_cycles).status()); // If we didn't return with an error above then the graph is probably // fine and we have a bug in deadness analysis. @@ -661,9 +790,12 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); - Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( - TensorId(pred_edge->src()->name(), pred_edge->src_output()), - /*must_be_true=*/true); + + Predicate* true_switch; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + pred_edge->src(), pred_edge->src_output(), + /*must_be_true=*/true, &true_switch)); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); // Output 0 is alive iff all inputs are alive and the condition is false. @@ -761,6 +893,23 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; } + +Status GetFullFrame(const Node* n, absl::Span cfi_infos, + std::vector* frame) { + int depth = 0; + for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); + n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { + frame->push_back(cfi_iter->frame_name); + + if (depth++ > 5000) { + return errors::Internal( + "Frame of depth > 5000: Probably malformed graph or a bug in " + "BuildControlFlowInfo"); + } + } + + return Status::OK(); +} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, @@ -783,8 +932,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, if (has_unvisited_backedge) { // We're visiting this merge for the first time and it has an unvisited // backedge. - Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false); + Predicate* input_data_pred; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -825,8 +976,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, Predicate* start = predicate_factory_.MakeOrPredicate(non_recurrent_inputs); - Predicate* and_rec = - predicate_factory_.MakeAndRecurrencePredicate(start, step); + std::vector frame; + TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame)); + Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( + start, step, std::move(frame)); SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); return Status::OK(); } @@ -841,8 +994,10 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, // acquire a dead signal from a _Send. std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); - input_preds.push_back(predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false)); + Predicate* signal_is_alive; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive)); + input_preds.push_back(signal_is_alive); SetPredicate(n, {0, Graph::kControlSlot}, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); @@ -892,6 +1047,24 @@ Status DeadnessAnalysisImpl::Populate() { Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( absl::Span rpo) { + std::vector unreachable_nodes; + // Compute the loop structure of the graph. + TF_RETURN_IF_ERROR( + BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes)); + + // Do some opportunistic error checking: + if (!unreachable_nodes.empty()) { + if (unreachable_nodes.size() > 5) { + unreachable_nodes.erase(unreachable_nodes.begin() + 5, + unreachable_nodes.end()); + } + + return errors::InvalidArgument( + "Found unreachable nodes, most likely source and sink nodes not " + "connected: ", + absl::StrJoin(unreachable_nodes, ", ")); + } + // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 8a73101c184e6190921fd7729742922bd96f4bcf..38a5118d9a721b814e1b52ce4202d4fb783e3ac3 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -123,10 +123,9 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = - ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); - Output loop_cond = - ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); - ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, + loop_cond_expr); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), @@ -140,7 +139,7 @@ InductionVarInfo CreateInductionVariable(const Scope& root, root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return {iv.output, loop_cond}; + return {iv.output, loop_cond_expr}; } InductionVarInfo CreateInductionVariable(const Scope& root, @@ -515,24 +514,27 @@ TEST(DeadnessAnalysisTest, Loop) { // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 // produce the same deadness. But we're not that smart today. - EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], + "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], + "{#true,&,*iv2/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + "({#true,&,*iv0/cond:0} & {#true,&,*iv1/cond:0})"); EXPECT_EQ(predicate_map[ControlOutputFor(add1)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); } } TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0); Output dependent_iv0 = - CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0) .induction_var; Output dependent_iv1 = - CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0) .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); @@ -549,13 +551,13 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], - "{#true,&,*iv0/cond:0}"); + "{#true,&,*iv0/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); } } @@ -595,32 +597,33 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); InductionVarInfo iv_outer = - CreateInductionVariable(root, "iv_outer", "frame", 0); + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); ops::Switch inner_value(root.WithOpName("outer_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer.loop_cond); + enter_constant_outer_loop, iv_outer.loop_cond); InductionVarInfo iv_inner = CreateInductionVariable( - root, "iv_inner", "frame", - ops::internal::Enter(root.WithOpName("iv_inner/enter"), - inner_value.output_true, "frame_inner")); + root, "iv_inner", "inner_loop", inner_value.output_true); Output dependent_outer_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; Output dependent_outer_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; - Output dependent_inner_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", - iv_inner.loop_cond, dependent_outer_iv0) - .induction_var; - Output dependent_inner_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", - iv_inner.loop_cond, dependent_outer_iv1) - .induction_var; + Output dependent_inner_iv0 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv0", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv1", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, dependent_inner_iv1); @@ -638,46 +641,51 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], - "{#true,&,*iv_outer/cond:0}"); + "{#true,&,*iv_outer/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], - "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," - "*iv_inner/cond:0}"); + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); } } TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv_outer_0 = - CreateInductionVariable(root, "iv_outer_0", "frame", 0); - ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_0.loop_cond); - InductionVarInfo iv_inner_0 = CreateInductionVariable( - root, "iv_inner_0", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), - inner_value_0.output_true, "frame_inner")); - - InductionVarInfo iv_outer_1 = - CreateInductionVariable(root, "iv_outer_1", "frame", 1); - ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_1.loop_cond); - InductionVarInfo iv_inner_1 = CreateInductionVariable( - root, "iv_inner_1", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), - inner_init_value_1.output_true, "frame_inner")); - Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, - iv_inner_1.induction_var); + + std::array outer_iv; + std::array inner_iv; + + for (int i : {0, 1}) { + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + enter_constant_outer_loop, iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "inner_loop", inner_value.output_true); + + outer_iv[i] = iv_outer.induction_var; + inner_iv[i] = iv_inner.induction_var; + } + + Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]); VLogGraphIfAsked(*root.graph()); @@ -692,21 +700,77 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], - "{#true,&,*iv_outer_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], - "{#true,&,*iv_outer_1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], - "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])], + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])], + "{#true,&,*iv_outer/cond_1:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])], + "{(*iv_outer/cond_1:0 & " + "{#true,&,*iv_outer/cond_1:0}),&,*iv_inner/" + "cond_1:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0} & " - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0})"); + "({(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0} & {(*iv_outer/cond_1:0 & " + "{#true,&,*iv_outer/cond_1:0}),&,*iv_inner/" + "cond_1:0})"); + } +} + +TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10); + InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9); + + Output init = CreateSwitch(root, "init").output_true; + Output step = CreateSwitch(root, "step").output_true; + + std::array exits; + std::array next_iterations; + + for (int i : {0, 1}) { + Output init_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("init_enter_frame_", i)), init, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + Output step_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("step_enter_frame_", i)), step, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + + ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)), + {init_enter, init_enter}); + Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output, + step_enter); + next_iterations[i] = ops::NextIteration( + root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add); + EXPECT_TRUE( + root.graph() + ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1) + .ok()); + exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)), + iv.output); + } + + FixupSourceAndSinkEdges(root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], + predicate_map[ControlOutputFor(exits[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], ""); + + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], + predicate_map[ControlOutputFor(next_iterations[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], ""); } } @@ -818,5 +882,82 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) { EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); } +TEST(DeadnessAnalysisTest, DeMorgan) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL); + Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0); + ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1); + + Output and_0_1 = + ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true); + + Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"), + {sw_0.output_false, sw_1.output_false}) + .output; + + // Predicate(should_always_be_dead) = + // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False + Output should_always_be_dead = + ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1); + + // Predicate(should_always_be_dead) = + // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True + Output should_always_be_alive = + ops::Merge(root.WithOpName("should_always_be_alive"), + {and_0_1, or_not0_not1}) + .output; + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_true = ops::Const(root.WithOpName("const_true"), true); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_true); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_false = ops::Const(root.WithOpName("const_false"), false); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_false); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index a21e083131de9b1ba20872338311768569b982f3..c3a0b6521da2355f0b25eeac08e7fac999442438 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -108,14 +109,14 @@ void MarkGuaranteedConstants( for (const auto& src_arg : src_arg_pairs) { srcs.push_back(src_arg.first); } - ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, - /*leave=*/[&guaranteed_const_nodes](const Node* n) { - // TODO(vinuraja): Doesn't work in the presence of loops. - if (AreAllParentsGuaranteedConst(*n, - guaranteed_const_nodes)) { - guaranteed_const_nodes.insert(n); - } - }); + ReverseDFSFrom( + graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](const Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); for (auto& src_arg : src_arg_pairs) { if (guaranteed_const_nodes.count(src_arg.first) != 0) { @@ -1008,13 +1009,15 @@ Status Encapsulator::Subgraph::AddHostComputes( // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, host_compute); + graph_->AddControlEdge(src_image, host_compute, + /* allow_duplicates= */ true); } // Connect the _HostCompute node to its ancestor host compute nodes. for (const auto& ancestor_name : host_compute_ancestors) { Node* ancestor = host_compute_node[ancestor_name]; - graph_->AddControlEdge(ancestor, host_compute); + graph_->AddControlEdge(ancestor, host_compute, + /* allow_duplicates= */ true); } // Connect the consumers in the subgraph to the _HostCompute node. @@ -1031,7 +1034,8 @@ Status Encapsulator::Subgraph::AddHostComputes( // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(host_compute, dst_image); + graph_->AddControlEdge(host_compute, dst_image, + /* allow_duplicates= */ true); } } } @@ -1059,7 +1063,8 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_); + graph_out->AddControlEdge(sequencer_, call_node_, + /* allow_duplicates= */ true); } } @@ -1120,10 +1125,9 @@ Status Encapsulator::Subgraph::BuildFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), - *graph_, library); - dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), - fdef); + DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_, + library); + DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef); } const FunctionDef* original_fdef = library->Find(name); @@ -1186,11 +1190,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; - dump_graph::DumpGraphToFile( - absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, - library); - dump_graph::DumpFunctionDefToFile( - absl::StrCat("replace_encapsulate_fdef_", name), fdef); + DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name), + *graph_, library); + DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name), + fdef); } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); @@ -1279,7 +1282,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( // completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_, + true /* skip duplicates check */); return Status::OK(); } @@ -1336,7 +1340,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( // subgraph completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_, + /* allow_duplicates= */ true); return Status::OK(); } @@ -1446,7 +1451,8 @@ Status Encapsulator::CopySubgraphEdges( src_func_id == dst_func_id) { Graph* g = subgraphs_[src_func_id].GetGraph(); if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); + g->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } else { g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); } @@ -1549,7 +1555,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { if (VLOG_IS_ON(1)) { // Dump subgraphs. for (auto& entry : subgraphs_) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } @@ -1732,7 +1738,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edges_added ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { - graph_out->AddControlEdge(src_image, dst_image); + graph_out->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } return Status::OK(); @@ -1761,7 +1768,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), - subgraphs_[subgraph].GetCallNode()); + subgraphs_[subgraph].GetCallNode(), + /* allow_duplicates= */ true); } } return Status::OK(); @@ -2129,7 +2137,8 @@ Status CheckClusterDependencyForCycles( const string& ancestor, const string& successor, const std::unordered_map>& ancestors, const std::unordered_map& node_ancestors_map, - GraphCycles* cycle_detector, std::map* cycle_detector_map) { + GraphCycles* cycle_detector, + std::unordered_map* cycle_detector_map) { if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); } @@ -2173,7 +2182,7 @@ Status Encapsulator::FindClusterDependencies() { // We check that clusters are acyclic using this cycle detector. GraphCycles cycle_detector; // Map from cluster name to cycle detector node id. - std::map cycle_detector_map; + std::unordered_map cycle_detector_map; // Process the nodes in topologically-sorted order. std::vector nodes; GetReversePostOrder(*graph_in_, &nodes); @@ -2311,13 +2320,18 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( " in function library."); } FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fdef, node->attrs(), library, - [library](const string& op, const OpDef** sig) { - return library->LookUpOpDef(op, sig); - }, - &fbody)); - InlineFunctionBody(*library, pruned_graph->get(), node, fbody); + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fdef, node->attrs(), library, + [library](const string& op, const OpDef** sig) { + return library->LookUpOpDef(op, sig); + }, + &fbody)); + + InlineFunctionBodyOptions inline_opts; + inline_opts.override_device = false; + + TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node, + fbody, inline_opts)); delete fbody; } @@ -2383,8 +2397,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( &node_images, library)); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", - *pruned_graph, library); + DumpGraphToFile("pruned_graph_for_shape_inference", *pruned_graph, library); } for (auto& subgraph_entry : subgraphs_) { @@ -2515,19 +2528,49 @@ Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, - options.flib_def); + DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, + options.flib_def); } std::unique_ptr graph_out; FunctionLibraryDefinition* const library = options.flib_def; + // Constant folding below might need to run part of the function to compute + // constants. Create an FunctionLibraryRuntime with a single CPU device + // that can run the part of the function. + // NOTE: If this turns out to be slow, we can cache the FLRs keyed by + // `options`. + SessionOptions session_options; + auto* device_count = session_options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + std::vector> devices; + + DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU"); + if (!cpu_factory) { + return errors::NotFound( + "CPU Factory not registered. Can't run EncapsulateSubgraphsPass"); + } + TF_RETURN_IF_ERROR(cpu_factory->CreateDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + if (devices.empty()) { + return errors::NotFound( + "Failed to create a CPU device for EncapsulateSubgraphsPass"); + } + + std::unique_ptr device_mgr = + absl::make_unique(std::move(devices)); OptimizerOptions opts; std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env, + new ProcessFunctionLibraryRuntime(device_mgr.get(), + options.session_options->env, TF_GRAPH_DEF_VERSION, library, opts)); FunctionLibraryRuntime* flr = - pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0"); + if (flr == nullptr) { + return errors::Internal( + "Failed to create and retrieve function library runtime to run " + "constant folding"); + } auto rewrite_subgraph = [flr](const std::vector& arg_source_tensors, @@ -2565,8 +2608,9 @@ Status EncapsulateSubgraphsPass::Run( const int num_args = input_permutation->size(); std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr)); + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(**subgraph, &const_args, + /*compile_time_const_nodes=*/nullptr, flr)); DataTypeVector arg_types(num_args); TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); @@ -2625,8 +2669,8 @@ Status EncapsulateSubgraphsPass::Run( "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, - options.flib_def); + DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, + options.flib_def); } *options.graph = std::move(graph_out); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 1f8ec09e19c01d0a8b2a3761135ed53dfb2ad3b0..261519de3478c8b3e30d206a15944b5a686598e2 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -307,22 +307,6 @@ REGISTER_OP("XlaHostCompute") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); -REGISTER_OP("_XlaSendFromHost") - .Input("inputs: Tinputs") - .Input("dynamic_key: string") - .Attr("Tinputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - -REGISTER_OP("_XlaRecvAtHost") - .Input("dynamic_key: string") - .Output("outputs: Toutputs") - .Attr("Toutputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 3bb979e0698d2d6be42ed5bae66c25267928192c..6d1661222e3eaf9df4f9f91f2b426c80b55245b2 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index ec745cdbb7e237f8b4935dd41e9791fc75f5355d..4e65971191aab69b5f4df780dfed939613d38c0f 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,17 +15,21 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -36,6 +40,25 @@ namespace { const char* const kXlaClusterOutput = "XlaClusterOutput"; +bool IsCpuGpuCompile(const Graph* graph) { + for (Node* n : graph->nodes()) { + string name; + // Only consider nodes being compiled. + if (!GetNodeAttr(n->attrs(), + EncapsulateXlaComputationsPass::kXlaClusterAttr, &name) + .ok()) + continue; + // Early return for any node with a device that is not a CPU or GPU. + DeviceNameUtils::ParsedName parsed; + if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { + if (parsed.type != DEVICE_CPU && parsed.type != DEVICE_GPU) { + return false; + } + } + } + return true; +} + // Checks if a graph node is marked to be a guaranteed constant. bool is_guaranteed_constant(const Node& n) { bool guaranteed_constant = false; @@ -173,10 +196,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // Nondeterminism in serialization would not lead to incorrect results, but // may cause spurious cache misses. DeterministicSerialization is a // best-effort deterministic serialization. - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + const size_t size = gdef.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size)); + uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); + VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return Status::OK(); } @@ -348,18 +372,25 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, Status EncapsulateXlaComputationsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateXlaComputations(): " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", - **options.graph, options.flib_def); + << DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + const char* additional_help = + IsCpuGpuCompile(options.graph->get()) + ? xla::status_macros::kPossibleAutoJitAlternative + : ""; - TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def), + additional_help); VLOG(1) << "EncapsulateXlaComputations() half-way: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", - **options.graph, options.flib_def); + << DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()), + additional_help); VLOG(1) << "EncapsulateXlaComputations() finished: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", - **options.graph, options.flib_def); + << DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); return Status::OK(); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 2a770c527b2fae91352fd17dacb13495a3a73f34..4d383698d3a33e391c713a855b3f1521e9aab844 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" @@ -31,6 +30,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -581,10 +581,9 @@ Status ConstructHostGraph( &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("extract_outside_compilation_host_graph_for_", - xla_cluster_name), - host_graph, fld); + DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_", + xla_cluster_name), + host_graph, fld); } FunctionDef host_graph_fdef; @@ -789,7 +788,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, std::unordered_set{send_from_host}); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile(shape_inference_graph_name, *g, fld); + DumpGraphToFile(shape_inference_graph_name, *g, fld); } // Replace original shape inference graph. @@ -1620,7 +1619,7 @@ Status ExtractOutsideCompilationForFunction( TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( fbody->graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_before_", func_name), *fbody->graph, fld); } @@ -1705,7 +1704,7 @@ Status ExtractOutsideCompilationForFunction( TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); } if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_after_", func_name), *graph_out, fld); } @@ -1719,7 +1718,7 @@ Status ExtractOutsideCompilation( const std::unordered_map& clusters, Graph* g, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); + DumpGraphToFile("extract_outside_compilation_before", *g, fld); } std::vector shape_inference_graphs; @@ -1747,7 +1746,7 @@ Status ExtractOutsideCompilation( } if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_after", *g, fld); + DumpGraphToFile("extract_outside_compilation_after", *g, fld); } return Status::OK(); } diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 98e344b3a080aa8aab27cd41564a90427bac151e..7fcf2b42e4315de853a4116968ae1a24647f512a 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -23,7 +23,6 @@ namespace tensorflow { namespace { BuildXlaOpsPassFlags* build_ops_flags; -DumpGraphFlags* dump_graph_flags; MarkForCompilationPassFlags* mark_for_compilation_flags; XlaDeviceFlags* device_flags; XlaOpsCommonFlags* ops_flags; @@ -31,15 +30,6 @@ XlaOpsCommonFlags* ops_flags; std::vector* flag_list; std::once_flag flags_init; -void AppendDumpGraphFlagsInternal(std::vector* flag_list) { - std::vector new_flags = { - Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix, - "Path prefix to which graphs dumped during debugging should be " - "written."), - }; - flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); -} - void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { std::vector new_flags = { Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit, @@ -68,7 +58,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_fusion_only", &mark_for_compilation_flags->tf_xla_fusion_only, "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}; + "global_jit_level is ON*."), + Flag("tf_xla_disable_deadness_safety_checks_for_debugging", + &mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging, + "Disable deadness related safety checks when clustering (this is " + "unsound).")}; flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); } @@ -76,12 +71,9 @@ void AllocateAndParseFlags() { build_ops_flags = new BuildXlaOpsPassFlags; build_ops_flags->tf_xla_enable_lazy_compilation = true; - dump_graph_flags = new DumpGraphFlags; - dump_graph_flags->tf_dump_graph_prefix = "/tmp/"; - mark_for_compilation_flags = new MarkForCompilationPassFlags; mark_for_compilation_flags->tf_xla_auto_jit = 0; - mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_clustering_debug = false; @@ -89,6 +81,8 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_fusion_only = false; + mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging = false; device_flags = new XlaDeviceFlags; device_flags->tf_xla_compile_on_demand = false; @@ -107,7 +101,6 @@ void AllocateAndParseFlags() { Flag("tf_xla_always_defer_compilation", &ops_flags->tf_xla_always_defer_compilation, ""), }); - AppendDumpGraphFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); } @@ -119,11 +112,6 @@ const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { return *build_ops_flags; } -DumpGraphFlags* GetDumpGraphFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return dump_graph_flags; -} - MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { std::call_once(flags_init, &AllocateAndParseFlags); return mark_for_compilation_flags; @@ -144,9 +132,4 @@ void AppendMarkForCompilationPassFlags(std::vector* flag_list) { AppendMarkForCompilationPassFlagsInternal(flag_list); } -void AppendDumpGraphFlags(std::vector* flag_list) { - std::call_once(flags_init, &AllocateAndParseFlags); - AppendDumpGraphFlagsInternal(flag_list); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 5ddea588eef5270880d91623dc05893da265960a..f87edcc1a33af081e74e65ed551c84a4fbc1163a 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -25,27 +25,39 @@ namespace tensorflow { // Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { - int32 tf_xla_auto_jit; // Control compilation of operators into XLA - // computations on CPU and GPU devices. 0 = use - // ConfigProto setting; -1 = off; 1 = on for things - // very likely to be improved; 2 = on for everything. - // Experimental. - int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA - // compilation. Ignored for operators placed - // on an XLA device or operators explicitly - // marked for compilation. - int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA - // compilation. - bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. - bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU - // via SessionOptions. - int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this - // many ops will be marked as eligible for - // clustering. - bool tf_xla_fusion_only; // This flag is effective only when global_jit_level - // is set to ON* and overrides its behavior. If - // true, enable fusion of element-wise operations - // only using XLA. + // Control compilation of operators into XLA computations on CPU and GPU + // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very + // likely to be improved; 2 = on for everything. + // + // Experimental. + int32 tf_xla_auto_jit; + + // Minimum number of operators in an XLA compilation. Ignored for operators + // placed on an XLA device or operators explicitly marked for compilation. + int32 tf_xla_min_cluster_size; + + // Maximum number of operators in an XLA compilation. + int32 tf_xla_max_cluster_size; + + // Dump graphs during XLA compilation. + bool tf_xla_clustering_debug; + + // Enables global JIT compilation for CPU via SessionOptions. + bool tf_xla_cpu_global_jit; + + // "Compiler fuel" for clustering. Only this many ops will be marked as + // eligible for clustering. + int64 tf_xla_clustering_fuel; + + // tf_xla_fusion_only is effective only when global_jit_level is set to ON* + // and overrides its behavior. If true, enable fusion of element-wise + // operations only using XLA. + bool tf_xla_fusion_only; + + // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then + // we do not do deadness related safety checks. This is unsound in general, + // but can be used as a debugging aid. + bool tf_xla_disable_deadness_safety_checks_for_debugging; }; // Flags associated with the XLA bridge's xla_device module. @@ -71,12 +83,6 @@ struct BuildXlaOpsPassFlags { bool tf_xla_enable_lazy_compilation; }; -// Flags for the XLA bridge's dump_graph module. -struct DumpGraphFlags { - // Path prefix to which graphs dumped during debugging should be written. - string tf_dump_graph_prefix; -}; - // Return a pointer to the DumpGraphFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. @@ -88,7 +94,6 @@ MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); XlaDeviceFlags* GetXlaDeviceFlags(); const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); -DumpGraphFlags* GetDumpGraphFlags(); // Appends the flag definitions associated with // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. @@ -96,8 +101,6 @@ DumpGraphFlags* GetDumpGraphFlags(); // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. void AppendMarkForCompilationPassFlags( std::vector* flag_list); -void AppendDumpGraphFlags(std::vector* flag_list); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index ce53f70b79d97ab087fefe542920b33f883632a2..23931a0d7cd4e6cb1f8ba99869db6a3b25b49b97 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" @@ -26,12 +27,12 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -144,7 +145,9 @@ SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope, // same constant value. This helps make the generated GraphDef more readable. class ConstantCache { public: - explicit ConstantCache(const Scope& s) : scope_(s) {} + explicit ConstantCache(const Scope& s, + const std::vector& control_deps) + : scope_(s), control_deps_(control_deps) {} Output Get1DHostConstant(int64 constant) { auto it = cache_.find(constant); @@ -152,6 +155,9 @@ class ConstantCache { Output new_const = ops::Const(scope_.WithOpName("const_", constant), {constant}); it = cache_.insert({constant, new_const}).first; + for (const Edge* e : control_deps_) { + scope_.graph()->AddControlEdge(e->src(), new_const.node()); + } } return it->second; } @@ -159,11 +165,13 @@ class ConstantCache { private: Scope scope_; std::unordered_map cache_; + std::vector control_deps_; }; // Returns a node computing the size of the Slice op with inputs `slice_inputs`. Status ComputeSliceSize(const Scope& host_scope, - const SliceInputs& slice_inputs, Output* size) { + const SliceInputs& slice_inputs, + std::vector control_deps, Output* size) { // If slice_size[i] >= 0 then slice_size[i] = slice_size[i]. // // If slice_size[i] == -1 then slice_size[i] = input_size[i] - @@ -183,7 +191,7 @@ Status ComputeSliceSize(const Scope& host_scope, ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input, ops::Shape::OutType(DT_INT64)); - ConstantCache constant_pool(host_scope); + ConstantCache constant_pool(host_scope, control_deps); std::vector slice_size; for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) { @@ -209,11 +217,16 @@ Status ComputeSliceSize(const Scope& host_scope, } // Trivial ConcatV2 nodes (with exactly one input) are disallowed. - *size = - slice_size.size() == 1 - ? slice_size[0] - : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, - ops::Const(host_scope.WithOpName("concat_axis"), 0)); + if (slice_size.size() == 1) { + *size = slice_size[0]; + } else { + auto concat_axis = ops::Const(host_scope.WithOpName("concat_axis"), 0); + for (const Edge* e : control_deps) { + host_scope.graph()->AddControlEdge(e->src(), concat_axis.node()); + } + *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + concat_axis); + } return Status::OK(); } @@ -234,12 +247,21 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); Scope host_scope = main_scope.WithAssignedDevice(host_name); + // In the future we may want to be clever here and avoid the extra Cast ops. SliceInputs slice_inputs_int64 = MakeSliceIndexAndSizeInt64(host_scope, slice_inputs); + // Create a list of all control dependencies to be copied when possibly + // replacing nodes related to slice_size. + Node* old_size; + std::vector old_size_ctrl_deps; + TF_RETURN_IF_ERROR(slice->input_node(2, &old_size)); + absl::c_copy_if(old_size->in_edges(), std::back_inserter(old_size_ctrl_deps), + [](const Edge* e) { return e->IsControlEdge(); }); + Output slice_size; - TF_RETURN_IF_ERROR( - ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size)); + TF_RETURN_IF_ERROR(ComputeSliceSize(host_scope, slice_inputs_int64, + old_size_ctrl_deps, &slice_size)); *result = ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name()) @@ -291,9 +313,9 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// Return true if `n` is a slice we can rewrite to have a static shape +// Return true if `n` is a slice we should rewrite to have a static shape // (i.e. have the output shape only depend on the "size" input). -xla::StatusOr IsRewritableSlice(Node* n) { +xla::StatusOr ShouldRewriteSlice(Node* n) { if (n->type_string() != "Slice") { return false; } @@ -311,14 +333,20 @@ xla::StatusOr IsRewritableSlice(Node* n) { // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - return absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); + bool slice_size_has_error = absl::c_all_of( + slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; }); + if (!slice_size_has_error) { + return false; + } + + // No point in rewriting slices that have both size and begin as constants. + return !slice_inputs->begin.node()->IsConstant(); } Status FindAndRewriteSlices(Graph* g, bool* changed) { std::vector slices_to_rewrite; for (Node* n : g->nodes()) { - TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n)); if (is_rewritable) { slices_to_rewrite.push_back(n); } @@ -347,15 +375,15 @@ Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", - **options.graph, options.flib_def); + DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); } bool changed; TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed)); if (changed && flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass", - **options.graph, options.flib_def); + DumpGraphToFile("increase_dynamism_for_auto_jit_pass", **options.graph, + options.flib_def); } return Status::OK(); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index a2f1b831ad7605237e23c15cc43b337e06265553..2add2c13f92f561904163012ee16cc17ce5badce 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -401,5 +401,57 @@ TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { Name("begin/static_shaped_slice/static_shaped_slice"))), _))); } + +// New constants being created need to have control dependencies copied to +// ensure correct control flow analysis in TF V2. +TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + // Add an additional dependency that should still exist in with the new size + // variables. + Output dependency = ops::Placeholder(root.WithOpName("dependency"), DT_BOOL); + root.graph()->AddControlEdge(dependency.node(), size.node()); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + // Check that the new constants have control dependencies. + Node* const_0 = testing::FindNodeByName(result.get(), + "slice/static_shaped_slice/const_0"); + EXPECT_NE(const_0, nullptr); + EXPECT_THAT(const_0, + NodeWith(Op("Const"), CtrlDeps(NodeWith(Op("Placeholder"), + Name("dependency"))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Const(root.WithOpName("begin"), {10, 10}); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* slice_node = testing::FindNodeByName(result.get(), "slice"); + EXPECT_THAT(slice_node, + NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))), + Out(NodeWith(Op("Const"))), + Out(NodeWith(Op("Const")))))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index f79bdc1e2e8d82c9144d1bb9923ad36d8541cbdb..7326b6c222b30fc929d87d6b56d2de624dc0b24d 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" @@ -41,6 +42,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, // POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5, + CloneConstantsForBetterClusteringPass); + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 0583774714c6db7a2fa515fc8a0d304e1898db97..3524da23fb396d59b92aafa2892c8ca1d94d01ac 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -19,14 +19,16 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index ad71df5a694a5f8da94675049df1062a7edb6253..88d00f7f8e1b8bb5372dff3508dddcc216297e97 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" @@ -34,7 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" @@ -207,6 +210,28 @@ static Status BuildCompilationCache(OpKernelContext* ctx, if (!platform.ok()) { return platform.status(); } + + xla::StatusOr compiler_for_platform = + xla::Compiler::GetForPlatform(platform.ValueOrDie()); + if (!compiler_for_platform.ok()) { + // In some rare cases (usually in unit tests with very small clusters) we + // may end up transforming an XLA cluster with at least one GPU operation + // (which would normally force the cluster to be compiled using XLA:GPU) + // into an XLA cluster with no GPU operations (i.e. containing only CPU + // operations). Such a cluster can fail compilation (in way that + // MarkForCompilation could not have detected) if the CPU JIT is not linked + // in. + // + // So bail out of _XlaCompile in this case, and let the executor handle the + // situation for us. + const Status& status = compiler_for_platform.status(); + if (status.code() == error::NOT_FOUND) { + return errors::Unimplemented("Could not find compiler for platform ", + platform.ValueOrDie()->Name(), ": ", + status.ToString()); + } + } + xla::LocalClientOptions client_options; client_options.set_platform(platform.ValueOrDie()); client_options.set_intra_op_parallelism_threads( @@ -304,10 +329,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - OP_REQUIRES_OK( - ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/false, &client, - &variables, &kernel, &executable)); + { + Status s = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false, + &client, &variables, &kernel, &executable); + if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || + platform_info_.device_type().type_string() == DEVICE_GPU)) { + // Suggest auto jit if the failure was with GPU or CPU. + errors::AppendToMessage(&s, + xla::status_macros::kPossibleAutoJitAlternative); + } + + OP_REQUIRES_OK(ctx, s); + } se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index cb01690845a5205cf2f8a03db8b34fb279daecb7..6da181352337a0782e7fd8548680cf9c76ce7f33 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,7 +21,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" @@ -29,11 +31,11 @@ limitations under the License. #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -42,14 +44,30 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { +// The clusters we create here are eventually lowered into an +// _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the +// PartitionedCall op to execute the cluster in the regular graph executor if +// need be. PartitionedCall, however, reruns the entire TF graph optimization +// pipeline over the cluster which includes this mark for compilation pass. To +// avoid endlessly recursing we tag nodes that we've already visited with this +// attribute so that we can bail out if we see them a second time. +// +// TODO(sanjoy): This method is not robust since it is possible that the +// optimizations run by PartitionedCall can mutate the cluster arbitrarily, +// dropping the kXlaAlreadyClustered attributes from all nodes in the process. +// The correct fix is to use the ConfigProto to pass in some sort of flag into +// the PartitionedCall kernel that tells it to not rerun auto-clustering on the +// cluster. +const char* kXlaAlreadyClustered = "_XlaAlreadyClustered"; + // Aggregates information about what kinds of ops are allowed. struct OperationFilter { // Whether resource variable ops are allowed. We do not allow resource @@ -209,10 +227,9 @@ bool IsCompilableCall(const NodeDef& call_def, } FunctionLibraryRuntime::Handle handle; - Status status = - lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); + Status status = InstantiateFunctionCall(call_def, *lib_runtime, &handle); if (!status.ok()) { - VLOG(2) << "Rejecting " << call_def.op() + VLOG(2) << "Rejecting " << call_def.DebugString() << ": could not instantiate: " << status; return false; } @@ -441,7 +458,7 @@ Status FindCompilationCandidates( std::vector compile_time_const_nodes(graph.num_node_ids(), false); TF_RETURN_IF_ERROR( BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, - &compile_time_const_nodes)); + &compile_time_const_nodes, lib_runtime)); int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; @@ -490,7 +507,7 @@ Status FindCompilationCandidates( XlaOpRegistry::AutoclusteringPolicy::kAlways; OperationFilter op_filter; - op_filter.allow_resource_ops = registration->compile_resource_ops; + op_filter.allow_resource_ops = registration->compile_all_resource_ops; op_filter.allow_stateful_rng_ops = always_auto_cluster; op_filter.allow_control_trigger = always_auto_cluster; op_filter.allow_dummy_ops = always_auto_cluster; @@ -525,7 +542,7 @@ Status FindCompilationCandidates( continue; } - if (!op_filter.allow_resource_ops && + if (!registration->compile_all_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { // We don't have a way of returning values of type DT_RESOURCE from XLA // computations so we avoid auto-clustering nodes producing DT_RESOURCE. @@ -591,8 +608,8 @@ Status FindCompilationCandidates( } // We don't auto-cluster functional control flow nodes containing resource // operations because safety checks are trickier in this case. - // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not - // for CPU/GPU. + // registration->compile_all_resource_ops is true for XLA_CPU/XLA_GPU but + // not for CPU/GPU. if (node->type_string() == "While" && !IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) { continue; @@ -614,34 +631,21 @@ Status FindCompilationCandidates( return Status::OK(); } -// Determine the global jit level which is ON if either the -// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag -// is true. -OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( - const GraphOptimizationPassOptions& options) { - OptimizerOptions::GlobalJitLevel global_jit_level = - options.session_options->config.graph_options() - .optimizer_options() - .global_jit_level(); - if (global_jit_level == OptimizerOptions::DEFAULT) { - // To set compilation to be on by default, change the following line. - global_jit_level = OptimizerOptions::OFF; - } - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - if (flags->tf_xla_auto_jit == -1 || - (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { - // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides - // the setting in ConfigProto. - global_jit_level = - static_cast(flags->tf_xla_auto_jit); - } - return global_jit_level; -} - struct Cluster { // Identifies the node that represents this cluster in the cycle detection // graph. int representative = -1; + + // The set of devices the nodes in this cluster are placed on. + absl::flat_hash_set devices; + + // If there are resource operation in the cluster then this is the device that + // resource operations are placed on. All resource operations in a cluster + // must be placed on the same device. + string resource_op_device; + + // True if any node in the cluster has an _XlaCompile attribute set to true. + bool has_xla_compile_attr; }; } // anonymous namespace @@ -683,12 +687,30 @@ Status MarkForCompilationPass::Run( // So fix up the source and sink edges before calling into deadness analysis. FixupSourceAndSinkEdges(options.graph->get()); + // See explanation on `kXlaAlreadyClustered`. + for (Node* n : options.graph->get()->nodes()) { + if (n->attrs().Find(kXlaAlreadyClustered)) { + return Status::OK(); + } + } + std::unique_ptr deadness; { XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); } + bool deadness_analysis_disabled = + GetMarkForCompilationPassFlags() + ->tf_xla_disable_deadness_safety_checks_for_debugging; + + if (deadness_analysis_disabled) { + LOG(WARNING) << "Deadness analysis was manually disabled via " + "--tf_xla_disable_deadness_safety_checks_for_debugging; " + "auto-clustering " + "is unsound!"; + } + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -721,9 +743,12 @@ Status MarkForCompilationPass::Run( // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. - if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { - VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; - return false; + if (!deadness_analysis_disabled) { + if (node->IsMerge() || + deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; + return false; + } } // Check for fusable ops only if requested. @@ -733,25 +758,9 @@ Status MarkForCompilationPass::Run( return false; } - // Otherwise use the value of global_jit_level and the device's - // autoclustering policy. - bool should_compile = - registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways || - (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && - global_jit_level != OptimizerOptions::OFF); - if (!should_compile) { - if (global_jit_level == OptimizerOptions::OFF) { - VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; - } else { - VLOG(2) - << "Rejecting " << node->name() - << ": autoclustering for device only when requested explicitly."; - } - } - return should_compile; + return true; }; + return RunImpl(options, is_compilable); } @@ -927,7 +936,7 @@ static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { *ignore = true; } else { - *ignore = registration->compile_resource_ops; + *ignore = registration->compile_all_resource_ops; } return Status::OK(); } @@ -935,6 +944,134 @@ static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { // Sequence number generator to ensure clusters have unique names. static std::atomic cluster_sequence_num; +// Returns true if the devices in `cluster_a` and `cluster_b` are compatible and +// therefore not a hindrance for combining the two clusters into a larger +// cluster. +static Status AreDevicesCompatible( + const Cluster& cluster_a, const Cluster& cluster_b, + OptimizerOptions::GlobalJitLevel global_jit_level, bool* result) { + std::vector devices; + absl::c_remove_copy(cluster_a.devices, std::back_inserter(devices), ""); + absl::c_remove_copy(cluster_b.devices, std::back_inserter(devices), ""); + absl::c_sort(devices); + + if (devices.empty()) { + *result = false; + return Status::OK(); + } + + // First check if we will even be able to pick a device for the larger + // combined cluster. + bool can_pick_device; + TF_RETURN_IF_ERROR(CanPickDeviceForXla( + devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device)); + if (!can_pick_device) { + *result = false; + return Status::OK(); + } + + string chosen_device; + TF_RETURN_IF_ERROR(PickDeviceForXla( + devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); + + // If we are able to pick a device `chosen_device` for the larger cluster, the + // resource operations in `cluster_a` and `cluster_b` must be placed on the + // same device as `chosen_device`. This is because the _XlaCompile and + // _XlaRun kernels are going to run on and therefore try to access the + // resource variables from `chosen_device`, which will be an error if the + // resource variables are placed on some other device. + auto resource_op_device_ok = [&](const string& resource_op_device) { + return resource_op_device.empty() || resource_op_device == chosen_device; + }; + + *result = resource_op_device_ok(cluster_a.resource_op_device) && + resource_op_device_ok(cluster_b.resource_op_device); + if (!*result) { + return Status::OK(); + } + + // We will check this again later, but here we prune out clusters that would + // never have been sent to XLA to save compile time. Without this change we + // will e.g. create a CPU cluster only to later notice that the user did not + // enable the CPU JIT via --tf_xla_cpu_global_jit. With this change we avoid + // creating the cluster to begin with. + // + // TODO(b/126629785): It is possible that this is just papering over O(n^2) + // behavior in our clustering algorithm. + const XlaOpRegistry::DeviceRegistration* registration; + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type)); + TF_RET_CHECK( + XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) + << "chosen device = " << chosen_device + << "; device type = " << device_type.type() << "; devices (" + << devices.size() << ") = " << absl::StrJoin(devices, ", "); + + *result = cluster_a.has_xla_compile_attr || cluster_b.has_xla_compile_attr || + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); + + return Status::OK(); +} + +// Returns `true` iff we should compile `cluster`. +static Status ShouldCompileClusterImpl( + const Cluster& cluster, OptimizerOptions::GlobalJitLevel global_jit_level, + bool* should_compile, string* device) { + std::vector devices; + absl::c_remove_copy(cluster.devices, std::back_inserter(devices), ""); + absl::c_sort(devices); + + string chosen_device; + TF_RETURN_IF_ERROR(PickDeviceForXla( + devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); + + const XlaOpRegistry::DeviceRegistration* registration; + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type)); + TF_RET_CHECK( + XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) + << "chosen device = " << chosen_device + << "; device type = " << device_type.type() << "; devices (" + << devices.size() << ") = " << absl::StrJoin(devices, ", "); + + *should_compile = + cluster.has_xla_compile_attr || + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); + + VLOG(3) << (*should_compile ? "Compiling" : "Not compiling") + << " cluster with device " << chosen_device; + + *device = std::move(chosen_device); + return Status::OK(); +} + +static Status ShouldCompileCluster( + absl::flat_hash_map>* cache, + OptimizerOptions::GlobalJitLevel global_jit_level, const Cluster& cluster, + bool* should_compile, string* device) { + auto it = cache->find(cluster.representative); + if (it != cache->end()) { + *should_compile = it->second.first; + *device = it->second.second; + return Status::OK(); + } + + string device_s; + TF_RETURN_IF_ERROR(ShouldCompileClusterImpl(cluster, global_jit_level, + should_compile, &device_s)); + cache->insert({cluster.representative, {*should_compile, device_s}}); + *device = std::move(device_s); + return Status::OK(); +} + Status MarkForCompilationPass::RunImpl( const GraphOptimizationPassOptions& options, const std::function& @@ -960,7 +1097,11 @@ Status MarkForCompilationPass::RunImpl( } GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(graph, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); @@ -972,6 +1113,23 @@ Status MarkForCompilationPass::RunImpl( for (Node* node : compilation_candidates) { Cluster& cluster = clusters[node->id()].Get(); cluster.representative = node->id(); + const string& device = !node->assigned_device_name().empty() + ? node->assigned_device_name() + : node->requested_device(); + if (HasResourceInput(*node) || HasResourceOutput(*node)) { + cluster.resource_op_device = device; + } + cluster.has_xla_compile_attr = false; + bool xla_compile_attr; + if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) { + cluster.has_xla_compile_attr |= xla_compile_attr; + } + if (options.flib_def->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr) + .ok()) { + cluster.has_xla_compile_attr |= xla_compile_attr; + } + + cluster.devices.insert(device); worklist.push_back(&clusters[node->id()]); } @@ -985,7 +1143,8 @@ Status MarkForCompilationPass::RunImpl( // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for // example, from the Grappler fusion pass). while (!worklist.empty()) { - int from = worklist.front()->Get().representative; + Cluster* cluster_from = &worklist.front()->Get(); + int from = cluster_from->representative; worklist.pop_front(); Node* node_from = graph->FindNodeId(from); @@ -1009,13 +1168,17 @@ Status MarkForCompilationPass::RunImpl( // graph. No clustering is possible. continue; } + + const Cluster& cluster_to = clusters[to].Get(); Node* node_to = graph->FindNodeId(to); if (compilation_candidates.find(node_to) == compilation_candidates.cend()) { continue; } - if (node_from->assigned_device_name() != - node_to->assigned_device_name()) { + bool devices_compatible; + TF_RETURN_IF_ERROR(AreDevicesCompatible( + *cluster_from, cluster_to, global_jit_level, &devices_compatible)); + if (!devices_compatible) { continue; } if (isolated_nodes.count(node_to)) { @@ -1064,9 +1227,14 @@ Status MarkForCompilationPass::RunImpl( if (in_id >= graph->num_node_ids()) continue; Node* in = graph->FindNodeId(in_id); - if (compilation_candidates.find(in) != compilation_candidates.cend() && - in->assigned_device_name() != node_to->assigned_device_name()) { - found_split = true; + const Cluster& cluster_in = clusters[in_id].Get(); + if (compilation_candidates.find(in) != compilation_candidates.cend()) { + bool devices_compatible; + TF_RETURN_IF_ERROR(AreDevicesCompatible( + cluster_to, cluster_in, global_jit_level, &devices_compatible)); + if (!devices_compatible) { + found_split = true; + } } } if (found_split) continue; @@ -1080,6 +1248,12 @@ Status MarkForCompilationPass::RunImpl( // Merge the clusters. ContractEdge uses 'from' as the number of the // merged node, so make sure 'from' is the chosen representative. + cluster_from->devices.insert(cluster_to.devices.begin(), + cluster_to.devices.end()); + if (!cluster_to.resource_op_device.empty()) { + cluster_from->resource_op_device = cluster_to.resource_op_device; + } + cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr; clusters[from].Merge(&clusters[to]); worklist.push_back(&clusters[from]); @@ -1089,23 +1263,37 @@ Status MarkForCompilationPass::RunImpl( // Count the number of non-trivial elements in each cluster. std::vector effective_cluster_sizes(graph->num_node_ids()); + + // has_functional_control_flow remembers if a cluster contains a functional + // control flow node. + std::vector has_functional_control_flow(graph->num_node_ids()); + for (const Node* n : compilation_candidates) { int cluster = clusters[n->id()].Get().representative; - // Identity nodes will be removed if the node gets marked for compilation. - // Therefore we don't want to count them towards the effective cluster size. - if (n->def().op() != "Identity") { + // We want clusters to be big enough that the benefit from XLA's + // optimizations offsets XLA related overhead (for instance we add some + // Switch/Merge nodes into the graph to implement lazy compilation). To + // this end, we don't count Identity and Constant nodes because they do not + // enable interesting optimizations by themselves. + if (!n->IsIdentity() && !n->IsConstant()) { effective_cluster_sizes[cluster]++; } + if (n->type_string() == "While" || n->type_string() == "If") { + has_functional_control_flow[cluster] = true; + } } // Names for each cluster. std::unordered_map cluster_names; if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph, - options.flib_def); + DumpGraphToFile("before_mark_for_compilation", **options.graph, + options.flib_def); } + absl::flat_hash_map> + should_compile_cluster_cache; + // Mark clusters for compilation that: // * are placed on a device that requires compilation (an XlaDevice), // * are explicitly marked for compilation (_XlaCompile=true), or @@ -1113,7 +1301,17 @@ Status MarkForCompilationPass::RunImpl( // if compilation is enabled, otherwise there will be no such candidates). const int min_cluster_size = flags->tf_xla_min_cluster_size; for (Node* n : compilation_candidates) { - int cluster = clusters[n->id()].Get().representative; + const Cluster& cluster = clusters[n->id()].Get(); + bool should_compile; + string device; + TF_RETURN_IF_ERROR(ShouldCompileCluster(&should_compile_cluster_cache, + global_jit_level, cluster, + &should_compile, &device)); + if (!should_compile) { + continue; + } + + int cluster_repr = cluster.representative; // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; @@ -1125,32 +1323,26 @@ Status MarkForCompilationPass::RunImpl( marked_for_compilation = compile_attr; } - // Compile if this operator is placed on a device that requires - // compilation. - DeviceType device_type(""); - TF_RETURN_IF_ERROR( - DeviceToDeviceType(n->assigned_device_name(), &device_type)); - const XlaOpRegistry::DeviceRegistration* registration; - XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); + // We assume that functional If and While nodes have at least + // min_cluster_size non-trivial nodes in them. It would be more principled + // to (recursively) verify this fact, but that's probably not worth the + // trouble. - // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if it contains at least one op that is marked for - // compilation that is not an Identity op. - if (effective_cluster_sizes[cluster] >= min_cluster_size || - (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) { - string& name = cluster_names[cluster]; + if (effective_cluster_sizes[cluster_repr] >= min_cluster_size || + has_functional_control_flow[cluster_repr] || marked_for_compilation) { + string& name = cluster_names[cluster_repr]; if (name.empty()) { name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); + n->AddAttr(kXlaAlreadyClustered, true); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; } } if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, - options.flib_def); + DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); // We also dump out an annoated version of the TF graph where the nodes // names are prefixed with the cluster names. This can help visualizing the @@ -1162,6 +1354,8 @@ Status MarkForCompilationPass::RunImpl( if (absl::optional cluster_name = GetXlaClusterForNode(*n)) { n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else if (n->type_string() == "VarHandleOp") { + n->set_name(absl::StrCat("varhandle/", n->name())); } else { // There is room for improvement here. In particular, it may help to // split these unclustered nodes into classes where every node in a @@ -1170,8 +1364,8 @@ Status MarkForCompilationPass::RunImpl( } } - dump_graph::DumpGraphToFile("mark_for_compilation_annotated", new_graph, - options.flib_def); + DumpGraphToFile("mark_for_compilation_annotated", new_graph, + options.flib_def); } VLogClusteringSummary(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index c2b6250f738fafa35b2c5f79e97cf1281b50a316..da0fbf35de5d6594d00b20fe6ee698050075fd8f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" @@ -38,6 +39,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +using ::tensorflow::testing::FindNodeByName; + namespace tensorflow { namespace { @@ -192,35 +195,6 @@ TEST(XlaCompilationTest, HalfSupported) { EXPECT_FALSE(clusters.empty()); } -TEST(XlaCompilationTest, ConcatWithConstArg) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - Tensor t(DT_INT32, TensorShape()); - t.scalar()() = 0; - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* dim = ops::SourceOp("Const", builder.opts() - .WithName("Dim") - .WithAttr("dtype", DT_INT32) - .WithAttr("value", t)); - Node* a = ops::SourceOp("Const", builder.opts() - .WithName("A") - .WithAttr("dtype", DT_FLOAT) - .WithAttr("value", t)); - - NodeBuilder concat_builder("Concat", "Concat", - builder.opts().op_registry()); - concat_builder.Input(dim).Input({a, a}).Attr("N", 2); - builder.opts().FinalizeBuilder(&concat_builder); - - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } - - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(3, clusters.size()); // Everything should be compiled. -} - TEST(XlaCompilationTest, FunctionCalls) { FunctionDef compilable = FunctionDefHelper::Define( "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, @@ -424,12 +398,8 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) { FunctionDefLibrary flib; FunctionLibraryDefinition flib_def(graph->op_registry(), flib); - SessionOptions session_options; - session_options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_global_jit_level(OptimizerOptions::ON_2); - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( - &graph, &flib_def, &session_options)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -460,7 +430,8 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -478,20 +449,28 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { .WithName("A") .WithAttr("dtype", DT_FLOAT) .WithAttr("value", Tensor()) + .WithAttr(kXlaCompileAttr, true) .WithAttr(kXlaScopeAttr, "Scope1")); - Node* b = ops::UnaryOp( - "Relu", a, - builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "Scope1")); - Node* c = ops::BinaryOp( - "MatMul", a, b, - builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "Scope2")); - ops::BinaryOp( - "Add", b, c, - builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); + Node* b = ops::UnaryOp("Relu", a, + builder.opts() + .WithName("B") + .WithAttr(kXlaCompileAttr, true) + .WithAttr(kXlaScopeAttr, "Scope1")); + Node* c = ops::BinaryOp("MatMul", a, b, + builder.opts() + .WithName("C") + .WithAttr(kXlaCompileAttr, true) + .WithAttr(kXlaScopeAttr, "Scope2")); + ops::BinaryOp("Add", b, c, + builder.opts() + .WithName("D") + .WithAttr(kXlaCompileAttr, true) + .WithAttr(kXlaScopeAttr, "Scope2")); TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -513,31 +492,39 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { .WithName("A") .WithAttr("dtype", DT_FLOAT) .WithAttr("value", Tensor()) + .WithAttr(kXlaCompileAttr, true) .WithAttr(kXlaScopeAttr, "ScopeA")); - Node* b = ops::UnaryOp( - "Relu", a, - builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); + Node* b = ops::UnaryOp("Relu", a, + builder.opts() + .WithName("B") + .WithAttr(kXlaCompileAttr, true) + .WithAttr(kXlaScopeAttr, "ScopeB")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. // In this case, we cannot fuse anything. - EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(3, clusters.size()); EXPECT_NE(clusters["A"], clusters["B"]); EXPECT_EQ(clusters["B"], clusters["C"]); } namespace { -Node* MakeRead(const Scope& scope, const string& id) { +Node* MakeRead(const Scope& scope, const string& id, + Node** var_handle_op = nullptr) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output read = ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + if (var_handle_op) { + *var_handle_op = var_handle.node(); + } return read.node(); } @@ -590,10 +577,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); - ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes = {"AssignmentW", - "ValueToAssignW"}; - ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); + ASSERT_EQ(cluster_sets.size(), 0); } TEST(XlaCompilationTest, ChainOfOps) { @@ -621,15 +605,11 @@ TEST(XlaCompilationTest, ChainOfOps) { absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); - ASSERT_EQ(cluster_sets.size(), 2); - - std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", - "ValueToAssignW0"}; - ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); + ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes_b = { + std::vector expected_clustered_nodes_a = { "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; - ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -688,9 +668,7 @@ TEST(XlaCompilationTest, Retval) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); - EXPECT_EQ(2, clusters.size()); - EXPECT_TRUE(clusters.find("R") == clusters.cend()); - EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, DontCountIdentityOps) { @@ -709,22 +687,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { EXPECT_TRUE(clusters.empty()); } -TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - Scope root = Scope::NewRootScope().ExitOnError(); - { - auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); - auto b = ops::Identity(root.WithOpName("B"), a); - b.node()->AddAttr(kXlaCompileAttr, true); - auto r = ops::_Retval(root.WithOpName("R"), b, 0); - } - TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - - EXPECT_TRUE(clusters.empty()); -} - TEST(XlaCompilationTest, ConstOp) { // valid data type { @@ -980,8 +942,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:XLA_GPU:1"; std::unique_ptr graph(new Graph(OpRegistry::Global())); - Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); - Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}); + Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"), + ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2})); + Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"), + ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2})); Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); @@ -1223,5 +1187,132 @@ TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { EXPECT_NE(clusters["test/tensor_list_reserve"], ""); } +const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0"; +const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0"; +const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0"; +const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1"; + +TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::MatMul(root.WithOpName("test/y"), a, b); + Output z = ops::Add(root.WithOpName("test/z"), x, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0); + FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_NE(clusters["test/x"], ""); + + EXPECT_EQ(clusters["test/x"], clusters["test/y"]); + EXPECT_EQ(clusters["test/y"], clusters["test/z"]); +} + +TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::Add(root.WithOpName("test/y"), x, x); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/x"], ""); + EXPECT_EQ(clusters["test/y"], ""); +} + +TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::Add(root.WithOpName("test/y"), x, x); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/x"], ""); + EXPECT_EQ(clusters["test/y"], ""); +} + +TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Node* var_handle; + Node* resource_read = MakeRead(root, "read", &var_handle); + Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); + + string resource_read_name = resource_read->name(); + string var_handle_name = var_handle->name(); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0); + FindNodeByName(graph.get(), resource_read_name) + ->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_NE(clusters["test/b"], ""); + EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]); +} + +TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Node* var_handle; + Node* resource_read = MakeRead(root, "read", &var_handle); + Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); + + string resource_read_name = resource_read->name(); + string var_handle_name = var_handle->name(); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), resource_read_name) + ->set_assigned_device_name(kCPU0); + FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/b"], ""); + EXPECT_EQ(clusters[resource_read_name], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 64a3301745790132fe3149bf8fb52d6c45ecc3c1..5f0ebe150fa0300940d52e036f7a60ca9fef22e5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - SessionOptions* session_options) { + bool enable_global_jit) { // Assign all unassigned nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { @@ -30,31 +30,32 @@ namespace tensorflow { } } + SessionOptions session_options; + if (enable_global_jit) { + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + } + // Call AddDevices to register the XLA devices. // // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to // make this more direct, but probably not worth it solely for this test. std::vector> devices; - TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(session_options, "", &devices)); GraphOptimizationPassOptions opt_options; opt_options.graph = graph; - opt_options.session_options = session_options; + opt_options.session_options = &session_options; opt_options.flib_def = flib_def; MarkForCompilationPass pass; return pass.RunImpl(opt_options); } /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( - std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { - SessionOptions session_options; - return MarkForCompilation(graph, flib_def, &session_options); -} - -/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( - std::unique_ptr* graph) { + std::unique_ptr* graph, bool enable_global_jit) { FunctionDefLibrary flib; FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def); + return MarkForCompilation(graph, &flib_def, enable_global_jit); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h index 216baaf933dc1f7e694289eea5d23996b595f4d4..df751978562aab8b89aa3966a461c614b1adde5b 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -23,17 +23,14 @@ class MarkForCompilationPassTestHelper { public: // Runs the MarkForCompilation pass on `graph` after assigning all nodes in // `graph` to the CPU device. To make testing easier, ignores device - // registration, _XlaCompile attributes, input deadness and global jit level. + // registration, _XlaCompile attributes and input deadness. static Status MarkForCompilation(std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - SessionOptions* session_options); - - // Like `MarkForCompilation` but creates a default SessionOptions. - static Status MarkForCompilation(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def); + bool enable_global_jit = true); // Like `MarkForCompilation` but creates `flib_def` from the op registry. - static Status MarkForCompilation(std::unique_ptr* graph); + static Status MarkForCompilation(std::unique_ptr* graph, + bool enable_global_jit = true); }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index e1fd2aaee2822daeffb415d053c9c4f56002a856..ffc5d0edbcc7668d5ee137c3c8bbe74167e37a1a 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -20,9 +20,13 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -272,12 +276,20 @@ Status MustCompileNode(const Node* n, bool* must_compile) { // We assume here that the extra repeated (repeated compared to a clustered f // where it will always be constant folded) host-side computation of f does not // regress performance in any significant manner. We will have to revisit this -// algorith with a more complex cost model if this assumption turns out to be +// algorithm with a more complex cost model if this assumption turns out to be // incorrect. -Status PartiallyDeclusterGraph(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph, + const FunctionLibraryDefinition* flib_def, + Env* env) { std::vector compile_time_const_nodes(graph->num_node_ids()); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + OptimizerOptions opts; + auto pflr = absl::make_unique( + nullptr, env, TF_GRAPH_DEF_VERSION, flib_def, opts); + FunctionLibraryRuntime* lib_runtime = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr, + &compile_time_const_nodes, + lib_runtime, IsIntraClusterEdge)); std::vector rpo; GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), @@ -341,7 +353,19 @@ Status PartiallyDeclusterPass::Run( TF_RETURN_IF_ERROR( reduce_device_to_host_copies::PartiallyDeclusterGraph(graph)); - TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph)); + if (options.flib_def == nullptr) { + return errors::InvalidArgument( + "GraphOptimizationPassOptions::flib_def must be set for " + "PartiallyDeclusterPass."); + } + if (options.session_options == nullptr || + options.session_options->env == nullptr) { + return errors::InvalidArgument( + "GraphOptimizationPassOptions::session_options::env must be set for " + "PartiallyDeclusterPass."); + } + TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph( + graph, options.flib_def, options.session_options->env)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 1d81a8f4fcbf050663626b1f7660afd71f4027bc..3494d0ee7efb51a5620f68bc1772e111db493c8d 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" @@ -90,6 +92,12 @@ Status PartiallyDecluster(std::unique_ptr* graph) { GraphOptimizationPassOptions opt_options; opt_options.graph = graph; + FunctionDefLibrary fdef_lib; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + opt_options.flib_def = &flib_def; + SessionOptions session_options; + session_options.env = Env::Default(); + opt_options.session_options = &session_options; PartiallyDeclusterPass pass; return pass.Run(opt_options); } diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index a27e0d9f2a6ecddfdbdb29be673084d77a178d8a..a9c53a943bee58355b634586806c5bedd6fe67b5 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index fef28fc810cb4e544fe3f271f0b96cebd8a96779..cb8ac06207e00395ef35c23dc96f8461182aa45f 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,11 +17,19 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -43,7 +51,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, return ""; } - auto node_name = [cycles, &graph](int node_id) { + auto node_name = [&graph](int node_id) { if (!FastBoundsCheck(node_id, graph.num_node_ids())) { return string("(null)"); } @@ -103,7 +111,8 @@ bool HasForwardedRefInput(const Node& node) { return false; } -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles) { for (int i = 0; i < graph->num_node_ids(); ++i) { // We rely on the node IDs in the cycle detection graph being consecutive // integers starting from 0. @@ -166,9 +175,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { } if (!cycles->InsertEdge(src, dst)) { - return errors::Internal( - "Cycle detected when adding ", src_type, "->", dst_type, - " edge: ", DescribeCycle(cycles, *graph, src, dst)); + // TODO(b/127521408): We can probably handle this situation with a more + // sophisticated SCC based algorithm, but for now we bail out. + VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type + << " edge: " << DescribeCycle(cycles, *graph, src, dst); + return false; } // Drop the original edge. continue; @@ -186,7 +197,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); } } - return Status::OK(); + + return true; } absl::optional GetXlaClusterForNode(const Node& node) { @@ -236,4 +248,126 @@ Status AdjustCycleDetectionGraphForResourceOps( return Status::OK(); } +Status PickDeviceForXlaImpl(absl::Span device_names, + bool allow_mixing_unknown_and_cpu, + bool* out_can_pick_device, + string* out_device_picked) { + if (out_can_pick_device) { + *out_can_pick_device = true; + } + +#define FAILED_TO_PICK_DEVICE(failing_status) \ + do { \ + if (out_can_pick_device) { \ + *out_can_pick_device = false; \ + return Status::OK(); \ + } else { \ + return failing_status; \ + } \ + } while (false) + + TF_RET_CHECK(!device_names.empty()) << "No devices to choose from"; + DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr); + + absl::flat_hash_set device_names_set; + for (absl::string_view device_name : device_names) { + if (!device_name.empty()) { + device_names_set.insert(device_name); + } + } + + absl::optional maybe_gpu_device; + absl::optional maybe_cpu_device; + absl::optional maybe_unknown_device; + + for (absl::string_view device_name : device_names_set) { + DeviceNameUtils::ParsedName parsed_name; + TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name)) + << device_name; + if (parsed_name.type == "GPU") { + if (maybe_gpu_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple GPU devices ", absl::StrJoin(device_names, ", "))); + } + maybe_gpu_device = device_name; + } else if (parsed_name.type == "CPU") { + if (maybe_cpu_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple CPU devices ", absl::StrJoin(device_names, ", "))); + } + maybe_cpu_device = device_name; + } else { + if (maybe_unknown_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple unknown devices ", absl::StrJoin(device_names, ", "))); + } + maybe_unknown_device = device_name; + } + } + + if (maybe_unknown_device && maybe_gpu_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Found both unknown and GPU devices: ", *maybe_unknown_device, ", ", + *maybe_gpu_device)); + } + + if (!allow_mixing_unknown_and_cpu) { + if (maybe_unknown_device && maybe_cpu_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Found both unknown and CPU devices: ", *maybe_unknown_device, ", ", + *maybe_cpu_device)); + } + } + + if (out_device_picked) { + if (maybe_gpu_device) { + *out_device_picked = string(*maybe_gpu_device); + } else if (maybe_unknown_device) { + *out_device_picked = string(*maybe_unknown_device); + } else { + *out_device_picked = string(*maybe_cpu_device); + } + } + + return Status::OK(); + +#undef FAILED_TO_PICK_DEVICE +} + +Status PickDeviceForXla(absl::Span device_names, + bool allow_mixing_unknown_and_cpu, + string* out_device_picked) { + return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, + /*out_can_pick_device=*/nullptr, + out_device_picked); +} + +Status CanPickDeviceForXla(absl::Span device_names, + bool allow_mixing_unknown_and_cpu, + bool* out_can_pick_device) { + return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, + out_can_pick_device, + /*out_device_picked=*/nullptr); +} + +OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( + const GraphOptimizationPassOptions& options) { + OptimizerOptions::GlobalJitLevel global_jit_level = + options.session_options->config.graph_options() + .optimizer_options() + .global_jit_level(); + if (global_jit_level == OptimizerOptions::DEFAULT) { + // To set compilation to be on by default, change the following line. + global_jit_level = OptimizerOptions::OFF; + } + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) { + // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides + // the setting in ConfigProto. + global_jit_level = + static_cast(flags->tf_xla_auto_jit); + } + return global_jit_level; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index fa6eaab3900b37baf7271c8c431c8384ceeda59f..af01e1d30231b0bb2c4aac5183e8b02c2e595135 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,7 +20,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -52,7 +55,11 @@ bool HasForwardedRefInput(const Node& node); // Creates a graph representation to enable cycle detection when clustering. // This representation handles loops in graph by disconnecting each loop from // the enclosing graph. -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// +// Returns true for success and false for valid graphs that we can't handle yet +// (b/127521408). +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. @@ -74,6 +81,57 @@ Status AdjustCycleDetectionGraphForResourceOps( const std::function& resource_ops_to_ignore, GraphCycles* cycles); +// Picks the device for which XLA should compile a cluster that contains +// operations placed in devices in `device_names`. For instance a cluster that +// contains operations solely placed on the CPU will be compiled into a CPU +// executable by XLA, whereas a cluster that contains operations placed on the +// CPU and also operations placed on the GPU will be compiled into a GPU +// executable. +// +// Returns a non-OK Status if no unambiguous choice of device exists. +// +// We choose the device using the following rules: +// +// - It is an error for `device_names` to contain more than one device of the +// same type. +// - GPU is preferred over CPU. +// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are +// preferred over CPU. +// - XLA devices count as "unrecognized devices". +// +// This set of rules above implicitly assume that XLA:GPU can compile all +// operations in the cluster that XLA:CPU can compile, and if +// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile +// all operations in the cluster that XLA:CPU can compile. +// +// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of +// the following things: +// +// - Let MarkForCompilationPass not inject CPU-placed operations into clusters +// that will run on unknown devices (because the unknown XLA backend may not +// support every operation supported by CPU). +// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster +// that contains nodes placed on both the CPU and on unknown devices. In this +// case it is the responsibility of the optimization pass that injected the +// CPU nodes into the cluster to ensure that these nodes can be compiled by +// the unknown XLA backend. +Status PickDeviceForXla(absl::Span device_names, + bool allow_mixing_unknown_and_cpu, + string* out_device_picked); + +// This is like `PickDeviceForXla` except that it returns false (instead of a +// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device +// exists. +Status CanPickDeviceForXla(absl::Span device_names, + bool allow_mixing_unknown_and_cpu, + bool* out_can_pick_device); + +// Determine the global jit level which is ON if either the +// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag +// is true. +OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( + const GraphOptimizationPassOptions& options); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 65bbf3efe85ba30f44531ff6d54b041786dca0a5..cbaac719f2e62cb85e8ac5893d0f4d50be0dae5c 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/algorithm.h" @@ -42,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } @@ -61,8 +64,93 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } + +TEST(CreateCycleDetectionGraph, ReachingEnterExit) { + // TODO(b/127521408): We can lift this limitation with some work. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + + Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0); + + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK_AND_ASSIGN(bool ok, + CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(ok); +} + +void CheckPickDeviceResult(absl::string_view expected_result, + bool allow_mixing_unknown_and_cpu, + absl::Span inputs) { + std::vector inputs_string; + absl::c_transform(inputs, std::back_inserter(inputs_string), + [](absl::string_view sv) { return string(sv); }); + string result; + TF_ASSERT_OK( + PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result)) + << "inputs = [" << absl::StrJoin(inputs, ", ") + << "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu + << ", expected_result=" << expected_result; + EXPECT_EQ(result, expected_result); +} + +void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu, + absl::Span inputs) { + std::vector inputs_string; + absl::c_transform(inputs, std::back_inserter(inputs_string), + [](absl::string_view sv) { return string(sv); }); + string result; + EXPECT_FALSE( + PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result) + .ok()); +} + +const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0"; +const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0"; +const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0"; + +const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1"; +const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1"; +const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1"; + +TEST(PickDeviceForXla, UniqueDevice) { + CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0}); +} + +TEST(PickDeviceForXla, DeviceOrder) { + CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0}); + CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0}); +} + +TEST(PickDeviceForXla, MultipleUnknownDevices) { + CheckPickDeviceHasError(false, {kXPU0, kXPU1}); +} + +TEST(PickDeviceForXla, GpuAndUnknown) { + CheckPickDeviceHasError(false, {kGPU0, kXPU1}); +} + +TEST(PickDeviceForXla, UnknownAndCpu) { + CheckPickDeviceHasError(false, {kXPU0, kCPU1}); +} + +TEST(PickDeviceForXla, MultipleDevicesOfSameType) { + CheckPickDeviceHasError(false, {kCPU0, kCPU1}); + CheckPickDeviceHasError(false, {kGPU0, kGPU1}); + CheckPickDeviceHasError(false, {kXPU0, kXPU1}); + CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0}); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bff4cc57ee1f3ac0fc12aaa93b1588553aec8c45..f53a1e5d403156c5a81925878927a45b8ff51716 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -30,11 +30,11 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -70,9 +70,9 @@ string XlaCompilationCache::DebugString() const { // arguments in the supplied list. string XlaCompilationCache::Signature::HumanString() const { string result = name; - for (const auto& a : arg_types) { - absl::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + for (const auto& a : arg_shapes) { + absl::StrAppend(&result, ",", DataTypeString(a.first)); + absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]"); } for (const auto& v : arg_values) { @@ -83,7 +83,7 @@ string XlaCompilationCache::Signature::HumanString() const { bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (name != other.name) return false; - if (arg_types != other.arg_types) return false; + if (arg_shapes != other.arg_shapes) return false; if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { @@ -99,10 +99,10 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { uint64 XlaCompilationCache::Signature::Hash::operator()( const XlaCompilationCache::Signature& signature) const { uint64 h = std::hash()(signature.name); - for (const auto& arg : signature.arg_types) { + for (const auto& arg : signature.arg_shapes) { h = Hash64Combine(h, std::hash()(static_cast(arg.first))); - h = Hash64Combine(h, std::hash()(arg.second.dims())); - for (int dim : arg.second.dim_sizes()) { + h = Hash64Combine(h, std::hash()(arg.second.size())); + for (int dim : arg.second) { h = Hash64Combine(h, std::hash()(dim)); } } @@ -126,7 +126,7 @@ XlaCompilationCache::BuildSignature( break; case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kResource: - signature.arg_types.emplace_back(arg.type, arg.shape); + signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes()); break; default: return errors::InvalidArgument( @@ -205,6 +205,10 @@ Status XlaCompilationCache::CompileSingleOp( NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); + // Remove the "_class" attribute from the attribute set used to create the + // compilation cache key. This attribute is information for the colocator + // and causes false uniqueness between nodes. + name.mutable_attr()->erase("_class"); auto compile_op = [&](XlaCompiler* compiler, XlaCompiler::CompilationResult* result) { std::vector result_dtypes(ctx->num_outputs()); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 02aa8f8839e2c033e06d043b0f17d89a08d5d9e6..7748b4700f39da4f952278ca6c6d2cadff4d3fb8 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -95,7 +95,9 @@ class XlaCompilationCache : public ResourceBase { struct Signature { string name; - std::vector> arg_types; + // List of Tensor types & shapes for compile-time constant arguments to the + // compilation, ordered by argument number. + std::vector>> arg_shapes; // List of Tensor values for compile-time constant arguments to the // compilation, ordered by argument number. Tensors must be in host memory. diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 94dc61d55fb047c0ea81d98fde24cb55387c27d7..f6e73ab7fecebee80cfbea5f26cc12dd7fc32094 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -46,7 +46,7 @@ Status XlaCpuDeviceFactory::CreateDevices( compile_on_demand ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested : XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.compile_resource_ops = true; + registration.compile_all_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); static XlaDeviceOpRegistrations* registrations = @@ -83,9 +83,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { +constexpr std::array kAllXlaCpuTypes = { {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}}; + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, + DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index e2397f6fcb8677f4bd5151646f9ebacd3e23af5b..0c4a1ce80b9b7e731a92bf80f22a1705a7064c68 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -51,6 +50,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" @@ -102,7 +102,8 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - absl::make_unique(); + absl::make_unique( + backend->stream_executors()[device_ordinal]); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -289,17 +290,17 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, &need_new_device_context)); - TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", - &device_to_host_stream_, - &need_new_device_context)); for (std::shared_ptr& stream : device_to_device_streams_) { TF_RETURN_IF_ERROR( EnsureStreamOkLocked(backend, "device_to_device_stream", &stream, &need_new_device_context)); } host_to_device_stream = host_to_device_stream_; - device_to_host_stream = device_to_host_stream_; device_to_device_streams = device_to_device_streams_; + // The data transfer requests from device to host could arrive out of order, + // so a single stream would cause deadlock. For this case, + // xla_device_context would borrow a stream for each transfer request. + device_to_host_stream = nullptr; } else { host_to_device_stream = stream_; device_to_host_stream = stream_; @@ -428,7 +429,7 @@ void XlaDevice::Sync(const DoneCallback& done) { // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. // This achieves a device-wide sync. stream->ThenEnqueueOnBackgroundThread( - [this, stream, done](se::StreamExecutor*) { + [stream, done](se::StreamExecutor*) { tracing::ScopedActivity activity("XlaDevice::Sync::Callback", /*is_expensive=*/true); done(stream->ok() ? Status::OK() @@ -479,7 +480,24 @@ bool XlaDevice::AllowsSyncOnCompletion() const { return sync_on_completion_; } -Status XlaDevice::CurrentStatus() { +void XlaDevice::SetHandleDeviceErrorCallback(std::function callback) { + mutex_lock lock(mu_); + device_error_callback_ = callback; +} + +Status XlaDevice::HandleDeviceError() { + std::function local_device_error_callback; + { + mutex_lock lock(mu_); + local_device_error_callback = device_error_callback_; + } + if (local_device_error_callback != nullptr) { + return local_device_error_callback(); + } + return Status::OK(); +} + +Status XlaDevice::RefreshStatus() { std::shared_ptr stream; { mutex_lock lock(mu_); @@ -488,7 +506,14 @@ Status XlaDevice::CurrentStatus() { if (!stream) { return Status::OK(); } - return stream->ok() ? Status::OK() : errors::Internal("XlaDevice is not OK."); + Status status = stream->RefreshStatus(); + if (!status.ok()) { + // Ignore errors from HandleDeviceError, since by definition the status is + // already non-ok, so there's nothing extra to report if HandleDeviceError + // itself returns an error. + HandleDeviceError().IgnoreError(); + } + return status; } XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index e35a1c7d29514dc5777bdbd3858c56401d7b9044..51910c6fabc7e3565ef89b8eb2852f3257162055 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -169,10 +169,12 @@ class XlaDevice : public LocalDevice { // Instructs this XlaDevice to return 'sync_on_completion' for // AllowsSyncOnCompletion(). void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); - bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); - Status CurrentStatus() override LOCKS_EXCLUDED(mu_); + // Installs an error handling callback when RefreshStatus sees !status.ok(). + void SetHandleDeviceErrorCallback(std::function callback); + + Status RefreshStatus() override LOCKS_EXCLUDED(mu_); private: xla::LocalClient* client() const; @@ -188,6 +190,9 @@ class XlaDevice : public LocalDevice { static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); + // Handles error when RefreshStatus sees !status.ok(). + Status HandleDeviceError(); + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; @@ -207,14 +212,12 @@ class XlaDevice : public LocalDevice { std::shared_ptr stream_ GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are - // performed by host_to_device/device_to_host_stream. + // performed by host_to_device/device_to_device stream or borrowing a stream + // for each device to host transfer. const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); - // If use_multiple_streams_, device to host transfers are performed using this - // stream. - std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, transfers between different devices are performed // using these streams. std::vector> device_to_device_streams_ @@ -238,6 +241,9 @@ class XlaDevice : public LocalDevice { // regardless of status. bool sync_on_completion_ GUARDED_BY(mu_) = true; + // A callback that will be invoked when RefreshStatus sees a status error. + std::function device_error_callback_ GUARDED_BY(mu_); + // Set of devices to use. This controls which of the devices on the given // platform will have resources allocated. For GPUs this will be // filled from visible_gpu_devices list from session configuration. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 1f3afe8822d441a5ce37617fe18d7767e9bc72e4..b273cbb36e12b9d3d3750c5f21ff4c2ff6e7a6a5 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -25,11 +25,15 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/stream_executor/platform/port.h" namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator() {} +XlaDeviceAllocator::XlaDeviceAllocator( + stream_executor::StreamExecutor* stream_executor) + : stream_executor_(stream_executor) {} + XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } @@ -48,7 +52,21 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { delete XlaTensor::FromOpaquePointer(ptr); } -void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } +absl::optional XlaDeviceAllocator::GetStats() { + absl::optional se_stats = + stream_executor_->GetAllocatorStats(); + if (!se_stats) { + return absl::nullopt; + } + + tensorflow::AllocatorStats tf_stats; + tf_stats.num_allocs = se_stats->num_allocs; + tf_stats.bytes_in_use = se_stats->bytes_in_use; + tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use; + tf_stats.largest_alloc_size = se_stats->largest_alloc_size; + tf_stats.bytes_limit = se_stats->bytes_limit; + return tf_stats; +} XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, @@ -67,7 +85,6 @@ XlaDeviceContext::XlaDeviceContext( shape_representation_fn_(std::move(shape_representation_fn)), thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); - CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { shape_representation_fn_ = [](const TensorShape& shape, @@ -131,7 +148,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla::ShapeUtil::MakeShape(shape.element_type(), xla::AsInt64Slice(shape.dimensions()))); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( @@ -196,8 +213,23 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, << cpu_tensor->shape().DebugString() << " " << device_tensor->shape().DebugString(); + std::shared_ptr device_to_host_stream; + if (device_to_host_stream_) { + device_to_host_stream = device_to_host_stream_; + } else { + stream_executor::port::StatusOr ptr_or_status = + client_->mutable_backend()->BorrowStream( + stream_->parent()->device_ordinal()); + if (!ptr_or_status.status().ok()) { + done(ptr_or_status.status()); + return; + } + device_to_host_stream = + std::shared_ptr(std::move(ptr_or_status.ValueOrDie())); + } + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); + xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get()); // Transfer manager requires the shape of the shaped buffer to be the same as // literal shape except for the layout. Set the literal to use xla_tensor's @@ -210,11 +242,13 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, cpu_tensor, &literal)); TensorReference ref(*device_tensor); + // Explicitly capture device_to_host_stream to make sure the stream is alive + // before the transfer finishes. transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, - [ref, xla_tensor, done](xla::Status status) { + device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal, + [ref, xla_tensor, done, device_to_host_stream](xla::Status status) { done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " + VLOG(2) << "Transfer from device as literal: " << xla_tensor->shaped_buffer().ToString(); return status; }()); diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index e45db989fac720df6c3458c93a6b8dbb0919f930..efbc4bc148acaf4ed0bb0617084e946b6a8dbc00 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -34,14 +34,18 @@ namespace tensorflow { // empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(); + XlaDeviceAllocator(se::StreamExecutor* stream_executor); ~XlaDeviceAllocator() override; string Name() override; void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void GetStats(AllocatorStats* stats) override; + absl::optional GetStats() override; + + private: + // The stream executor of the device. + se::StreamExecutor* stream_executor_; }; // Helper class for managing data transfers between host and XLA devices. @@ -71,9 +75,6 @@ class XlaDeviceContext : public DeviceContext { se::Stream* host_to_device_stream() const { return host_to_device_stream_.get(); } - se::Stream* device_to_host_stream() const { - return device_to_host_stream_.get(); - } se::Stream* device_to_device_stream(int index) const { return device_to_device_streams_.at(index).get(); } @@ -95,7 +96,8 @@ class XlaDeviceContext : public DeviceContext { // idential to stream_, but must not be nullptr. std::shared_ptr host_to_device_stream_; // The stream to use for transferring data from device to host. Can be - // idential to stream_, but must not be nullptr. + // idential to stream_. If nullptr, borrow a stream from backend for each + // transfer request to support out-of-order requests. std::shared_ptr device_to_host_stream_; // Streams to use for transferring data directly between different devices, // e.g., over NVLINK. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 927f983ba9ef23c8509523f42366c0c89c29db9f..eac6586203eec777d2bccc4b64bc44fa1a3813a1 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -25,9 +25,11 @@ limitations under the License. #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/data/generator_dataset_op.h" #include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/function_ops.h" +#include "tensorflow/core/kernels/host_constant_op.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" @@ -92,11 +94,22 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ ConstantOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("HostConst").Device(DEVICE).HostMemory("output"), _HostConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING), \ IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE).TypeConstraint("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER(Name("Identity") \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input") \ + .HostMemory("output"), \ + IdentityOp); \ REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ @@ -195,9 +208,7 @@ class XlaAssignVariableOp : public OpKernel { Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ \ REGISTER_KERNEL_BUILDER( \ - Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ - TYPES), \ - ArgOp); \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \ REGISTER_KERNEL_BUILDER(Name(kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ @@ -206,11 +217,8 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ - REGISTER_KERNEL_BUILDER(Name(kRetOp) \ - .Device(DEVICE) \ - .TypeConstraint("T", TYPES) \ - .HostMemory("input"), \ - RetvalOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ .TypeConstraint("T") \ @@ -241,6 +249,8 @@ class XlaAssignVariableOp : public OpKernel { data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ + data::IteratorGetNextAsOptionalOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ @@ -251,6 +261,15 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .HostMemory("string_handle"), \ data::IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ + data::OptionalNoneOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ + data::OptionalFromValueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ + data::OptionalHasValueOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ + data::OptionalGetValueOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index bc0db558d8d0b7c666efcfac5c4926144b830380..a2a06f57698538b03ff08c99f570661b7312b0ec 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -208,7 +209,12 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, } GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(&graph, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index b29f6a009b9e9fdba76ac55386a4bec2f339cc0e..b37926073ac9da835dbad8911f998081d5f2c143 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -66,7 +66,7 @@ Status XlaGpuDeviceFactory::CreateDevices( registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.compile_resource_ops = true; + registration.compile_all_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); static XlaDeviceOpRegistrations* registrations = diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index e1a582406153d2af447fa9d4ebcaf0bf0842b132..15f5ddbd7ba845af8fe8796f69bb0db93a0ef6be 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -47,7 +47,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices( registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.compile_resource_ops = true; + registration.compile_all_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 554227f09de0ab4d9e07f199b957657f3121ff06..c915b7118d09abe467ebf0b1d74a1efab94fd724 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -26,9 +26,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f80cb1812f00d36ddb7c28ae0e77c58498058ef3..0c3adb0bcf97a30d3d195546f5958e5b0527fcf2 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "large", + size = "medium", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -138,6 +138,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "add_n_test", + size = "small", + srcs = ["add_n_test.py"], + # TensorList ops are not implemented in the on-demand compilation model yet. + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:list_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "addsign_test", size = "small", @@ -170,13 +186,6 @@ tf_xla_py_test( name = "argminmax_test", size = "small", srcs = ["argminmax_test.py"], - # ArgMax needs CustomCall on CPU, which is not available in normal - # (not precompiled) TensorFlow. The flag below excludes the CPU - # backend. - disabled_backends = [ - "cpu", - "cpu_ondemand", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -243,9 +252,44 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "self_adjoint_eig_op_test", + size = "medium", + srcs = ["self_adjoint_eig_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + +tf_xla_py_test( + name = "svd_op_test", + size = "medium", + srcs = ["svd_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", ], ) @@ -278,10 +322,9 @@ tf_xla_py_test( ], ) -# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "large", + size = "medium", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -407,7 +450,7 @@ tf_xla_py_test( tf_xla_py_test( name = "eager_test", - size = "large", + size = "medium", srcs = ["eager_test.py"], deps = [ ":xla_test", @@ -439,7 +482,7 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - shard_count = 3, + shard_count = 6, tags = ["optonly"], deps = [ ":xla_test", @@ -819,6 +862,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stateful_random_ops_test", + size = "small", + srcs = ["stateful_random_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:standard_ops", + "//tensorflow/python:stateful_random_ops", + ], +) + tf_xla_py_test( name = "stateless_random_ops_test", size = "small", @@ -835,7 +892,7 @@ tf_xla_py_test( tf_xla_py_test( name = "tensor_array_ops_test", - size = "small", + size = "medium", srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], @@ -1060,6 +1117,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], + shard_count = 5, ) cuda_py_test( @@ -1094,7 +1152,6 @@ cc_library( "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:test", "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -1193,10 +1250,6 @@ tf_xla_py_test( name = "quantized_ops_test", size = "medium", srcs = ["quantized_ops_test.py"], - disabled_backends = [ - "cpu", - "cpu_ondemand", - ], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index b7b7fda293b69d6f0cec61d0d234277636a3670d..6cf16cc07ff503c4f3e008cfb720224abe5e9166 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -32,10 +32,18 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): def testBasic(self): num_updates = 4 # number of ADADELTA steps to perform + if "CPU" in self.device: + # To avoid timeout on CPU. + all_grad = [0.2, 0.01] + all_lr = [1.0, 0.1] + else: + all_grad = [0.2, 0.1, 0.01] + all_lr = [1.0, 0.5, 0.1] + for dtype in self.float_types: with self.cached_session(), self.test_scope(): - for grad in [0.2, 0.1, 0.01]: - for lr in [1.0, 0.5, 0.1]: + for grad in all_grad: + for lr in all_lr: var0_init = [1.0, 2.0] var1_init = [3.0, 4.0] var0 = resource_variable_ops.ResourceVariable( diff --git a/tensorflow/compiler/tests/add_n_test.py b/tensorflow/compiler/tests/add_n_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ba11c6d322f126130410a82ef760ce40e2c4e430 --- /dev/null +++ b/tensorflow/compiler/tests/add_n_test.py @@ -0,0 +1,84 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AddN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class XlaAddNTest(xla_test.XLATestCase): + + def testAddTensorLists(self): + with self.cached_session(), self.test_scope(): + l1 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=3) + l2 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=3) + l1 = list_ops.tensor_list_set_item(l1, 0, 5.) + l2 = list_ops.tensor_list_set_item(l2, 2, 10.) + + l = math_ops.add_n([l1, l2]) + self.assertAllEqual( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32), + [5.0, 0.0, 10.0]) + + def testAddTensorListsFailsIfLeadingDimsMismatch(self): + with self.cached_session(), self.test_scope(): + l1 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=2) + l2 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=3) + l = math_ops.add_n([l1, l2]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "TensorList arguments to AddN must all have the same shape"): + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval() + + def testAddTensorListsFailsIfElementShapesMismatch(self): + with self.cached_session() as session, self.test_scope(): + # Use placeholders instead of constant values for shapes to prevent TF's + # shape inference from catching this early. + l1_element_shape = array_ops.placeholder(dtype=dtypes.int32) + l2_element_shape = array_ops.placeholder(dtype=dtypes.int32) + l1 = list_ops.tensor_list_reserve( + element_shape=l1_element_shape, + element_dtype=dtypes.float32, + num_elements=3) + l2 = list_ops.tensor_list_reserve( + element_shape=l2_element_shape, + element_dtype=dtypes.float32, + num_elements=3) + l = math_ops.add_n([l1, l2]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "TensorList arguments to AddN must all have the same shape"): + session.run( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32), { + l1_element_shape: [], + l2_element_shape: [2] + }) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index a3651b4b0de2bd34f57464c9552a8fd577866510..c829c50b5518b29c96c0b0117a6cd143911bd1fc 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -311,6 +311,30 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + x = np.array([ + -0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0, + 1.0 + ], + dtype=dtype) + y = np.array( + [-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0], + dtype=dtype) + expected = np.nextafter(x, y) + + # We use assertAllEqual to expose any bugs hidden by relative or + # absolute error tolerances. + def NextAfterEqualityTest(result, expected, rtol): + del rtol + return self.assertAllEqual(result, expected) + + self._testBinary( + math_ops.nextafter, + x, + y, + expected=expected, + equality_test=NextAfterEqualityTest) + # min/max not supported for complex if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5d5e486f616937601214aa169a4c329ab78932c8..eec69ea7d2d9af9ff570f927fb25b668ccce2b97 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -119,7 +119,7 @@ class CategoricalTest(xla_test.XLATestCase): def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. - num_samples = 21000 + num_samples = 40000 rand_probs = np.random.dirichlet([1., 1., 2., 3.]) rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2187f57960f80300d631bdc7eb8fe5e9c8dddeea..76750decd2963ea12680a46d7340f48e8b011fa9 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -294,6 +294,9 @@ class ConcatTest(xla_test.XLATestCase): # The purpose of this is to ensure that XLA on GPU will not run out of memory # with too many arguments. def testConcatLargeNumberOfTensors(self): + if "CPU" in self.device: + self.skipTest("This test can time out on CPU, so we will just allow " + "other backends to catch this specific error.") with self.cached_session(): with self.test_scope(): for concat_dim in range(2): diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index bf5ea7b1fb6fb3c774c4db20d059f131990d20d3..b7d08df9f7d144b71fd0b09535e10b8f596ea6ca 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase): x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index c9fce39f6c5111f93a54708b59b4c42c3ba844b6..632eccbb097b4e84f10f926e89d7fa439c8a38cd 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -34,6 +34,7 @@ from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -607,6 +608,21 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(11.0, plus_one.numpy()) self.assertEqual(9.0, minus_one.numpy()) + def testScanInDefun(self): + with self.test_scope(): + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data') + v = constant_op.constant(2.0, name='v') + + @def_function.function + def f(y): + # pylint: disable=unnecessary-lambda + return functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), y, initializer=v) + # pylint: enable=unnecessary-lambda + + r = f(elems) + self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 374942a0b339b816944ea5529e4f84134b60017b..56a8e1b1667f154f6cec475ee0f4f8b308121c09 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -191,6 +191,20 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + + # The TensorFlow FusedBatchNormGrad training operation takes two inputs with + # implementation defined values. In theory the only correct value these + # inputs are the corresponding reserve_space_{1|2} outputs from the + # FusedBatchNorm training operation. However, in practice, we rely on the + # first one being mean on {C|G}PU, and the second one being variance on CPU + # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption + # separately). + reserve_space_1_val = mean_val + if self.device == "XLA_GPU": + reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon)) + else: + reserve_space_2_val = var_val + data_format_src = "NHWC" grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) @@ -207,18 +221,26 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): np.float32, shape=x_val_converted.shape, name="grad") x = array_ops.placeholder( np.float32, shape=x_val_converted.shape, name="x") - mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") - var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + reserve_space_1 = array_ops.placeholder( + np.float32, shape=scale_shape, name="reserve_space_1") + reserve_space_2 = array_ops.placeholder( + np.float32, shape=scale_shape, name="reserve_space_2") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format=data_format, is_training=True) + grad, + x, + scale, + reserve_space_1, + reserve_space_2, + data_format=data_format, + is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { grad: grad_val_converted, x: x_val_converted, - mean: mean_val, - var: var_val, + reserve_space_1: reserve_space_1_val, + reserve_space_2: reserve_space_2_val, scale: scale_val }) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 12741c4d4af86a8c669afc01e92ca5a761c692ab..c8b71f802c41ef8143f3e7b951b3e0134dcda97b 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -68,8 +68,8 @@ class RGBToHSVTest(xla_test.XLATestCase): {batch0: inp}) # Verify that processing batch elements together is the same as separate - self.assertAllClose(batch1, join1) - self.assertAllClose(batch2, join2) + self.assertAllCloseAccordingToType(batch1, join1, half_rtol=0.000002) + self.assertAllCloseAccordingToType(batch2, join2, half_rtol=0.000002) self.assertAllCloseAccordingToType( batch2, inp, bfloat16_atol=0.03, half_rtol=0.02) @@ -423,7 +423,7 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase): out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) if large_tolerance: self.assertAllClose( - expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + expected[np.newaxis, :, :, np.newaxis], out, rtol=2e-4, atol=2e-4) else: self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) @@ -452,7 +452,7 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase): np.array([[1, 2], [3, 4]], dtype=np.float32), [4, 4], expected=np.array( [[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], - dtype=np.float32)) + dtype=np.float32), large_tolerance=True) def testAlignCorners3x3To2x2(self): self._assertForwardOpMatchesExpected( diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index dbea9849e217519874352b789588a2af62f1c826..777a15629804207a8873e3e16b370c8b65056e7b 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -513,9 +513,10 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ( - "--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) + old_tf_xla_flags = os.environ.get("TF_XLA_FLAGS", "") + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " + "--tf_xla_min_cluster_size=2 " + "--tf_xla_cpu_global_jit " + old_tf_xla_flags) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) @@ -525,6 +526,7 @@ class ElementWiseFusionTest(test.TestCase): self.assertEqual(2, tfef_count) self.assertAllClose(tf_op, tfef_op, rtol=1e-1) + os.environ["TF_XLA_FLAGS"] = old_tf_xla_flags class LazyCompilationTest(test.TestCase): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index c61965b97fc142ce452cf28def8c937f692d2f84..0eec070a906670ff36c772edda22f8291b5b734a 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -167,6 +167,11 @@ class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): }, ) def testMatrixBandPart(self, batch_shape, rows, cols): + # TODO(b/125505881): Disabled due to LLVM backend crash. + if self.device == 'XLA_CPU' and cols == 7 and rows == 1 and batch_shape == [ + 1, 3, 2 + ]: + pass for dtype in self.float_types: with self.cached_session(): mat = np.ones(batch_shape + [rows, cols]).astype(dtype) diff --git a/tensorflow/compiler/tests/plugin.bzl b/tensorflow/compiler/tests/plugin.bzl index fbc8781a3e59faecf985cde5114bf56a041c4be0..46a854d1459b7ea9d9fe3cf7689faee557c2cf84 100644 --- a/tensorflow/compiler/tests/plugin.bzl +++ b/tensorflow/compiler/tests/plugin.bzl @@ -18,13 +18,12 @@ # git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl plugins = { - #"example": { - # "device":"XLA_MY_DEVICE", - # "types":"DT_FLOAT,DT_HALF,DT_INT32", - # "tags":[], - # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], - # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], - # "deps":[], - #}, + #"example": { + # "device":"XLA_MY_DEVICE", + # "types":"DT_FLOAT,DT_HALF,DT_INT32", + # "tags":[], + # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], + # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], + # "deps":[], + #}, } - diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 97ffad34c00b8ec16eb1ec109ba5d980e0ce673d..0611d6749fad053657c62368bc230cfe7c929e66 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -36,7 +36,7 @@ class RandomOpsTest(xla_test.XLATestCase): def _random_types(self): return set(self.numeric_types) - set( - self.complex_types) - {np.uint8, np.int8} + self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. @@ -122,8 +122,8 @@ class RandomOpsTest(xla_test.XLATestCase): beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) - self.assertTrue((y >= a).sum() == count) - self.assertTrue((y <= b).sum() == count) + self.assertEqual((y >= a).sum(), count) + self.assertEqual((y <= b).sum(), count) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 1521cc760b85b176acb27c1489640e92ef90e247..7623ab79371a9cd6b3512cda392e0e401242a7a1 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -62,7 +62,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 693f8513bc54e30060a2e963abd504768535a50a..a9a87b8fb3104f8b9870c41e2aa28b0c48c12921 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -134,6 +134,12 @@ class ScatterNdTest(xla_test.XLATestCase): expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + def testRepeatedIndices(self): + indices = np.array([[0], [1], [0], [1]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([20, 22], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2])) + def testSimple2(self): indices = np.array([[1, 0], [1, 1]], dtype=np.int32) updates = np.array([11., 12.], dtype=np.float32) diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb5c82b22ea1d7400b54045edee0ca0782ce979 --- /dev/null +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.self_adjoint_eig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _test(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + x_np = x_np + np.swapaxes(x_np, -1, -2) + n = shape[-1] + + e_np, _ = np.linalg.eigh(x_np) + with self.cached_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + e, v = linalg_ops.self_adjoint_eig(x_tf) + e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np}) + + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6) + self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6) + + SIZES = [1, 2, 5, 10, 32] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSelfAdjointEig(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._test(dtype, batch_dims + (n, n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f0535579bf29641d34f818c7fe079a9c2d59073c --- /dev/null +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -0,0 +1,282 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateful random-number generation ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.client import device_lib +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_stateful_random_ops +from tensorflow.python.ops import stateful_random_ops as \ +random +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def xla_device_name(): + devices = device_lib.list_local_devices() + def find_type(device_type): + for d in devices: + if d.device_type == device_type: + return d.name + return None + name = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU") + if name is None: + raise ValueError( + "Can't find any XLA device. Available devices:\n%s" % devices) + return str(name) + + +class StatefulRandomOpsTest(xla_test.XLATestCase): + """Test cases for stateful random-number generator operators.""" + + @test_util.run_v2_only + def testSimple(self): + """A simple test. + """ + with ops.device(xla_device_name()): + gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + gen.normal(shape=(3,)) + gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32) + gen.uniform_full_int(shape=(3,)) + + @test_util.run_v2_only + def testDefun(self): + """Test for defun. + """ + with ops.device(xla_device_name()): + gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + @def_function.function + def f(): + x = gen.normal(shape=(3,)) + y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32) + z = gen.uniform_full_int(shape=(3,)) + return (x, y, z) + f() + + @test_util.run_v2_only + def testThreefry2x32(self): + """Tests ThreeFry2x32 conforms to known results. + """ + # Based on + # https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85 + # which is in turn based on + # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32 + + def uint32s_to_uint64(a, b): + return b << 32 | a + + def verify(counter1, counter2, key1, key2, expect1, expect2): + counter = uint32s_to_uint64(counter1, counter2) + key = uint32s_to_uint64(key1, key2) + random.get_global_generator().reset([counter, key]) + got = random.get_global_generator().uniform_full_int( + shape=(2,), dtype=dtypes.uint32) + expect = [expect1, expect2] + self.assertAllEqual(expect, got) + random.get_global_generator().reset([counter, key]) + got = random.get_global_generator().uniform_full_int( + shape=(), dtype=dtypes.uint64) + self.assertAllEqual(uint32s_to_uint64(*expect), got) + + with ops.device(xla_device_name()): + random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + verify(0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x6b200159, 0x99ba4efe) + verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0x1cb996fc, 0xbb002be7) + verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, + 0xc4923a9c, 0x483df7a0) + + @test_util.run_v2_only + def testNewState(self): + """Tests that the new state is correct. + """ + with ops.device(xla_device_name()): + counter = 57 + key = 0x1234 + size = 46 + seed = [counter, key] + gen = random.Generator( + seed=seed, algorithm=random.RNG_ALG_THREEFRY) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) + self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value()) + gen.reset(seed=seed) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) + self.assertAllEqual([counter+size, key], gen.state.read_value()) + + def _testRngIsNotConstant(self, rng, dtype): + # Tests that 'rng' does not always return the same value. + # The random-number generator, if working correctly, should produce the + # same output multiple times with low probability. + x = rng(dtype).numpy() + y = rng(dtype).numpy() + self.assertFalse(np.array_equal(x, y)) + + @test_util.run_v2_only + def testUniformIsNotConstant(self): + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + def rng(dtype): + maxval = dtype.max + # Workaround for b/125364959 + if dtype == dtypes.uint64: + maxval = 10000000 + return gen.uniform(shape=[2], dtype=dtype, maxval=maxval) + + for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + self._testRngIsNotConstant(rng, dtype) + + @test_util.run_v2_only + def testNormalIsNotConstant(self): + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + def rng(dtype): + return gen.normal(shape=[2], dtype=dtype) + + for dtype in {dtypes.float32}: + self._testRngIsNotConstant(rng, dtype) + + @test_util.run_v2_only + def testUniformIntIsInRange(self): + minval = 2 + maxval = 33 + size = 1000 + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + x = gen.uniform( + shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() + self.assertTrue(np.all(x >= minval)) + self.assertTrue(np.all(x < maxval)) + + @test_util.run_v2_only + def testNormalIsFinite(self): + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + for dtype in {dtypes.float32}: + x = gen.normal(shape=[10000], dtype=dtype).numpy() + self.assertTrue(np.all(np.isfinite(x))) + + def _chi_squared(self, x, bins): + """Pearson's Chi-squared test.""" + x = np.ravel(x) + n = len(x) + histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) + expected = n / float(bins) + return np.sum(np.square(histogram - expected) / expected) + + @test_util.run_v2_only + def testDistributionOfUniform(self): + """Use Pearson's Chi-squared test to test for uniformity.""" + with ops.device(xla_device_name()): + n = 1000 + seed = 12 + for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY) + maxval = 1 + if dtype.is_integer: + maxval = 100 + x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy() + if maxval > 1: + # Normalize y to range [0, 1). + x = x.astype(float) / maxval + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + val = self._chi_squared(x, 10) + self.assertLess(val, 16.92) + + def _normal_cdf(self, x): + """Cumulative distribution function for a standard normal distribution.""" + return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) + + def _anderson_darling(self, x): + """Anderson-Darling test for a standard normal distribution.""" + x = np.sort(np.ravel(x)) + n = len(x) + i = np.linspace(1, n, n) + z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + + (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) + return -n - z / n + + @test_util.run_v2_only + def testDistributionOfNormal(self): + """Use Anderson-Darling test to test distribution appears normal.""" + with ops.device(xla_device_name()): + n = 1000 + for dtype in {dtypes.float32}: + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + x = gen.normal(shape=[n], dtype=dtype).numpy() + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertLess(self._anderson_darling(x.astype(float)), 2.492) + + @test_util.run_v2_only + def testErrors(self): + """Tests that proper errors are raised. + """ + shape = [2, 3] + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + r"algorithm must be of shape \[\], not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, [0, 0], shape) + with self.assertRaisesWithPredicateMatch( + TypeError, "Requested dtype: int64"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, 1.1, shape) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "Unsupported algorithm id"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, 123, shape) + var = variables.Variable([0, 0], dtype=dtypes.uint32) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "Type mismatch for read of variable .* Expected int64; got"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_THREEFRY, shape) + var = variables.Variable([[0]], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "RNG state must have one and only one dimension, not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_THREEFRY, shape) + var = variables.Variable([0], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "For the ThreeFry algorithm, the size of state must be at least"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_THREEFRY, shape) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index ee7ca7e6f196e114ff18e2597145e5c198980b08..df5914a518e06e4190c623a14287de8daefebd40 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -167,8 +167,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) - self.assertTrue((y >= a).sum() == n) - self.assertTrue((y <= b).sum() == n) + self.assertEqual((y >= a).sum(), n) + self.assertEqual((y <= b).sum(), n) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". diff --git a/tensorflow/compiler/tests/svd_op_test.py b/tensorflow/compiler/tests/svd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..73bec949f3d4a08076853b537d610c156897757c --- /dev/null +++ b/tensorflow/compiler/tests/svd_op_test.py @@ -0,0 +1,81 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.svd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _compute_usvt(self, s, u, v): + m = u.shape[-1] + n = v.shape[-1] + if m <= n: + v = v[..., :m] + else: + u = u[..., :n] + + return np.matmul(u * s[..., None, :], np.swapaxes(v, -1, -2)) + + def _testSvdCorrectness(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype) + m, n = shape[-2], shape[-1] + _, s_np, _ = np.linalg.svd(x_np) + with self.cached_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + s, u, v = linalg_ops.svd(x_tf, full_matrices=True) + s_val, u_val, v_val = sess.run([s, u, v], feed_dict={x_tf: x_np}) + u_diff = np.matmul(u_val, np.swapaxes(u_val, -1, -2)) - np.eye(m) + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + # Check u_val and v_val are orthogonal matrices. + self.assertLess(np.linalg.norm(u_diff), 1e-2) + self.assertLess(np.linalg.norm(v_diff), 1e-2) + # Check that the singular values are correct, i.e., close to the ones from + # numpy.lingal.svd. + self.assertLess(np.linalg.norm(s_val - s_np), 1e-2) + # The tolerance is set based on our tests on numpy's svd. As our tests + # have batch dimensions and all our operations are on float32, we set the + # tolerance a bit larger. Numpy's svd calls LAPACK's svd, which operates + # on double precision. + self.assertLess( + np.linalg.norm(self._compute_usvt(s_val, u_val, v_val) - x_np), 2e-2) + + SIZES = [1, 2, 5, 10, 32, 64] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSvd(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._testSvdCorrectness(dtype, batch_dims + (n, n)) + self._testSvdCorrectness(dtype, batch_dims + (2 * n, n)) + self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index d7e26d79c4c054860ade5c8960a3bca984e020b0..e64aa26cd4bb5f9130def4e0f3a1799db9f2428e 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test @@ -25,7 +26,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops @@ -41,8 +44,10 @@ def _make_converter(dtype): return _converter +@test_util.with_control_flow_v2 class TensorArrayTest(xla_test.XLATestCase): + @test_util.disable_control_flow_v2("Tries to evaluate flow") def testTensorArrayWriteRead(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -117,6 +122,7 @@ class TensorArrayTest(xla_test.XLATestCase): convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0)) + @test_util.disable_control_flow_v2("b/122315751 (concat)") def testTensorArrayWriteConcat(self): for dtype in self.numeric_tf_types: self._testTensorArrayWriteConcat(dtype) @@ -224,10 +230,12 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(convert([[2.0, 201.0]]), d1) self.assertAllEqual(convert([[3.0, 301.0]]), d2) + @test_util.disable_control_flow_v2("b/122315872 (split)") def testTensorArraySplitRead(self): for dtype in self.numeric_tf_types: self._testTensorArraySplitRead(dtype) + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorGradArrayWriteRead(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -261,6 +269,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[2.0]], g_d1) self.assertAllEqual([[-2.0]], g_d2) + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorGradArrayDynamicWriteRead(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -300,6 +309,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(3, vs) self.assertAllEqual(3, g_vs) + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorGradAccessTwiceReceiveSameObject(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -317,6 +327,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(t_g_ta_0, t_g_ta_1) self.assertAllEqual([[4.0, 5.0]], d_r1_0) + @test_util.disable_control_flow_v2("b/124334470") def testTensorArrayWriteWrongIndexOrDataTypeFails(self): with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -327,6 +338,7 @@ class TensorArrayTest(xla_test.XLATestCase): "TensorArray dtype is float but op has dtype int32"): ta.write(-1, np.int32(7)).flow.eval() + @test_util.disable_control_flow_v2("b/124334096 verify dtype") def testTensorArrayReadWrongIndexOrDataTypeFails(self): # Find two different floating point types, create an array of # the first type, but try to read the other type. @@ -347,6 +359,7 @@ class TensorArrayTest(xla_test.XLATestCase): # Test reading from a different index than the one we wrote to w0.read(1) + @test_util.disable_control_flow_v2("b/122315872 (split)") def testTensorArraySplitIncompatibleShapesFails(self): with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -406,6 +419,7 @@ class TensorArrayTest(xla_test.XLATestCase): r"Mismatched TensorArray sizes"): wb1_grad.flow.eval() + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorArrayWriteGradientAddMultipleAdds(self): for dtype in self.numeric_tf_types: self._testTensorArrayWriteGradientAddMultipleAdds(dtype) @@ -510,6 +524,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) + @test_util.disable_control_flow_v2("b/122315751 (concat)") def testTensorArrayGradientWritePackConcatAndRead(self): self._testTensorArrayGradientWritePackConcatAndRead() @@ -557,6 +572,7 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayGradientUnpackRead(self): self._testTensorArrayGradientUnpackRead() + @test_util.disable_control_flow_v2("b/122315751(concat), b/122315872(split)") def testTensorArrayGradientSplitConcat(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -903,6 +919,7 @@ class TensorArrayTest(xla_test.XLATestCase): "zero-size TensorArrays."): ta.stack().eval() + @test_util.disable_control_flow_v2("b/124335246") def testTensorArrayEvalEmpty(self): self._testTensorArrayEvalEmpty() @@ -916,7 +933,9 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape) # Concatenating zero tensors along their first dimension gives a # first dimension of zero - self.assertAllEqual([0, 5], ta.concat().eval().shape) + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: + # TODO(b/122315751): Enable this. + self.assertAllEqual([0, 5], ta.concat().eval().shape) def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() @@ -944,11 +963,13 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + @test_util.disable_control_flow_v2("b/122315734 (scatter)") def testTensorArrayScatterRead(self): for dtype in self.numeric_tf_types: self._testTensorArrayScatterRead(dtype) self._testTensorArrayScatterRead(dtypes.bool) + @test_util.disable_control_flow_v2("b/122315734 (scatter)") def testTensorArrayScatterReadAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -977,6 +998,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([10.0, -10.0], read_vals[1]) self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) + @test_util.disable_control_flow_v2("b/122315378 (gather)") def testTensorArrayWriteGatherAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -1052,4 +1074,6 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(size1_v, 4) if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 1ecdb22cd0bc7e42d7ff67d20544fd26a65f6204..e07b150d6013582a9a9d3bb14e82cf9a4e962bc1 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -18,29 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -def scalar_shape(): - return ops.convert_to_tensor([], dtype=dtypes.int32) - - class ListOpsTest(xla_test.XLATestCase): def testElementShape(self): with self.cached_session() as sess, self.test_scope(): dim = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(dim, 15), num_elements=20, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(dim, 15), + element_dtype=dtypes.float32, + max_num_elements=20) e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) @@ -48,8 +45,10 @@ class ListOpsTest(xla_test.XLATestCase): def testPushPop(self): with self.cached_session() as sess, self.test_scope(): - l = list_ops.tensor_list_reserve( - element_shape=(7, 15), num_elements=10, element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( @@ -62,8 +61,10 @@ class ListOpsTest(xla_test.XLATestCase): def testDoNotConstantFoldVariants(self): with self.cached_session() as sess, self.test_scope(): val = array_ops.placeholder(dtype=dtypes.float32) - l = list_ops.tensor_list_reserve( - element_shape=(7, 15), num_elements=10, element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) # Note: Pushing a Placeholder will force the constant folding code # to build a Const node with a DT_VARIANT output. This tests that XLA # passes a cf_consider_fn which prevent folding such nodes. @@ -78,10 +79,10 @@ class ListOpsTest(xla_test.XLATestCase): def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): - l = list_ops.tensor_list_reserve( - element_shape=scalar_shape(), - num_elements=20, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=[], + element_dtype=dtypes.float32, + max_num_elements=20) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) @@ -102,7 +103,7 @@ class ListOpsTest(xla_test.XLATestCase): _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Set the max number of elements"): - self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) def testEmptyTensorListMax(self): with self.cached_session() as sess, self.test_scope(): @@ -114,6 +115,103 @@ class ListOpsTest(xla_test.XLATestCase): _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15))) + def testListFromTensor(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 2.0) + l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e1, 1.0) + self.assertAllEqual(list_ops.tensor_list_length(l), 0) + + def testGetSet(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 1.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 2.0]) + + def testSetDoesNotUpdatePushIndex(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) + # SetItem should not change the push index. + l = list_ops.tensor_list_set_item(l, 1, 3.) + l = list_ops.tensor_list_push_back(l, 5.) + l = list_ops.tensor_list_push_back(l, 7.) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [5., 7.]) + + def testGetSetReserved(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=2) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 0.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 0.0]) + + def testSetStackReservedUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=2) + l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) + + def testPushInEmptyListWithUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) + l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) + # Pushing an element with a different shape should raise an error. + with self.assertRaisesRegexp(errors.InternalError, "shape"): + l = list_ops.tensor_list_push_back(l, 5.) + self.evaluate( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) + + def testGetSetReservedNonScalar(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, + element_shape=(7, 15), + num_elements=2) + l = list_ops.tensor_list_set_item( + l, 0, constant_op.constant(1.0, shape=(7, 15))) + e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e1), np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), np.zeros((7, 15))) + + def testStack(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=[], + max_num_elements=2) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t.shape.as_list(), [None]) + self.assertAllEqual(t, [1.0, 2.0]) + + def testStackWithUninitializedTensors(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=3) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [0., 0., 0.]) if __name__ == "__main__": + os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' + + os.environ.get('TF_XLA_FLAGS', '')) test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 3c2875ba477fa71e9e56a18d10efe0808533dd03..159fa6685b5c333c4669f1f141a4d41f267255f4 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -72,8 +72,9 @@ class UnaryOpsTest(xla_test.XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) self.assertAllCloseAccordingToType( - result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) else: equality_test(result, expected, rtol=rtol, atol=atol) @@ -260,7 +261,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], + dtype=dtype)).astype(dtype), rtol=1e-4, atol=1e-6) @@ -391,6 +393,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype), + expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.is_finite, np.array( @@ -705,7 +712,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=dtype)) + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) self._assertOpOutputMatchesExpected( math_ops.negative, @@ -743,6 +750,10 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array( [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[np.nan]], dtype=dtype), + expected=np.array([[0.0]], dtype=dtype)) def testLogicalOps(self): self._assertOpOutputMatchesExpected( @@ -760,7 +771,7 @@ class UnaryOpsTest(xla_test.XLATestCase): lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), np.array( [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), - expected=np.array([10., 26.], dtype=np.float32)) + expected=np.array([14., 22.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] @@ -811,6 +822,12 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2, 0], np.int32), expected=np.array([2, 0, 1], dtype=np.int32)) + def testInvertPermutationTwiceIsNoop(self): + self._assertOpOutputMatchesExpected( + lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), + np.array([1, 2, 0], np.int32), + expected=np.array([1, 2, 0], dtype=np.int32)) + def testRank(self): rank_op = lambda x: array_ops.rank_internal(x, optimize=False) for dtype in self.numeric_types: @@ -865,6 +882,17 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1], [1], [4]], dtype=dtype), expected=np.int32(3)) + def testSizeWithInt64OutType(self): + + def size_op(x): + return array_ops.size_internal(x, optimize=False, out_type=np.int64) + + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + size_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.int64(3)) + def testUnpack(self): self._assertOpOutputMatchesExpected( array_ops.unstack, @@ -928,6 +956,15 @@ class UnaryOpsTest(xla_test.XLATestCase): [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], dtype=dtype), data_format)) + self._assertOpOutputMatchesExpected( + make_op("NCHW_VECT_C"), + np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), + expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]], + [[[2, 3], [10, 11]], [[18, 19], [26, 27]]], + [[[4, 5], [12, 13]], [[20, 21], [28, 29]]], + [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], + dtype=dtype)) + def testSpaceToDepth(self): def make_op(data_format): @@ -971,10 +1008,19 @@ class UnaryOpsTest(xla_test.XLATestCase): [13, 14, 15, 16]]]], dtype=dtype), data_format)) + self._assertOpOutputMatchesExpected( + make_op("NCHW_VECT_C"), + np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), + expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]], + [[[4, 5, 6, 7, 20, 21, 22, 23]]], + [[[8, 9, 10, 11, 24, 25, 26, 27]]], + [[[12, 13, 14, 15, 28, 29, 30, 31]]]]], + dtype=dtype)) + def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) - expected = np.logaddexp(zero, features) + expected = np.logaddexp(zero, features).astype(dtype) self._assertOpOutputMatchesExpected( nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8..18c5870e0decb686f4df1c16bbb4a340c93ad21d 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -485,7 +485,7 @@ class SliceAssignTest(xla_test.XLATestCase): checker2[None] = [6] # new axis def testUninitialized(self): - with self.assertRaisesRegexp(errors.InvalidArgumentError, + with self.assertRaisesRegexp(errors.FailedPreconditionError, "uninitialized variable"): with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable([1, 2]) diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 4ee144beb7f3243be069d59ee4a613484fe183b3..15a31111cb6b51f6d6e501b86d906d9ba53d1c22 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test @@ -25,7 +26,12 @@ from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import map_fn +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -125,6 +131,121 @@ class WhileTest(xla_test.XLATestCase): result = sess.run(loop_outputs, {init_index: 0}) self.assertAllClose(result, [10, 7], rtol=1e-3) - -if __name__ == '__main__': + def _testMaxItersSimple(self): + if is_compile_on_demand(): + self.skipTest("list_ops are not supported in cpu_ondemand") + with self.cached_session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + v = constant_op.constant(1.0) + p = array_ops.placeholder(dtype=dtypes.int32) + + def create_while_loop(): + iterations = array_ops.size(p, name="iterations") + r = control_flow_ops.while_loop( + lambda *_: True, + lambda i, x: (i + 1, v * x), (0, 1.0), + maximum_iterations=iterations, + name="outer") + return array_ops.identity(r[1]) + + output = create_while_loop() + output = gradients_impl.gradients(output, v)[0] + + result = sess.run(output, feed_dict={p: [0, 0, 0]}) + print(result) + xla_context.Exit() + + def testMaxItersSimple(self): + self.skipTest("Fails with v1 control flow") + # This fails with old control. + # self._testMaxItersSimple() + + @test_util.enable_control_flow_v2 + def testMaxItersSimpleV2(self): + self._testMaxItersSimple() + + def _testNestedWhileLoopWithMaxItersFromOuterContext(self): + if is_compile_on_demand(): + self.skipTest("list_ops are not supported in cpu_ondemand") + with self.cached_session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + v = constant_op.constant(1.0) + p = array_ops.placeholder(dtype=dtypes.int32) + + def mid_body_builder(iterations): + + def mid_body(i, x): + r = control_flow_ops.while_loop( + lambda *_: True, + lambda i, x: (i + 1, v * x), (0, x), + maximum_iterations=iterations, + name="inner") + return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) + + return mid_body + + def outer_body(i, x): + iterations = array_ops.size(p, name="iterations") + return (i + 1, x + control_flow_ops.while_loop( + lambda *_: True, + mid_body_builder(iterations), (0, x), + maximum_iterations=iterations, + name="mid")[1]) + + def create_while_loop(): + r = control_flow_ops.while_loop( + lambda *_: True, + outer_body, (0, 1.0), + maximum_iterations=5, + name="outer") + return array_ops.identity(r[1]) + + # p:placeholder + # j = 0 + # i, x = 0, 1. + # while j++ < 5: + # i1, x1 = 0, x + # while i1++ < len(p): + # i2, x2 = 0, x1 + # while i2++ < len(p): + # x2 = v * x2 + # x1 = grad(x1 + x2, v) + # x = x1 + # output = x + output = create_while_loop() + sess.run(output, feed_dict={p: [0, 0, 0]}) + xla_context.Exit() + + def testNestedWhileLoopWithMaxItersFromOuterContext(self): + self._testNestedWhileLoopWithMaxItersFromOuterContext() + + @test_util.enable_control_flow_v2 + def testNestedWhileLoopWithMaxItersFromOuterContextV2(self): + self._testNestedWhileLoopWithMaxItersFromOuterContext() + + @test_util.enable_control_flow_v2 + def testMap(self): + if is_compile_on_demand(): + self.skipTest("list_ops are not supported in cpu_ondemand") + with self.cached_session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, name="data") + r = map_fn.map_fn(lambda x: math_ops.multiply(math_ops.add(x, 3), 2), + elems) + self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) + xla_context.Exit() + + +def is_compile_on_demand(): + return ("TF_XLA_FLAGS" in os.environ and + "tf_xla_compile_on_demand" in os.environ["TF_XLA_FLAGS"]) + + +if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..12a51f7d32dc08f9ed40bdf57350ae451c90ce83 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -0,0 +1,447 @@ +# Description: +# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow +# and provide TensorRT operators and converter package. +# APIs are meant to change over time. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_shared_object", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) + +tf_cuda_cc_test( + name = "tensorrt_test_cc", + size = "small", + srcs = ["tensorrt_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + "//tensorflow/core:gpu_init", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "trt_op_kernels", + srcs = [ + "kernels/get_serialized_resource_op.cc", + "kernels/trt_engine_op.cc", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":trt_allocator", + ":trt_conversion", + ":trt_logging", + ":trt_plugins", + ":trt_resources", + ":utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_cc_shared_object( + name = "python/ops/libtftrt.so", + copts = tf_copts(is_external = True), + linkopts = ["-lm"], + deps = [ + ":trt_op_kernels", + ":trt_op_libs", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_cuda_cc_test( + name = "get_serialized_resource_op_test", + size = "small", + srcs = ["kernels/get_serialized_resource_op_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + # TODO(laigd): consider splitting get_serialized_resource_op out from + # TF-TRT. + ":trt_op_kernels", + ":trt_op_libs", + ":trt_resources", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "trt_engine_op", + "get_serialized_resource_op", + ], +) + +cc_library( + name = "trt_op_libs", + deps = [ + ":get_serialized_resource_op_op_lib", + ":trt_engine_op_op_lib", + ], +) + +tf_cuda_library( + name = "trt_logging", + srcs = ["utils/trt_logger.cc"], + hdrs = ["utils/trt_logger.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_gen_op_wrapper_py( + name = "trt_ops", + deps = [ + ":trt_op_libs", + ], +) + +tf_custom_op_py_library( + name = "trt_ops_loader", + srcs = ["python/ops/trt_ops.py"], + dso = [ + "python/ops/libtftrt.so", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), + kernels = [ + ":trt_op_kernels", + ":trt_op_libs", + ], + srcs_version = "PY2AND3", + deps = [ + ":trt_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + ], +) + +tf_cuda_library( + name = "trt_resources", + srcs = [ + "utils/trt_int8_calibrator.cc", + "utils/trt_resources.cc", + ], + hdrs = [ + "utils/trt_int8_calibrator.h", + "utils/trt_lru_cache.h", + "utils/trt_resources.h", + ], + deps = [ + ":trt_allocator", + ":trt_logging", + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_library( + name = "trt_allocator", + srcs = ["utils/trt_allocator.cc"], + hdrs = ["utils/trt_allocator.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cc_test( + name = "trt_allocator_test", + size = "small", + srcs = ["utils/trt_allocator_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_allocator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "trt_lru_cache_test", + size = "small", + srcs = ["utils/trt_lru_cache_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_resources", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Library for the node-level conversion portion of TensorRT operation creation +tf_cuda_library( + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + "convert/trt_optimization_pass.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + "convert/trt_optimization_pass.h", + ], + deps = [ + ":segment", + ":trt_allocator", + ":trt_plugins", + ":trt_logging", + ":trt_resources", + ":utils", + "@com_google_absl//absl/strings", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:graph", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "convert_graph_test", + size = "medium", + srcs = ["convert/convert_graph_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_conversion", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "convert_nodes_test", + size = "medium", + srcs = ["convert/convert_nodes_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_logging", + ":trt_conversion", + ":trt_plugins", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + "segment/union_find.h", + ], + copts = tf_copts(), + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cuda_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":segment", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# Library for the plugin factory +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "trt_plugin_factory_test", + size = "small", + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_plugins", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "utils", + srcs = ["convert/utils.cc"], + hdrs = ["convert/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib_proto_parsing", + ], +) + +cc_library( + name = "py_utils", + srcs = ["utils/py_utils.cc"], + hdrs = ["utils/py_utils.h"], + copts = tf_copts(), + deps = if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc similarity index 73% rename from tensorflow/contrib/tensorrt/convert/convert_graph.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index e2350f69a6ace005c319aa97efa99f5e36e6a831..f2a6b74ec248cf9f2ee66de830900a9d7841b55e 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include @@ -24,13 +24,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -63,100 +62,15 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - -// Returns compiled TRT version information {Maj, Min, Patch} -std::vector GetLinkedTensorRTVersion() { - return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; -} - -// Returns loaded TRT library version {Maj, Min, Patch} -std::vector GetLoadedTensorRTVersion() { - int ver = getInferLibVersion(); - int ver_major = ver / 1000; - ver = ver - ver_major * 1000; - int ver_minor = ver / 100; - int ver_patch = ver - ver_minor * 100; - return {ver_major, ver_minor, ver_patch}; -} +using absl::StrAppend; +using absl::StrCat; TrtCandidateSelector::TrtCandidateSelector( - const grappler::GraphProperties& graph_properties, int precision_mode) + const grappler::GraphProperties& graph_properties, + TrtPrecisionMode precision_mode) : graph_properties_(graph_properties), precision_mode_(precision_mode) {} -Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { - // TODO(laigd): move this set to TrtNodeValidator where it should belong. - // LINT.IfChange - static const std::set candidate_ops = { - "Abs", - "Add", - "AvgPool", - "BatchMatMul", - "BiasAdd", - "ConcatV2", - "Const", - "Conv2D", - "DepthwiseConv2dNative", - "Div", - "Exp", - "ExpandDims", - "FusedBatchNorm", - "FusedBatchNormV2", - "Identity", - "Log", - "MatMul", - "Max", - "MaxPool", - "Maximum", - "Mean", - "Min", - "Minimum", - "Mul", - "Neg", - "Pad", - "Prod", - "RealDiv", - "Reciprocal", - "Relu", - "Relu6", - "Reshape", - "Rsqrt", - "Rsqrt", - "Sigmoid", - "Snapshot", - "Softmax", - "Sqrt", - "Square", - "Squeeze", - "StridedSlice", - "Sub", - "Sum", - "Tanh", - "TopKV2", - "Transpose", - }; - bool is_supported_op_type = - (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); - static const std::set quantize_ops = { - "QuantizeAndDequantizeV2", - "QuantizeAndDequantizeV3", - "FakeQuantWithMinMaxVars", - "FakeQuantWithMinMaxArgs", - }; - // In INT8 mode, we will always apply the quantization ranges provided by - // these ops to the relevant tensors. This happens regardless of the value of - // use_calibration. - if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { - is_supported_op_type = true; - } - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) - if (!is_supported_op_type) { - return errors::Unimplemented("Op type ", node->type_string(), - " is not supported"); - } - +Status TrtCandidateSelector::IsTensorRTCandidate(const Node* node) { std::vector input_edges; TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); std::vector> input_node_and_ports; @@ -166,83 +80,32 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { input_edge->src_output()); } return validator_.ValidateNode(node->def(), input_node_and_ports, - graph_properties_); + precision_mode_, graph_properties_); } namespace { -tensorflow::Status BuildNodeMap( - const tensorflow::Graph& graph, - std::unordered_map* node_map) { +Status BuildNodeMap(const Graph& graph, + std::unordered_map* node_map) { for (auto* node : graph.op_nodes()) { if (!node_map->insert({node->name(), node}).second) { - return tensorflow::errors::AlreadyExists( - "Node name is not unique in graph: " + node->name()); + return errors::AlreadyExists("Node name is not unique in graph: " + + node->name()); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace -// Function to get calibration from ResourceMgr and put them into nodedef. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, - bool is_dyn_op) { - LOG(INFO) << "Starting Calib Conversion"; - infer_graph->CopyFrom(graph_def); - auto trt_rm = TRTResourceManager::instance(); - auto calib_rm = trt_rm->getManager("TRTCalibration"); - int num_nodes = infer_graph->node_size(); - if (!is_dyn_op) { - LOG(WARNING) << "Construction of static int8 engine is not implemented " - "yet!. Dynamic engine will be constructed"; - } - for (int i = 0; i < num_nodes; ++i) { - auto n = infer_graph->mutable_node(i); - if (n->op() == "TRTEngineOp") { - VLOG(1) << "Processing " << n->name(); - const string& container_name = n->attr().at("segment_funcdef_name").s(); - TRTCalibrationResource* cres = nullptr; - auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); - if (!status.ok()) { - LOG(ERROR) << "Could not get Calibration information. Did you run with " - "calibration data?"; - return tensorflow::errors::FailedPrecondition( - "Need to run graph with calibration data first!"); - } - if (cres->calibrator_) { - cres->calibrator_->waitAndSetDone(); - cres->thr_->join(); - const auto& calibration_table = - cres->calibrator_->getCalibrationTableAsString(); - if (!calibration_table.size()) { - LOG(ERROR) << "Calibration table is empty"; - return tensorflow::errors::Unknown( - "Calibration table is missing. This shouldn't have happened!"); - } - n->mutable_attr()->at("calibration_data").set_s(calibration_table); - } else { - LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; - return tensorflow::errors::Unknown( - "Can't get TRTCalibrator from resource manager!"); - } - cres->Unref(); - TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); - } - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertGraphDefToTensorRT( - const tensorflow::GraphDef& graph_def, - const std::vector& output_names, size_t max_batch_size, - size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size, bool is_dyn_op, - int max_cached_engines, std::vector cached_engine_batches, - bool use_calibration) { +Status ConvertGraphDefToTensorRT( + const GraphDef& graph_def, const std::vector& output_names, + size_t max_batch_size, size_t max_workspace_size_bytes, + GraphDef* new_graph_def, TrtPrecisionMode precision_mode, + int minimum_segment_size, bool is_dyn_op, int max_cached_engines, + std::vector cached_engine_batches, bool use_calibration) { // Create GrapplerItem. - tensorflow::grappler::GrapplerItem item; + grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; @@ -256,13 +119,13 @@ tensorflow::Status ConvertGraphDefToTensorRT( // Create single machine cluster. Note that this will create a session and // initialize the gpu devices. const int num_cpu_cores = - tensorflow::grappler::GetNumAvailableLogicalCPUCores(); - const int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); + grappler::GetNumAvailableLogicalCPUCores(); + const int num_gpus = grappler::GetNumAvailableGPUs(); VLOG(2) << "cpu_cores: " << num_cpu_cores; VLOG(2) << "gpus: " << num_gpus; const int timeout_s = 60 * 10; - std::unique_ptr cluster( - new tensorflow::grappler::SingleMachine( + std::unique_ptr cluster( + new grappler::SingleMachine( timeout_s, num_cpu_cores, num_gpus)); // These settings are the defaults in tensorflow/python/grappler/cluster.py. cluster->DisableDetailedStats(true); @@ -273,18 +136,17 @@ tensorflow::Status ConvertGraphDefToTensorRT( // Create virtual cluster. Grappler requires a virtual cluster with a proper // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode. // We add numbers from a Pascal card here to have flops>0. - tensorflow::DeviceProperties device_properties; + DeviceProperties device_properties; device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); device_properties.set_num_cores(3584); device_properties.set_frequency(1531); - std::unique_ptr cluster( - new tensorflow::grappler::VirtualCluster( - {{"/GPU:0", device_properties}})); + std::unique_ptr cluster( + new grappler::VirtualCluster({{"/GPU:0", device_properties}})); #endif // Create RewriterConfig. - tensorflow::ConfigProto config_proto; + ConfigProto config_proto; auto& rw_cfg = *config_proto.mutable_graph_options()->mutable_rewrite_options(); // TODO(aaroey): use only const folding and layout for the time being since @@ -298,7 +160,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( parameters["max_batch_size"].set_i(max_batch_size); parameters["is_dynamic_op"].set_b(is_dyn_op); parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); - TF_RETURN_IF_ERROR(GetPrecisionModeName( + TF_RETURN_IF_ERROR(TrtPrecisionModeToName( precision_mode, parameters["precision_mode"].mutable_s())); parameters["maximum_cached_engines"].set_i(max_cached_engines); if (!cached_engine_batches.empty()) { @@ -310,7 +172,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( parameters["use_calibration"].set_b(use_calibration); // Run optimizer. - tensorflow::grappler::MetaOptimizer meta_opt(nullptr, config_proto); + grappler::MetaOptimizer meta_opt(nullptr, config_proto); TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); if (VLOG_IS_ON(5)) { @@ -324,20 +186,18 @@ tensorflow::Status ConvertGraphDefToTensorRT( } struct EdgePtrCompare { - bool operator()(const tensorflow::Edge* lhs, - const tensorflow::Edge* rhs) const { + bool operator()(const Edge* lhs, const Edge* rhs) const { return lhs->id() < rhs->id(); } }; // Function to get subsegment information structure. -tensorflow::Status GetEngineInfo( - const tensorflow::Graph* g, - const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& segment_nodes, - const std::unordered_map& node_map, - const std::vector& reverse_topo_order, - EngineInfo* info) { +Status GetEngineInfo(const Graph* g, + const grappler::GraphProperties& graph_properties, + const std::set& segment_nodes, + const std::unordered_map& node_map, + const std::vector& reverse_topo_order, + EngineInfo* info) { std::vector subgraph_nodes; // Topologically sorted nodes. std::set added_const_nodes; // Used to prevent double insertion. std::set segment_devices; @@ -384,8 +244,8 @@ tensorflow::Status GetEngineInfo( // Create input connections. Sort edges first to make determnistic since // in_edges is a set of pointers. - std::vector in_edges(node->in_edges().begin(), - node->in_edges().end()); + std::vector in_edges(node->in_edges().begin(), + node->in_edges().end()); std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); for (const auto edge : in_edges) { auto input_node = edge->src(); @@ -436,8 +296,8 @@ tensorflow::Status GetEngineInfo( } // Create output connections. Sort edges first to make determnistic since // out_edges is a set of pointers. - std::vector out_edges(node->out_edges().begin(), - node->out_edges().end()); + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); for (const auto edge : out_edges) { auto output_node = edge->dst(); @@ -471,9 +331,13 @@ tensorflow::Status GetEngineInfo( // Construct the const nodes first. subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(), added_const_nodes.end()); + string scope_name; TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( g, graph_properties, subgraph_nodes, &info->connections, - &info->segment_graph_def, &info->engine_name)); + &info->segment_graph_def, &scope_name)); + info->engine_name = StrCat(scope_name, info->engine_name); + VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name + << "' to a GraphDef"; // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -496,7 +360,7 @@ void UpdateToEngineNode(const std::vector& infos, const size_t my_engine_id, const std::vector& engine_nodes, const bool is_input_edge, const string& node_name, - tensorflow::Node** node, int* port) { + Node** node, int* port) { for (size_t t = 0; t < infos.size(); ++t) { if (t == my_engine_id) { continue; @@ -533,20 +397,20 @@ void UpdateToEngineNode(const std::vector& infos, // one). Connect to the pre-existing engine node instead. // 3. In this way, we ensure the graph is topologically sort-able after each // invocation of CreateTRTNode(). -tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, - int max_batch_size, tensorflow::Graph* graph, - nvinfer1::IGpuAllocator* alloc, - std::vector* engine_nodes) { +Status CreateTRTNode(const ConversionParams& params, + const std::vector& infos, int pos, + int max_batch_size, Graph* graph, + nvinfer1::IGpuAllocator* alloc, + std::vector* engine_nodes) { const auto& info = infos.at(pos); - TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); - std::vector output_shape_protos; - std::vector input_shape_protos; - std::vector input_shapes; - std::vector inputs; - std::vector input_nodes; - std::vector control_input_nodes; + std::vector output_shape_protos; + std::vector input_shape_protos; + std::vector input_shapes; + std::vector inputs; + std::vector input_nodes; + std::vector control_input_nodes; std::unordered_set control_input_names; - std::vector out_types; + std::vector out_types; VLOG(1) << "Processing " << info.engine_name; // Collect needed info for creating the engine node in the graph @@ -558,8 +422,8 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, if (!conn.is_input_edge) continue; // Rewrire control input if it's not found in original graph. - tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); - int port = tensorflow::Graph::kControlSlot; + Node* input_node = graph->FindNodeId(conn.outside_id); + int port = Graph::kControlSlot; if (!input_node) { UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, conn.outside_node_name, &input_node, &port); @@ -575,7 +439,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, // Data edges if (!conn.is_input_edge) { // Set the shapes and data types of output edge. - tensorflow::TensorShapeProto out_shape; + TensorShapeProto out_shape; // shape of the output node inside segment conn.inside_shape.AsProto(&out_shape); if (output_shape_protos.size() <= conn.port_number) { @@ -586,7 +450,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, out_types.at(conn.port_number) = conn.connection_type; } else { // Set the shapes and data types of input edge. - tensorflow::TensorShapeProto in_shape; + TensorShapeProto in_shape; conn.outside_shape.AsProto(&in_shape); if (input_shape_protos.size() <= conn.port_number) { input_shape_protos.resize(conn.port_number + 1); @@ -599,7 +463,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, if (info.engine_type == EngineInfo::EngineType::TRTStatic) { for (int i = 1; i < conn.outside_shape.dims(); i++) { if (conn.outside_shape.dim_size(i) <= 0) { - return tensorflow::errors::Internal( + return errors::Internal( "Input shapes must be fully defined when in static mode. " "Please try is_dynamic_op=True (shape was ", conn.outside_shape.DebugString(), ")"); @@ -608,7 +472,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } // Rewrire data input if it's not found in original graph. - tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + Node* input_node = graph->FindNodeId(conn.outside_id); int port = conn.outside_port; if (!input_node) { UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, @@ -631,13 +495,12 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, // avoid crash later. Constant folding should've folded the ops that make up // these segments. if (inputs.empty()) { - return tensorflow::errors::Internal( - "Segment has no inputs (possible " - "constfold failure)"); + return errors::Internal( + "Segment has no inputs (possible constfold failure)"); } const bool calibrate_int8 = - (info.precision_mode == INT8MODE && info.use_calibration); + (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration); // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { @@ -650,14 +513,15 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( - info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode, + info.segment_graph_def, + calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, &trt_logger, alloc, /*calibrator=*/nullptr, &engine, info.use_calibration, /*convert_successfully=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); - segment_string = - string((const char*)engine_data->data(), engine_data->size()); + segment_string = string(static_cast(engine_data->data()), + engine_data->size()); if (calibrate_int8) { // See above comment about why not putting this inside the 'else' branch. segment_string = info.segment_graph_def.SerializeAsString(); @@ -666,15 +530,9 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, segment_string = info.segment_graph_def.SerializeAsString(); } - // TODO(aaroey): use enum instead, and add a helper method to do the - // conversion. string prec_string; - TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); - if (info.precision_mode == INT8MODE && calibrate_int8 && - !TRTResourceManager::instance()->getManager("TRTCalibration")) { - LOG(ERROR) << "Failed to construct calibration storage"; - } - tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); + TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string)); + NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); if (!info.device.empty()) node_builder.Device(info.device); if (VLOG_IS_ON(1)) { string ins = StrCat(info.engine_name, " inputs= "); @@ -692,14 +550,16 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, !info.cached_engine_batches.empty()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; } - tensorflow::NodeDef trt_node; - tensorflow::Status status = + NodeDef trt_node; + Status status = node_builder.Attr("input_shapes", input_shape_protos) .Attr("output_shapes", output_shape_protos) .Attr("static_engine", info.engine_type == EngineInfo::EngineType::TRTStatic) .Attr("segment_funcdef_name", - StrCat(info.engine_name, "_native_segment")) + params.use_function_backup + ? StrCat(info.engine_name, "_native_segment") + : "") .Attr("serialized_segment", segment_string) .Attr("calibration_data", "") .Attr("max_cached_engines_count", info.maximum_cached_engines) @@ -718,7 +578,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, // here, this segment will be skipped // TODO(aaroey): let it return proper error status for the following logic // instead of checking fail. - tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + Node* engine_node = graph->AddNode(trt_node, &status); (*engine_nodes)[pos] = engine_node; if (!status.ok()) { LOG(ERROR) << "Adding node failed " << status; @@ -745,7 +605,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, if (conn.is_input_edge) { continue; } - tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + Node* output_node = graph->FindNodeId(conn.outside_id); int port = conn.outside_port; if (!output_node) { UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, @@ -768,20 +628,19 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } // Function to construct a funcdef from the segment and add it to the graph. -tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( - tensorflow::Graph* graph, const tensorflow::GraphDef& segment, - const string& engine_name) { - tensorflow::Graph sgraph(graph->flib_def()); - tensorflow::GraphConstructorOptions gcopts; - TF_RETURN_IF_ERROR( - tensorflow::ConvertGraphDefToGraph(gcopts, segment, &sgraph)); - std::map io_nodes; +Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph, + const GraphDef& segment, + const string& engine_name) { + Graph sgraph(graph->flib_def()); + GraphConstructorOptions gcopts; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph)); + std::map io_nodes; int num_inputs = 0; for (auto n : sgraph.op_nodes()) { - if (tensorflow::str_util::StartsWith(n->name(), kInputPHName)) { + if (str_util::StartsWith(n->name(), kInputPHName)) { num_inputs++; io_nodes.insert({n->name(), n}); - } else if (tensorflow::str_util::StartsWith(n->name(), kOutputPHName)) { + } else if (str_util::StartsWith(n->name(), kOutputPHName)) { io_nodes.insert({n->name(), n}); } } @@ -789,14 +648,14 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( for (int i = 0; i < num_inputs; ++i) { auto name = StrCat(kInputPHName, i); auto node = io_nodes[name]; - tensorflow::NodeDef nd; - tensorflow::NodeDefBuilder node_builder( - StrCat(name, "_Arg"), tensorflow::FunctionLibraryDefinition::kArgOp); + NodeDef nd; + NodeDefBuilder node_builder(StrCat(name, "_Arg"), + FunctionLibraryDefinition::kArgOp); VLOG(1) << "Adding " << StrCat(name, "_Arg"); TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) .Attr("index", i) .Finalize(&nd)); - tensorflow::Status s; + Status s; auto node_arg = sgraph.AddNode(nd, &s); if (!s.ok()) { LOG(ERROR) << "Couldn't add _Arg node for " << name; @@ -816,15 +675,14 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( for (int i = 0; i < io_nodes.size() - num_inputs; ++i) { auto name = StrCat(kOutputPHName, i); auto node = io_nodes[name]; - tensorflow::NodeDef nd; - tensorflow::NodeDefBuilder node_builder( - StrCat(name, "_Ret"), tensorflow::FunctionLibraryDefinition::kRetOp); + NodeDef nd; + NodeDefBuilder node_builder(StrCat(name, "_Ret"), + FunctionLibraryDefinition::kRetOp); auto edge = *(node->in_edges().begin()); - tensorflow::NodeDefBuilder::NodeOut nout( - edge->src()->name(), edge->src_output(), - edge->src()->output_type(edge->src_output())); + NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())); VLOG(1) << " input " << nout.node << ":" << nout.index - << " dtype=" << tensorflow::DataTypeString(nout.data_type); + << " dtype=" << DataTypeString(nout.data_type); // nvcc complains that Input() is // ambiguous, so do not use Input({nout}). node_builder.Input(nout); @@ -834,7 +692,7 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( if (VLOG_IS_ON(3)) { VLOG(3) << nd.DebugString(); } - tensorflow::Status s; + Status s; auto node_ret = sgraph.AddNode(nd, &s); if (!s.ok()) { LOG(ERROR) << "Couldn't add _Ret node for " << name; @@ -850,23 +708,29 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( } sgraph.RemoveNode(node); } - tensorflow::FunctionDefLibrary fdeflib; + FunctionDefLibrary fdeflib; auto native_segment = fdeflib.add_function(); - TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( + TF_RETURN_IF_ERROR(GraphToFunctionDef( sgraph, StrCat(engine_name, "_native_segment"), native_segment)); + // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on + // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32 + // would be on host if the op generating the tensor has host memory tag set. + (*native_segment + ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr] + .set_b(true); if (VLOG_IS_ON(7)) { VLOG(7) << engine_name << " Function_Def "; VLOG(7) << native_segment->DebugString(); } VLOG(1) << "Adding funcdef to graphlib"; TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); - return tensorflow::Status::OK(); + return Status::OK(); } -std::pair GetDeviceAndAllocator( - const ConversionParams& params, const EngineInfo& engine) { +std::pair GetDeviceAndAllocator(const ConversionParams& params, + const EngineInfo& engine) { int cuda_device_id = -1; - tensorflow::Allocator* dev_allocator = nullptr; + Allocator* dev_allocator = nullptr; if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr || engine.device.empty()) { // If device is not set, use the first found GPU device for the conversion. @@ -894,7 +758,7 @@ std::pair GetDeviceAndAllocator( // Use the device requested by the engine. auto device_set = params.cluster->GetDeviceSet(); - std::vector devices; + std::vector devices; DeviceNameUtils::ParsedName parsed_name; if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) && parsed_name.has_id) { @@ -908,7 +772,7 @@ std::pair GetDeviceAndAllocator( StrAppend(&msg, ". Will get the allocator from first one."); LOG(WARNING) << msg; } - tensorflow::AllocatorAttributes alloc_attr; + AllocatorAttributes alloc_attr; cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; dev_allocator = devices[0]->GetAllocator(alloc_attr); VLOG(1) << "Using allocator " << dev_allocator->Name() @@ -921,26 +785,38 @@ std::pair GetDeviceAndAllocator( } // Entry function from optimization pass. -// TODO(aaeory): parameter should use pointer type. -tensorflow::Status ConvertAfterShapes(ConversionParams& params) { +Status ConvertAfterShapes(const ConversionParams& params) { + // Sanity checks. + if (params.precision_mode == TrtPrecisionMode::INT8) { + if (params.use_calibration && !params.use_function_backup) { + return errors::InvalidArgument( + "Calibration requires enabling fallback to TF function execution."); + } + } else { + if (params.use_calibration) { + return errors::InvalidArgument( + "Calibration with FP32 or FP16 is not supported."); + } + } + // Convert graphdef to graph. - tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - params.input_graph_def->library()); - tensorflow::Graph graph(flib); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph)); + FunctionLibraryDefinition flib(OpRegistry::Global(), + params.input_graph_def->library()); + Graph graph(flib); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), + *params.input_graph_def, &graph)); // Segment the graph into subgraphs that can be converted to TensorRT - tensorflow::tensorrt::segment::SegmentOptions segment_options; + segment::SegmentOptions segment_options; // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) for (auto node : *(params.output_names)) { segment_options.exclude_node_list.insert(node); } segment_options.minimum_segment_size = params.minimum_segment_size; - tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; + segment::SegmentNodesVector initial_segments; TrtCandidateSelector candidate_selector(*params.graph_properties, params.precision_mode); - TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( + TF_RETURN_IF_ERROR(segment::SegmentGraph( &graph, std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector, std::placeholders::_1), @@ -952,20 +828,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { << initial_segments.size(); // Get the EngineInfo for each segment. - std::unordered_map node_map; + std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); float total_num_nodes_in_segments = 0.; std::vector engine_segments; engine_segments.reserve(initial_segments.size()); - std::vector reverse_topo_order; - tensorflow::GetPostOrder(graph, &reverse_topo_order); + std::vector reverse_topo_order; + GetPostOrder(graph, &reverse_topo_order); size_t total_engine_bytes_size = 0; std::vector engine_bytes_size; - tensorflow::tensorrt::segment::SegmentNodesVector converted_segments; + segment::SegmentNodesVector converted_segments; converted_segments.reserve(initial_segments.size()); for (size_t t = 0; t < initial_segments.size(); t++) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; + curr_engine.engine_name = StrCat("TRTEngineOp_", t); Status status = GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, node_map, reverse_topo_order, &curr_engine); @@ -975,23 +852,20 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { continue; } curr_engine.precision_mode = params.precision_mode; - if (params.use_calibration && params.precision_mode != INT8MODE) { - return errors::InvalidArgument( - "Calibration with FP32 or FP16 is not supported."); - } curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration) ? EngineInfo::EngineType::TRTDynamic : EngineInfo::EngineType::TRTStatic); curr_engine.use_calibration = params.use_calibration; curr_engine.cached_engine_batches = params.cached_engine_batches; curr_engine.maximum_cached_engines = params.max_cached_engines; - StrAppend(&curr_engine.engine_name, "TRTEngineOp_", t); - status = RegisterSegmentFunctionToFunctionLibrary( - &graph, curr_engine.segment_graph_def, curr_engine.engine_name); - if (!status.ok()) { - LOG(WARNING) << "Failed to register segment graphdef as a function " << t - << ": " << status; - continue; + if (params.use_function_backup) { + status = RegisterSegmentFunctionToFunctionLibrary( + &graph, curr_engine.segment_graph_def, curr_engine.engine_name); + if (!status.ok()) { + LOG(WARNING) << "Failed to register segment graphdef as a function " + << t << ": " << status; + continue; + } } engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); @@ -1042,8 +916,9 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); - auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, - &graph, alloc.get(), &engine_nodes); + auto status = + CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph, + alloc.get(), &engine_nodes); string msg = StrCat("TensorRT node ", engine.engine_name, " added for segment ", i, " consisting of ", @@ -1073,7 +948,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); VLOG(1) << "Returning from conversion"; - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace convert diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..88fda49d38ef38522b8d21b73c0fd0061509c8d1 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ + +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +// Helper class for the segmenter to determine whether given TF node is +// supported by TRT. +class TrtCandidateSelector { + public: + TrtCandidateSelector(const grappler::GraphProperties& graph_properties, + TrtPrecisionMode precision_mode); + + // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added + // to TRT subgraph and later converted into TRT engine. + Status IsTensorRTCandidate(const Node* node); + + private: + // The TF-TRT node converter used to verify whether individual node is + // supported. It will operate in validation-only mode. + TrtNodeValidator validator_; + + // GraphProperties of the graph whose nodes are to be validated by + // IsTensorRTCandidate(). + const grappler::GraphProperties& graph_properties_; + + // Quantization ops are only converted when using quantized precisions. + const TrtPrecisionMode precision_mode_; +}; + +struct ConversionParams { + const GraphDef* input_graph_def = nullptr; + const std::vector* output_names = nullptr; + size_t max_batch_size = 1; + size_t max_workspace_size_bytes = 1 << 30; + GraphDef* output_graph_def = nullptr; + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32; + int minimum_segment_size = 3; + const grappler::GraphProperties* graph_properties = nullptr; + const grappler::Cluster* cluster = nullptr; + // Whether to create engine on conversion or execution time + bool is_dyn_op = false; + // maximum number of cached engines + int max_cached_engines = 1; + bool use_calibration = true; + // list of cached engines + std::vector cached_engine_batches; + // Whether to use function fallback for TRTEngineOp + bool use_function_backup = true; +}; + +// - max_batch_size: maximum batch size which can be used for inference for +// optimization targets inference run with max batch size. +// - max_workspace_size_bytes: The upper bound of memory allowance for engine +// building. +Status ConvertGraphDefToTensorRT( + const GraphDef& graph_def, const std::vector& output_names, + size_t max_batch_size, size_t max_workspace_size_bytes, + GraphDef* new_graph_def, + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, + int minimum_segment_size = 3, bool is_dyn_op = false, + int max_cached_engines = 1, std::vector cached_engine_batches = {}, + bool use_calibration = true); + +// Method to call from optimization pass +Status ConvertAfterShapes(const ConversionParams& params); + +// Helper method for the conversion, expose for testing. +std::pair GetDeviceAndAllocator(const ConversionParams& params, + const EngineInfo& engine); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc similarity index 95% rename from tensorflow/contrib/tensorrt/convert/convert_graph_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 2d2bfeb192c1893824c7b30bfad593c62c203392..1a754181debf41865190aa7f9ca6a76efea98181 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -75,7 +75,7 @@ TEST(TrtCandidateSelector, Basics) { feed, const_1, matmul_attrs); // Unsupported op. - auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed); + auto unsupported_op = ops::Erf(s.WithOpName("sin"), feed); // Incompatible input. auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE); @@ -98,7 +98,8 @@ TEST(TrtCandidateSelector, Basics) { grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - for (const int precision_mode : {FP32MODE, INT8MODE}) { + for (const TrtPrecisionMode precision_mode : + {TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) { TrtCandidateSelector selector(graph_properties, precision_mode); TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); ExpectStatus( @@ -107,13 +108,13 @@ TEST(TrtCandidateSelector, Basics) { "transpose_a is not supported for TensorRT FullyConnected " "(op: MatMul), at: incompatible_matmul"); ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), - error::UNIMPLEMENTED, "Op type Sin is not supported"); + error::UNIMPLEMENTED, "Op type Erf is not supported"); ExpectStatus( selector.IsTensorRTCandidate( matmul_with_incompatible_input.operation.node()), error::INTERNAL, "Failed to convert input with index 0 to a TRT_TensorOrWeights"); - if (precision_mode == INT8MODE) { + if (precision_mode == TrtPrecisionMode::INT8) { TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node())); } else { ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()), diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc similarity index 68% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 8b7279ad0336b3c791c130a840b7926e6b99df94..8aeecaff92531ffceda4ae878bb104830c8649d9 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -43,6 +45,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" @@ -54,10 +57,10 @@ limitations under the License. // would work! #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) -#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ - do { \ - return tensorflow::errors::Internal( \ - "TFTRT::", __FUNCTION__, " failed to add TRT layer, at: ", node); \ +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return errors::Internal("TFTRT::", __FUNCTION__, \ + " failed to add TRT layer, at: ", node); \ } while (0) #define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ @@ -80,32 +83,119 @@ namespace tensorrt { const char* const kInputPHName = "TensorRTInputPH_"; const char* const kOutputPHName = "TensorRTOutputPH_"; +bool IsEngineInput(absl::string_view name) { + return absl::StartsWith(name, kInputPHName); +} +bool IsEngineOutput(absl::string_view name) { + return absl::StartsWith(name, kOutputPHName); +} + namespace convert { -using ::tensorflow::str_util::Split; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; -inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, - nvinfer1::DataType* trt_dtype) { +inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { switch (tf_dtype) { - case tensorflow::DataType::DT_FLOAT: + case DataType::DT_FLOAT: *trt_dtype = nvinfer1::DataType::kFLOAT; break; // TODO(aaroey): this should be DT_QINT8 which is not a well supported type. - case tensorflow::DataType::DT_INT8: + case DataType::DT_INT8: *trt_dtype = nvinfer1::DataType::kINT8; break; - case tensorflow::DataType::DT_HALF: + case DataType::DT_HALF: *trt_dtype = nvinfer1::DataType::kHALF; break; - case tensorflow::DataType::DT_INT32: + case DataType::DT_INT32: *trt_dtype = nvinfer1::DataType::kINT32; break; default: - return tensorflow::errors::InvalidArgument( - "Unsupported data type ", tensorflow::DataTypeString(tf_dtype)); + return errors::InvalidArgument("Unsupported data type ", + DataTypeString(tf_dtype)); + } + return Status::OK(); +} + +class TFAttrs { + public: + explicit TFAttrs(const NodeDef& tf_node) { + for (const auto& attr : tf_node.attr()) { + attrs_.insert({attr.first, &attr.second}); + } + } + + bool count(const string& key) const { return attrs_.count(key); } + + AttrValue const* at(const string& key) const { + if (!attrs_.count(key)) { + LOG(FATAL) << "Attribute not found: " << key; + } + return attrs_.at(key); + } + + template + T get(const string& key) const; + + template + T get(const string& key, const T& default_value) const { + return attrs_.count(key) ? this->get(key) : default_value; + } + + std::vector GetAllAttrKeys() const { + std::vector attr_list; + for (const auto& attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); + } + return attr_list; } - return tensorflow::Status::OK(); + + private: + typedef std::map AttrMap; + AttrMap attrs_; +}; + +template <> +string TFAttrs::get(const string& key) const { + return this->at(key)->s(); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().i(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +nvinfer1::DataType TFAttrs::get(const string& key) const { + nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); + return trt_dtype; +} + +template <> +DataType TFAttrs::get(const string& key) const { + return this->at(key)->type(); +} + +template <> +float TFAttrs::get(const string& key) const { + return this->at(key)->f(); +} + +template <> +bool TFAttrs::get(const string& key) const { + return this->at(key)->b(); +} + +template <> +int64 TFAttrs::get(const string& key) const { + return this->at(key)->i(); } template @@ -126,13 +216,23 @@ Status TensorShapeArrayToTrtDims(const std::vector& shape, PartialTensorShape tensor_shape; TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim); - return tensorflow::Status::OK(); + return Status::OK(); +} + +// TODO(laigd): use this utility function in more places. +Status RemoveBatchDimension(nvinfer1::Dims* dims) { + if (dims->nbDims < 2) { + return errors::InvalidArgument( + "Dropping batch dimension requires dims with rank>=2."); + } + std::copy(dims->d + 1, dims->d + dims->nbDims, dims->d); + dims->nbDims--; + return Status::OK(); } void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, - PartialTensorShape* shape, - tensorflow::DataType* dtype) { + PartialTensorShape* shape, DataType* dtype) { if (graph_properties.HasOutputProperties(node->name())) { auto output_params = graph_properties.GetOutputProperties(node->name()); auto out_shape = output_params.at(out_port); @@ -146,8 +246,7 @@ void GetOutputProperties(const grappler::GraphProperties& graph_properties, void GetInputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int in_port, - PartialTensorShape* shape, - tensorflow::DataType* dtype) { + PartialTensorShape* shape, DataType* dtype) { if (graph_properties.HasInputProperties(node->name())) { auto input_params = graph_properties.GetInputProperties(node->name()); auto in_shape = input_params.at(in_port); @@ -159,7 +258,7 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties, } Status ValidateTensorProperties(const string& producer_node_type, - const tensorflow::DataType dtype, + const DataType dtype, const PartialTensorShape& shape, bool validation_only, nvinfer1::DataType* trt_dtype, @@ -183,6 +282,15 @@ Status ValidateTensorProperties(const string& producer_node_type, *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); *batch_size = shape.dim_size(0); + // Don't convert empty tensors (dim value of 0). + for (int d = 1; d < shape.dims(); ++d) { + if (shape.dim_size(d) == 0) { + return errors::Unimplemented( + "Input tensor with shape ", shape.DebugString(), + " is an empty tensor, which is not supported by TRT"); + } + } + if (validation_only) return Status::OK(); // Following are validations at runtime. @@ -285,31 +393,31 @@ Status Converter::GetTrtBroadcastShape( } const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; - auto compute_output_dims = - [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims, - int* output_dims_array, nvinfer1::Dims* output_dims) { - const nvinfer1::Dims input_dims = input.GetTrtDims(); - std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); - std::copy(input_dims.d, input_dims.d + input_dims.nbDims, - output_dims_array + broadcast_num_dims - input_dims.nbDims); - if (input.is_tensor()) { - const int true_input_dims = input_dims.nbDims + 1; - if (true_input_dims < broadcast_num_dims) { - return errors::InvalidArgument( - "Broadcasting beyond batch dimension is not supported ", - "(tensor #dims ", true_input_dims, " vs broadcast #dims ", - broadcast_num_dims, ")"); - } - // Set the batch dimension to -1, since batch size is not supposed to - // be broadcasted. - output_dims_array[0] = -1; - } - // Copy to output dimensions (stripping the batch dimension). - output_dims->nbDims = broadcast_num_dims - 1; - std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, - output_dims->d); - return Status::OK(); - }; + auto compute_output_dims = [](const TRT_TensorOrWeights& input, + int broadcast_num_dims, int* output_dims_array, + nvinfer1::Dims* output_dims) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, + output_dims_array + broadcast_num_dims - input_dims.nbDims); + if (input.is_tensor()) { + const int true_input_dims = input_dims.nbDims + 1; + if (true_input_dims < broadcast_num_dims) { + return errors::InvalidArgument( + "Broadcasting beyond batch dimension is not supported ", + "(tensor #dims ", true_input_dims, " vs broadcast #dims ", + broadcast_num_dims, ")"); + } + // Set the batch dimension to -1, since batch size is not supposed to + // be broadcasted. + output_dims_array[0] = -1; + } + // Copy to output dimensions (stripping the batch dimension). + output_dims->nbDims = broadcast_num_dims - 1; + std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, + output_dims->d); + return Status::OK(); + }; // Compute the output dimensions. const int broadcast_num_dims = @@ -341,14 +449,80 @@ nvinfer1::ITensor* Converter::CreateConstantLayer( if (!layer) return nullptr; const nvinfer1::DataType trt_dtype = trt_weights.type; nvinfer1::ITensor* trt_tensor = layer->getOutput(0); +#if !IS_TRT_VERSION_GE(5, 1, 3) // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set // the data type below, it will always be kFLOAT regardless what the data type // of the weights is. Once NVIDIA fixes this bug, we should remove the data // type setting logic below and test should still pass. trt_tensor->setType(trt_dtype); +#endif return trt_tensor; } +Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, + const nvinfer1::Dims& dims, + const nvinfer1::ITensor** tensor, + const char* dtype_attr_name = "T") { + TFAttrs attrs(params->node_def); + DataType dtype; + if (attrs.count(dtype_attr_name)) { + dtype = attrs.get(dtype_attr_name); + } else { + dtype = DT_FLOAT; // Default to FP32. + } + + // In order to be broadcastable, the number of dims has to match. + nvinfer1::Dims broadcastable_dims(dims); + for (int i = 0; i < broadcastable_dims.nbDims; i++) { + broadcastable_dims.d[i] = 1; + } + TRT_ShapedWeights weights = + params->weight_store->GetTempWeights(dtype, broadcastable_dims); + void* raw_ptr = const_cast(weights.GetValues()); + switch (dtype) { + case DataType::DT_FLOAT: + static_cast(raw_ptr)[0] = value; + break; + case DataType::DT_HALF: + static_cast(raw_ptr)[0] = Eigen::half(value); + break; + default: + return errors::InvalidArgument("Unsupported data type ", + DataTypeString(dtype)); + } + *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); + params->converter->ProvideQuantizationRange( + const_cast(*tensor), value, value); + return Status::OK(); +} + +// Convert an axis from TF format to TRT format while validating. TF format +// includes the batch dimension, while TRT does not. TF can also use negative +// indices. +// TODO(tmorris): Use this method in more ops. +Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, + int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + 1; + // Check bounds. + if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { + return errors::InvalidArgument( + "Axis value of ", tf_axis, " is out of bounds, must be in range [", + -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name); + } + // Make negative axis positive. + if (tf_axis < 0) tf_axis += tf_nb_dims; + // Don't allow axis to be the batch dimension. + if (tf_axis == 0) { + return errors::Unimplemented( + "TensorRT does not allow manipulation of the batch dimension, at ", + node_name); + } + // Remove batch dimension. + *trt_axis = tf_axis - 1; + return Status::OK(); +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -362,7 +536,16 @@ inline bool DimsEqual(const nvinfer1::Dims& dim_l, return true; } -inline nvinfer1::Dims GetTrtDimsForTensor(const tensorflow::Tensor& tensor) { +bool AllLengthsEqual(const std::vector>& inputs) { + if (inputs.size() == 0) return true; + int length = inputs.at(0).size(); + for (int i = 1; i < inputs.size(); i++) { + if (inputs.at(i).size() != length) return false; + } + return true; +} + +inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); for (int i = 0; i < dims.nbDims; i++) { @@ -445,7 +628,7 @@ nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const { } size_t TRT_ShapedWeights::size_bytes() const { - return this->count() * tensorflow::DataTypeSize(this->type_); + return this->count() * DataTypeSize(this->type_); } string TRT_ShapedWeights::DebugString() const { @@ -493,12 +676,22 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { void setLocation(nvinfer1::TensorLocation location) override {} -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) bool setDynamicRange(float min, float max) override { return true; } float getDynamicRange() const override { return 0; } #endif +#if IS_TRT_VERSION_GE(5, 1, 0) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: nvinfer1::DataType trt_dtype_; nvinfer1::Dims trt_dims_; @@ -569,89 +762,6 @@ string TRT_TensorOrWeights::DebugString() const { return output; } -class TFAttrs { - public: - explicit TFAttrs(const tensorflow::NodeDef& tf_node) { - for (const auto& attr : tf_node.attr()) { - attrs_.insert({attr.first, &attr.second}); - } - } - - bool count(const string& key) const { return attrs_.count(key); } - - tensorflow::AttrValue const* at(const string& key) const { - if (!attrs_.count(key)) { - LOG(FATAL) << "Attribute not found: " << key; - } - return attrs_.at(key); - } - - template - T get(const string& key) const; - - template - T get(const string& key, const T& default_value) const { - return attrs_.count(key) ? this->get(key) : default_value; - } - - std::vector GetAllAttrKeys() const { - std::vector attr_list; - for (const auto& attr_item : attrs_) { - attr_list.emplace_back(attr_item.first); - } - return attr_list; - } - - private: - typedef std::map AttrMap; - AttrMap attrs_; -}; - -template <> -string TFAttrs::get(const string& key) const { - return this->at(key)->s(); -} - -template <> -std::vector TFAttrs::get>(const string& key) const { - auto attr = this->at(key)->list().i(); - return std::vector(attr.begin(), attr.end()); -} - -template <> -std::vector TFAttrs::get>(const string& key) const { - auto attr = this->at(key)->list().f(); - return std::vector(attr.begin(), attr.end()); -} - -template <> -nvinfer1::DataType TFAttrs::get(const string& key) const { - nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); - return trt_dtype; -} - -template <> -tensorflow::DataType TFAttrs::get( - const string& key) const { - return this->at(key)->type(); -} - -template <> -float TFAttrs::get(const string& key) const { - return this->at(key)->f(); -} - -template <> -bool TFAttrs::get(const string& key) const { - return this->at(key)->b(); -} - -template <> -int TFAttrs::get(const string& key) const { - return this->at(key)->i(); -} - // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -693,7 +803,7 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, const nvinfer1::DimsHW istrides = {1, k}; const nvinfer1::DimsHW ostrides = {c, 1}; switch (iweights.type_) { - case tensorflow::DataType::DT_FLOAT: { + case DataType::DT_FLOAT: { Reorder2({k, c}, static_cast(iweights.GetValues()), istrides, // TODO(aaroey): get rid of all the const_cast like this. @@ -701,7 +811,7 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, ostrides); break; } - case tensorflow::DataType::DT_HALF: { + case DataType::DT_HALF: { Reorder2( {k, c}, static_cast(iweights.GetValues()), istrides, @@ -737,14 +847,14 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; switch (iweights.type_) { - case tensorflow::DataType::DT_FLOAT: { + case DataType::DT_FLOAT: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), istrides, static_cast(const_cast(oweights->GetValues())), ostrides); break; } - case tensorflow::DataType::DT_HALF: { + case DataType::DT_HALF: { Reorder4( {k, c, r, s}, static_cast(iweights.GetValues()), istrides, @@ -759,7 +869,7 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } -TRT_ShapedWeights TrtWeightStore::GetTempWeights(tensorflow::DataType type, +TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type, const nvinfer1::Dims& dims) { TensorShape shape; // TODO(laigd): make it return a status. @@ -771,6 +881,13 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(tensorflow::DataType type, return weights; } +const std::set* TrtNodeValidator::quantize_ops = new std::set{ + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxArgs", +}; + TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); } Status TrtNodeValidator::ConvertToTensorOrWeights( @@ -816,9 +933,27 @@ Status TrtNodeValidator::ConvertToTensorOrWeights( } Status TrtNodeValidator::ValidateNode( - const tensorflow::NodeDef& node_def, + const NodeDef& node_def, const std::vector>& input_node_and_ports, + const TrtPrecisionMode precision_mode, const grappler::GraphProperties& graph_properties) { + const string& op = node_def.op(); + // It doesn't support validation of plugins. + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) return Status::OK(); + + // In INT8 mode, we will always apply the quantization ranges provided by + // these ops to the relevant tensors. This happens regardless of the value of + // use_calibration. + bool is_supported_op = false; + if (quantize_ops->count(op)) { + is_supported_op = (precision_mode == TrtPrecisionMode::INT8); + } else { + is_supported_op = op_validators_.count(node_def.op()); + } + if (!is_supported_op) { + return errors::Unimplemented("Op type ", op, " is not supported."); + } + // Convert input NodeDef and corresponding output ports to // TRT_TensorOrWeights. std::vector inputs; @@ -835,14 +970,7 @@ Status TrtNodeValidator::ValidateNode( inputs.push_back(tensor_or_weights); } - // Validate the node. - const auto iter = op_validators_.find(node_def.op()); - if (iter == op_validators_.end()) { - // If validator is not registered, it means no validation is needed. - return Status::OK(); - } - - OpConverter validator = iter->second; + OpConverter validator = op_validators_[node_def.op()]; OpConverterParams params( /*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr, /*arg_validation_only=*/true, &weight_store_); @@ -863,7 +991,7 @@ Status TrtNodeValidator::ConvertConstToWeights( } Converter::Converter(nvinfer1::INetworkDefinition* trt_network, - int precision_mode, bool use_calibration) + TrtPrecisionMode precision_mode, bool use_calibration) : trt_network_(trt_network), precision_mode_(precision_mode), use_calibration_(use_calibration) { @@ -881,7 +1009,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) { TF_RETURN_IF_ERROR(plugin_converter_(¶ms)); } else { if (!op_registry_.count(op)) { - return errors::Unimplemented("No converter registered for op: " + op); + return errors::Unimplemented("No converter registered for op: ", op); } OpConverter op_converter = op_registry_.at(op); TF_RETURN_IF_ERROR(op_converter(¶ms)); @@ -890,7 +1018,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) { for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights& output = outputs[i]; string output_name = node_def.name(); - if (i != 0) output_name = StrCat(output_name, ":", i); + if (i != 0) absl::StrAppend(&output_name, ":", i); // We need to check the name before setting it. If the input is one of the // engine input, setting the name here will overwrite engine input // bindings which will cause runtime error. @@ -898,7 +1026,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) { // in ConvertIdentity. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); - if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { + if (!IsEngineInput(tensor_name)) { // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename // them to match their corresponding TensorFlow name. // Note: ITensors that we create internally within TF-TRT which are @@ -944,27 +1072,31 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, } Status Converter::RenameAndMarkOutputTensors( - const std::vector>& output_tensors) { + const std::vector& output_tensors) { for (const auto& output : output_tensors) { TRT_TensorOrWeights tensor_or_weights; - TF_RETURN_IF_ERROR(GetTensorOrWeights(output.first, &tensor_or_weights)); + TF_RETURN_IF_ERROR( + GetTensorOrWeights(output.source_tensor_name, &tensor_or_weights)); if (!tensor_or_weights.is_tensor()) { - return errors::InvalidArgument("Output ", output.first, + return errors::InvalidArgument("Output ", output.source_tensor_name, " is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); if (tensor == nullptr) { - return errors::NotFound("Output tensor not found: ", output.first); + return errors::NotFound("Output tensor not found: ", + output.source_tensor_name); } - // Check if this tensor has already been marked as an output. + // Check if this tensor has already been marked as an input or output. + // // ConvertIdentity can cause the same tensor to be repeated in // output_tensors, which can cause us to overwrite the name of the output // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then // we won't be able to locate OutputPH_0 during runtime. To fix this, // duplicate the tensor using no-op shuffle. + // // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer // in ConvertIdentity. - if (tensorflow::str_util::StartsWith(tensor->getName(), kOutputPHName)) { + if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) { // Using shuffle layer for identity by not setting reshape or transpose. nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR( @@ -972,10 +1104,13 @@ Status Converter::RenameAndMarkOutputTensors( MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); } - tensor->setName(output.second.c_str()); - VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " - << output.second; + tensor->setName(output.dest_node_name.c_str()); network()->markOutput(*tensor); + // Set type after marking as output. TRT only supports setType for engine + // outputs and inputs (type is inferred otherwise). + tensor->setType(output.trt_dtype); + VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name + << ", which feeds TF node " << output.dest_node_name; } return Status::OK(); } @@ -1022,11 +1157,11 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, const auto dims = input_tensor->getDimensions(); if (order_with_batch_dim.size() - 1 != size_t(dims.nbDims)) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Rank of perm for transpose does not match with that of the input."); } if (order_with_batch_dim[0] != 0) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Transpose at batch dimension is not supported."); } @@ -1052,7 +1187,7 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, layer->setReshapeDimensions(reshape_dims); *output_tensor = layer->getOutput(0); - return tensorflow::Status::OK(); + return Status::OK(); } Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, @@ -1089,6 +1224,7 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, + const bool validation_only, const nvinfer1::ITensor** tensor) { // If -1 is not used for one of the dims, we can check if the shapes are // compatible. @@ -1105,6 +1241,10 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, DebugString(input.GetTrtDims()), " vs ", DebugString(dims), ")"); } + if (validation_only) { + *tensor = nullptr; + return Status::OK(); + } if (input.is_tensor()) { if (DimsEqual(input.GetTrtDims(), dims)) { @@ -1121,7 +1261,7 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, } else { *tensor = CreateConstantLayer(input.weights(), dims); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape"); - if (precision_mode() == INT8MODE && !use_calibration()) { + if (precision_mode() == TrtPrecisionMode::INT8 && !use_calibration()) { // If we are in int8 mode and not calibrating, we need to explicitly set a // quantization range for the output tensor of the IConstantLayer. Here we // set the range to [min(weights), max(weights)]. @@ -1140,7 +1280,7 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, min_range, max_range); } } - return tensorflow::Status::OK(); + return Status::OK(); } void Converter::MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, @@ -1156,12 +1296,12 @@ void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor, } void Converter::MaybeApplyQuantizationRanges() { - if (precision_mode() != INT8MODE) return; + if (precision_mode() != TrtPrecisionMode::INT8) return; // Infer ranges across marked ops. PropagateQuantizationRanges(); // Apply ranges. -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) for (auto pair : quantization_ranges_) { nvinfer1::ITensor* tensor = pair.first; const float range = pair.second; @@ -1238,7 +1378,7 @@ void Converter::PropagateQuantizationRanges() { } } -Status Converter::GetInputs(const tensorflow::NodeDef& node_def, +Status Converter::GetInputs(const NodeDef& node_def, std::vector* inputs) const { for (auto const& input_name : node_def.input()) { /************************************************************************* @@ -1273,15 +1413,73 @@ Status Converter::GetInputs(const tensorflow::NodeDef& node_def, StrAppend(&msg, node_def.name(), " should have an input named '", name, "' but it is not available"); LOG(ERROR) << msg; - return tensorflow::errors::InvalidArgument(msg); + return errors::InvalidArgument(msg); + } + } + return Status::OK(); +} + +// Checks that the number of inputs match, and enforces that the inputs marked +// as true are constant weights. true means that the input must be a weight, +// while false means the input must be a tensor. In the future, false will mean +// the input can be a tensor or weight. +Status CheckInputsWeights( + const OpConverterParams& params, + const std::vector>& inputs_is_weight) { + const auto& inputs = params.inputs; + const auto& node_def = params.node_def; + if (inputs.size() != inputs_is_weight.size()) { + return errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected ", + inputs_is_weight.size(), ", at ", node_def.name()); + } + for (int i = 0; i < inputs.size(); i++) { + if (inputs_is_weight[i].second && inputs.at(i).is_tensor()) { + return errors::Unimplemented("The input \"", inputs_is_weight[i].first, + "\" for ", node_def.op(), + " must be a constant, at ", node_def.name()); + } + // TODO(tmorris): Remove this check and provide a method to automatically + // retrive an input as a tensor, converting via CreateConstantLayer if it + // was originally a weight. We will want a caching mechanism to prevent many + // duplicate constants from being created. + if (!inputs_is_weight[i].second && inputs.at(i).is_weights()) { + return errors::Unimplemented("The input \"", inputs_is_weight[i].first, + "\" for ", node_def.op(), + " must be a tensor, at ", node_def.name()); } } - return tensorflow::Status::OK(); + return Status::OK(); +} + +Status AllowDataTypes(const OpConverterParams& params, + const std::set& allowed_dtypes, + const char* dtype_attr_name = "T") { + const auto& node_def = params.node_def; + TFAttrs attrs(node_def); + if (!attrs.count(dtype_attr_name)) { + return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + " not found."); + } + const auto op_dtype = attrs.get(dtype_attr_name); + if (!allowed_dtypes.count(op_dtype)) { + // Build string list of allowed types. + std::ostringstream ss; + for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { + if (it != allowed_dtypes.begin()) ss << ", "; + ss << DataTypeString(*it); + } + return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + " is not supported for ", node_def.op(), + ", must be one of [", ss.str(), "], at ", + node_def.name()); + } + return Status::OK(); } TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, const TRT_ShapedWeights& weights_src) { - auto dtype_new = tensorflow::DataType::DT_HALF; + auto dtype_new = DataType::DT_HALF; TRT_ShapedWeights weights = store->GetTempWeights(dtype_new, weights_src.shape_); const float* src = static_cast(weights_src.GetValues()); @@ -1340,18 +1538,17 @@ std::function LambdaFactory::unary() { } } -tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, - TRT_ShapedWeights* oweights, - LambdaFactory unary_op) { +Status UnaryCompute(const TRT_ShapedWeights& iweights, + TRT_ShapedWeights* oweights, LambdaFactory unary_op) { CHECK_EQ(iweights.type_, oweights->type_); switch (iweights.type_) { - case tensorflow::DataType::DT_FLOAT: { + case DataType::DT_FLOAT: { auto inp = static_cast(iweights.GetValues()); auto oup = static_cast(const_cast(oweights->GetValues())); std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); break; } - case tensorflow::DataType::DT_HALF: { + case DataType::DT_HALF: { auto inp = static_cast(iweights.GetValues()); auto oup = static_cast(const_cast(oweights->GetValues())); @@ -1360,11 +1557,10 @@ tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, break; } default: - return tensorflow::errors::Unimplemented( - "Data type not supported: " + - tensorflow::DataTypeString(iweights.type_)); + return errors::Unimplemented("Data type not supported: " + + DataTypeString(iweights.type_)); } - return tensorflow::Status::OK(); + return Status::OK(); } // If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the @@ -1471,7 +1667,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, const_cast(tensor), permutation, &tensor)); } - if (params->converter->precision_mode() == FP16MODE) { + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { weights = ConvertFP32ToFP16(params->weight_store, weights); } @@ -1514,7 +1710,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, // Because of this issue, fall back to BinaryTensorOpTensor if we are // doing INT8 with no calibration. There is most likely no performance // penalty by falling back here. - if (params->converter->precision_mode() == INT8MODE && + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && !params->converter->use_calibration()) { return errors::Unimplemented( "Intermediate quantization range cannot be determined without" @@ -1561,67 +1757,71 @@ Status BinaryTensorOpWeight(OpConverterParams* params, // Pass the output params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; - -tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { +Status ConvertConv2DHelper(OpConverterParams* params, int group, + bool is_conv2d_backprop_input) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2) { - return tensorflow::errors::InvalidArgument("Two inputs are expected for ", - node_def.op(), ", at ", - node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - node_def.op(), " is only implemented for tensors, not weights, at ", - node_def.name()); - } - if (inputs.at(1).is_tensor()) { - return tensorflow::errors::Unimplemented("Kernel for ", node_def.op(), - " must be constant weights, at ", - node_def.name()); + TRT_TensorOrWeights backprop_output_size; + const nvinfer1::ITensor* tensor = nullptr; + if (is_conv2d_backprop_input) { + // In the case when Conv2dBackpropInput is used for conv2d_transpose, these + // inputs correspond to: output size, filter, and input. + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}})); + backprop_output_size = inputs.at(0); + tensor = inputs.at(2).tensor(); + } else { + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"filter", true}})); + tensor = inputs.at(0).tensor(); } + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::InvalidArgument( - "Conv2D expects kernel of dimension 4, at " + node_def.name()); + return errors::InvalidArgument("Conv2D expects kernel of dimension 4, at " + + node_def.name()); } TFAttrs attrs(node_def); auto data_format = attrs.get("data_format"); int c_index = (data_format == "NHWC") ? 3 : 1; int h_index = (data_format == "NHWC") ? 1 : 2; int w_index = (data_format == "NHWC") ? 2 : 3; - auto tf_dilations = attrs.get>("dilations"); + auto tf_dilations = attrs.get>("dilations"); if (tf_dilations.size() != 4) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Convolution dilations field must specify 4 dimensions, at ", node_def.name()); } if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Dilation rate must be 1 for batch and channel dimensions, at ", node_def.name()); } const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]); - - const auto tf_stride = attrs.get>("strides"); + if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) { + return errors::Unimplemented( + "Dilation with Conv2DBackpropInput (conv2d_transpose) is not supported", + ", at ", node_def.name()); + } + + const auto tf_stride = attrs.get>("strides"); if (tf_stride.size() != 4) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Convolution strides field must specify 4 dimensions, at ", node_def.name()); } if (tf_stride[0] != 1 || tf_stride[c_index] != 1) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Stride must be 1 for batch and channel dimensions, at ", node_def.name()); } const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); - if (params->validation_only) return tensorflow::Status::OK(); - - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + if (params->validation_only) return Status::OK(); // Transpose to NCHW (NCHW is required for IConvLayer). const bool need_transpose = (data_format == "NHWC"); @@ -1632,19 +1832,23 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); - // For depthwise convolution, group will be 0 so set num_groups to size of - // input's channel dim. For a non-depthwise conv, num_groups will be 1. + // group == 0 signifies that this is a depthwise convolution, so set + // num_groups to size of input's channel dim. For a non-depthwise conv, + // num_groups will be 1. const int num_groups = (group == 0) ? tensor_dim.d[0] : group; - if (params->converter->precision_mode() == FP16MODE) { - weights_rsck = - ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { + weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck); } + // For conv, TF weights are RSCK, and TRT expects KCRS. + // For backprop, TF weights are RSKC, and TRT expects CKRS. + // Therefore, this reorder will work for both cases. TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); TRT_ShapedWeights biases(weights.type_); - const int noutput = weights.shape_.d[0] * num_groups; + const int output_axis = is_conv2d_backprop_input ? 1 : 0; + const int noutput = weights.shape_.d[output_axis] * num_groups; nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; @@ -1655,9 +1859,23 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { nvinfer1::DimsHW effective_kernel_size = kernel_size; effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1); effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1); - padding = CreateSamePadding( - stride, effective_kernel_size, - {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); + std::vector input_dims; + if (is_conv2d_backprop_input) { + // For backprop, calculate padding based on "input_sizes" input, which + // actually corresponds to output size. ("input_sizes" makes sense in the + // context of Conv2DBackpropInput). + // We use h_index and w_index instead of 1 and 2 because we havent + // transposed backprop_output_size along with the input. + auto output_size_weights = static_cast( + const_cast(backprop_output_size.weights().GetValues())); + input_dims = {output_size_weights[h_index], output_size_weights[w_index]}; + } else { + // Use 1 and 2 because tensor_dim has the dimensions of the transposed + // input. + input_dims = {static_cast(tensor_dim.d[1]), + static_cast(tensor_dim.d[2])}; + } + padding = CreateSamePadding(stride, effective_kernel_size, input_dims); } else { padding = {{0, 0}, {0, 0}}; } @@ -1676,17 +1894,32 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { } // Add convolution. - nvinfer1::IConvolutionLayer* layer = - params->converter->network()->addConvolution( - *const_cast(tensor), noutput, kernel_size, - weights.GetTrtWeights(), biases.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setStride(stride); - layer->setPadding({padding[0].first, padding[1].first}); - layer->setName(node_def.name().c_str()); - layer->setNbGroups(num_groups); - layer->setDilation(dilation); - const nvinfer1::ITensor* output_tensor = layer->getOutput(0); + nvinfer1::ILayer* conv_layer = nullptr; + if (is_conv2d_backprop_input) { + nvinfer1::IDeconvolutionLayer* layer = + params->converter->network()->addDeconvolution( + *const_cast(tensor), noutput, kernel_size, + weights.GetTrtWeights(), biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + conv_layer = layer; + } else { + nvinfer1::IConvolutionLayer* layer = + params->converter->network()->addConvolution( + *const_cast(tensor), noutput, kernel_size, + weights.GetTrtWeights(), biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + layer->setDilation(dilation); + conv_layer = layer; + } + const nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); // Restore transpose. if (need_transpose) { @@ -1696,19 +1929,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { } params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, - ConvolutionType type) { - switch (type) { - case ConvolutionType::DEFAULT: - return ConvertConv2DHelper(params, 1); - case ConvolutionType::DEPTHWISE_CONV: - return ConvertConv2DHelper(params, 0); - } - return tensorflow::errors::Unimplemented("Unsupported convolution type, at ", - params->node_def.name()); + return Status::OK(); } Status BinaryTensorOpTensor(OpConverterParams* params, @@ -1723,6 +1944,7 @@ Status BinaryTensorOpTensor(OpConverterParams* params, {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"Pow", nvinfer1::ElementWiseOperation::kPOW}, }; auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) { @@ -1750,10 +1972,10 @@ Status BinaryTensorOpTensor(OpConverterParams* params, const nvinfer1::ITensor* tensor_l = nullptr; const nvinfer1::ITensor* tensor_r = nullptr; status = params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, &tensor_l); + operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l); if (status.ok()) { status = params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, &tensor_r); + operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r); } if (!status.ok()) { return errors::Internal("Failed to convert binary op ", node_def.name(), @@ -1776,10 +1998,10 @@ Status BinaryTensorOpTensor(OpConverterParams* params, // Pass the output params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertPlugin(OpConverterParams* params) { +Status ConvertPlugin(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; // prepare input @@ -1804,7 +2026,7 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { size_t size_data = data.size() * sizeof(float); if (!plugin->SetAttribute(attr_key, static_cast(data.data()), size_data)) { - return tensorflow::errors::InvalidArgument("plugin SetAttribute failed"); + return errors::InvalidArgument("plugin SetAttribute failed"); } } @@ -1815,17 +2037,15 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { nvinfer1::ITensor* output_tensor = layer->getOutput(i); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertTranspose(OpConverterParams* params) { +Status ConvertTranspose(OpConverterParams* params) { const auto& inputs = params->inputs; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at ", params->node_def.name()); - } - + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"x", false}, {"perm", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); // Get the permutation from weights. TRT_ShapedWeights weights = inputs.at(1).weights(); const int* weights_ptr = @@ -1852,22 +2072,21 @@ tensorflow::Status ConvertTranspose(OpConverterParams* params) { params->converter->TransposeTensor(input_tensor, perm, &output_tensor)); params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertReshape(OpConverterParams* params) { +Status ConvertReshape(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects weights for shape, at ", node_def.name()); - } - + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); TRT_TensorOrWeights input_tensor = inputs.at(0); TRT_ShapedWeights weights = inputs.at(1).weights(); if (weights.count() == 0) { - return tensorflow::errors::Unimplemented( - "Reshape to shape=[] is not supported, at ", node_def.name()); + return errors::Unimplemented("Reshape to shape=[] is not supported, at ", + node_def.name()); } const int* weights_ptr = @@ -1949,27 +2168,19 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { // Start conversion. const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, reshape_dims, &output_tensor)); + input_tensor, reshape_dims, /*validation_only=*/false, &output_tensor)); params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertExpandDims(OpConverterParams* params) { +Status ConvertExpandDims(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2) { - return tensorflow::errors::InvalidArgument( - "Two inputs expected for ExpandDims, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "ExpandDims expects tensor for input, at ", node_def.name()); - } - if (!inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "ExpandDims expects weights for axis, at ", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"axis", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); // Get input shape as vector. TRT_TensorOrWeights input_tensor = inputs.at(0); const nvinfer1::Dims dims = input_tensor.GetTrtDims(); @@ -1980,15 +2191,15 @@ tensorflow::Status ConvertExpandDims(OpConverterParams* params) { // Get axis to expand on. TRT_ShapedWeights weights = inputs.at(1).weights(); if (weights.count() != 1) { - return tensorflow::errors::InvalidArgument( - "ExpandDims axis must be a scalar, at ", node_def.name()); + return errors::InvalidArgument("ExpandDims axis must be a scalar, at ", + node_def.name()); } const int* weights_ptr = static_cast(const_cast(weights.GetValues())); int axis = weights_ptr[0]; // Make sure axis is valid. if ((axis < (-input_rank - 1)) || (axis > input_rank)) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Axis for ExpandDims is invalid, must be in the range " "[-rank(input) - 1, rank(input)], at ", node_def.name()); @@ -1996,7 +2207,7 @@ tensorflow::Status ConvertExpandDims(OpConverterParams* params) { // Convert negative axis to corresponding positive axis. if (axis < 0) axis += input_rank + 1; if (axis == 0) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Modifying batch dimension is not supported for ExpandDims, at ", node_def.name()); } @@ -2010,23 +2221,18 @@ tensorflow::Status ConvertExpandDims(OpConverterParams* params) { /*ignore_first_dim=*/true)); const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, new_dims, &output_tensor)); + input_tensor, new_dims, /*validation_only=*/false, &output_tensor)); params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertSqueeze(OpConverterParams* params) { +Status ConvertSqueeze(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument( - "One input expected for Squeeze, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Squeeze expects tensor for input, at ", node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); // Get input shape. TRT_TensorOrWeights input_tensor = inputs.at(0); const nvinfer1::Dims dims = input_tensor.GetTrtDims(); @@ -2036,15 +2242,15 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { const int input_rank = input_dims.size(); // Mark axes to remove by setting them to 0. TFAttrs attrs(node_def); - auto squeeze_dims = attrs.get>("squeeze_dims"); - if (squeeze_dims.size() == 0) { - return tensorflow::errors::Unimplemented( + auto squeeze_dims = attrs.get>("squeeze_dims"); + if (squeeze_dims.empty()) { + return errors::Unimplemented( "Squeeze is only implemented for explicit dims, at ", node_def.name()); } for (int axis : squeeze_dims) { // Make sure axis is valid. if ((axis < -input_rank) || (axis >= input_rank)) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Axis for Squeeze is invalid, must be in the range " "[-rank(input), rank(input)), at ", node_def.name()); @@ -2053,14 +2259,14 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { if (axis < 0) axis += input_rank; // Don't squeeze batch dim. if (axis == 0) { - return tensorflow::errors::Unimplemented( - "Cannot squeeze batch dimension, at ", node_def.name()); + return errors::Unimplemented("Cannot squeeze batch dimension, at ", + node_def.name()); } // Make sure target dimension is size 1. if (input_dims[axis] != 1) { - return tensorflow::errors::InvalidArgument( - "Cannot squeeze a dimension which isn't size 1, at ", - node_def.name()); + return errors::InvalidArgument( + "Cannot squeeze ", axis, "th dimension ", input_dims[axis], + " which isn't size 1, at ", node_def.name()); } // Mark dim for removal by setting to 0. input_dims[axis] = 0; @@ -2076,117 +2282,79 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { /*ignore_first_dim=*/true)); const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, new_dims, &output_tensor)); + input_tensor, new_dims, /*validation_only=*/false, &output_tensor)); params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); -} - -// Gets the bounds (start or end) from the weights of a StridedSlice op. -tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, - const TRT_ShapedWeights& bound_weights, - int mask, bool begin, string node_name, - std::vector* output_bound) { - const string bound_name = (begin) ? "begin" : "end"; - const int* weights_ptr = static_cast(bound_weights.GetValues()); - *output_bound = - std::vector(weights_ptr, weights_ptr + bound_weights.count()); - if (output_bound->size() != input_dims.size()) { - return tensorflow::errors::InvalidArgument( - "StridedSlice \"", bound_name, "\" specified ", - std::to_string(output_bound->size()), " dimensions, but input rank is ", - std::to_string(input_dims.size()), ", at ", node_name); - } - for (int i = 0; i < output_bound->size(); i++) { - if ((1 << i) & mask) { - // Apply mask. - (*output_bound)[i] = (begin) ? 0 : input_dims[i]; - // Masked bound will always result in a valid, non-negative bound, so we - // don't need the following checks. For the common case of using masks on - // a undefined batch dim (-1), we specifically don't want to do the - // following checks because they will erroneously detect an out of range - // bound or try to correct the negative value. - continue; - } - // Make sure bound is valid. - if (((*output_bound)[i] < -input_dims[i]) || - ((*output_bound)[i] > input_dims[i])) { - return tensorflow::errors::InvalidArgument( - bound_name, " value of ", std::to_string((*output_bound)[i]), - " for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at ", - node_name); - } - // Convert negative values to their positive equivalent. - if ((*output_bound)[i] < 0) { - (*output_bound)[i] += input_dims[i]; - } - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { - const auto& inputs = params->inputs; + return Status::OK(); +} + +Status ConvertStridedSliceHelper(OpConverterParams* params, + const TRT_TensorOrWeights& input, + std::vector begin, std::vector size, + const std::vector& stride) { const auto& node_def = params->node_def; - if (inputs.size() != 4) { - return tensorflow::errors::InvalidArgument( - "StridedSlice expects 4 inputs, at ", node_def.name()); - } - if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights() || - !inputs.at(3).is_weights()) { - return tensorflow::errors::InvalidArgument( - "StridedSlice expects weights for begin, end, and strides, at ", - node_def.name()); - } - if (!inputs.at(0).is_tensor()) { - return tensorflow::errors::Unimplemented( - "StridedSlice is only implemented for tensors, at ", node_def.name()); - } // Get input dims. - nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + nvinfer1::Dims dims = input.GetTrtDims(); std::vector input_dims(dims.d, dims.d + dims.nbDims); - if (inputs.at(0).is_tensor()) { - // Temporarily add batch dimension so that indexes line up properly. - input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); - } - if (input_dims.size() > 4) { - return tensorflow::errors::Unimplemented( - "StridedSlice is not implemented for tensors with rank > 4, at ", - node_def.name()); + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), -1); + // Check bounds. + for (int i = 1; i < input_dims.size(); i++) { + if (begin[i] < 0 || begin[i] > input_dims[i]) { + return errors::InvalidArgument("\"begin\" for dimension ", + std::to_string(i), " in ", node_def.op(), + " is out of range, at ", node_def.name()); + } + const int end = begin[i] + size[i]; + if (end < 0 || end > input_dims[i]) { + return errors::InvalidArgument("\"begin\" + \"size\" for dimension ", + std::to_string(i), " in ", node_def.op(), + " is out of range, at ", node_def.name()); + } + if (size[i] <= 0) { + return errors::InvalidArgument("\"size\" cannot be negative or zero for ", + node_def.op(), ", at ", node_def.name()); + } } - TFAttrs attrs(node_def); - // Get begin and end bounds per axis. - std::vector begin, end; - TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), - attrs.get("begin_mask"), true, - node_def.name(), &begin)); - TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), - attrs.get("end_mask"), false, - node_def.name(), &end)); - // Get strides per axis (must all be 1). - TRT_ShapedWeights stride_weights = inputs.at(3).weights(); - const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); - std::vector strides(stride_weights_ptr, - stride_weights_ptr + stride_weights.count()); - for (int x : strides) { +// TRT 5.1 adds a slice layer. For older versions, we attempt to use the +// padding layer with negative padding. +#if IS_TRT_VERSION_GE(5, 1, 0) && 0 + // TODO(laigd): TRT 5.1 RC has a bug when ISliceLayer is used along with + // IConcatenationLayer, so disable ISliceLayer for now until it's fixed. + // Use ISliceLayer. + nvinfer1::Dims begin_dims, size_dims, stride_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &size_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(stride, &stride_dims, + /*ignore_first_dim=*/true)); + if (params->validation_only) return Status::OK(); + + nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( + *const_cast(input.tensor()), begin_dims, size_dims, + stride_dims); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +#else + // Use IPaddingLayer. + // Strides must be 1 in this case. + for (int x : stride) { if (x != 1) { - return tensorflow::errors::Unimplemented( - "StridedSlice is only implemented for stride of 1, at ", + return errors::Unimplemented( + "Strides other than 1 are not supported with this version of TRT, " + "at ", node_def.name()); } } - // Unsupported mask options. - for (const string& attr : - {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { - int attr_val = attrs.get(attr); - if (attr_val != 0) { - return tensorflow::errors::Unimplemented( - attr, " is not supported for StridedSlice, at ", node_def.name()); - } + // Rank must be 2, 3 or 4. + if (input_dims.size() > 4) { + return errors::Unimplemented(node_def.op(), + " for tensors with rank > 4 is " + "not supported in this version of " + "TRT, at ", + node_def.name()); } - - nvinfer1::ITensor* tensor = - const_cast(inputs.at(0).tensor()); // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. const bool need_reshape = (input_dims.size() != 4); int reshape_dims_added = 0; @@ -2196,7 +2364,7 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { while (input_dims.size() < 4) { input_dims.insert(input_dims.begin() + 1, 1); begin.insert(begin.begin() + 1, 0); - end.insert(end.begin() + 1, 1); + size.insert(size.begin() + 1, 1); reshape_dims_added++; } TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, @@ -2204,24 +2372,23 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { } // Find dimensions which need to be sliced. std::vector pad_dims; - for (int i = 0; i < input_dims.size(); i++) { - if ((begin[i] != 0) || (end[i] != input_dims[i])) { - if (i == 0) { - return tensorflow::errors::Unimplemented( - "StridedSlice can't modify batch dim, at ", node_def.name()); - } else if ((end[i] - begin[i]) < 0) { - return tensorflow::errors::InvalidArgument( - "New size of sliced dimension is negative, at ", node_def.name()); - } + for (int i = 1; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (begin[i] + size[i] != input_dims[i])) { pad_dims.push_back(i); } } - if (pad_dims.size() == 0) { - // No dimensions are changed. We could create a padding layer anyway with - // values of 0. + if (pad_dims.empty()) { + // No dimensions are changed, so this is a no-op. We could just return the + // input without creating a new layer. TRT will crash if an empty engine + // with no layers is attempted to be created, so we add a no-op shuffle to + // prevent our unit tests from breaking. + // TODO(tmorris): Allow empty engines in the unit tests and return the input + // as output here. if (params->validation_only) return Status::OK(); - params->outputs->push_back(inputs.at(0)); - return tensorflow::Status::OK(); + nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle( + *const_cast(input.tensor())); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); } else if (pad_dims.size() == 1) { // Only one dim is modified but we have to have 2, mark a second dim which // will have padding of 0. The dim we add is chosen to avoid an unecessary @@ -2232,17 +2399,20 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { pad_dims.push_back(3); } } else if (pad_dims.size() > 2) { - return tensorflow::errors::Unimplemented( - "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + return errors::Unimplemented( + node_def.op(), + " can only modify up to 2 dimensions in this version of TRT, at ", + node_def.name()); } std::sort(pad_dims.begin(), pad_dims.end()); // Convert to pre/post padding values. Since TRT does not have a StridedSlice - // or Slice layer, we instead create an IPaddingLayer with negative padding. + // or Slice layer prior to 5.1, we instead create an IPaddingLayer with + // negative padding. nvinfer1::DimsHW pre_padding, post_padding; for (int i = 0; i < pad_dims.size(); i++) { const int axis = pad_dims[i]; pre_padding.d[i] = -begin[axis]; - post_padding.d[i] = end[axis] - input_dims[axis]; + post_padding.d[i] = (begin[axis] + size[axis]) - input_dims[axis]; } // IPaddingLayer will always apply the padding to dims 2,3 (input format is @@ -2262,10 +2432,11 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Start conversion. + nvinfer1::ITensor* tensor = const_cast(input.tensor()); if (need_reshape) { const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), reshape_dims, &output_tensor)); + input, reshape_dims, /*validation_only=*/false, &output_tensor)); tensor = const_cast(output_tensor); } if (need_transpose) { @@ -2274,7 +2445,6 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { tensor, transpose_order, &output_tensor)); tensor = const_cast(output_tensor); } - // Add padding layer nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *const_cast(tensor), pre_padding, post_padding); @@ -2282,7 +2452,6 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { params->converter->MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); - // Restore transpose if (need_transpose) { const nvinfer1::ITensor* output_tensor = nullptr; @@ -2295,14 +2464,14 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { // Calculate output dimensions for (int i = 0; i < pad_dims.size(); i++) { const int axis = pad_dims[i]; - input_dims[axis] = end[axis] - begin[axis]; + input_dims[axis] = size[axis]; } // Remove added 1 dimensions for (int i = 0; i < reshape_dims_added; i++) { int value = input_dims[1]; if (value != 1) { - return tensorflow::errors::Internal( - "StridedSlice error when reshaping, at ", node_def.name()); + return errors::Internal("StridedSlice error when reshaping, at ", + node_def.name()); } input_dims.erase(input_dims.begin() + 1); } @@ -2312,45 +2481,181 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { /*ignore_first_dim=*/true)); const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + TRT_TensorOrWeights(tensor), new_dims, /*validation_only=*/false, + &output_tensor)); tensor = const_cast(output_tensor); } params->outputs->push_back( TRT_TensorOrWeights(const_cast(tensor))); - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertConv2D(OpConverterParams* params) { - return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); + return Status::OK(); +#endif } -tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) { - return ConvertConv2DHelper(params, ConvolutionType::DEPTHWISE_CONV); +Status ConvertSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"begin", true}, {"size", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector size = inputs.at(2).weights().ToVector(); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + if (!AllLengthsEqual({input_dims, begin, size})) { + return errors::InvalidArgument( + "Length of begin and size arguments must equal rank of input for " + "Slice, at ", + node_def.name()); + } + // Check that batch dimension is unmodified. + const bool begin_is_modified = begin[0] != 0; + // If size[0]s is not -1, we can only know if the batch dimension is + // unmodified when the batch size is defined. When the batch size is + // undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool size_is_modified = + size[0] != -1 && (!batch_size_is_defined || + (batch_size_is_defined && size[0] != input_dims[0])); + if (begin_is_modified || size_is_modified) { + return errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Size of -1 signifies to take all remaining elements. + for (int i = 1; i < input_dims.size(); i++) { + if (size[i] == -1) { + size[i] = input_dims[i] - begin[i]; + } + } + // Stride is 1 for all dims. + std::vector stride(begin.size(), 1); + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); } -tensorflow::Status ConvertPool(OpConverterParams* params) { +Status ConvertStridedSlice(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - node_def.op(), " is only implemented for tensors, not weights, at ", + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + // Get begin and end bounds per axis. + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector end = inputs.at(2).weights().ToVector(); + std::vector stride = inputs.at(3).weights().ToVector(); + if (!AllLengthsEqual({input_dims, begin, end, stride})) { + return errors::InvalidArgument( + "Length of begin, end, and stride arguments must equal rank of input " + "for StridedSlice, at ", node_def.name()); } + // Unsupported mask options. + TFAttrs attrs(node_def); + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + const int begin_mask = attrs.get("begin_mask"); + const int end_mask = attrs.get("end_mask"); + // Check that batch dimension is unmodified. + const bool begin_is_modified = !(begin_mask & 1) && begin[0] != 0; + const bool stride_is_modified = stride[0] != 1; + // If the batch size is -1 and the end mask is not set, we can only know if + // the batch dimension is unmodified when the batch size is defined. When the + // batch size is undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool end_is_modified = + !(end_mask & 1) && (!batch_size_is_defined || + (batch_size_is_defined && end[0] != input_dims[0])); + if (begin_is_modified || stride_is_modified || end_is_modified) { + return errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Standarize begin and end bounds by applying masks, making negative values + // positive, and correcting out of bounds ranges (StridedSlice does this + // silently). + for (int i = 1; i < input_dims.size(); i++) { + // Begin + if ((1 << i) & begin_mask) { + begin[i] = 0; + } else if (begin[i] < 0) { + begin[i] += input_dims[i]; + } + begin[i] = std::max(0, std::min(begin[i], input_dims[i])); + // End + if ((1 << i) & end_mask) { + end[i] = input_dims[i]; + } else if (end[i] < 0) { + end[i] += input_dims[i]; + } + end[i] = std::max(0, std::min(end[i], input_dims[i])); + } + // Negative or zero strides currently not supported. + for (int i = 0; i < input_dims.size(); i++) { + if (stride[i] <= 0) { + return errors::Unimplemented( + "Negative or zero stride values are not supported for StridedSlice, " + "at ", + node_def.name()); + } + } + // TRT Slice layer uses (begin, size) instead of (begin, end) + std::vector size(input_dims.size()); + for (int i = 0; i < input_dims.size(); i++) { + // Divide by stride (round up) + size[i] = (end[i] - begin[i] + stride[i] - 1) / stride[i]; + } + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); +} + +Status ConvertConv2D(OpConverterParams* params) { + return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false); +} + +Status ConvertConv2DDepthwise(OpConverterParams* params) { + return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false); +} + +Status ConvertConv2DBackpropInput(OpConverterParams* params) { + return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true); +} + +Status ConvertPool(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); nvinfer1::PoolingType type; if (node_def.op() == "MaxPool") { type = nvinfer1::PoolingType::kMAX; } else if (node_def.op() == "AvgPool") { type = nvinfer1::PoolingType::kAVERAGE; } else { - return tensorflow::errors::Unimplemented( - "Unsupported pooling type: ", node_def.op(), ", at ", node_def.name()); + return errors::Unimplemented("Unsupported pooling type: ", node_def.op(), + ", at ", node_def.name()); } TFAttrs attrs(node_def); const string padding_type = attrs.get("padding"); if ((padding_type != "SAME") && (padding_type != "VALID")) { - return tensorflow::errors::Unimplemented( - "Unsupported padding type: ", padding_type, ", at ", node_def.name()); + return errors::Unimplemented("Unsupported padding type: ", padding_type, + ", at ", node_def.name()); } if (params->validation_only) return Status::OK(); @@ -2365,10 +2670,10 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { const_cast(tensor), {0, 3, 1, 2}, &tensor)); } - const auto tf_stride = attrs.get>("strides"); + const auto tf_stride = attrs.get>("strides"); const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); - const auto tf_kernel = attrs.get>("ksize"); + const auto tf_kernel = attrs.get>("ksize"); const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); auto tensor_dim = tensor->getDimensions(); @@ -2420,21 +2725,61 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { } params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertActivation(OpConverterParams* params) { +// TODO(tmorris): Use ActivationType::kLEAKY_RELU in TRT 5.1+ once perf +// improves. +Status ConvertLeakyRelu(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument( - node_def.op(), " expects one input, at ", node_def.name()); - } - if (!inputs.at(0).is_tensor()) { - return tensorflow::errors::Unimplemented( - node_def.op(), " is only implemented for tensors, at ", + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + + TFAttrs attrs(node_def); + const float alpha = attrs.get("alpha"); + if (alpha < 0.0f || alpha > 1.0f) { + return errors::Unimplemented( + "Alpha value for LeakyRelu must be between 0 and 1, at ", node_def.name()); } + if (params->validation_only) return Status::OK(); + + // Input Tensor + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // Create const for alpha. + const nvinfer1::ITensor* const_alpha_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, alpha, tensor->getDimensions(), &const_alpha_tensor)); + // alpha * x + nvinfer1::IElementWiseLayer* mul_layer = + params->converter->network()->addElementWise( + *const_cast(tensor), + *const_cast(const_alpha_tensor), + nvinfer1::ElementWiseOperation::kPROD); + TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name()); + // max(x, alpha * x) + nvinfer1::IElementWiseLayer* max_layer = + params->converter->network()->addElementWise( + *const_cast(tensor), + *const_cast(mul_layer->getOutput(0)), + nvinfer1::ElementWiseOperation::kMAX); + TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name()); + nvinfer1::ITensor* output_tensor = max_layer->getOutput(0); + params->converter->MarkQuantizationRangesAsInferrable( + output_tensor, const_cast(mul_layer->getOutput(0))); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + +Status ConvertActivation(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); static const std::unordered_map ops{ {"Relu", nvinfer1::ActivationType::kRELU}, {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, @@ -2442,11 +2787,10 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { }; auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) { - return tensorflow::errors::Unimplemented( - "Activation op: ", node_def.op(), - " not supported at: ", node_def.name()); + return errors::Unimplemented("Activation op: ", node_def.op(), + " not supported at: ", node_def.name()); } - if (params->validation_only) return tensorflow::Status::OK(); + if (params->validation_only) return Status::OK(); // Start conversion. const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); @@ -2462,25 +2806,25 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } Status ConvertQuantize(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if ((inputs.size() == 0) || - (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) || - (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) || - (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) || - (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) { - return errors::InvalidArgument("Invalid number of inputs for ", - node_def.op(), ", at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - // TensorRT will automatically quantize weights, so we will ignore ranges - // for weights. - params->outputs->push_back(inputs.at(0)); - return Status::OK(); + if (node_def.op() == "FakeQuantWithMinMaxArgs") { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + } else if (node_def.op() == "FakeQuantWithMinMaxVars") { + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"min", true}, {"max", true}})); + } else if (node_def.op() == "QuantizeAndDequantizeV2") { + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"input_min", true}, {"input_max", true}})); + } else if (node_def.op() == "QuantizeAndDequantizeV3") { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}, + {"input_min", true}, + {"input_max", true}, + {"num_bits", true}})); } float min_range = 0.0f; float max_range = 0.0f; @@ -2497,11 +2841,6 @@ Status ConvertQuantize(OpConverterParams* params) { node_def.op() == "QuantizeAndDequantizeV2" || node_def.op() == "QuantizeAndDequantizeV3") { // Get ranges via inputs. - if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) { - return errors::InvalidArgument("Min and max inputs for ", node_def.op(), - " must be weights not tensors, at ", - node_def.name()); - } auto get_weights_value = [&inputs](int index) { auto raw_weights = static_cast( const_cast(inputs.at(index).weights().GetValues())); @@ -2532,20 +2871,13 @@ Status ConvertQuantize(OpConverterParams* params) { return Status::OK(); } -// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports -// Relu6 natively. -tensorflow::Status ConvertRelu6(OpConverterParams* params) { +// TODO(tmorris): Use ActivationType::kCLIP in TRT 5.1+ once perf improves. +Status ConvertRelu6(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument( - "Invalid number of inputs for Relu6, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Relu6 is only implemented for tensors, not weights, at ", - node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (params->validation_only) return Status::OK(); // *************************************************************************** // TensorRT does not implement Relu6 natively. This function converts Relu6 op @@ -2569,24 +2901,10 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f, 6.0f); - // Create a constant layer to store the floating point weight i.e. 6.0f This - // tensor will be broadcasted uniformly during elementwise `min` operation. - // The constant has to have the same rank as the input in order for TRT to - // broadcast - nvinfer1::Dims dims; - dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims; - for (int i = 0; i < dims.nbDims; i++) { - dims.d[i] = 1; - } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights( - tensorflow::DataType::DT_FLOAT, dims); - auto weights_ptr = - static_cast(const_cast(weights.GetValues())); - weights_ptr[0] = 6.0f; - nvinfer1::ITensor* const6_tensor = - params->converter->CreateConstantLayer(weights, dims); - TFTRT_RETURN_ERROR_IF_NULLPTR(const6_tensor, node_def.name()); - params->converter->ProvideQuantizationRange(const6_tensor, 0.0f, 6.0f); + // Create a constant layer to store the floating point weight i.e. 6.0f + const nvinfer1::ITensor* const6_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor)); // ElementWise Min Operation // Min op is a nop for INT8 execution path, as the input tensor @@ -2594,7 +2912,8 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { nvinfer1::IElementWiseLayer* relu6_layer = params->converter->network()->addElementWise( *const_cast(relu_layer->getOutput(0)), - *const6_tensor, nvinfer1::ElementWiseOperation::kMIN); + *const_cast(const6_tensor), + nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); @@ -2603,26 +2922,19 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { return Status::OK(); } -tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { +Status ConvertBiasAdd(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return errors::InvalidArgument("Input expects tensor and weights, at ", - node_def.name()); - } - TFAttrs attrs(node_def); - tensorflow::DataType tf_dtype = attrs.get("T"); - if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { - return errors::Unimplemented("Data type is not supported, for node ", - node_def.name(), " got ", - DataTypeString(tf_dtype)); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"value", false}, {"bias", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (params->validation_only) return Status::OK(); nvinfer1::ITensor* tensor = const_cast(inputs.at(0).tensor()); const nvinfer1::Dims original_dims = tensor->getDimensions(); + TFAttrs attrs(node_def); const string data_format = attrs.get("data_format"); const int channel_index = (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); @@ -2668,7 +2980,7 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { } TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->precision_mode() == FP16MODE) { + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { weights = ConvertFP32ToFP16(params->weight_store, weights); } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; @@ -2712,43 +3024,69 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { return Status::OK(); } -Status GetTensorDimsWithProtoShape(const Tensor& tensor, - int tensor_proto_array_len, - nvinfer1::Dims* dims) { +void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) { if (tensor.dims() > 0) { *dims = GetTrtDimsForTensor(tensor); - if (TrtDimsNumElements(*dims) != tensor_proto_array_len && - tensor_proto_array_len != 1) { - return errors::InvalidArgument( - "Broadcast on weights only supports kCHANNEL and kUNIFORM"); - } } else { dims->nbDims = 1; // No dimension provided. Flatten it. - dims->d[0] = tensor_proto_array_len; + dims->d[0] = tensor.NumElements(); dims->type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; ++i) { dims->d[i] = 0; } } - return Status::OK(); } -template -Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, - const CType* tensor_proto_array, - int tensor_proto_array_len, TrtWeightStore* store, +Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, TRT_ShapedWeights* weights) { + const DataType dtype = tensor.dtype(); + + // We always convert the integer constants to INT32, since TRT INT8 is for + // quantized inference. + // + // TODO(aaroey): FP16 will remain in half format and is not converted to + // FP32, but the converter currently uses all float weights as FP32. Fix + // this. + const DataType converted_dtype = + (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 + : dtype); + + // Verify that the dtype is supported by TensorRT. Otherwise, return an error. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); + + if (tensor.NumElements() == 0) { + // Return empty weights having converted dtype. + *weights = TRT_ShapedWeights(converted_dtype); + return Status::OK(); + } + nvinfer1::Dims weight_dims; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape(tensor, tensor_proto_array_len, - &weight_dims)); - *weights = store->GetTempWeights(dtype, weight_dims); - void* dst = const_cast(weights->GetValues()); - if (tensor_proto_array_len == 1) { - std::fill_n((CType*)dst, TrtDimsNumElements(weight_dims), - *tensor_proto_array); + GetTensorDimsWithProtoShape(tensor, &weight_dims); + *weights = weight_store->GetTempWeights(converted_dtype, weight_dims); + + // Copy the tensor directly if the tensor does not require cast to the + // supported type. + if (converted_dtype == dtype) { + char* dst = static_cast(const_cast(weights->GetValues())); + memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes()); + return Status::OK(); + } + + // Copy tensor elements after casting them to the converted DataType. + int32* dst = static_cast(const_cast(weights->GetValues())); + if (dtype == DT_INT16) { + const int16* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); + } else if (dtype == DT_INT8) { + const int8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } else { - memcpy(dst, tensor_proto_array, weights->size_bytes()); + // dtype can only be DT_UINT8 at this point. + TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); + const uint8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } return Status::OK(); } @@ -2758,7 +3096,7 @@ Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, // weights to params->outputs. We did this since TrtNodeValidator needs the // weights as input to other nodes, and use it to determine whether those nodes // are supported by TRT. -tensorflow::Status ConvertConst(OpConverterParams* params) { +Status ConvertConst(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (!inputs.empty()) { @@ -2766,117 +3104,55 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { "Constant node is expected to have empty input list: ", node_def.name()); } - TFAttrs attrs(node_def); - const DataType dtype = attrs.get("dtype"); - // We always convert the integer constants to kINT32, since TRT kINT8 is for - // quantized inference. - const DataType converted_dtype = - (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 - : dtype); - nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); // Create shaped weights as output const auto& tensor_proto = node_def.attr().at("value").tensor(); - tensorflow::Tensor tensor; + Tensor tensor; if (!tensor.FromProto(tensor_proto)) { - return tensorflow::errors::Internal("Cannot parse weight tensor proto: ", - node_def.name()); + return errors::Internal("Cannot parse weight tensor proto: ", + node_def.name()); } - TRT_ShapedWeights weights(converted_dtype); - if (tensor.NumElements() == 0) { - // Do nothing. - } else if (!tensor_proto.float_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.float_val().begin(), - tensor_proto.float_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.int_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.int_val().begin(), - tensor_proto.int_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.half_val().empty()) { - // TODO(aaroey): implement fp16 conversion. - return errors::Unimplemented("fp16 constant is not supported yet."); - } else if (!tensor_proto.tensor_content().empty()) { - // TODO(aaroey): fp16 will remain in half format and is not converted to - // fp32, but the converter currently uses all float weights as fp32. Fix - // this. - const auto& content = tensor_proto.tensor_content(); - if (content.size() > 0) { - const int dtype_size = tensorflow::DataTypeSize(dtype); - if (content.size() % dtype_size != 0) { - return errors::FailedPrecondition("Tensor content size ", - content.size(), - " is not a multiple of ", dtype_size); - } - nvinfer1::Dims weights_dim; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape( - tensor, content.size() / dtype_size, &weights_dim)); - const int64_t size_bytes = TrtDimsNumElements(weights_dim) * dtype_size; - if (content.size() != size_bytes) { - return errors::FailedPrecondition( - "Tensor size and TensorProto content size mismatch: ", size_bytes, - " vs ", content.size()); - } else if (tensor.NumElements() != content.size() / dtype_size) { - return errors::FailedPrecondition( - "Tensor elements count and TensorProto content size mismatch: ", - tensor.NumElements(), " vs ", content.size() / dtype_size); - } - weights = - params->weight_store->GetTempWeights(converted_dtype, weights_dim); - if (dtype_size == tensorflow::DataTypeSize(converted_dtype)) { - port::CopyToArray(content, static_cast( - const_cast(weights.GetValues()))); - } else { - // Copy out the weights as original data type. - std::vector temp_weights(content.size()); - port::CopyToArray(content, - reinterpret_cast(temp_weights.data())); - int32* dst = - static_cast(const_cast(weights.GetValues())); - // Copy to the weight store as converted data type. - if (dtype == DT_INT16) { - int16* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_INT8) { - int8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_UINT8) { - uint8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else { - return errors::FailedPrecondition( - "Unexpected data type: ", DataTypeString(dtype), - " at: ", node_def.name()); - } - } - } - } else { - return errors::Unimplemented("Not supported constant type, at ", - node_def.name()); + TFAttrs attrs(node_def); + const DataType dtype = attrs.get("dtype"); + if (dtype != tensor.dtype()) { + return errors::InvalidArgument("DataType mismatch between attr (", + DataTypeString(dtype), ") and tensor (", + DataTypeString(tensor.dtype()), ")"); } + + TRT_ShapedWeights weights; + TF_RETURN_IF_ERROR( + TfTensorToTrtWeights(tensor, params->weight_store, &weights)); + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } return Status::OK(); } -tensorflow::Status ConvertIdentity(OpConverterParams* params) { +Status ConvertIdentity(OpConverterParams* params) { // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT // 5.0, however once we know that it does it would be nice to use that // instead. + if (params->validation_only) return Status::OK(); params->outputs->push_back(params->inputs.at(0)); - return tensorflow::Status::OK(); + return Status::OK(); } Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); if (inputs.size() != 2) { - return errors::InvalidArgument("Binary ops require two inputs, at ", + return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), + " inputs but expected 2, at ", node_def.name()); } + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); // Constant folding should have been done by TensorFlow if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { @@ -2908,135 +3184,165 @@ Status ConvertBinary(OpConverterParams* params) { // If both input are tensors, or one of them is weights but the conversion // above failed, try the conversion using BinaryTensorOpTensor. if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { - if (!status.ok()) VLOG(1) << status; + if (!status.ok()) VLOG(2) << status; status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1)); } return status; } -tensorflow::Status ConvertUnary(OpConverterParams* params) { +Status ConvertRsqrt(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - static const std::unordered_map ops{ - {"Neg", nvinfer1::UnaryOperation::kNEG}, - {"Exp", nvinfer1::UnaryOperation::kEXP}, - {"Log", nvinfer1::UnaryOperation::kLOG}, - {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, - {"Abs", nvinfer1::UnaryOperation::kABS}, - {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, - }; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + if (params->validation_only) return Status::OK(); - if (inputs.size() != 1) { - return tensorflow::errors::FailedPrecondition( - "Unary ops require single tensor input, at ", node_def.name()); + // TODO(tmorris): params->converter is null during validation. Allow + // precision_mode and use_calibration to be accessed during validation and + // include this check in validation. + // We will need a quantization range for intermediate tensor if not using + // calibration. + // + // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) + // ^ + // need range here + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration for Rsqrt, consider replacing with " + "Sqrt -> FakeQuant -> Reciprocal ops, at ", + node_def.name()); } + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // Sqrt + nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary( + *const_cast(tensor), nvinfer1::UnaryOperation::kSQRT); + TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name()); + // Recip + nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary( + *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0))); + return Status::OK(); +} - // TODO(jie): check type - const nvinfer1::ITensor* tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), inputs.at(0).GetTrtDims(), &tensor)); +const std::unordered_map* +UnaryOperationMap() { + static auto* const m = + new std::unordered_map({ + {"Neg", nvinfer1::UnaryOperation::kNEG}, + {"Exp", nvinfer1::UnaryOperation::kEXP}, + {"Log", nvinfer1::UnaryOperation::kLOG}, + {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"Abs", nvinfer1::UnaryOperation::kABS}, + {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, +#if IS_TRT_VERSION_GE(5, 1, 0) + {"Sin", nvinfer1::UnaryOperation::kSIN}, + {"Cos", nvinfer1::UnaryOperation::kCOS}, + {"Tan", nvinfer1::UnaryOperation::kTAN}, + {"Sinh", nvinfer1::UnaryOperation::kSINH}, + {"Cosh", nvinfer1::UnaryOperation::kCOSH}, + {"Asin", nvinfer1::UnaryOperation::kASIN}, + {"Acos", nvinfer1::UnaryOperation::kACOS}, + {"Atan", nvinfer1::UnaryOperation::kATAN}, + {"Asinh", nvinfer1::UnaryOperation::kASINH}, + {"Acosh", nvinfer1::UnaryOperation::kACOSH}, + {"Atanh", nvinfer1::UnaryOperation::kATANH}, + {"Ceil", nvinfer1::UnaryOperation::kCEIL}, + {"Floor", nvinfer1::UnaryOperation::kFLOOR}, +#endif + }); + return m; +} - nvinfer1::IUnaryLayer* layer; - if (node_def.op() == "Rsqrt") { - // We will need a quantization range for intermediate tensor if not using - // calibration. - // - // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) - // ^ - // need range here - if (params->converter->precision_mode() == INT8MODE && - !params->converter->use_calibration()) { - return errors::Unimplemented( - "Intermediate quantization range cannot be determined without" - " calibration for Rsqrt, consider replacing with " - "Sqrt -> FakeQuant -> Reciprocal ops, at ", - node_def.name()); - } - layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kSQRT); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - tensor = layer->getOutput(0); - layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kRECIP); - } else if (ops.count(node_def.op()) != 0) { - layer = params->converter->network()->addUnary( - *const_cast(tensor), ops.at(node_def.op())); - } else { - return tensorflow::errors::InvalidArgument( - "Binary op: ", node_def.op(), " not supported, at ", node_def.name()); +Status ConvertUnary(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + auto op_pair = UnaryOperationMap()->find(node_def.op()); + if (op_pair == UnaryOperationMap()->end()) { + return errors::Unimplemented("Unary op: ", node_def.op(), + " not supported at: ", node_def.name()); } + if (params->validation_only) return Status::OK(); + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( + *const_cast(tensor), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Set quantization ranges. + if (node_def.op() == "Sin" || node_def.op() == "Cos") { + params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); + } else if (node_def.op() == "Asin" || node_def.op() == "Atan") { + params->converter->ProvideQuantizationRange(output_tensor, -M_PI_2, M_PI_2); + } else if (node_def.op() == "Acos") { + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, M_PI); + } else if (node_def.op() == "Neg" || node_def.op() == "Abs") { + // Neg and Abs will have same range as input since TRT uses symmetric + // quantization. + // TODO(tmorris): Should we infer ranges for Ceil and Floor as well? + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), output_tensor); + } params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertSquare(OpConverterParams* params) { +Status ConvertSquare(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument("Square expects one input, at ", - node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Square is only implemented for tensors, at ", node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (params->validation_only) return Status::OK(); // Constant 2 with same rank as input - nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); - for (int i = 0; i < dims.nbDims; i++) { - dims.d[i] = 1; - } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights( - tensorflow::DataType::DT_FLOAT, dims); - auto weights_ptr = - static_cast(const_cast(weights.GetValues())); - weights_ptr[0] = 2.f; - nvinfer1::ITensor* const2_tensor = - params->converter->CreateConstantLayer(weights, dims); - TFTRT_RETURN_ERROR_IF_NULLPTR(const2_tensor, node_def.name()); + const nvinfer1::ITensor* const2_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor)); // ElementWise Pow Operation nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( *const_cast(inputs.at(0).tensor()), - *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); + *const_cast(const2_tensor), + nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertReduce(OpConverterParams* params) { +Status ConvertReduce(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"axis", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TRT_ShapedWeights index_list = inputs.at(1).weights(); TFAttrs attrs(node_def); - auto index_type = attrs.get("Tidx"); - // Only expect to handle INT32 as attributes for now - if (index_type != tensorflow::DataType::DT_INT32) { - return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); + if (attrs.get("Tidx") != DataType::DT_INT32) { + return errors::Unimplemented("Tidx supports only DT_INT32"); } int axes = 0; if (index_list.count() == 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "TRT cannot support reduce on all (batch) dimensions, at", node_def.name()); } else { @@ -3046,7 +3352,7 @@ tensorflow::Status ConvertReduce(OpConverterParams* params) { int axis = index_list_data[i]; if (axis < 0) axis += tensor->getDimensions().nbDims + 1; if (axis == 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "TRT cannot reduce at batch dimension, at", node_def.name()); } axes |= (1 << (axis - 1)); @@ -3065,9 +3371,10 @@ tensorflow::Status ConvertReduce(OpConverterParams* params) { } else if (node_def.op() == "Mean") { reduce_operation = nvinfer1::ReduceOperation::kAVG; } else { - return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), - " , at ", node_def.name()); + return errors::Unimplemented("Op not supported ", node_def.op(), ", at ", + node_def.name()); } + if (params->validation_only) return Status::OK(); const auto keep_dims = attrs.get("keep_dims"); nvinfer1::ILayer* layer = params->converter->network()->addReduce( @@ -3076,18 +3383,16 @@ tensorflow::Status ConvertReduce(OpConverterParams* params) { TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertPad(OpConverterParams* params) { +Status ConvertPad(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - // TODO(aaroey): make a routine for this check and reuse it. - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); @@ -3100,19 +3405,18 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { TFAttrs attrs(node_def); // Padding type here is done through TF type // so I can leverage their EnumToDataType for my cast - auto padding_type = attrs.get("Tpaddings"); + auto padding_type = attrs.get("Tpaddings"); // TODO(jie): handle data type conversion for TRT? if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Pad only supports explicit padding on 4 dimensional tensor, at ", node_def.name()); } // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) { - return tensorflow::errors::Unimplemented( - "Tpaddings supports only DT_INT32"); + if (padding_type != DataType::DT_INT32) { + return errors::Unimplemented("Tpaddings supports only DT_INT32"); } auto pad_data = static_cast(const_cast(pads.GetValues())); @@ -3124,27 +3428,27 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { } // No padding at all, we should exit - if (pad_index.size() == 0) { + if (pad_index.empty()) { params->outputs->push_back(inputs.at(0)); - return tensorflow::Status::OK(); + return Status::OK(); } // Only supports padding on less than 2 axis GIE-2579 if (pad_index.size() > 2) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Padding layer does not support padding on > 2"); } // Padding on batch dimension is not supported if (pad_index[0] == 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Padding layer does not support padding on batch dimension"); } // Not doing the legit thing here. ignoring padding on dim 1 and 3; // TODO(jie): implement pad as uff parser if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); } if (params->validation_only) return Status::OK(); @@ -3185,17 +3489,21 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertConcat(OpConverterParams* params) { +Status ConvertConcat(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + // TODO(tmorris): There is a bug with Concat and INT32 in TRT - it is supposed + // to be supported. + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); // not including the last input (axis) here int input_size = static_cast(inputs.size()) - 1; if (!inputs.at(0).is_tensor()) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Concat in TRT support only Tensor input, at ", node_def.name()); } @@ -3203,13 +3511,13 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { TRT_ShapedWeights axis = inputs.at(input_size).weights(); TFAttrs attrs(node_def); - auto index_type = attrs.get("Tidx"); + auto index_type = attrs.get("Tidx"); // TODO(jie): handle data type // Only expect to handle INT32 as index attributes for now - if (index_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32, at ", - node_def.name()); + if (index_type != DataType::DT_INT32) + return errors::Unimplemented("Tidx supports only DT_INT32, at ", + node_def.name()); int index = *(static_cast(const_cast(axis.GetValues()))); @@ -3218,11 +3526,11 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { auto dim = inputs.at(0).tensor()->getDimensions(); // dimension check if (index > dim.nbDims + 1) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Concatenate on axis out of dimension range, at ", node_def.name()); } if (index == 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Concatenate on batch dimension not supported, at ", node_def.name()); } if (index < 0) { @@ -3236,14 +3544,14 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { auto tensor_i = inputs.at(i).tensor(); auto dim_i = tensor_i->getDimensions(); if (dim_i.nbDims != dim.nbDims) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Concatenate receives inputs with inconsistent dimensions, at ", node_def.name()); } for (int j = 0; j < dim.nbDims; j++) { // check dimension consistency on non-concatenate axis if (j != index - 1 && dim_i.d[j] != dim.d[j]) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Concatenate receives inputs with inconsistent shape, at", node_def.name()); } @@ -3251,7 +3559,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { inputs_vec.push_back(tensor_i); } - if (params->validation_only) return tensorflow::Status::OK(); + if (params->validation_only) return Status::OK(); // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); nvinfer1::IConcatenationLayer* layer = @@ -3262,17 +3570,24 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { layer->setAxis(index - 1); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { +Status ConvertFusedBatchNorm(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, + {"scale", true}, + {"offset", true}, + {"mean", true}, + {"variance", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); TFAttrs attrs(node_def); float epsilon = attrs.get("epsilon"); auto data_format = attrs.get("data_format"); if (data_format != "NCHW") { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( node_def.op(), " only supports data_format=NCHW, at ", node_def.name()); } bool is_training = attrs.get("is_training"); @@ -3284,38 +3599,23 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { << "are using Keras, please call " << "keras.backend.set_learning_phase(0) before constructing " << "your model. At " << node_def.name(); - return tensorflow::errors::Unimplemented( - node_def.op(), " only supports is_training=false, at ", - node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - node_def.op(), - " is only implemented for tensor inputs, not weights, at ", - node_def.name()); - } - for (int i = 1; i < 5; i++) { - if (inputs.at(i).is_tensor()) { - return tensorflow::errors::Unimplemented( - node_def.op(), - " must have constant inputs for scale, offset, mean and variance, " - "at ", - node_def.name()); - } + return errors::Unimplemented(node_def.op(), + " only supports is_training=false, at ", + node_def.name()); } nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); // Check parameter types auto parameter_type = inputs.at(1).weights().type_; - if ((parameter_type != tensorflow::DataType::DT_FLOAT) && - (parameter_type != tensorflow::DataType::DT_HALF)) { - return tensorflow::errors::Unimplemented( + if ((parameter_type != DataType::DT_FLOAT) && + (parameter_type != DataType::DT_HALF)) { + return errors::Unimplemented( "only float32 or float16 weight data type is supported, for node " + - node_def.name() + " got " + tensorflow::DataTypeString(parameter_type)); + node_def.name() + " got " + DataTypeString(parameter_type)); } for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().type_ != parameter_type) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); } @@ -3324,7 +3624,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { TRT_ShapedWeights dummy_power_weights(parameter_type); size_t nweight = 0; for (int i = 1; i < 5; i++) { - nweight = std::max(nweight, (size_t)inputs.at(i).weights().count()); + nweight = std::max(nweight, inputs.at(i).weights().count()); } TRT_ShapedWeights* ptr_shape_weights = nullptr; for (int i = 1; i < 5; i++) { @@ -3332,7 +3632,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { ptr_shape_weights = const_cast(&(inputs.at(i).weights())); } else if (inputs.at(i).weights().count() != 1) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Inconsistent batchnorm parameter count, at: " + node_def.name()); } } @@ -3366,16 +3666,16 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { float batchnorm_data[4]; for (int j = 0; j < 4; j++) { if (inputs.at(j + 1).weights().count() != 1) { - if (parameter_type == tensorflow::DT_FLOAT) { + if (parameter_type == DT_FLOAT) { batchnorm_data[j] = vals_array[j][i]; - } else if (parameter_type == tensorflow::DT_HALF) { + } else if (parameter_type == DT_HALF) { batchnorm_data[j] = Eigen::half_impl::half_to_float(cast_vals_array[j][i]); } } else { - if (parameter_type == tensorflow::DT_FLOAT) { + if (parameter_type == DT_FLOAT) { batchnorm_data[j] = vals_array[j][0]; - } else if (parameter_type == tensorflow::DT_HALF) { + } else if (parameter_type == DT_HALF) { batchnorm_data[j] = Eigen::half_impl::half_to_float(cast_vals_array[j][0]); } @@ -3387,10 +3687,10 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { float variance = batchnorm_data[3]; float combined_scale_val = scale / sqrtf(variance + epsilon); float combined_offset_val = offset - mean * combined_scale_val; - if (parameter_type == tensorflow::DT_FLOAT) { + if (parameter_type == DT_FLOAT) { combined_scale_vals[i] = combined_scale_val; combined_offset_vals[i] = combined_offset_val; - } else if (parameter_type == tensorflow::DT_HALF) { + } else if (parameter_type == DT_HALF) { cast_combined_scale_vals[i] = Eigen::half(combined_scale_val); cast_combined_offset_vals[i] = Eigen::half(combined_offset_val); } @@ -3406,17 +3706,88 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, - TRT_TensorOrWeights tensor_input, - TRT_ShapedWeights weights_raw, - bool transpose_weight, - string node_name) { +Status ConvertGather(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"params", false}, {"indices", false}, {"axis", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, + /*dtype_attr_name=*/"Tparams")); + absl::Span axis = inputs.at(2).weights().GetSpan(); + if (axis.size() != 1) { + return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", + node_def.name()); + } + int trt_axis = 0; + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, + node_def.name(), &trt_axis)); + TRT_TensorOrWeights params_tensor = inputs.at(0); + TRT_TensorOrWeights indices_tensor = inputs.at(1); + if (indices_tensor.batch_size() != 1) { + return errors::InvalidArgument("Only indices with batch 1 are supported."); + } + // Both input are tensors, and the TF gather result will have rank: + // (params.nbDims + 1) + (indices.nbDims + 1) - 1, + // where "+ 1" adds the batch dim. + const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims + + indices_tensor.GetTrtDims().nbDims + 1; + if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) { + return errors::InvalidArgument( + "Result of gather has dimension greater than ", + nvinfer1::Dims::MAX_DIMS + 1); + } + if (params->validation_only) return Status::OK(); + + // Note on how IGatherLayer works: if both the data and indices tensors have + // a batch size dimension of size N, it performs: + // for batchid in xrange(N): + // output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = ( + // data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn]) + nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( + *const_cast(params_tensor.tensor()), + *const_cast(indices_tensor.tensor()), trt_axis); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + nvinfer1::ITensor* gather_output = layer->getOutput(0); + nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions(); + // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT, + // and the other is for the output dimension that is squeezed by IGatherLayer + // because of the implicit batch dim in the indices (see the above note). + if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) { + return errors::Internal( + "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", + tf_gather_output_rank - 2, + ", actual nbDims: ", trt_gather_output_dims.nbDims); + } + // Reshape the output so after adding the implicit batch dim it'll match the + // output shape of TF GatherV2. + for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) { + trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1]; + } + trt_gather_output_dims.d[trt_axis] = 1; + ++trt_gather_output_dims.nbDims; + + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(gather_output), trt_gather_output_dims, + /*validation_only=*/false, &output_tensor)); + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return Status::OK(); +} + +Status ConvertMatMulHelper(OpConverterParams* params, + TRT_TensorOrWeights tensor_input, + TRT_ShapedWeights weights_raw, bool transpose_weight, + string node_name) { nvinfer1::ITensor* output_tensor; if (!tensor_input.is_tensor()) { - return tensorflow::errors::InvalidArgument("Input 0 expects tensor"); + return errors::InvalidArgument("Input 0 expects tensor"); } const nvinfer1::ITensor* tensor = tensor_input.tensor(); @@ -3436,7 +3807,7 @@ tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, input_dim.d[input_dim.nbDims++] = 1; } TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - tensor_input, input_dim, &tensor)); + tensor_input, input_dim, /*validation_only=*/false, &tensor)); nvinfer1::IFullyConnectedLayer* layer = params->converter->network()->addFullyConnected( @@ -3449,29 +3820,22 @@ tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, auto output_dim = output_tensor->getDimensions(); output_dim.nbDims = 1; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(output_tensor), output_dim, &temp_tensor)); + TRT_TensorOrWeights(output_tensor), output_dim, /*validation_only=*/false, + &temp_tensor)); output_tensor = const_cast(temp_tensor); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -// inputs are both two dimensional (tensorflow::ops::MatMul) -tensorflow::Status ConvertMatMul(OpConverterParams* params) { +// inputs are both two dimensional (ops::MatMul) +Status ConvertMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return errors::InvalidArgument("Input expects tensor and weights, at ", - node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); TFAttrs attrs(node_def); - tensorflow::DataType tf_dtype = attrs.get("T"); - if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { - return errors::Unimplemented("Data type is not supported, for node ", - node_def.name(), " got ", - DataTypeString(tf_dtype)); - } bool transpose_a = attrs.get("transpose_a"); bool transpose_b = attrs.get("transpose_b"); @@ -3486,66 +3850,64 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { transpose_b, node_def.name()); } -tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { +Status ConvertBatchMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - TFAttrs attrs(node_def); - - tensorflow::DataType tf_dtype = attrs.get("T"); - if (tf_dtype != tensorflow::DataType::DT_FLOAT && - tf_dtype != tensorflow::DataType::DT_HALF) { - return tensorflow::errors::Unimplemented( - "data type is not supported, for node " + node_def.name() + " got " + - tensorflow::DataTypeString(tf_dtype)); + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + if (inputs.size() != 2) { + return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), + " inputs but expected 2, at ", + node_def.name()); } - - bool transpose_a = attrs.get("adj_x"); - bool transpose_b = attrs.get("adj_y"); - - auto dims = inputs.at(0).GetTrtDims(); + if (inputs[0].is_weights() && inputs[1].is_weights()) { + return errors::InvalidArgument( + "All inputs are weights, but Grappler is expected to fold them."); + } + TFAttrs attrs(node_def); + const bool transpose_a = attrs.get("adj_x"); + const bool transpose_b = attrs.get("adj_y"); + const auto dims = inputs.at(0).GetTrtDims(); if (dims.nbDims == 1) { // NC * CK is only supported through fully connected if (transpose_a == false && inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(), transpose_b, node_def.name()); } else { - return tensorflow::errors::InvalidArgument( - "Invalid configuration for MatMul, at: " + node_def.name()); + return errors::InvalidArgument("Invalid configuration for MatMul, at: ", + node_def.name()); } } - const nvinfer1::ITensor* tensor_l; - const nvinfer1::ITensor* tensor_r; - auto dims_l = inputs.at(0).GetTrtDims(); - auto dims_r = inputs.at(1).GetTrtDims(); - if (inputs.at(0).is_weights()) { - if (inputs.at(0).GetTrtDims().d[0] != 1) { - return tensorflow::errors::InvalidArgument( - "Input 0 as weight assumes broadcast across batch for MatMul, at: " + - node_def.name()); - } else { - for (int i = 0; i < dims_l.nbDims - 1; i++) { - dims_l.d[i] = dims_l.d[i + 1]; + auto get_tensor_with_proper_dims = [params]( + const TRT_TensorOrWeights& input, + const nvinfer1::ITensor** tensor) { + auto dims = input.GetTrtDims(); + if (input.is_weights()) { + // The other operand must be a tensor, this is ensured by earlier checks. + // Checks that the batch dimension is not changed by broadcasting. + if (dims.d[0] != 1) { + return errors::InvalidArgument( + "Input weight attempts to broadcast across batch dimension for " + "BatchMatMul, at ", + params->node_def.name()); } - dims_l.nbDims--; + // Remove the batch dimension from the weights. + TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims)); } - } - if (inputs.at(1).is_weights()) { - if (inputs.at(1).GetTrtDims().d[0] != 1) { - return tensorflow::errors::InvalidArgument( - "Input 1 as weight assumes broadcast across batch for MatMul, at: " + - node_def.name()); - } else { - for (int i = 0; i < dims_r.nbDims - 1; i++) { - dims_r.d[i] = dims_r.d[i + 1]; - } - dims_r.nbDims--; - } - } - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), dims_l, &tensor_l)); - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), dims_r, &tensor_r)); + // Create tensor and reshape if necessary. + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input, dims, params->validation_only, tensor)); + return Status::OK(); + }; + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + TF_RETURN_IF_ERROR(get_tensor_with_proper_dims(inputs.at(0), &tensor_l)); + TF_RETURN_IF_ERROR(get_tensor_with_proper_dims(inputs.at(1), &tensor_r)); + if (params->validation_only) return Status::OK(); nvinfer1::IMatrixMultiplyLayer* layer = params->converter->network()->addMatrixMultiply( @@ -3554,20 +3916,25 @@ tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertSoftmax(OpConverterParams* params) { +Status ConvertSoftmax(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int nbDims = tensor->getDimensions().nbDims; if (nbDims == 0) { - return tensorflow::errors::InvalidArgument( - "TensorRT Softmax cannot apply on batch dimension, at" + + return errors::InvalidArgument( + "TensorRT Softmax cannot apply on batch dimension, at", node_def.name()); } + if (params->validation_only) return Status::OK(); + nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax( *const_cast(tensor)); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); @@ -3578,65 +3945,80 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { // Quantization range for SoftMax is always (0, 1) params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertTopK(OpConverterParams* params) { +Status ConvertTopK(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"k", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - - int nbDims = tensor->getDimensions().nbDims; - if (nbDims == 0) { - return tensorflow::errors::InvalidArgument( - "TensorRT TopK cannot apply on batch dimension, at" + node_def.name()); + const int num_dims = tensor->getDimensions().nbDims; + if (num_dims == 0) { + return errors::InvalidArgument( + "TensorRT TopK cannot apply on batch dimension, at", node_def.name()); } TRT_ShapedWeights k_w = inputs.at(1).weights(); - int k = *(static_cast(const_cast(k_w.GetValues()))); - - nvinfer1::TopKOperation op; - uint32_t reducedAxes = 0; - if (node_def.op() == "TopKV2") { - op = nvinfer1::TopKOperation::kMAX; - reducedAxes |= 1 << (nbDims - 1); - } else { - return tensorflow::errors::Unimplemented( - "Operation: " + node_def.op() + - " not implemented, at: " + node_def.name()); + if (k_w.count() != 1) { + return errors::InvalidArgument("k value of TopK should be a scalar, at", + node_def.name()); } + // Note that ITopKLayer always have sorted outputs, so we don't need to handle + // the 'sorted' attribute of the node. + if (params->validation_only) return Status::OK(); + const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX; + const int k = *(static_cast(const_cast(k_w.GetValues()))); + const uint32_t reduce_axes = 1 << (num_dims - 1); nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( - *const_cast(tensor), op, k, reducedAxes); + *const_cast(tensor), op, k, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); - // Tensor type for network output is not inferred. Indices should be INT32 - // (default is float). - output_indices_tensor->setType(nvinfer1::DataType::kINT32); params->outputs->push_back(TRT_TensorOrWeights(output_value_tensor)); params->outputs->push_back(TRT_TensorOrWeights(output_indices_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } static void RegisterValidatableOpConverters( std::unordered_map* registration) { - // TODO(laigd): support all op types. (*registration)["BiasAdd"] = ConvertBiasAdd; (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; + (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["GatherV2"] = ConvertGather; + (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; (*registration)["Reshape"] = ConvertReshape; + (*registration)["Rsqrt"] = ConvertRsqrt; + (*registration)["Slice"] = ConvertSlice; (*registration)["Square"] = ConvertSquare; (*registration)["Squeeze"] = ConvertSqueeze; (*registration)["StridedSlice"] = ConvertStridedSlice; (*registration)["Transpose"] = ConvertTranspose; + (*registration)["TopKV2"] = ConvertTopK; + + // TODO(ben,jie): this is a temp hack. + (*registration)["Identity"] = ConvertIdentity; // Identity should be removed + (*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed + + (*registration)["Sum"] = ConvertReduce; + (*registration)["Prod"] = ConvertReduce; + (*registration)["Max"] = ConvertReduce; + (*registration)["Min"] = ConvertReduce; + (*registration)["Mean"] = ConvertReduce; + (*registration)["Softmax"] = ConvertSoftmax; + (*registration)["BatchMatMul"] = ConvertBatchMatMul; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3644,7 +4026,7 @@ static void RegisterValidatableOpConverters( (*registration)[quantization_op_type] = ConvertQuantize; } for (auto binary_op_type : - {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum"}) { + {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum", "Pow"}) { (*registration)[binary_op_type] = ConvertBinary; } for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { @@ -3656,6 +4038,9 @@ static void RegisterValidatableOpConverters( for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { (*registration)[normalization_op_type] = ConvertFusedBatchNorm; } + for (auto unary_op_pair : *UnaryOperationMap()) { + (*registration)[unary_op_pair.first] = ConvertUnary; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3664,36 +4049,14 @@ void TrtNodeValidator::RegisterOpValidators() { void Converter::RegisterOpConverters() { RegisterValidatableOpConverters(&op_registry_); - // TODO(ben,jie): this is a temp hack. - op_registry_["Identity"] = ConvertIdentity; // Identity should be removed - op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - - op_registry_["Rsqrt"] = ConvertUnary; - op_registry_["Reciprocal"] = ConvertUnary; - op_registry_["Exp"] = ConvertUnary; - op_registry_["Log"] = ConvertUnary; - op_registry_["Sqrt"] = ConvertUnary; - op_registry_["Abs"] = ConvertUnary; - op_registry_["Neg"] = ConvertUnary; - - op_registry_["Sum"] = ConvertReduce; - op_registry_["Prod"] = ConvertReduce; - op_registry_["Max"] = ConvertReduce; - op_registry_["Min"] = ConvertReduce; - op_registry_["Mean"] = ConvertReduce; - op_registry_["Softmax"] = ConvertSoftmax; - op_registry_["BatchMatMul"] = ConvertBatchMatMul; - op_registry_["TopKV2"] = ConvertTopK; - plugin_converter_ = ConvertPlugin; } -tensorflow::Status ConvertGraphDefToEngine( - const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, +Status ConvertGraphDefToEngine( + const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size, size_t max_workspace_size_bytes, - const std::vector& input_shapes, - Logger* logger, nvinfer1::IGpuAllocator* allocator, - TRTInt8Calibrator* calibrator, + const std::vector& input_shapes, Logger* logger, + nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully) { engine->reset(); @@ -3705,9 +4068,13 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setMaxBatchSize(max_batch_size); builder->setMaxWorkspaceSize(max_workspace_size_bytes); builder->setGpuAllocator(allocator); - if (precision_mode == FP16MODE) { - builder->setHalf2Mode(true); - } else if (precision_mode == INT8MODE) { + if (precision_mode == TrtPrecisionMode::FP16) { + builder->setFp16Mode(true); + } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder->setFp16Mode(true); builder->setInt8Mode(true); if (use_calibration) { builder->setInt8Calibrator(calibrator); @@ -3720,25 +4087,23 @@ tensorflow::Status ConvertGraphDefToEngine( auto trt_network = TrtUniquePtrType(builder->createNetwork()); if (!trt_network) { - return tensorflow::errors::Internal( - "Failed to create TensorRT network object"); + return errors::Internal("Failed to create TensorRT network object"); } // Build the network VLOG(1) << "Starting engine conversion "; Converter converter(trt_network.get(), precision_mode, use_calibration); - std::vector> output_tensors; + std::vector output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { string node_name = node_def.name(); VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); - if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && - (node_def.op() == "Placeholder")) { + if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kInputPHName), &slot_number)) { - return tensorflow::errors::InvalidArgument( - "Failed to parse slot number from ", node_name); + return errors::InvalidArgument("Failed to parse slot number from ", + node_name); } nvinfer1::DataType trt_dtype; nvinfer1::Dims trt_dims; @@ -3761,18 +4126,23 @@ tensorflow::Status ConvertGraphDefToEngine( // engines offline, by calling sess.run() and cache/serialize the engines. TF_RETURN_IF_ERROR( converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); - } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && - (node_def.op() == "Identity")) { + } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kOutputPHName), &slot_number)) { - return tensorflow::errors::InvalidArgument( - "Failed to parse slot number from ", node_name); + return errors::InvalidArgument("Failed to parse slot number from ", + node_name); } + // Get output type that TensorFlow expects + TFAttrs attrs(node_def); + DataType tf_dtype = attrs.get("T"); + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); } - output_tensors.at(slot_number) = {node_def.input(0), node_name}; + output_tensors.at(slot_number) = {node_def.input(0), node_name, + trt_dtype}; } else { VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); @@ -3789,18 +4159,17 @@ tensorflow::Status ConvertGraphDefToEngine( VLOG(1) << "Starting engine creation"; engine->reset(builder->buildCudaEngine(*converter.network())); if (engine->get() == nullptr) { - return tensorflow::errors::Internal("Failed to build TensorRT engine"); + return errors::Internal("Failed to build TensorRT engine"); } VLOG(1) << "Finished conversion"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertSegmentToGraphDef( - const tensorflow::Graph* graph, - const tensorflow::grappler::GraphProperties& graph_properties, +Status ConvertSegmentToGraphDef( + const Graph* graph, const grappler::GraphProperties& graph_properties, const std::vector& subgraph_nodes, // In topological order - std::vector* connections, - tensorflow::GraphDef* segment_def, string* common_scope) { + std::vector* connections, GraphDef* segment_def, + string* scope_name) { std::set marker_nodes; // Update connection shapes/data types and add corresponding input/output // nodes in the segment graphdef. @@ -3810,12 +4179,12 @@ tensorflow::Status ConvertSegmentToGraphDef( auto outside_node = graph->FindNodeId(connection.outside_id); if (!outside_node) { // This should never happen, unless the original graph is problematic. - return tensorflow::errors::NotFound( - "Cannot find node with id ", connection.outside_id, " in the graph."); + return errors::NotFound("Cannot find node with id ", + connection.outside_id, " in the graph."); } // Updates the shape and data types of input/output connections. - tensorflow::DataType dtype; - tensorflow::PartialTensorShape partial_shape; + DataType dtype; + PartialTensorShape partial_shape; if (connection.is_input_edge) { GetOutputProperties(graph_properties, graph->FindNodeId(connection.outside_id), @@ -3841,7 +4210,7 @@ tensorflow::Status ConvertSegmentToGraphDef( } marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); - tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); + NodeDefBuilder builder(node_name, "Placeholder"); auto status = builder.Attr("shape", partial_shape) .Attr("dtype", dtype) .Finalize(seg_node); @@ -3860,7 +4229,7 @@ tensorflow::Status ConvertSegmentToGraphDef( } marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); - tensorflow::NodeDefBuilder builder(node_name, "Identity"); + NodeDefBuilder builder(node_name, "Identity"); auto status = builder .Input(connection.inside_node_name, connection.inside_port, dtype) @@ -3879,7 +4248,7 @@ tensorflow::Status ConvertSegmentToGraphDef( local_scope = GetCommonNameScope(local_scope, node->name()); old_to_new_id_map[node->id()] = segment_def->node_size(); auto snode = segment_def->add_node(); - snode->CopyFrom(node->def()); + *snode = node->def(); VLOG(2) << "Copying " << snode->name() << " to subgraph"; } // Update the inputs of the new input nodes to point to placeholder nodes. @@ -3910,14 +4279,14 @@ tensorflow::Status ConvertSegmentToGraphDef( TensorId input = ParseTensorName(snode->input(input_idx)); if (!subgraph_node_names.count( string(input.first.data(), input.first.size())) && - !str_util::StartsWith(input.first, kInputPHName)) { + !IsEngineInput(input.first)) { if (input.second == Graph::kControlSlot) { VLOG(1) << "... removing control inputs " << input.first << " from subgraph."; ++input_idx; continue; } else { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Found non control input outside the segment that is not an " "engine connection to ", snode->name(), ": ", input.first); @@ -3933,13 +4302,11 @@ tensorflow::Status ConvertSegmentToGraphDef( snode->mutable_input()->RemoveLast(); } } - *common_scope = local_scope; - VLOG(1) << "Converted TensorRT candidate segment @scope '" << local_scope - << "' to a GraphDef"; - return tensorflow::Status::OK(); + *scope_name = local_scope; + return Status::OK(); } -bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { +bool OutputEdgeValidator::operator()(const Edge* out_edge) const { if (out_edge->IsControlEdge()) return true; if (out_edge->src()->type_string() == "Const") { VLOG(1) << "--> Need to remove output node " << out_edge->src()->name() diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h similarity index 82% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 4ea5775f04ca8faa9e5019d43ba146a8e5ff41b7..068482a3f64b474199ca73b4e46a3938316e6880 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ #include #include @@ -22,11 +22,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -43,6 +43,12 @@ extern const char* const kOutputPHName; namespace convert { +#define IS_TRT_VERSION_GE(major, minor, patch) \ + ((NV_TENSORRT_MAJOR > major) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH >= patch)) + struct EngineConnection { // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, @@ -74,14 +80,14 @@ struct EngineConnection { const string outside_node_name; const int outside_id; const int outside_port; - tensorflow::PartialTensorShape outside_shape; // Only set for input edge. + PartialTensorShape outside_shape; // Only set for input edge. const string inside_node_name; const int inside_id; const int inside_port; - tensorflow::PartialTensorShape inside_shape; // Only set for output edge. + PartialTensorShape inside_shape; // Only set for output edge. - tensorflow::DataType connection_type; + DataType connection_type; const bool is_input_edge; // The port number of the TRT node connected with this edge. @@ -92,12 +98,12 @@ struct EngineInfo { EngineInfo() : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), - precision_mode(FP32MODE), + precision_mode(TrtPrecisionMode::FP32), use_calibration(true) {} string engine_name; string device; - tensorflow::GraphDef segment_graph_def; + GraphDef segment_graph_def; // Non-control input connections inside this vector are sorted in a way such // that, the segment nodes connecting to them are topological sorted. @@ -109,7 +115,7 @@ struct EngineInfo { int64 max_workspace_size_bytes; int maximum_cached_engines; std::vector cached_engine_batches; - int precision_mode; + TrtPrecisionMode precision_mode; bool use_calibration; }; @@ -123,14 +129,14 @@ struct EngineInfo { // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be // sorted in topological order. +// - scope_name: the name of the scope where the TRTEngineOp will be placed. // // TODO(aaroey): add tests to validate these properties. -tensorflow::Status ConvertSegmentToGraphDef( - const tensorflow::Graph* graph, - const tensorflow::grappler::GraphProperties& graph_properties, +Status ConvertSegmentToGraphDef( + const Graph* graph, const grappler::GraphProperties& graph_properties, const std::vector& subgraph_nodes, - std::vector* connections, - tensorflow::GraphDef* segment_def, string* common_scope); + std::vector* connections, GraphDef* segment_def, + string* scope_name); // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff // 'builder' successfully build the engine. If the result is not ok, 'engine' @@ -140,12 +146,11 @@ tensorflow::Status ConvertSegmentToGraphDef( // - convert_successfully: indicates whether the converson to TensorRT network // is successful. This is different than successfully building the engine: // building can still fail afterwards. -tensorflow::Status ConvertGraphDefToEngine( - const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, +Status ConvertGraphDefToEngine( + const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size, size_t max_workspace_size_bytes, - const std::vector& input_shapes, - Logger* logger, nvinfer1::IGpuAllocator* allocator, - TRTInt8Calibrator* calibrator, + const std::vector& input_shapes, Logger* logger, + nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully); @@ -155,7 +160,7 @@ class OutputEdgeValidator { public: // Return true if the specified edge is eligible to be an output edge of the // TRT segment. - bool operator()(const tensorflow::Edge* out_edge) const; + bool operator()(const Edge* out_edge) const; }; string DebugString(const nvinfer1::DimensionType type); @@ -178,6 +183,8 @@ class TRT_ShapedWeights { nvinfer1::Weights GetTrtWeights() const; + // Returns the raw pointer to the underlying buffer which holds the weights + // value. void* GetValues() const { return const_cast(tensor_.tensor_data().data()); } @@ -188,15 +195,30 @@ class TRT_ShapedWeights { string DebugString() const; + template + absl::Span GetSpan() const { + return absl::Span(tensor_.flat().data(), count()); + } + + template + std::vector ToVector() const { + auto span = GetSpan(); + return std::vector(span.data(), span.data() + span.size()); + } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. - tensorflow::DataType type_; + DataType type_; private: // This constructor is only used by TrtWeightStore, which creates the // underlying buffer. TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); + // All weights should be stored inside TrtWeightStore to make sure lifetime of + // all the underlying tensors are available until the engine is built. For + // this reason, tensor_ should never be reassigned to a different value that + // is not already present in the TrtWeightStore. Tensor tensor_; friend class TrtWeightStore; @@ -212,8 +234,7 @@ class TRT_ShapedWeights { class TrtWeightStore { public: // Get a TRT_ShapedWeights with 'type' and 'dims'. - TRT_ShapedWeights GetTempWeights(tensorflow::DataType type, - const nvinfer1::Dims& dims); + TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims); // Get a TRT_ShapedWeights with the same data type and dimensions as // 'weights'. @@ -324,8 +345,7 @@ class Converter; // Parameters for each op converter. struct OpConverterParams { - OpConverterParams(Converter* arg_converter, - const tensorflow::NodeDef& arg_node_def, + OpConverterParams(Converter* arg_converter, const NodeDef& arg_node_def, const std::vector& arg_inputs, std::vector* arg_outputs, bool arg_validation_only, TrtWeightStore* arg_weight_store) @@ -337,7 +357,7 @@ struct OpConverterParams { weight_store(arg_weight_store) {} Converter* converter; - const tensorflow::NodeDef& node_def; + const NodeDef& node_def; const std::vector& inputs; std::vector* outputs; const bool validation_only; @@ -362,9 +382,12 @@ class TrtNodeValidator { Status ValidateNode( const NodeDef& node_def, const std::vector>& input_node_and_ports, + const TrtPrecisionMode precision_mode, const grappler::GraphProperties& graph_properties); private: + static const std::set* quantize_ops; + void RegisterOpValidators(); // Convert a Const node to a TRT_TensorOrWeights. @@ -396,28 +419,38 @@ class TrtNodeValidator { // Class to convert TF nodes to TRT network. class Converter { public: - Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode, - bool use_calibration); + // Used for Converter::RenameAndMarkOutputTensors() + struct EngineOutputInfo { + // The TRT tensor name which produces the output. + string source_tensor_name; + // The TensorFlow node name which is receiving the output from the TRT + // engine. This should always be the Identity node created in + // ConvertSegmentToGraphDef. + string dest_node_name; + // Output type. TensorRT requires this to be explicitly set for engine + // outputs. + nvinfer1::DataType trt_dtype; + }; + + Converter(nvinfer1::INetworkDefinition* trt_network, + TrtPrecisionMode precision_mode, bool use_calibration); ////////////////////////////////////////////////////////////////////////////// // Methods used by the TRT engine builder to build a TRT network from a TF // function/subgraph. // Convert the node to TRT network. - Status ConvertNode(const tensorflow::NodeDef& node_def); + Status ConvertNode(const NodeDef& node_def); // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and // 'batch_size'. Status AddInputTensor(const string& name, nvinfer1::DataType dtype, const nvinfer1::Dims& dims, int batch_size); - // Mark the tensors with names specified by output_tensors[i].first as output - // of the TRT network, and set their names in the TRT network as - // output_tensors[i].second. The tensor names (output_tensors[i].first) are - // standard TF tensor names, i.e. node names followed by output slot number - // (or just the node name if the tensor is the first output of the node). + // Mark the tensors with names specified by source_tensor_name as output of + // the TRT network, and set their names in the TRT network as dest_node_name. Status RenameAndMarkOutputTensors( - const std::vector>& output_tensors); + const std::vector& output_tensors); ////////////////////////////////////////////////////////////////////////////// // Methods used by op converters to convert individual TF node and add layers @@ -428,7 +461,7 @@ class Converter { nvinfer1::INetworkDefinition* network() { return trt_network_; } // What precision are we targeting? - int precision_mode() const { return precision_mode_; } + TrtPrecisionMode precision_mode() const { return precision_mode_; } // Calibration will be or was previously performed on this network? bool use_calibration() const { return use_calibration_; } @@ -460,8 +493,13 @@ class Converter { const nvinfer1::ITensor** output_tensor); // Converts 'input' into 'tensor' with shape specified by 'dims'. + // + // If validation_only is true, it doesn't do the conversion but only do some + // minimum validation for the eligibility of the conversion, and *tensor will + // be set to nullptr. Status PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, + const bool validation_only, const nvinfer1::ITensor** tensor); // Return OK if the broadcast scheme is supported and compute the shapes after @@ -488,7 +526,7 @@ class Converter { Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output); // Get the inputs of 'node_def' from trt_tensors_. - Status GetInputs(const tensorflow::NodeDef& node_def, + Status GetInputs(const NodeDef& node_def, std::vector* inputs) const; void RegisterOpConverters(); @@ -530,7 +568,7 @@ class Converter { std::vector> quantization_infer_; - const int precision_mode_; + const TrtPrecisionMode precision_mode_; const bool use_calibration_; @@ -544,6 +582,9 @@ class Converter { friend class OpConverterTest; }; +// Map of all supported UnaryOperations +const std::unordered_map* UnaryOperationMap(); + } // namespace convert } // namespace tensorrt } // namespace tensorflow @@ -551,4 +592,4 @@ class Converter { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc similarity index 62% rename from tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index c8739b46c150b04c0e28bf01b821ae5f6ea5c1d2..853b313367c9439c81ecb57b33e97bf8a1c1c481 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -21,11 +21,17 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -35,7 +41,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/public/session.h" @@ -50,9 +58,10 @@ namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::NanSensitiveFloatNear; // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -101,13 +110,17 @@ DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { } NodeDef MakeNodeDef(const string& name, const string& op, - const std::vector& inputs) { + const std::vector& inputs, + const std::map attrs = {}) { NodeDef node_def; node_def.set_name(name); node_def.set_op(op); for (const string& input : inputs) { node_def.add_input(input); } + for (const auto& attr : attrs) { + (*node_def.mutable_attr())[attr.first] = attr.second; + } return node_def; } @@ -115,7 +128,7 @@ template NodeDef MakeConstNodeDef(const string& name, const std::vector& vals, const TensorShape& shape) { Scope s = Scope::NewRootScope(); - Tensor t = ::tensorflow::test::AsTensor(vals, shape); + Tensor t = test::AsTensor(vals, shape); auto const_op = ops::Const(s.WithOpName(name), t); return const_op.node()->def(); } @@ -152,7 +165,7 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, } template -void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { +void ExpectArrayNear(const std::vector& lhs, absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); for (int i = 0; i < lhs.size(); i++) { EXPECT_FLOAT_EQ(lhs[i], rhs[i]); @@ -163,7 +176,7 @@ void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { // EXPECT_FLOAT_EQ. template <> void ExpectArrayNear(const std::vector& lhs, - const std::vector& rhs) { + absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); for (int i = 0; i < lhs.size(); i++) { EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]), @@ -225,7 +238,7 @@ class FakeITensor : public nvinfer1::ITensor { location_ = location; } -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) bool setDynamicRange(float min, float max) override { dynamic_range_ = std::max(std::abs(min), std::abs(max)); return true; @@ -234,6 +247,16 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return dynamic_range_; } #endif +#if IS_TRT_VERSION_GE(5, 1, 0) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: string name_; nvinfer1::Dims dims_; @@ -371,8 +394,8 @@ TEST(TRT_TensorOrWeights_Test, Basic) { class ValidatorTest : public ::testing::Test { public: - void AddOpValidator(const string& op_name, OpConverter op_validator) { - validator_.op_validators_[op_name] = op_validator; + std::unordered_map& op_validators() { + return validator_.op_validators_; } Status ConvertToTensorOrWeights( @@ -383,10 +406,18 @@ class ValidatorTest : public ::testing::Test { node_def, output_port, graph_properties, tensor_or_weights); } + const std::set* GetQuantizeOps() { return validator_.quantize_ops; } + protected: TrtNodeValidator validator_; }; +TEST_F(ValidatorTest, QuantizeOpsAreRegistered) { + for (const string& quantize_op : *GetQuantizeOps()) { + QCHECK(op_validators().count(quantize_op)); + } +} + TEST_F(ValidatorTest, ConvertToTensorOrWeights) { // Convert Const. { @@ -459,18 +490,30 @@ TEST_F(ValidatorTest, ValidateNode) { }; NodeDef node_def = MakeNodeDef("my_op", "MyOp", {}); - // Validator not registered, validation should pass. - TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); + // Validator not registered. + ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32, + graph_properties), + error::UNIMPLEMENTED, "Op type MyOp is not supported."); // Register validator. - AddOpValidator("MyOp", op_converter); - TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); + op_validators()["MyOp"] = op_converter; + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32, + graph_properties)); EXPECT_EQ(false, start_conversion); // Let the converter return error. should_fail = true; - ExpectStatus(validator_.ValidateNode(node_def, {}, graph_properties), + ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32, + graph_properties), error::INVALID_ARGUMENT); + + // Test quantization ops, they're only supported in INT8 mode. The success + // case is tested in OpConverterTest.ConvertQuantize. + node_def = MakeNodeDef("my_op", "FakeQuantWithMinMaxArgs", {}); + ExpectStatus(validator_.ValidateNode(node_def, {}, TrtPrecisionMode::FP32, + graph_properties), + error::UNIMPLEMENTED, + "Op type FakeQuantWithMinMaxArgs is not supported."); } class ConverterTest : public ::testing::Test { @@ -478,8 +521,7 @@ class ConverterTest : public ::testing::Test { ConverterTest() { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); - converter_.reset(new Converter(network_.get(), - /*precision_mode=*/FP32MODE, + converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32, /*use_calibration=*/false)); weight_store_ = &converter_->weight_store_; } @@ -674,23 +716,34 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { TRT_TensorOrWeights tw(input_tensor); const nvinfer1::ITensor* output_tensor = nullptr; - // Shape size doesn't match. - ExpectStatus(converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}), - &output_tensor), - error::INVALID_ARGUMENT, "Reshape shapes are not compatible"); - - // TODO(aaroey): we should check the case where uninferred dimensions are not - // an exact divisor of input dim ensions, e.g. for dims {-1, 7}. - - // Infer shape, ok. - TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({-1, 2}), - &output_tensor)); - ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()); + for (bool validation_only : {false, true}) { + // Shape size doesn't match. + ExpectStatus( + converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}), + validation_only, &output_tensor), + error::INVALID_ARGUMENT, "Reshape shapes are not compatible"); + + // TODO(aaroey): we should check the case where uninferred dimensions are + // not an exact divisor of input dim ensions, e.g. for dims {-1, 7}. + + // Infer shape, ok. + TF_EXPECT_OK(converter_->PrepareTensorForShape( + tw, GetTestDims({-1, 2}), validation_only, &output_tensor)); + if (validation_only) { + EXPECT_EQ(nullptr, output_tensor); + } else { + ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()); + } - // Regular shape. - TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), - &output_tensor)); - ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); + // Regular shape. + TF_EXPECT_OK(converter_->PrepareTensorForShape( + tw, GetTestDims({10, 3}), validation_only, &output_tensor)); + if (validation_only) { + EXPECT_EQ(nullptr, output_tensor); + } else { + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); + } + } } TEST_F(ConverterTest, PrepareTensorForShape_Weights) { @@ -698,9 +751,15 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5})); TRT_TensorOrWeights tw(weights); const nvinfer1::ITensor* output_tensor = nullptr; - TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), - &output_tensor)); - ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); + for (bool validation_only : {false, true}) { + TF_EXPECT_OK(converter_->PrepareTensorForShape( + tw, GetTestDims({10, 3}), validation_only, &output_tensor)); + if (validation_only) { + EXPECT_EQ(nullptr, output_tensor); + } else { + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); + } + } } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -781,7 +840,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // input -> infer1 -> infer2 -> infer3 FakeITensor input, infer_1, infer_2, infer_3; FakeITensor not_infer; - Converter int8_converter(/*trt_network=*/nullptr, INT8MODE, + Converter int8_converter(/*trt_network=*/nullptr, TrtPrecisionMode::INT8, /*use_calibration=*/true); int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f); int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); @@ -791,7 +850,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // Input range should be inferred along the chain and applied to tensors. int8_converter.MaybeApplyQuantizationRanges(); -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0) EXPECT_EQ(input.getDynamicRange(), 5.0f); EXPECT_EQ(infer_1.getDynamicRange(), 5.0f); EXPECT_EQ(infer_2.getDynamicRange(), 5.0f); @@ -926,6 +985,83 @@ TEST_F(ConverterTest, CreateConstantLayer) { } } +class ConvertGraphDefToEngineTest : public ::testing::Test { + public: + Status RunConvertGraphDefToEngine(Scope* s) { + GraphDef gdef; + TF_EXPECT_OK(s->ToGraphDef(&gdef)); + std::vector input_shapes; + int batch_size = -1; + for (const NodeDef& node : gdef.node()) { + absl::string_view node_name(node.name()); + if (str_util::ConsumePrefix(&node_name, kInputPHName)) { + int port = -1; + EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name(); + if (input_shapes.size() < port + 1) input_shapes.resize(port + 1); + input_shapes[port] = + PartialTensorShape(node.attr().at("shape").shape()); + if (batch_size == -1) { + batch_size = input_shapes[port].dim_size(0); + } else { + EXPECT_EQ(batch_size, input_shapes[port].dim_size(0)); + } + } + } + // TODO(laigd): execute the engine and get outputs. + return ConvertGraphDefToEngine( + gdef, TrtPrecisionMode::FP32, /*max_batch_size=*/1, + /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_, + /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_, + /*use_calibration=*/false, /*convert_successfully=*/nullptr); + } + + protected: + TrtUniquePtrType engine_; + + private: + Logger logger_; +}; + +TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT, + ops::Placeholder::Shape({1, 1})); + auto output = ops::Identity(s.WithOpName("identity1"), input); + output = ops::Identity(s.WithOpName("identity2"), output); + output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output); + // If the converter marks the input tensor as output tensor, the conversion + // below will fail with: + // > TensorRTOutputPH_0 cannot be both input and output + // > Network must have at least one output + TF_EXPECT_OK(RunConvertGraphDefToEngine(&s)); +} + +// Input/output data format for OpConverterTest::BuildAndRun(). +struct InputOutputData { + void* Buffer() const { + return const_cast(tensor.tensor_data().data()); + } + + size_t TotalBytes() const { return tensor.TotalBytes(); } + + const char* name; + Tensor tensor; +}; + +template +Tensor ConstructTensor(int data_size, const T& value = T()) { + std::vector values(data_size, value); + return test::AsTensor(values); +} + +using DataVec = std::vector; + +template +inline absl::Span GetSpanForData(const InputOutputData& data) { + const auto& tensor_map = data.tensor.flat(); + return absl::Span(tensor_map.data(), tensor_map.size()); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -951,11 +1087,11 @@ class OpConverterTest : public ::testing::Test { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); builder_->setMaxBatchSize(1); + builder_->setMaxWorkspaceSize(1 << 26); // Reset the validator and converter. validator_.reset(new TrtNodeValidator); - converter_.reset(new Converter(network_.get(), - /*precision_mode=*/FP32MODE, + converter_.reset(new Converter(network_.get(), precision_mode_to_test_, /*use_calibration=*/false)); // Reset other related artifacts. @@ -963,47 +1099,85 @@ class OpConverterTest : public ::testing::Test { validator_inputs_.clear(); } - // TODO(laigd): test fp16 and int8 support. - template - void BuildAndRun( - const std::vector>>& - input_data, - const char* output_name, std::vector* output_data) { + void CheckDataTypeMatches(const DataVec& datas) { + for (const auto& data : datas) { + const int input_index = engine_->getBindingIndex(data.name); + ASSERT_NE(-1, input_index); + const nvinfer1::DataType trt_dtype = + engine_->getBindingDataType(input_index); + const DataType tf_dtype = TrtDataTypeToTf(trt_dtype); + ASSERT_EQ(data.tensor.dtype(), tf_dtype) + << DataTypeString(data.tensor.dtype()) << " vs. " + << DataTypeString(tf_dtype); + } + } + + // TODO(laigd): test fp16 and int8 support for more converters. + void BuildAndRun(const DataVec& input_data, DataVec* output_data, + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32) { // Mark the output tensor as TRT engine output. - TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors( - {{string(output_name), string(output_name)}})); + std::vector output_info; + for (const auto& data : *output_data) { + output_info.push_back( + {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); + } + TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. + if (precision_mode == TrtPrecisionMode::FP16) { + builder_->setFp16Mode(true); + } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder_->setFp16Mode(true); + builder_->setInt8Mode(true); + } ASSERT_EQ(nullptr, engine_.get()); engine_.reset(builder_->buildCudaEngine(*converter_->network())); CHECK_NOTNULL(engine_.get()); + CheckDataTypeMatches(input_data); + CheckDataTypeMatches(*output_data); // Execute the TRT engine. - ASSERT_LE(input_data.size() + 1, 3); - void* buffers[3]; - for (const auto name_and_data : input_data) { - const int input_size = name_and_data.second.size() * sizeof(T); - const int input_index = engine_->getBindingIndex(name_and_data.first); - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); - ASSERT_EQ( - 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(), - input_size, cudaMemcpyHostToDevice, stream_)); + const int num_bindings = input_data.size() + output_data->size(); + std::vector buffers(num_bindings); + + for (const auto& data : input_data) { + const int input_index = engine_->getBindingIndex(data.name); + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes())); + ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(), + data.TotalBytes(), cudaMemcpyHostToDevice, + stream_)); + } + struct SizeAndIndex { + SizeAndIndex(int in_size, int in_index) + : size(in_size), index(in_index) {} + int size; + int index; + }; + std::vector output_infos; + for (const auto& data : *output_data) { + const int output_index = engine_->getBindingIndex(data.name); + output_infos.emplace_back(data.TotalBytes(), output_index); + ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes())); } - const int output_size = output_data->size() * sizeof(T); - const int output_index = engine_->getBindingIndex(output_name); - ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size)); - - ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1); - + ASSERT_EQ(engine_->getNbBindings(), num_bindings); TrtUniquePtrType execution_context( engine_->createExecutionContext()); - execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr); - ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index], - output_size, cudaMemcpyDeviceToHost, stream_)); + execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_, + nullptr); + + for (int i = 0; i < output_infos.size(); ++i) { + const auto& output_info = output_infos[i]; + ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(), + buffers[output_info.index], output_info.size, + cudaMemcpyDeviceToHost, stream_)); + } cudaStreamSynchronize(stream_); - for (int i = 0; i < input_data.size() + 1; ++i) { + for (int i = 0; i < num_bindings; ++i) { ASSERT_EQ(0, cudaFree(buffers[i])); } } @@ -1075,9 +1249,10 @@ class OpConverterTest : public ::testing::Test { grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - ExpectStatus(validator_->ValidateNode(node_def, input_node_and_ports, - graph_properties), - expected_code, expected_msg_substr); + ExpectStatus( + validator_->ValidateNode(node_def, input_node_and_ports, + precision_mode_to_test_, graph_properties), + expected_code, expected_msg_substr); } void RunConversion(const NodeDef& node_def, @@ -1107,6 +1282,10 @@ class OpConverterTest : public ::testing::Test { std::unique_ptr converter_; std::unique_ptr validator_; + protected: + // TODO(laigd): parameterize the test and make the precision mode a parameter. + TrtPrecisionMode precision_mode_to_test_ = TrtPrecisionMode::FP32; + private: Logger logger_; TrtUniquePtrType builder_; @@ -1122,6 +1301,30 @@ class OpConverterTest : public ::testing::Test { std::unordered_map validator_inputs_; }; +template +void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { + out->Clear(); + if (tensor.NumElements() == 0) return; + + // TensorProto does not need to have all the elements present and can truncate + // trailing elements with the same value for compressed representation. Such + // elements are derived based on the tensor shape. + const auto flat = tensor.flat(); + int64 last_index = 0; + for (int64 i = 0; i < tensor.NumElements(); ++i) { + if (flat(i) != flat(last_index)) { + last_index = i; + } + } + + int num_out_elements = last_index + 1; + out->Reserve(num_out_elements); + out->AddNAlreadyReserved(num_out_elements); + const T* src = flat.data(); + T* dst = out->mutable_data(); + std::copy(src, src + num_out_elements, dst); +} + template void TestConvertConst(OpConverterTest* test) { NodeDef node_def; @@ -1134,11 +1337,23 @@ void TestConvertConst(OpConverterTest* test) { const std::vector& expected_value) { test->Reset(); - auto& attr = *node_def.mutable_attr(); + TensorProto* tensor_attr = + (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor_attr->Clear(); + if (as_tensor_content) { - tensor.AsProtoTensorContent(attr["value"].mutable_tensor()); + tensor.AsProtoTensorContent(tensor_attr); } else { - tensor.AsProtoField(attr["value"].mutable_tensor()); + tensor.shape().AsProto(tensor_attr->mutable_tensor_shape()); + tensor_attr->set_dtype(tensor.dtype()); + + if (tensor.dtype() == DT_FLOAT) { + CopyTensorElements(tensor, tensor_attr->mutable_float_val()); + } else if (tensor.dtype() == DT_INT32) { + CopyTensorElements(tensor, tensor_attr->mutable_int_val()); + } else { + tensor.AsProtoField(tensor_attr); + } } test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; @@ -1151,26 +1366,41 @@ void TestConvertConst(OpConverterTest* test) { { // By default empty tensor will pick DT_FLOAT as data type and we fix it // here. - attr["value"].mutable_tensor()->set_dtype(dtype); - Tensor t; // Empty tensor. + Tensor t(dtype); // Empty tensor. reset_and_test(t, false, {}, {}); } { - Tensor t = ::tensorflow::test::AsScalar(12); + Tensor t = test::AsScalar(12); reset_and_test(t, false, {1}, {12}); reset_and_test(t, true, {1}, {12}); } { - Tensor t = ::tensorflow::test::AsTensor({1, 2}); + Tensor t = test::AsTensor({1, 2}); reset_and_test(t, false, {2}, {1, 2}); reset_and_test(t, true, {2}, {1, 2}); } { - Tensor t = ::tensorflow::test::AsTensor({1, 2, 3, 4, 5, 6}, - TensorShape({2, 3})); + Tensor t = + test::AsTensor({1, 2, 3, 4, 5, 6}, TensorShape({2, 3})); reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6}); reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6}); } + { + // Set all tensor elements to the same value. Such tensors are encoded + // using a single element list in tensor proto. + Tensor t = + test::AsTensor({1, 1, 1, 1, 1, 1}, TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1}); + } + { + // Set trailing tensor elements to the same value. Such tensors are + // encoded by truncating all equal elements except the first one. + Tensor t = + test::AsTensor({2, 2, 1, 1, 1, 1}, TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1}); + } } TEST_F(OpConverterTest, ConvertConst) { @@ -1200,7 +1430,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_transpose"); + "Transpose got 0 inputs but expected 2, at my_transpose"); } // Get the NodeDef for Transpose. @@ -1216,8 +1446,8 @@ TEST_F(OpConverterTest, ConvertTranspose) { AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_transpose"); + node_def, error::UNIMPLEMENTED, + "The input \"perm\" for Transpose must be a constant, at my_transpose"); } { // Transpose at batch dimension, should fail. @@ -1247,10 +1477,12 @@ TEST_F(OpConverterTest, ConvertTranspose) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_transpose", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 4, 2, 5, 3, 6)); } } @@ -1260,7 +1492,7 @@ TEST_F(OpConverterTest, ConvertReshape) { NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects weights for shape, at my_reshape"); + "Reshape got 0 inputs but expected 2, at my_reshape"); } // Get the NodeDef for Reshape. @@ -1276,8 +1508,8 @@ TEST_F(OpConverterTest, ConvertReshape) { AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Input expects weights for shape, at my_reshape"); + node_def, error::UNIMPLEMENTED, + "The input \"shape\" for Reshape must be a constant, at my_reshape"); } { // Reshape to scalar, should fail. @@ -1290,11 +1522,6 @@ TEST_F(OpConverterTest, ConvertReshape) { } struct TestParams { - TestParams(int input_batch_size, const std::vector& input_tensor_dims, - const std::vector& input_shape) - : batch_size(input_batch_size), - tensor_dims(input_tensor_dims), - shape(input_shape) {} int batch_size; std::vector tensor_dims; std::vector shape; @@ -1337,10 +1564,12 @@ TEST_F(OpConverterTest, ConvertReshape) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_reshape", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -1350,7 +1579,7 @@ TEST_F(OpConverterTest, ConvertMatMul) { NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_matmul"); + "MatMul got 0 inputs but expected 2, at my_matmul"); } // Get the NodeDef for MatMul. @@ -1372,9 +1601,9 @@ TEST_F(OpConverterTest, ConvertMatMul) { NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false); AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); AddTestWeights("weights", {2, 1}, {3, 5}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "Data type is not supported, for node my_matmul got int32"); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Data type int32 is not supported for MatMul, " + "must be one of [float, half], at my_matmul"); } // transpose_a is set. for (bool transpose_b : {false, true}) { @@ -1400,12 +1629,13 @@ TEST_F(OpConverterTest, ConvertMatMul) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); - std::vector output_data(2); - BuildAndRun({{"input", {0, 1}}}, "my_matmul", &output_data); + const DataVec input_data{{"input", test::AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", ConstructTensor(2)}}; + BuildAndRun(input_data, &output_data); if (transpose_b) { - EXPECT_THAT(output_data, ElementsAre(1, 3)); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { - EXPECT_THAT(output_data, ElementsAre(2, 3)); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(2, 3)); } } } @@ -1459,23 +1689,28 @@ void TestConvertBiasAdd(OpConverterTest* test) { const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), num_input); - std::vector output_data(num_input); - test->BuildAndRun( - {{"input", std::vector(num_input, CType(0))}}, "my_biasadd", - &output_data); + + const DataVec input_data{ + {"input", ConstructTensor(num_input, CType(0))}}; + DataVec output_data{{"my_biasadd", ConstructTensor(num_input)}}; + test->BuildAndRun(input_data, &output_data); if (trt_input_rank == 1) { if (data_format == "NHWC") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2), CType(3))); } else { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2))); } } else { if (data_format == "NHWC") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3), - CType(1), CType(2), CType(3))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2), CType(3), CType(1), + CType(2), CType(3))); } else { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1), - CType(2), CType(2), CType(2))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(1), CType(1), CType(2), + CType(2), CType(2))); } } } @@ -1488,7 +1723,7 @@ TEST_F(OpConverterTest, ConvertBiasAdd) { NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_biasadd"); + "BiasAdd got 0 inputs but expected 2, at my_biasadd"); } // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test @@ -1553,21 +1788,27 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(2); - test->BuildAndRun( - {{"input", - /*input_data=*/swap_inputs ? operand2 : operand1}}, - "my_binary", &output_data); + const DataVec input_data{ + {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; + DataVec output_data{{"my_binary", ConstructTensor(2)}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(5), CType(10.5))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(4.5))); } else if (node_def.op() == "Mul") { - EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(6), CType(22.5))); } else if (node_def.op() == "Div") { - EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1.5), CType(2.5))); } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1.5), CType(2.5))); } else { ASSERT_TRUE(false); } @@ -1602,13 +1843,14 @@ void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); - test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + test->BuildAndRun(input_data, &output_data); if (weights_dims.size() == 1) { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(22), CType(13), CType(24))); } else { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(12), CType(23), CType(24))); } } @@ -1636,9 +1878,10 @@ void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); - test->BuildAndRun({{"input", input}}, "my_binary", &output_data); - EXPECT_THAT(output_data, + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(12), CType(13), CType(14))); } @@ -1686,17 +1929,19 @@ void TestBinaryTensorOpWeightFallback(OpConverterTest* test, // Check the result of running the engine. const int expected_num_outputs = TrtDimsNumElements(GetTestDims(expected_output_dims)); - std::vector output_data(expected_num_outputs); - test->BuildAndRun( - {{"input", - /*input_data=*/std::vector(num_inputs, CType(2))}}, - "my_binary", &output_data); + const DataVec input_data{ + {"input", ConstructTensor(num_inputs, CType(2))}}; + DataVec output_data{ + {"my_binary", ConstructTensor(expected_num_outputs)}}; + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, ElementsAreArray(std::vector( - expected_num_outputs, CType(3)))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(std::vector(expected_num_outputs, CType(3)))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(output_data, ElementsAreArray(std::vector( - expected_num_outputs, CType(1)))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(std::vector(expected_num_outputs, CType(1)))); } else { ASSERT_TRUE(false); } @@ -1723,51 +1968,64 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); + const DataVec input_data{ + {"input1", test::AsTensor({CType(3), CType(6)})}, + {"input2", test::AsTensor({CType(2), CType(3)})}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun( - {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}}, - "my_binary", &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(8), CType(6), CType(9))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1), CType(4), CType(0), CType(3))); } else if (node_def.op() == "Mul") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(6), CType(12), CType(9), CType(18))); } else if (node_def.op() == "Div") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(2), CType(2), CType(3), CType(3))); } else if (node_def.op() == "Maximum") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(3), CType(6), CType(3), CType(6))); + } else if (node_def.op() == "Pow") { + ExpectArrayNear( + std::vector{CType(9), CType(36), CType(27), CType(216)}, + GetSpanForData(output_data[0])); } else { ASSERT_TRUE(false); } } TEST_F(OpConverterTest, ConvertBinary) { + AttrValue dtype; + dtype.set_type(DT_FLOAT); // Input size doesn't match, should fail. for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { Reset(); - NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); + NodeDef node_def = + MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Binary ops require two inputs, at my_add"); + StrCat("Add got ", std::to_string(num_inputs), + " inputs but expected 2, at my_add") + .c_str()); } { // Both inputs are weights. Reset(); - NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); + NodeDef node_def = + MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}}); AddTestWeights("weights1", {1}, {1}); AddTestWeights("weights2", {1}, {1}); RunValidationAndConversion( @@ -1782,15 +2040,12 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); -#if 0 - // TODO(b/119560144): it doesn't support FP16 constants and the following test - // will fail. + TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); -#endif // Test BinaryTensorOpWeight() with channel-wise broadcasting. TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); @@ -1821,6 +2076,7 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); @@ -1829,17 +2085,23 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); } TEST_F(OpConverterTest, ConvertQuantize) { - for (const string& op : - {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars", - "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) { + precision_mode_to_test_ = TrtPrecisionMode::INT8; + const std::pair op_with_num_inputs[4] = { + {"FakeQuantWithMinMaxArgs", 1}, + {"FakeQuantWithMinMaxVars", 3}, + {"QuantizeAndDequantizeV2", 3}, + {"QuantizeAndDequantizeV3", 4}}; + for (const auto& pair : op_with_num_inputs) { // Input list is empty, should fail. - NodeDef node_def = MakeNodeDef("my_quantize", op, {}); + NodeDef node_def = MakeNodeDef("my_quantize", pair.first, {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - StrCat("Invalid number of inputs for ", op, ", at my_quantize") + StrCat(pair.first, " got 0 inputs but expected ", + std::to_string(pair.second), ", at my_quantize") .c_str()); } { @@ -1926,9 +2188,9 @@ TEST_F(OpConverterTest, ConvertQuantize) { AddTestTensor("weights_min", {1}); AddTestTensor("weights_max", {1}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " - "tensors, at my_quantize"); + node_def, error::UNIMPLEMENTED, + "The input \"input_min\" for QuantizeAndDequantizeV2 must be a constant" + ", at my_quantize"); } { // QuantizeAndDequantizeV3 ranges set via inputs, ok. @@ -1955,46 +2217,6 @@ TEST_F(OpConverterTest, ConvertQuantize) { } } -TEST_F(OpConverterTest, ConvertRelu6) { - { - // Input list is empty, should fail. - NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Invalid number of inputs for Relu6, at my_relu6"); - } - - // Get the NodeDef for Relu6. - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input); - const NodeDef node_def = relu6.operation.node()->def(); - { - // Input is weights, should fail. - Reset(); - AddTestWeights("input", {1}, {1.0f}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "Relu6 is only implemented for tensors, not weights, at my_relu6"); - } - { - // Clip tensor values and set quantization ranges, ok. - Reset(); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output)); - EXPECT_TRUE(output.is_tensor()); - auto ranges = quantization_ranges(); - EXPECT_EQ(ranges[output.tensor()], 6.0f); - - std::vector output_data(6); - BuildAndRun({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6", - &output_data); - EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); - } -} - template void TestConvertSquare(OpConverterTest* test) { test->Reset(); @@ -2005,7 +2227,8 @@ void TestConvertSquare(OpConverterTest* test) { auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - test->AddTestTensor("input", {1, 20}); + test->AddTestTensor("input", {1, 20}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); @@ -2013,24 +2236,30 @@ void TestConvertSquare(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); const int num_inputs = 20; - std::vector input_data(num_inputs); - std::vector expected_output_data(num_inputs); - for (int i = 0; i < 20; i++) { + std::vector inputs(num_inputs); + std::vector expected_outputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { const CType value = CType(i - 9); - input_data[i] = value; - expected_output_data[i] = value * value; - } - std::vector output_data(num_inputs); - test->BuildAndRun({{"input", input_data}}, "my_square", &output_data); - ExpectArrayNear(expected_output_data, output_data); + inputs[i] = value; + expected_outputs[i] = value * value; + } + const DataVec input_data{{"input", test::AsTensor(inputs)}}; + // Engine outputs are converted to FP16 automatically if we set FP16 mode in + // the builder. + DataVec output_data{{"my_square", ConstructTensor(num_inputs)}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } TEST_F(OpConverterTest, ConvertSquare) { { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_square", "Square", {}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Square expects one input, at my_square"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Square got 0 inputs but expected 1, at my_square"); } { // Input is weights, should fail. @@ -2042,15 +2271,13 @@ TEST_F(OpConverterTest, ConvertSquare) { AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Square is only implemented for tensors, at my_square"); + "The input \"x\" for Square must be a tensor, at my_square"); } // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't // test DT_INT32 type here. TestConvertSquare(this); - // TODO(tmorris): Looks like there may be a bug with this layer for FP16 - // inputs. Disabling for now. - // TestConvertSquare(this); + TestConvertSquare(this); } TEST_F(OpConverterTest, ConvertActivation) { @@ -2058,7 +2285,7 @@ TEST_F(OpConverterTest, ConvertActivation) { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_act", "Relu", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Relu expects one input, at my_act"); + "Relu got 0 inputs but expected 1, at my_act"); } { // Input is weights, should fail. @@ -2070,16 +2297,26 @@ TEST_F(OpConverterTest, ConvertActivation) { AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Relu is only implemented for tensors, at my_act"); + "The input \"input\" for Relu must be a tensor, at my_act"); } + constexpr float kAlpha = 0.2f; + // Get nodedef for activation layer. auto get_act_nodedef = [](string op_name) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "Relu") { + if (op_name == "LeakyRelu") { + auto act = + ops::internal::LeakyRelu(s.WithOpName("my_act"), input, + ops::internal::LeakyRelu::Alpha(kAlpha)); + return act.operation.node()->def(); + } else if (op_name == "Relu") { auto act = ops::Relu(s.WithOpName("my_act"), input); return act.operation.node()->def(); + } else if (op_name == "Relu6") { + auto act = ops::Relu6(s.WithOpName("my_act"), input); + return act.operation.node()->def(); } else if (op_name == "Sigmoid") { auto act = ops::Sigmoid(s.WithOpName("my_act"), input); return act.operation.node()->def(); @@ -2092,8 +2329,12 @@ TEST_F(OpConverterTest, ConvertActivation) { }; // Get expected output for activation layer. auto get_act_output = [](string op_name, float input) -> float { - if (op_name == "Relu") { + if (op_name == "LeakyRelu") { + return (input > 0.0f) ? input : input * kAlpha; + } else if (op_name == "Relu") { return (input > 0.0f) ? input : 0.0f; + } else if (op_name == "Relu6") { + return std::min(std::max(input, 0.0f), 6.0f); } else if (op_name == "Sigmoid") { return 1.0f / (1.0f + std::exp(-input)); } else if (op_name == "Tanh") { @@ -2104,7 +2345,8 @@ TEST_F(OpConverterTest, ConvertActivation) { }; // Ok. - for (string op_name : {"Relu", "Sigmoid", "Tanh"}) { + for (const string& op_name : + {"LeakyRelu", "Relu", "Relu6", "Sigmoid", "Tanh"}) { Reset(); NodeDef node_def = get_act_nodedef(op_name); AddTestTensor("input", {1, 2, 3}); @@ -2113,13 +2355,20 @@ TEST_F(OpConverterTest, ConvertActivation) { TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + if (op_name == "Relu6") { + // Relu6 should set quantization range automatically. + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + } - const std::vector input_data = {-100, -2, -1, 0, 1, 100}; - std::vector output_data(6); - BuildAndRun({{"input", input_data}}, "my_act", &output_data); - for (int i = 0; i < input_data.size(); i++) { - const float expected_output = get_act_output(op_name, input_data[i]); - EXPECT_FLOAT_EQ(output_data[i], expected_output); + const std::vector input = {-100, -2, -1, 0, 1, 100}; + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_act", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + for (int i = 0; i < input.size(); i++) { + const float expected_output = get_act_output(op_name, input[i]); + EXPECT_FLOAT_EQ(GetSpanForData(output_data[0])[i], + expected_output); } } } @@ -2130,7 +2379,7 @@ TEST_F(OpConverterTest, ConvertExpandDims) { NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Two inputs expected for ExpandDims, at my_expanddims"); + "ExpandDims got 0 inputs but expected 2, at my_expanddims"); } // Get the NodeDef for ExpandDims. @@ -2145,18 +2394,18 @@ TEST_F(OpConverterTest, ConvertExpandDims) { Reset(); AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestWeights("weights", {1}, {1}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "ExpandDims expects tensor for input, at my_expanddims"); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for ExpandDims must be a " + "tensor, at my_expanddims"); } { // Axis is a tensor, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "ExpandDims expects weights for axis, at my_expanddims"); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"axis\" for ExpandDims must be a " + "constant, at my_expanddims"); } { // Add dim at batch dimension, should fail. @@ -2203,11 +2452,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) { } struct TestParams { - TestParams(const std::vector& input_dims, int axis, - const std::vector& expected_output_dims) - : input_dims(input_dims), - axis(axis), - expected_output_dims(expected_output_dims) {} std::vector input_dims; int axis; std::vector expected_output_dims; @@ -2232,10 +2476,12 @@ TEST_F(OpConverterTest, ConvertExpandDims) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_expanddims", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -2243,8 +2489,9 @@ TEST_F(OpConverterTest, ConvertSqueeze) { { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "One input expected for Squeeze, at my_squeeze"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Squeeze got 0 inputs but expected 1, at my_squeeze"); } { // No attrs, should fail. @@ -2264,7 +2511,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); ops::Squeeze::Attrs squeeze_attrs; - squeeze_attrs.axis_ = gtl::ArraySlice(axis); + squeeze_attrs.axis_ = gtl::ArraySlice(axis); // non-absl ok auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); return squeeze.operation.node()->def(); @@ -2277,7 +2524,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Squeeze expects tensor for input, at my_squeeze"); + "The input \"input\" for Squeeze must be a tensor, at my_squeeze"); } { // Squeeze batch dim, should fail. @@ -2317,11 +2564,6 @@ TEST_F(OpConverterTest, ConvertSqueeze) { } struct TestParams { - TestParams(const std::vector& input_dims, const std::vector& axis, - const std::vector& expected_output_dims) - : input_dims(input_dims), - axis(axis), - expected_output_dims(expected_output_dims) {} std::vector input_dims; std::vector axis; std::vector expected_output_dims; @@ -2352,10 +2594,12 @@ TEST_F(OpConverterTest, ConvertSqueeze) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_squeeze", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -2365,13 +2609,13 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "StridedSlice expects 4 inputs, at my_strided_slice"); + "StridedSlice got 0 inputs but expected 4, at my_strided_slice"); } // Get nodedef for StridedSlice layer. auto get_strided_slice_nodedef = - [](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, - int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef { + [](int64 begin_mask = 0, int64 end_mask = 0, int64 ellipsis_mask = 0, + int64 new_axis_mask = 0, int64 shrink_axis_mask = 0) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); @@ -2396,9 +2640,9 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("begin", {4}, {0, 0, 0, 0}); AddTestWeights("end", {4}, {1, 1, 2, 3}); AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "StridedSlice is only implemented for tensors, at my_strided_slice"); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for StridedSlice must be a " + "tensor, at my_strided_slice"); } { // Begin, end, strides are tensors, should fail. @@ -2409,8 +2653,8 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestTensor("end", {4}); AddTestTensor("strides", {4}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "StridedSlice expects weights for begin, end, and strides, at " + node_def, error::UNIMPLEMENTED, + "The input \"begin\" for StridedSlice must be a constant, at " "my_strided_slice"); } { @@ -2438,46 +2682,62 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("strides", {4}, {1, 1, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "StridedSlice can't modify batch dim, at my_strided_slice"); + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); } { - // Stride is not 1, should fail. + // Dynamic batch size without end_mask, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); AddTestWeights("begin", {4}, {0, 0, 0, 0}); AddTestWeights("end", {4}, {1, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 2, -1, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "StridedSlice is only implemented for stride of " - "1, at my_strided_slice"); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); + } + { + // Dynamic batch size but using end_mask, ok. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*end_mask=*/1); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 2}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def); } +// TRT 5.1+ supports strides +#if IS_TRT_VERSION_GE(5, 1, 0) { - // Begin out of bounds, should fail. + // Negative strides, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("begin", {4}, {1, 2, 3, 4}); - AddTestWeights("end", {4}, {0, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "begin value of 2 for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at my_strided_slice"); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, -1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Negative or zero stride values are not " + "supported for StridedSlice, at " + "my_strided_slice"); } +#else { - // End out of bounds, should fail. + // Stride is not 1, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestTensor("input", {1, 2, 3}); AddTestWeights("begin", {4}, {0, 0, 0, 0}); - AddTestWeights("end", {4}, {1, 2, 3, 4}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "end value of 2 for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at my_strided_slice"); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, 1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Strides other than 1 are not supported with " + "this version of TRT, at my_strided_slice"); } +#endif { // Size of sliced dim is negative, should fail. Reset(); @@ -2486,126 +2746,183 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("begin", {4}, {0, 0, 2, 0}); AddTestWeights("end", {4}, {1, 1, 0, 3}); AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "New size of sliced dimension is negative, at my_strided_slice"); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "\"size\" cannot be negative or zero for " + "StridedSlice, at my_strided_slice"); } struct TestParams { - TestParams(const std::vector& input_dims, - const std::vector& expected_output_dims, - const std::vector& begin, const std::vector& end, - const std::vector& begin_mask, - const std::vector& end_mask, - const std::vector& expected_output) - : input_dims(input_dims), - expected_output_dims(expected_output_dims), - begin(begin), - end(end), - expected_output(expected_output) { - // Masks are provided in terms of vectors for readability. Convert them to - // binary here. - this->begin_mask = 0; - for (int i = 0; i < begin_mask.size(); i++) { - if (begin_mask[i]) this->begin_mask |= (1 << i); - } - this->end_mask = 0; - for (int i = 0; i < end_mask.size(); i++) { - if (end_mask[i]) this->end_mask |= (1 << i); - } - } - std::vector input_dims; - std::vector expected_output_dims; std::vector begin; std::vector end; + std::vector strides; int begin_mask; int end_mask; - std::vector expected_output; + std::vector expected_output_dims; + std::vector expected_output; + }; + + auto get_mask = [](const std::vector& mask) { + int result = 0; + for (int i = 0; i < mask.size(); i++) { + if (mask[i]) result += (1 << i); + } + return result; }; + // Same input is used for all tests. + const std::vector ok_input = {1, 2, 3, 4, 5, 6}; + +#if IS_TRT_VERSION_GE(5, 1, 0) + const int kStridedSliceOKCases = 23; +#else + const int kStridedSliceOKCases = 19; +#endif // Ok. - const int kStridedSliceOKCases = 18; TestParams ok_params[kStridedSliceOKCases] = { - // 2D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, - /*expected_output=*/{5, 6}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, - /*expected_output=*/{5, 6}}, - // 2D Crop, with transpose. - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, - /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{5, 6}}, - TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{5, 6}}, - // 2D Crop, with reshape. - TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, - /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, - /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, - /*expected_output=*/{5, 6}}, - // 1D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, - /*expected_output=*/{1, 2, 4, 5}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, - /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, - /*expected_output=*/{4, 5, 6}}, - // 1D Crop, with transpose. - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, - /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, - /*expected_output=*/{4, 5, 6}}, - // 1D Crop, with reshape. - TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, - /*begin=*/{0, 0}, /*end=*/{0, 3}, - /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, - /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, - /*expected_output=*/{3, 4, 5}}, - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, - /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, - /*expected_output=*/{3, 4, 5}}, - // Negative axis. - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, - /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, - /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, - /*expected_output=*/{1, 2, 3, 4, 5}}, + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), + /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0}), + /*expected_output_dims=*/{1, 3}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{5, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + // Clamp out of bounds begin and end. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9}, + /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}}, +#if IS_TRT_VERSION_GE(5, 1, 0) + // Strides + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{2, 4, 6}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{2}, + /*expected_output=*/{3, 6}}, +#endif }; for (int i = 0; i < kStridedSliceOKCases; i++) { @@ -2618,17 +2935,166 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { ok_params[i].begin); AddTestWeights("end", {static_cast(ok_params[i].end.size())}, ok_params[i].end); - std::vector strides(ok_params[i].input_dims.size(), 1); - AddTestWeights("strides", {static_cast(strides.size())}, - strides); + AddTestWeights("strides", + {static_cast(ok_params[i].strides.size())}, + ok_params[i].strides); RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); - std::vector output_data(ok_params[i].expected_output.size()); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice", - &output_data); - EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{{"input", test::AsTensor(ok_input)}}; + DataVec output_data{ + {"my_strided_slice", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertSlice) { + // Get nodedef for Slice layer. + auto get_slice_nodedef = []() -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32); + auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size); + return slice.operation.node()->def(); + }; + + { + // Begin is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, -1, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Begin is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 3, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Size is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, -2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 3 in Slice is out of range, at " + "my_slice"); + } + { + // Size is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 3, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 2 in Slice is out of range, at " + "my_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size with size[0] not -1, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size but using size[0] of -1, ok. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {-1, 1, 2, 2}); + RunValidationAndConversion(node_def); + } + + struct TestParams { + std::vector input_dims; + std::vector begin; + std::vector size; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kSliceOKCases = 5; + TestParams ok_params[kSliceOKCases] = { + TestParams{{1, 2, 3}, + {0, 0, 0, 0}, + {-1, -1, -1, -1}, + {1, 2, 3}, + {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, 1, 2, 3}, {1, 2, 3}, {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, -1, 2, 2}, {1, 2, 2}, {1, 2, 4, 5}}, + TestParams{{6}, {0, 1}, {1, 5}, {5}, {2, 3, 4, 5, 6}}, + TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}}, + }; + + for (int i = 0; i < kSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("size", {static_cast(ok_params[i].size.size())}, + ok_params[i].size); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_slice", ConstructTensor( + ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); } } @@ -2638,22 +3104,34 @@ TEST_F(OpConverterTest, ConvertConv2D) { NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Two inputs are expected for Conv2D, at my_conv2d"); + "Conv2D got 0 inputs but expected 2, at my_conv2d"); } // Get nodedef for Conv2D layer. auto get_conv2d_nodedef = [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", - string data_format = "NCHW", - std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + string data_format = "NCHW", std::vector dilations = {1, 1, 1, 1}, + bool is_conv2d_backprop_input = false) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); - ops::Conv2D::Attrs attrs = - ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); - auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, - padding, attrs); - return conv2d.operation.node()->def(); + if (is_conv2d_backprop_input) { + auto input_sizes = + ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32); + ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs() + .DataFormat(data_format) + .Dilations(dilations); + auto conv2d = + ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes, + filter, input, strides, padding, attrs); + return conv2d.operation.node()->def(); + } else { + ops::Conv2D::Attrs attrs = + ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, + strides, padding, attrs); + return conv2d.operation.node()->def(); + } }; { @@ -2664,7 +3142,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Conv2D is only implemented for tensors, not weights, at my_conv2d"); + "The input \"input\" for Conv2D must be a tensor, at my_conv2d"); } { // Filter is tensor, should fail. @@ -2674,7 +3152,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { AddTestTensor("weights", {3, 3, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Kernel for Conv2D must be constant weights, at my_conv2d"); + "The input \"filter\" for Conv2D must be a constant, at my_conv2d"); } { // Filter is not 4D, should fail. @@ -2719,6 +3197,19 @@ TEST_F(OpConverterTest, ConvertConv2D) { "Dilation rate must be 1 for batch and channel " "dimensions, at my_conv2d"); } + { + // Dilation + Conv2DBackpropInput, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + AddTestWeights("input_sizes", {4}, {1, 2, 3, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation with Conv2DBackpropInput " + "(conv2d_transpose) is not supported, " + "at my_conv2d"); + } { // Strides is not 4D, should fail. Reset(); @@ -2743,25 +3234,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { } struct TestParams { - TestParams(const std::vector& input_dims, - const std::vector& input, - const std::vector& filter_dims, - const std::vector& filter, - const std::vector& strides, const string& padding, - const string& data_format, const std::vector& dilations, - const std::vector& expected_output_dims, - const std::vector& expected_output) - : input_dims(input_dims), - input(input), - filter_dims(filter_dims), - filter(filter), - strides(strides), - padding(padding), - data_format(data_format), - dilations(dilations), - expected_output_dims(expected_output_dims), - expected_output(expected_output) {} - std::vector input_dims; std::vector input; std::vector filter_dims; @@ -2770,12 +3242,13 @@ TEST_F(OpConverterTest, ConvertConv2D) { string padding; string data_format; std::vector dilations; + bool is_conv2d_backprop_input; std::vector expected_output_dims; std::vector expected_output; }; // Ok. - const int kConv2DOKCases = 6; + const int kConv2DOKCases = 7; TestParams ok_params[kConv2DOKCases] = { // Basic TestParams{/*input_dims=*/{1, 2, 3}, @@ -2786,6 +3259,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, // SAME padding (Asymmetric) @@ -2797,6 +3271,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 3}, /*expected_output=*/{1, 1, -2, 0, 1, -4}}, // SAME padding (Symmetric) @@ -2808,6 +3283,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, // NHWC @@ -2819,6 +3295,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{2, 2, 1}, /*expected_output=*/{1, 1, 0, 1}}, // Dilated @@ -2830,6 +3307,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 2}, + /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 1}, /*expected_output=*/{2, 1}}, // Strided @@ -2841,28 +3319,400 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, + // Transpose Strided + TestParams{/*input_dims=*/{1, 2, 2}, + /*input=*/{0, 1, 2, 3}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/true, + /*expected_output_dims=*/{1, 2, 4}, + /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, }; for (int i = 0; i < kConv2DOKCases; i++) { Reset(); - NodeDef node_def = - get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, - ok_params[i].data_format, ok_params[i].dilations); + NodeDef node_def = get_conv2d_nodedef( + ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, + ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input); AddTestTensor("input", ok_params[i].input_dims); AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); + if (ok_params[i].is_conv2d_backprop_input) { + AddTestWeights( + "input_sizes", + {static_cast(ok_params[i].expected_output.size())}, + ok_params[i].expected_output); + } RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - std::vector output_data(ok_params[i].expected_output.size()); - BuildAndRun({{"input", ok_params[i].input}}, "my_conv2d", - &output_data); - EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + + const DataVec input_data{ + {"input", test::AsTensor(ok_params[i].input)}}; + DataVec output_data{ + {"my_conv2d", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertTopK) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "TopKV2 got 0 inputs but expected 2, at my_topk"); + } + + for (const auto dtype : {DT_FLOAT, DT_INT32}) { + // Get the NodeDef for TopKV2. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights); + const NodeDef& node_def = topk.operation.node()->def(); + { + // K is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, + /*trt_dtype=*/TfDataTypeToTrt(dtype)); + AddTestTensor("weights", {2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"k\" for TopKV2 must be a constant, at my_topk"); + } + { + // Ok. + Reset(); + AddTestTensor("input", {1, 2, 5}); + AddTestWeights("weights", {1}, {2}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights outputs[2]; + TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0])); + TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1])); + for (auto& output : outputs) { + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions()); + } + + const DataVec input_data{ + {"input", test::AsTensor({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; + DataVec output_data{{"my_topk", ConstructTensor(4)}, + {"my_topk:1", ConstructTensor(4)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(6, 5, 7, 1)); + EXPECT_THAT(GetSpanForData(output_data[1]), + ElementsAre(4, 2, 1, 2)); + } + } +} + +template +void TestConvertGather(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), dtype); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + + struct TestParams { + std::vector params_dims; + std::vector indices_dims; + std::vector indices; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Input is the same {1, 2, 3, 4, 5, 6} for all cases. + const int kGatherOKCases = 5; + const std::vector params_input = {CType(1), CType(2), CType(3), + CType(4), CType(5), CType(6)}; + TestParams ok_params[kGatherOKCases] = { + // Indices are always of rank>1, and output rank is + // rank(params) + rank(indices) - 1. + // TODO(laigd): do we support 0-rank ITensor as indices? + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}}, + TestParams{ + {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}}, + TestParams{{3, 2}, + {2, 2}, + {0, 0, 1, 0}, + 2, + {3, 1, 2, 2}, + {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}}, + }; + + // Ok. + for (int i = 0; i < kGatherOKCases; i++) { + test->Reset(); + test->AddTestTensor("params", ok_params[i].params_dims, 1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("indices", ok_params[i].indices_dims, 1, + nvinfer1::DataType::kINT32); + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + // Create input in CType and convert expected output to CType. + std::vector converted_expected_output( + ok_params[i].expected_output.begin(), + ok_params[i].expected_output.end()); + + const DataVec input_data{ + {"params", test::AsTensor(params_input)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec output_data{ + {"my_gather", + ConstructTensor(ok_params[i].expected_output.size())}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(converted_expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertGather) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "GatherV2 got 0 inputs but expected 3, at my_gather"); + } + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestTensor("axis", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for GatherV2 must be a constant, at my_gather"); + } + { + // Axis is out of bounds, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {4}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_gather"); + } + { + // Axis is batch dimension, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {0}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_gather"); + } + + Reset(); + TestConvertGather(this); + TestConvertGather(this); + TestConvertGather(this); +} + +TEST_F(OpConverterTest, ConvertUnary) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_unary", "Neg", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Neg got 0 inputs but expected 1, at my_unary"); + } + { + // Input is weights, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto neg = ops::Neg(s.WithOpName("my_unary"), input); + const NodeDef& node_def = neg.operation.node()->def(); + AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"x\" for Neg must be a tensor, at my_unary"); + } + + // Get nodedef for unary layer. + auto get_unary_nodedef = [](string op_name) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + if (op_name == "Abs") { + auto unary = ops::Abs(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Acos") { + auto unary = ops::Acos(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Acosh") { + auto unary = ops::Acosh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Asin") { + auto unary = ops::Asin(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Asinh") { + auto unary = ops::Asinh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Atan") { + auto unary = ops::Atan(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Atanh") { + auto unary = ops::Atanh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Ceil") { + auto unary = ops::Ceil(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Cos") { + auto unary = ops::Cos(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Cosh") { + auto unary = ops::Cosh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Exp") { + auto unary = ops::Exp(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Floor") { + auto unary = ops::Floor(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Log") { + auto unary = ops::Log(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Neg") { + auto unary = ops::Neg(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Reciprocal") { + auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Rsqrt") { + auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sin") { + auto unary = ops::Sin(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sinh") { + auto unary = ops::Sinh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sqrt") { + auto unary = ops::Sqrt(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Tan") { + auto unary = ops::Tan(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } + EXPECT_TRUE(false); + return NodeDef(); + }; + // Get expected output for unary layer. + auto get_unary_output = [](string op_name, float input) -> float { + if (op_name == "Abs") { + return std::abs(input); + } else if (op_name == "Acos") { + return std::acos(input); + } else if (op_name == "Acosh") { + return std::acosh(input); + } else if (op_name == "Asin") { + return std::asin(input); + } else if (op_name == "Asinh") { + return std::asinh(input); + } else if (op_name == "Atan") { + return std::atan(input); + } else if (op_name == "Atanh") { + return std::atanh(input); + } else if (op_name == "Ceil") { + return std::ceil(input); + } else if (op_name == "Cos") { + return std::cos(input); + } else if (op_name == "Cosh") { + return std::cosh(input); + } else if (op_name == "Exp") { + return std::exp(input); + } else if (op_name == "Floor") { + return std::floor(input); + } else if (op_name == "Log") { + return std::log(input); + } else if (op_name == "Neg") { + return -input; + } else if (op_name == "Reciprocal") { + return 1.0 / input; + } else if (op_name == "Rsqrt") { + return 1.0 / std::sqrt(input); + } else if (op_name == "Sin") { + return std::sin(input); + } else if (op_name == "Sinh") { + return std::sinh(input); + } else if (op_name == "Sqrt") { + return std::sqrt(input); + } else if (op_name == "Tan") { + return std::tan(input); + } + EXPECT_TRUE(false); + return 0; + }; + + // Get list of ops to test. + std::vector ops_to_test; + // Add all ops supported by ConvertUnary. + auto* map = UnaryOperationMap(); + ops_to_test.reserve(map->size()); + for (auto& pair : *map) { + ops_to_test.push_back(pair.first); + } + // Add other unary ops to test. + ops_to_test.push_back("Rsqrt"); + // Ok. + for (string op_name : ops_to_test) { + Reset(); + NodeDef node_def = get_unary_nodedef(op_name); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + + const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_unary", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + for (int i = 0; i < input.size(); ++i) { + const float expected_output = get_unary_output(op_name, input[i]); + EXPECT_THAT(GetSpanForData(output_data[0])[i], + NanSensitiveFloatNear(expected_output, 0.0001)); + } } } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc similarity index 82% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index d57f2300f8e6e6ce79c538133da6bc5cf5ead2f5..d325d11dfff54cc4e4d282bc513ee056fdf97271 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -12,9 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" @@ -30,15 +32,15 @@ namespace tensorflow { namespace tensorrt { namespace convert { // TODO(sami): Remove VLOG messages once the code matures -using tensorflow::str_util::Uppercase; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; +using str_util::Uppercase; -tensorflow::Status TRTOptimizationPass::Init( - const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { +Status TRTOptimizationPass::Init( + const RewriterConfig_CustomGraphOptimizer* config) { VLOG(1) << "Called INIT for " << name_ << " with config = " << config; if (config == nullptr) { - return tensorflow::Status::OK(); + return Status::OK(); } const auto params = config->parameter_map(); if (params.count("minimum_segment_size")) { @@ -64,18 +66,20 @@ tensorflow::Status TRTOptimizationPass::Init( max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i(); } if (params.count("precision_mode")) { - TF_RETURN_IF_ERROR(GetPrecisionMode( + TF_RETURN_IF_ERROR(TrtPrecisionModeFromName( Uppercase(params.at("precision_mode").s()), &precision_mode_)); } if (params.count("use_calibration")) { use_calibration_ = params.at("use_calibration").b(); } - return tensorflow::Status::OK(); + if (params.count("use_function_backup")) { + use_function_backup_ = params.at("use_function_backup").b(); + } + return Status::OK(); } -void TRTOptimizationPass::PrintDebugInfo( - tensorflow::grappler::Cluster* cluster, - const tensorflow::grappler::GrapplerItem& item) { +void TRTOptimizationPass::PrintDebugInfo(grappler::Cluster* cluster, + const grappler::GrapplerItem& item) { LOG(INFO) << "Cluster = " << cluster; string offset(" "); string offset2 = StrCat(offset, offset); @@ -85,7 +89,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << offset << "type = " << cluster->type(); LOG(INFO) << offset << "num warmup steps = " << cluster->NumWarmupSteps(); const auto dev_names = cluster->GetDeviceNames(); - if (dev_names.size()) { + if (!dev_names.empty()) { LOG(INFO) << offset << " Device names:"; for (const auto s : dev_names) { LOG(INFO) << offset2 << s; @@ -93,7 +97,7 @@ void TRTOptimizationPass::PrintDebugInfo( } std::unordered_map peak_mem; auto status = cluster->GetPeakMemoryUsage(&peak_mem); - if (status == tensorflow::Status::OK()) { + if (status == Status::OK()) { LOG(INFO) << offset << "Peak Memory Usage :"; for (auto s : peak_mem) { LOG(INFO) << offset2 << s.first << " = " << s.second; @@ -101,7 +105,7 @@ void TRTOptimizationPass::PrintDebugInfo( } const auto dev_props = cluster->GetDevices(); - if (dev_props.size()) { + if (!dev_props.empty()) { LOG(INFO) << offset << "Device properties:"; for (auto k : dev_props) { LOG(INFO) << offset2 << k.first; @@ -129,7 +133,7 @@ void TRTOptimizationPass::PrintDebugInfo( } } LOG(INFO) << "item: " << item.id; - if (item.feed.size()) { + if (!item.feed.empty()) { LOG(INFO) << offset << "Feeds :"; for (const auto& f : item.feed) { const auto& shape = f.second.shape(); @@ -138,7 +142,7 @@ void TRTOptimizationPass::PrintDebugInfo( } else { LOG(INFO) << offset << "No Feeds"; } - if (item.fetch.size()) { + if (!item.fetch.empty()) { LOG(INFO) << offset << "Fetches :"; for (const auto& f : item.fetch) { LOG(INFO) << offset2 << f; @@ -147,7 +151,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << offset << "No Fetches"; } - if (item.init_ops.size()) { + if (!item.init_ops.empty()) { LOG(INFO) << offset << "init ops :"; for (const auto& f : item.init_ops) { LOG(INFO) << offset2 << f; @@ -158,7 +162,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << "Save Op = " << item.save_op; LOG(INFO) << "Restore Op = " << item.restore_op; LOG(INFO) << "save_restore_loc_tensor = " << item.save_restore_loc_tensor; - if (item.keep_ops.size()) { + if (!item.keep_ops.empty()) { LOG(INFO) << offset << "keep ops :"; for (const auto& f : item.keep_ops) { LOG(INFO) << offset2 << f; @@ -175,9 +179,9 @@ void TRTOptimizationPass::PrintDebugInfo( } } -tensorflow::Status TRTOptimizationPass::Optimize( - tensorflow::grappler::Cluster* cluster, - const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) { +Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* optimized_graph) { VLOG(1) << "Called TRTOptimization Pass " << name_; // This is a hack to workaround optimizer issue. MetaOptimizer calls // optimization passes on function objects as well, we should not modify @@ -188,14 +192,14 @@ tensorflow::Status TRTOptimizationPass::Optimize( << " is probably called on funcdef! This optimizer must *NOT* " "be called on function objects."; *optimized_graph = item.graph; - return tensorflow::Status::OK(); + return Status::OK(); } if (VLOG_IS_ON(3)) { LOG(INFO) << CurrentStackTrace(); PrintDebugInfo(cluster, item); } int max_dim = -1; - if (item.feed.size()) { + if (!item.feed.empty()) { for (const auto& f : item.feed) { const auto& shape = f.second.shape(); if (shape.dims() > 0) { @@ -221,11 +225,11 @@ tensorflow::Status TRTOptimizationPass::Optimize( << " adjusting maximum batch size to match input batch size"; } } - tensorflow::grappler::GraphProperties static_graph_properties(item); + grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - tensorflow::tensorrt::convert::ConversionParams cp; + ConversionParams cp; - if (use_calibration_ && precision_mode_ != INT8MODE) { + if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) { VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " << "Falling back to use_calibration = False." << "Note that the default value of use_calibration is True."; @@ -243,7 +247,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( // If the last token is not an integer, it must be part of the name. // Otherwise it is port number. if (tokens.size() > 1 && - !strings::safe_strto32(tokens.back(), &dumm_port)) { + !strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok StrAppend(&s, ":", tokens.back()); } nodes_to_preserve.push_back(s); @@ -261,27 +265,24 @@ tensorflow::Status TRTOptimizationPass::Optimize( cp.cached_engine_batches = batches_; cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; - auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); + cp.use_function_backup = use_function_backup_; + auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; } -void TRTOptimizationPass::Feedback( - tensorflow::grappler::Cluster* cluster, - const tensorflow::grappler::GrapplerItem& item, - const GraphDef& optimized_graph, double result) {} - -} // namespace convert -} // namespace tensorrt -} // namespace tensorflow +void TRTOptimizationPass::Feedback(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + const GraphDef& optimized_graph, + double result) {} class VerboseCustomGraphOptimizerRegistrar - : public tensorflow::grappler::CustomGraphOptimizerRegistrar { + : public grappler::CustomGraphOptimizerRegistrar { public: VerboseCustomGraphOptimizerRegistrar( - const tensorflow::grappler::CustomGraphOptimizerRegistry::Creator& cr, - const tensorflow::string& name) - : tensorflow::grappler::CustomGraphOptimizerRegistrar(cr, name) { + const grappler::CustomGraphOptimizerRegistry::Creator& cr, + const string& name) + : grappler::CustomGraphOptimizerRegistrar(cr, name) { VLOG(1) << "Constructing a CustomOptimizationPass registration object for " << name; } @@ -291,10 +292,13 @@ static VerboseCustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar( []() { VLOG(1) << "Instantiating CustomOptimizationPass object TensorRTOptimizer"; - return new tensorflow::tensorrt::convert::TRTOptimizationPass( - "TensorRTOptimizer"); + return new TRTOptimizationPass("TensorRTOptimizer"); }, ("TensorRTOptimizer")); +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + #endif #endif diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h similarity index 63% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index 3e8dc0978e43e2e9ba07aaa09f74acfe8e59b9a7..d3fd914b30210e10b211cfc6281964af620c1427 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ #include +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" #include "tensorflow/core/platform/logging.h" @@ -29,46 +30,49 @@ namespace tensorflow { namespace tensorrt { namespace convert { -class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { +class TRTOptimizationPass : public grappler::CustomGraphOptimizer { public: TRTOptimizationPass(const string& name = "TRTOptimizationPass") : name_(name), minimum_segment_size_(3), - precision_mode_(0), + precision_mode_(TrtPrecisionMode::FP32), maximum_batch_size_(-1), is_dynamic_op_(false), max_cached_batches_(1), max_workspace_size_bytes_(256LL << 20), - use_calibration_(true) { + use_calibration_(true), + use_function_backup_(true) { VLOG(1) << "Constructing " << name_; } string name() const override { return name_; }; - tensorflow::Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* - config = nullptr) override; + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override; - tensorflow::Status Optimize(tensorflow::grappler::Cluster* cluster, - const tensorflow::grappler::GrapplerItem& item, - GraphDef* optimized_graph) override; + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* optimized_graph) override; - void Feedback(tensorflow::grappler::Cluster* cluster, - const tensorflow::grappler::GrapplerItem& item, + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, const GraphDef& optimized_graph, double result) override; - void PrintDebugInfo(tensorflow::grappler::Cluster* cluster, - const tensorflow::grappler::GrapplerItem& item); + void PrintDebugInfo(grappler::Cluster* cluster, + const grappler::GrapplerItem& item); private: const string name_; int minimum_segment_size_; - int precision_mode_; + TrtPrecisionMode precision_mode_; int maximum_batch_size_; bool is_dynamic_op_; std::vector batches_; int max_cached_batches_; int64_t max_workspace_size_bytes_; bool use_calibration_; + + // Whether to allow TF function fallback path in TRTEngineOp. + bool use_function_backup_; }; } // namespace convert @@ -77,4 +81,4 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.cc similarity index 50% rename from tensorflow/contrib/tensorrt/test/utils.h rename to tensorflow/compiler/tf2tensorrt/convert/utils.cc index 4bb4120206cfaae70107e55d1818e3af2f02717a..ca21c193d6313ec1883788486f05d49f889a7145 100644 --- a/tensorflow/contrib/tensorrt/test/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -13,32 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace tensorrt { -namespace test { - -// Helper methods to inject values used by testing tools. -void EnableTestValue(); -void ClearTestValues(const string& pattern); -void AddTestValue(const string& label, const string& value); -string GetTestValue(const string& label); - -#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \ - do { \ - if (::tensorflow::tensorrt::test::GetTestValue(label) == \ - value_to_return) { \ - return errors::Internal("Injected manually"); \ - } \ - } while (0) - -} // namespace test + +Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) { + switch (mode) { + case TrtPrecisionMode::FP32: + *name = "FP32"; + break; + case TrtPrecisionMode::FP16: + *name = "FP16"; + break; + case TrtPrecisionMode::INT8: + *name = "INT8"; + break; + default: + return errors::OutOfRange("Unknown precision mode"); + } + return Status::OK(); +} + +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) { + if (name == "FP32") { + *mode = TrtPrecisionMode::FP32; + } else if (name == "FP16") { + *mode = TrtPrecisionMode::FP16; + } else if (name == "INT8") { + *mode = TrtPrecisionMode::INT8; + } else { + return errors::InvalidArgument("Invalid precision mode name: ", name); + } + return Status::OK(); +} + } // namespace tensorrt } // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h similarity index 69% rename from tensorflow/contrib/tensorrt/convert/utils.h rename to tensorflow/compiler/tf2tensorrt/convert/utils.h index 0592f31462af2b20f3a13fe5119e89c2ba42dd8a..91c8c660f85dcea9ad4d6b33a7c0fd979be0f819 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ #include @@ -33,18 +33,13 @@ struct TrtDestroyer { template using TrtUniquePtrType = std::unique_ptr>; -bool IsGoogleTensorRTEnabled(); +enum class TrtPrecisionMode { FP32, FP16, INT8 }; -// TODO(aaroey): use an enum instead. -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; +Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name); -Status GetPrecisionModeName(const int precision_mode, string* name); - -Status GetPrecisionMode(const string& name, int* precision_mode); +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc similarity index 87% rename from tensorflow/contrib/tensorrt/kernels/get_serialized_resource_op.cc rename to tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc index f68bc2b48583904a2b5b7ef7139505b3c141c165..e252f9111d61dce0b0821f72b3c56f2516fc20f3 100644 --- a/tensorflow/contrib/tensorrt/kernels/get_serialized_resource_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ - #include #include -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -49,7 +46,7 @@ class GetSerializedResourceOp : public OpKernel { SerializableResourceBase* resource = nullptr; OP_REQUIRES_OK(context, context->resource_manager()->Lookup( container, resource_name, &resource)); - ::tensorflow::core::ScopedUnref sc(resource); + core::ScopedUnref sc(resource); // Serialize the resource as output. string serialized_resource; @@ -70,4 +67,3 @@ REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU), #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc similarity index 97% rename from tensorflow/contrib/tensorrt/kernels/get_serialized_resource_op_test.cc rename to tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc index a91228e4c420ad241ab673a254364203e324a282..ec038ebda073c8050321d5668b15a2c6faa72a4b 100644 --- a/tensorflow/contrib/tensorrt/kernels/get_serialized_resource_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc similarity index 70% rename from tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 7548c8ccda4571843f4b4792c9d97f7972963274..30f29902d73487de1a59d70518a7bab06448b578 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -12,40 +12,47 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" - #include +#include +#include #include "absl/memory/memory.h" -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { static Logger logger; +using absl::StrAppend; +using absl::StrCat; using ::nvinfer1::IRuntime; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; // A helper class to call done() when destructed for asynchronous execution. // Helps simultaneous execution of native and TRT engines. -class AsyncHelper : public tensorflow::core::RefCounted { +class AsyncHelper : public core::RefCounted { public: AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; } ~AsyncHelper() override { done_(); } @@ -54,37 +61,115 @@ class AsyncHelper : public tensorflow::core::RefCounted { AsyncOpKernel::DoneCallback done_; }; -#define TYPECASE(dt, X, Y) \ - case dt: { \ - return (void*)X->flat::Type>().data(); \ +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { + public: + explicit TRTEngineOp(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + private: + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Execute the tensorrt engine. Returns whether we need to retry by running + // the native segment. + bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + SerializableResourceBase** cr); + + // Get engine for the input shape + EngineContext* GetEngine(const std::vector& input_shapes, + OpKernelContext* ctx); + + // Return engine batch in cached_engne_batch_sizes_ which is closest to input + // batch. + bool GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes); + + std::vector input_nodes_; + std::vector output_nodes_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. If empty, it + // means TF native execution is not allowed, and if TRT engine fails to run + // an error will be returned. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Engine Precision mode. + TrtPrecisionMode precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; + + // If true, create calibration graph for INT8 mode. Otherwise, we are using + // user-provided quantization ranges. + bool use_calibration_; +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ } void* GetTensorAddress(const Tensor* tensor_ptr) { auto tensor_type = tensor_ptr->dtype(); switch (tensor_type) { - TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + TYPECASE(DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(DT_INT8, tensor_ptr, dest_ptr); default: { - LOG(ERROR) << "Unsupported Data type " - << tensorflow::DataTypeString(tensor_type); + LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type); return nullptr; } } } -tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { +Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { VLOG(1) << "Constructing function handle"; auto lib = ctx->function_library(); if (lib == nullptr) { - return tensorflow::errors::Internal("Context function library is null"); + return errors::Internal("Context function library is null"); } auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_); if (fdef == nullptr) { - return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_, - " can't be found in function library"); + return errors::Internal("Native FunctionDef ", funcdef_name_, + " can't be found in function library"); } - tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops; + FunctionLibraryRuntime::InstantiateOptions inst_ops; inst_ops.overlay_lib = nullptr; inst_ops.state_handle = ""; inst_ops.target = ctx->device()->name(); @@ -109,11 +194,15 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) if (!static_engine_) { if (!segment_graph_.ParseFromString(serialized_segment_)) { LOG(ERROR) << "Parsing segment graph failed!"; - context->SetStatus(tensorflow::errors::InvalidArgument( - "Failed to parse segment graphdef!")); + context->SetStatus( + errors::InvalidArgument("Failed to parse segment graphdef!")); return; } - serialized_segment_.resize(0); + VLOG(1) << "Size of serialized GraphDef: " + << serialized_segment_.capacity(); + string tmp; + // Swap with temporary empty string to deallocate the CPU memory. + serialized_segment_.swap(tmp); } VLOG(1) << "Constructing " << name(); string precision_string; @@ -124,16 +213,18 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("calibration_data", &calibration_data)); OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); - OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); + OP_REQUIRES_OK(context, + TrtPrecisionModeFromName(precision_string, &precision_mode_)); OP_REQUIRES_OK(context, context->GetAttr("use_calibration", &use_calibration_)); - calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE && - calibration_data.size() == 0); - if (calibration_data.size()) { + calibration_mode_ = + (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 && + calibration_data.empty()); + if (!calibration_data.empty()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); } - native_func_ = tensorflow::kInvalidHandle; + native_func_ = kInvalidHandle; OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", @@ -150,9 +241,15 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { + if (funcdef_name_.empty()) { + const string err_msg = StrCat("Fallback path is disabled, for ", name()); + LOG(WARNING) << err_msg; + ctx->SetStatus(errors::Internal(err_msg)); + return; + } std::vector inputs; std::vector* outputs = new std::vector(); - if (native_func_ == tensorflow::kInvalidHandle) { + if (native_func_ == kInvalidHandle) { auto status = ConstructFunctionHandle(ctx); if (!status.ok()) { LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_; @@ -161,19 +258,20 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, } } auto lib = ctx->function_library(); - tensorflow::FunctionLibraryRuntime::Options opts; + FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.runner = ctx->runner(); + inputs.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); i++) { inputs.push_back(ctx->input(i)); } helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment: " << name(); lib->Run(opts, native_func_, inputs, outputs, - [this, ctx, outputs, helper](const tensorflow::Status& s) { - tensorflow::core::ScopedUnref sc(helper); + [this, ctx, outputs, helper](const Status& s) { + core::ScopedUnref sc(helper); if (!s.ok()) { LOG(ERROR) << "Failed to execute native segment " << this->name() << ": " << s; @@ -184,8 +282,6 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } - test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), - "done"); delete outputs; }); } @@ -194,20 +290,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper) { VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); - tensorflow::core::ScopedUnref sc(helper); - // TODO(aaroey): remove the ResourceMgr singleton. - auto trt_rm = TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibration"); + core::ScopedUnref sc(helper); + auto res_mgr = ctx->resource_manager(); TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->LookupOrCreate( - funcdef_name_, "Calibrator", &calib_res, - {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { - return this->AllocateCalibrationResources(ctx, cr); - }}); - if (!status.ok()) { - ctx->SetStatus(status); - return; - } + OP_REQUIRES_OK(ctx, + res_mgr->LookupOrCreate( + "TF_TRT_Calibration", name(), + reinterpret_cast(&calib_res), + {[ctx, this](SerializableResourceBase** cr) -> Status { + return this->AllocateCalibrationResources(ctx, cr); + }})); + core::ScopedUnref calib_sc(calib_res); int num_inputs = ctx->num_inputs(); // Pass input data to calibrator std::unordered_map input_data; @@ -215,7 +308,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, const Tensor& t = ctx->input(i); void* data_address = GetTensorAddress(&t); if (data_address == nullptr) { - ctx->SetStatus(tensorflow::errors::InvalidArgument( + ctx->SetStatus(errors::InvalidArgument( "Unsupported data type encountered in input ", i)); return; } @@ -233,7 +326,6 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ->implementation() ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); - test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } @@ -271,21 +363,22 @@ bool TRTEngineOp::GetCompatibleCachedEngine( void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, AsyncOpKernel::DoneCallback done) { auto helper = new AsyncHelper(done); - tensorflow::core::ScopedUnref sc(helper); + core::ScopedUnref sc(helper); if (calibration_mode_) { ExecuteCalibration(ctx, helper); return; } // Get shapes of inputs to engine. - std::vector input_shapes; + std::vector input_shapes; + input_shapes.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { - input_shapes.emplace_back(ctx->input(i).shape()); + input_shapes.push_back(ctx->input(i).shape()); } EngineContext* engine_context = GetEngine(input_shapes, ctx); if (!engine_context->cuda_engine) { - LOG(WARNING) << "Engine retrieval for input shapes: " - << TensorShapeUtils::ShapeListString(input_shapes) - << " failed. Running native segment for " << name(); + VLOG(1) << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_shapes) + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -312,8 +405,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, const string input_name = StrCat(kInputPHName, i); const int binding_index = cuda_engine->getBindingIndex(input_name.c_str()); if (binding_index == -1) { - LOG(ERROR) << "Input node not found, at " << input_name; - return kRetry; + const string msg = + StrCat("Input node ", input_name, " not found, at ", name()); + LOG(ERROR) << msg; + ctx->SetStatus(errors::NotFound(msg)); + return !kRetry; } const Tensor& input_tensor = ctx->input(i); @@ -326,7 +422,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: - buffers[binding_index] = (void*)(input_tensor.flat().data()); + buffers[binding_index] = + const_cast(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: LOG(ERROR) << "FP16 inputs are not supported yet!"; @@ -335,10 +432,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, LOG(ERROR) << "INT8 inputs are not supported yet!"; return kRetry; case nvinfer1::DataType::kINT32: - buffers[binding_index] = (void*)(input_tensor.flat().data()); + buffers[binding_index] = + const_cast(input_tensor.flat().data()); break; default: - LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); return kRetry; } } @@ -362,8 +460,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, return kRetry; } } else { - LOG(ERROR) << "Output node not found, at " << output_name; - return kRetry; + const string msg = + StrCat("Ouput node ", output_name, " not found, at ", name()); + LOG(ERROR) << msg; + ctx->SetStatus(errors::NotFound(msg)); + return !kRetry; } auto status = ctx->allocate_output(i, output_shape, &output_tensor); if (!status.ok()) { @@ -377,7 +478,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = - reinterpret_cast(output_tensor->flat().data()); + const_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: LOG(WARNING) << "half size is not supported yet!"; @@ -387,7 +488,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, return kRetry; case nvinfer1::DataType::kINT32: buffers[binding_index] = - reinterpret_cast(output_tensor->flat().data()); + const_cast(output_tensor->flat().data()); break; default: LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); @@ -403,7 +504,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex // for it. - tensorflow::mutex_lock lock(engine_context->mu); + mutex_lock lock(engine_context->mu); // TODO(jie): trt enqueue does not return error auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0], *stream, nullptr); @@ -411,7 +512,6 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; } - test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); // Synchronization will be done by TF. return !kRetry; } @@ -419,15 +519,15 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* TRTEngineOp::GetEngine( const std::vector& input_shapes, OpKernelContext* ctx) { static EngineContext empty_context; - tensorflow::mutex_lock lock(engine_mutex_); + mutex_lock lock(engine_mutex_); // TODO(tmorris): using first input to get batch size - is this reliable? const int batch_size = input_shapes[0].dim_size(0); // Get engine cache TRTEngineCacheResource* cache_res = nullptr; auto status = ctx->resource_manager()->LookupOrCreate( - "TRTEngineCache", funcdef_name_, &cache_res, - {[this, ctx](TRTEngineCacheResource** cr) -> tensorflow::Status { + "TRTEngineCache", name(), &cache_res, + {[this, ctx](TRTEngineCacheResource** cr) -> Status { *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); return Status::OK(); }}); @@ -435,7 +535,7 @@ EngineContext* TRTEngineOp::GetEngine( ctx->SetStatus(status); return &empty_context; } - tensorflow::core::ScopedUnref sc(cache_res); + core::ScopedUnref sc(cache_res); auto& cache = cache_res->cache_; auto allocator = cache_res->allocator_.get(); if (allocator == nullptr) { @@ -477,7 +577,11 @@ EngineContext* TRTEngineOp::GetEngine( TrtUniquePtrType( raw_static_engine->createExecutionContext()))); // Runtime is safe to delete after engine creation - serialized_segment_.clear(); + VLOG(1) << "Size of serialized TRT engine: " + << serialized_segment_.capacity(); + string tmp; + // Swap with temporary empty string to deallocate the CPU memory. + serialized_segment_.swap(tmp); if (max_batch_size < batch_size) { return &empty_context; } @@ -487,7 +591,7 @@ EngineContext* TRTEngineOp::GetEngine( // Handle the dynamic engine case. // See if there is a compatible engine cached. The batch size should be <= the // cached batch size. - std::vector engine_input_shapes; + std::vector engine_input_shapes; const bool matched_successfully = GetCompatibleCachedEngine(input_shapes, &engine_input_shapes); // If matched, use that engine. Otherwise, we will look in cache for that @@ -509,11 +613,11 @@ EngineContext* TRTEngineOp::GetEngine( LOG(INFO) << "Building a new TensorRT engine for " << name() << " input shapes: " << TensorShapeUtils::ShapeListString(engine_input_shapes); + // Convert to partial shapes - std::vector partial_shapes; - for (int i = 0; i < engine_input_shapes.size(); i++) { - partial_shapes.emplace_back(engine_input_shapes[i]); - } + std::vector partial_shapes(engine_input_shapes.begin(), + engine_input_shapes.end()); + // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( @@ -521,14 +625,12 @@ EngineContext* TRTEngineOp::GetEngine( partial_shapes, &logger, allocator, calibrator_.get(), &engine, use_calibration_, &convert_successfully); if (!status.ok()) { - if (convert_successfully) { - // This means it fail to build the engine even when the network is built - // successfully, probably due to internal issues. In this case we don't - // retry in the future. - cache.emplace(engine_input_shapes, absl::make_unique()); - } - LOG(WARNING) << "Engine creation for batch size " << batch_size - << " failed " << status; + LOG(WARNING) << "Engine creation for " << name() << " failed. " + << "The native segment will be used instead. " + << "Reason: " << status; + // Store an empty engine in the cache for these input shapes so we don't + // try to build the same failing engine again. + cache.emplace(engine_input_shapes, absl::make_unique()); return &empty_context; } VLOG(1) << "Conversion is done"; @@ -541,12 +643,12 @@ EngineContext* TRTEngineOp::GetEngine( return cache.at(engine_input_shapes).get(); } -tensorflow::Status TRTEngineOp::AllocateCalibrationResources( - OpKernelContext* ctx, TRTCalibrationResource** cr) { +Status TRTEngineOp::AllocateCalibrationResources( + OpKernelContext* ctx, SerializableResourceBase** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. - auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + auto alloc = ctx->device()->GetAllocator(AllocatorAttributes()); if (!alloc) { LOG(WARNING) << "Can't get device allocator will not be able to " "allocate memory from TensorFlow memory pool"; @@ -557,12 +659,12 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( // Get the input shapes. const int batch_size = ctx->input(0).dim_size(0); const int num_inputs = ctx->num_inputs(); - std::vector shapes; + std::vector shapes; cres->device_tensors_.resize(num_inputs); VLOG(1) << " Constructing calibrator"; for (int i = 0; i < num_inputs; i++) { // allocate workspace on device for inputs - const tensorflow::Tensor& t = ctx->input(i); + const Tensor& t = ctx->input(i); shapes.emplace_back(t.shape()); Tensor* device_tensor; TF_RETURN_IF_ERROR(ctx->allocate_persistent( @@ -570,7 +672,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); void* device_address = GetTensorAddress(device_tensor); if (device_address == nullptr) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Unsupported data type encountered in input ", i); } cres->device_buffers_.emplace( @@ -585,7 +687,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( ctx->device()->tensorflow_gpu_device_info()->gpu_id; if (platform_gpu_id < 0) { LOG(ERROR) << "Can't get gpu_device_info from context->device()"; - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Context->device doesn't contain device info!"); } const int64 workspace_size_bytes = workspace_size_; @@ -607,9 +709,10 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( // TODO(aaroey): maybe setting the max batch size using the python // calibration wrapper class. auto s = convert::ConvertGraphDefToEngine( - *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), - workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), - cres->calibrator_.get(), &cres->engine_, + *segment_graph, TrtPrecisionMode::INT8, + cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, + &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), + &cres->engine_, /*use_calibration=*/true, /*convert_successfully=*/nullptr); if (!s.ok()) { @@ -619,7 +722,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( VLOG(1) << "Calibration loop terminated " << label; })); VLOG(1) << "initialized calibrator resource"; - return tensorflow::Status::OK(); + return Status::OK(); } REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); diff --git a/tensorflow/contrib/tensorrt/ops/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc similarity index 100% rename from tensorflow/contrib/tensorrt/ops/get_serialized_resource_op.cc rename to tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc similarity index 86% rename from tensorflow/contrib/tensorrt/ops/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index b84d2fe0b8cef3475f2a7d0f5383d5e11cde099a..791ddc41b4ff7cadb80618a1f017d8af32c862df 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -24,12 +24,9 @@ limitations under the License. namespace tensorflow { -namespace shape_inference { -extern Status TRTEngineOpShapeInference(InferenceContext* c); -} - -// NOTE: please try NOT to add/modify/remove attributes or inputs/outputs to the -// list below, this will break backward compatibility! +// NOTE: when making changes please follow +// https://www.tensorflow.org/guide/extend/op#backwards_compatibility to not +// break backward compatibility. // // TODO(laigd): consider making this op stateful. The only problem is it uses TF // function which has to be stateless, but we can use function library as the @@ -41,8 +38,6 @@ REGISTER_OP("TRTEngineOp") .Attr("segment_funcdef_name: string") .Attr("InT: list({int8,float16,float32,int32})") .Attr("OutT: list({int8,float16,float32,int32})") - .Attr("static_engine: bool = true") - .Attr("fixed_input_size: bool = true") .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") @@ -57,8 +52,10 @@ REGISTER_OP("TRTEngineOp") // implementation, we do require all input tensor to carry the same batch // size, but this could change in the future). Hence we disable shape // inference function as a workaround. - // .SetShapeFn(shape_inference::TRTEngineOpShapeInference); - .SetShapeFn(shape_inference::UnknownShape); + .SetShapeFn(shape_inference::UnknownShape) + // Deprecated attributes. + .Attr("fixed_input_size: bool = true") + .Attr("static_engine: bool = true"); } // namespace tensorflow #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc index 062f86e8bb4dc753925e4e2baf0bc80a5312a94f..a4341c530fffca88c82813cc2ace2c0ae1df5345 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" + #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h similarity index 92% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h index 754920b60ca7439513a91ad0354833a2482b29c1..f495d857037c79a1783f8eb232fb57c20e229169 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ #include #include @@ -71,4 +71,4 @@ class PluginTensorRT : public nvinfer1::IPlugin { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc similarity index 89% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc index cccc91226265ed139fb8db0b71c40b868f729562..dd73d15029d6fe5515c823223ffe743e52dde6e9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -33,7 +33,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return nullptr; } - tensorflow::mutex_lock lock(instance_m_); + mutex_lock lock(instance_m_); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); owned_plugins_.emplace_back(plugin_ptr); @@ -44,7 +44,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) { if (!IsPlugin(op_name)) return nullptr; - tensorflow::mutex_lock lock(instance_m_); + mutex_lock lock(instance_m_); auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); @@ -56,7 +56,7 @@ bool PluginFactoryTensorRT::RegisterPlugin( PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; - tensorflow::mutex_lock lock(instance_m_); + mutex_lock lock(instance_m_); auto ret = plugin_registry_.emplace( op_name, std::make_pair(deserialize_func, construct_func)); @@ -64,7 +64,7 @@ bool PluginFactoryTensorRT::RegisterPlugin( } void PluginFactoryTensorRT::DestroyPlugins() { - tensorflow::mutex_lock lock(instance_m_); + mutex_lock lock(instance_m_); owned_plugins_.clear(); } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h similarity index 85% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h index bbae9fb65c22cf69d2e7954436fd04dd16f7f6c8..cce4f52d9f1080fe0174b5fcb5dd0afdaf6e7769 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -69,7 +69,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { // TODO(jie): Owned plugin should be associated with different sessions; // should really hand ownership of plugins to resource management; std::vector> owned_plugins_; - tensorflow::mutex instance_m_; + mutex instance_m_; }; class TrtPluginRegistrar { @@ -89,9 +89,8 @@ class TrtPluginRegistrar { construct_func) \ REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) #define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ - static ::tensorflow::tensorrt::TrtPluginRegistrar trt_plugin_registrar##ctr \ - TF_ATTRIBUTE_UNUSED = ::tensorflow::tensorrt::TrtPluginRegistrar( \ - name, deserialize_func, construct_func) + static TrtPluginRegistrar trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED = \ + TrtPluginRegistrar(name, deserialize_func, construct_func) } // namespace tensorrt } // namespace tensorflow @@ -99,4 +98,4 @@ class TrtPluginRegistrar { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc index 129bdcdbc2f8d9d5215f45f381bcadf35e4fa75e..7d9c465c22beed0e252cbc26d6c533a0789d4f49 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc similarity index 94% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc index a8f60886c03c174a612e7a135b6eb7bb7cb9997a..f3d6b4ff476139693a5251ddf58a3200d8af8efc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include #if GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h similarity index 82% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h index 274ce42fec9283c643004d45fba461879fc5f2dc..e5eff15c19694093c7a5ea933a41375e8e01c8b9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -43,4 +43,4 @@ string ExtractOpName(const void* serial_data, size_t serial_length, #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..92aae7bb6b4520be97ca70fbe99586aab912e598 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the Python wrapper of TRTEngineOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import platform +from tensorflow.python.framework import errors + +_tf_trt_so = None +_module_lock = threading.Lock() + + +def load_trt_ops(): + """Load TF-TRT op libraries so if it hasn't been loaded already.""" + global _tf_trt_so + + if platform.system() == "Windows": + raise RuntimeError("Windows platforms are not supported") + + with _module_lock: + if _tf_trt_so: + return + + try: + # pylint: disable=g-import-not-at-top,unused-variable + # This will call register_op_list() in + # tensorflow/python/framework/op_def_registry.py, but it doesn't register + # the op or the op kernel in C++ runtime. + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op + # pylint: enable=g-import-not-at-top,unused-variable + except ImportError as e: + print("**** Failed to import TF-TRT ops. This is because the binary was " + "not built with CUDA or TensorRT enabled. ****") + raise e + + try: + # pylint: disable=g-import-not-at-top + from tensorflow.python.framework import load_library + from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top + + # Loading the shared object will cause registration of the op and the op + # kernel if we link TF-TRT dynamically. + _tf_trt_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("libtftrt.so")) + except errors.NotFoundError as e: + no_trt_message = ( + "**** Failed to initialize TensorRT. This is either because the " + "TensorRT installation path is not in LD_LIBRARY_PATH, or because " + "you do not have it installed. If not installed, please go to " + "https://developer.nvidia.com/tensorrt to download and install " + "TensorRT ****") + print(no_trt_message) + raise e diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc similarity index 89% rename from tensorflow/contrib/tensorrt/segment/segment.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment.cc index ecaffa3023bc8f317d956181b44639bc80efda29..593b991d09494fac503022f5c016274175f2b250 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include #include #include #include -#include "tensorflow/contrib/tensorrt/segment/union_find.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -29,13 +30,16 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; -// A simple graph representation to mirror tensorflow::Graph. This structure +// A simple graph representation to mirror Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing // the need to create a copy of the graph. It is composed of edges and nodes. // Nodes keep pointers to original TF nodes. @@ -71,7 +75,7 @@ class SimpleEdge { class SimpleNode { public: - SimpleNode(const tensorflow::Node* node, const int id); + SimpleNode(const Node* node, const int id); const std::vector& in_edges() const { return in_edges_; } const std::vector& out_edges() const { return out_edges_; } @@ -95,11 +99,11 @@ class SimpleNode { } const string& name() const { return node_->name(); } - const tensorflow::Node* tf_node() const { return node_; } + const Node* tf_node() const { return node_; } int id() const { return id_; } private: - const tensorflow::Node* node_; + const Node* node_; std::vector in_edges_; std::vector out_edges_; int id_; @@ -109,7 +113,7 @@ class SimpleNode { class SimpleGraph { public: - explicit SimpleGraph(const tensorflow::Graph* g); + explicit SimpleGraph(const Graph* g); ~SimpleGraph(); void AddControlEdge(SimpleNode* src, SimpleNode* dst); @@ -122,15 +126,11 @@ class SimpleGraph { return nodes_[node_id]; } int num_node_ids() const { return nodes_.size(); } - const SimpleNode* source_node() const { - return nodes_[tensorflow::Graph::kSourceId]; - } - const SimpleNode* sink_node() const { - return nodes_[tensorflow::Graph::kSinkId]; - } + const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; } + const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; } private: - const tensorflow::Graph* g_; + const Graph* g_; std::vector nodes_; std::vector edges_; // free_edge_ids_ and free_node_ids_ contain freed indices. @@ -138,15 +138,14 @@ class SimpleGraph { std::set free_node_ids_; }; -SimpleNode::SimpleNode(const tensorflow::Node* node, const int id) - : node_(node), id_(id) { +SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) { if (node_) { in_edges_.reserve(node_->in_edges().size()); out_edges_.reserve(node_->out_edges().size()); } } -SimpleGraph::SimpleGraph(const tensorflow::Graph* g) : g_(g) { +SimpleGraph::SimpleGraph(const Graph* g) : g_(g) { int n_nodes = g_->num_node_ids(); nodes_.resize(n_nodes, nullptr); nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId); @@ -190,8 +189,8 @@ void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, } else { edges_.push_back(nullptr); } - bool is_control = (out_port == tensorflow::Graph::kControlSlot); - is_control |= (in_port == tensorflow::Graph::kControlSlot); + bool is_control = (out_port == Graph::kControlSlot); + is_control |= (in_port == Graph::kControlSlot); auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control); edges_[i] = edge; src->out_edges_.push_back(edge); @@ -199,8 +198,7 @@ void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, } void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) { - AddEdge(src, tensorflow::Graph::kControlSlot, dst, - tensorflow::Graph::kControlSlot); + AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); } void SimpleGraph::RemoveEdge(const SimpleEdge* edge) { @@ -237,15 +235,14 @@ struct SimpleEdgePtrCompare { }; struct NodePtrCompare { - bool operator()(const tensorflow::Node* lhs, - const tensorflow::Node* rhs) const { + bool operator()(const Node* lhs, const Node* rhs) const { return lhs->name() < rhs->name(); } }; namespace { -// Copied from TF ReverseDFS, which only works for tensorflow::Graph. +// Copied from TF ReverseDFS, which only works for Graph. void StableDFS(const SimpleGraph& g, bool reverse, const std::vector& start, const std::function& enter, @@ -367,8 +364,7 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, if (in_edge->src() != src) { SimpleEdge* e = const_cast(in_edge); if (e->src() == graph->source_node()) { - graph->AddEdge(e->src(), e->src_output(), src, - tensorflow::Graph::kControlSlot); + graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot); } else { graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); } @@ -387,8 +383,7 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, if (e->dst() == graph->sink_node()) { VLOG(1) << " edge to sink node " << src->name() << " -> " << e->dst()->name(); - graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), - e->dst_input()); + graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input()); } else { graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); } @@ -406,12 +401,12 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } } -tensorflow::Status SegmentGraph( - const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, - const std::function& input_candidate_fn, - const std::function& output_candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments) { +Status SegmentGraph(const Graph* tf_graph, + const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, + const SegmentOptions& options, + SegmentNodesVector* segments) { // Steps: // 1. run the segmentation algorithm to find all the segments, which uses // candidate_fn to determine the candidates segment nodes; @@ -548,7 +543,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -574,7 +569,7 @@ tensorflow::Status SegmentGraph( device_maps[u.ParentValue()->name()].insert( tf_node->requested_device()); } else { - VLOG(1) << "Node " << tf_node->name() + VLOG(2) << "Node " << tf_node->name() << " has no device assigned requested device is: " << tf_node->requested_device(); } @@ -584,17 +579,16 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = - itr.second; + std::set& segment_nodes = itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { - std::deque in_nodes_que, out_nodes_que; + std::deque in_nodes_que, out_nodes_que; // Find an input node that is not eligible and add it to the queue. // Nodes that has no incoming edges should not be treated as "input", // as there are really no inputs to them. Similar for output nodes. for (auto node : segment_nodes) { bool added = false; - for (const tensorflow::Edge* edge : node->in_edges()) { + for (const Edge* edge : node->in_edges()) { if (!edge->IsControlEdge() && !edge->src()->IsSource() && !segment_nodes.count(edge->src())) { // 'node' is an input node. if (!input_candidate_fn(edge)) { @@ -605,7 +599,7 @@ tensorflow::Status SegmentGraph( } } if (added) continue; // Only adding the node once to either queue. - for (const tensorflow::Edge* edge : node->out_edges()) { + for (const Edge* edge : node->out_edges()) { if (!edge->dst()->IsSink() && !edge->IsControlEdge() && !segment_nodes.count(edge->dst())) { // 'node' is an output node. if (!output_candidate_fn(edge)) { @@ -633,13 +627,11 @@ tensorflow::Status SegmentGraph( // remove all their inputs, and for non-const output nodes remove all // their outputs. In this way, for common cases the number of removed // nodes should be minimum. - auto remove_nodes = [&segment_nodes]( - bool is_input_nodes, - std::deque* que) { + auto remove_nodes = [&segment_nodes](bool is_input_nodes, + std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. - std::set visited; - std::set logged(que->begin(), - que->end()); + std::set visited; + std::set logged(que->begin(), que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); @@ -676,10 +668,13 @@ tensorflow::Status SegmentGraph( const string& segment_root = itr.first; // Return format does not require set comparator. std::set segment_nodes(itr.second.begin(), itr.second.end()); - if (VLOG_IS_ON(1)) { - string s = "parent=" + segment_root + ":"; - for (auto node : segment_nodes) s += " " + node->name(); - VLOG(1) << "Segment " << segments->size() << ": " << s; + if (VLOG_IS_ON(1) && !segment_nodes.empty()) { + string s; + for (auto node : segment_nodes) { + StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name()); + } + VLOG(1) << "Nodes in segment " << segments->size() + << " with parent=" << segment_root << ":" << s; } // Don't use small segments. @@ -718,9 +713,12 @@ tensorflow::Status SegmentGraph( VLOG(1) << "Devices " << s; } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace segment } // namespace tensorrt } // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h similarity index 72% rename from tensorflow/contrib/tensorrt/segment/segment.h rename to tensorflow/compiler/tf2tensorrt/segment/segment.h index 6cc92cdb5df396a6bca26119f152487bc3685a6d..e31f1a989d9d9f203554811093e830ee8b139a6e 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ #include #include @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -namespace tensorflow { +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +namespace tensorflow { namespace tensorrt { namespace segment { @@ -42,22 +44,25 @@ struct SegmentOptions { // Get the subgraphs of a graph that can be handled by TensorRT. // -// @param graph tensorflow::Graph of the network +// @param graph Graph of the network // @param candidate_fn A function that returns OK for a Node* if // that node can be handled by TensorRT. // @param segments Returns the TensorRT segments/subgraphs. Each entry // in the vector describes a subgraph by giving a set of the names of // all the NodeDefs in that subgraph. // @return the status. -tensorflow::Status SegmentGraph( - const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, - const std::function& input_candidate_fn, - const std::function& output_candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments); +Status SegmentGraph(const Graph* tf_graph, + const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, + const SegmentOptions& options, + SegmentNodesVector* segments); } // namespace segment } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc similarity index 91% rename from tensorflow/contrib/tensorrt/segment/segment_test.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 4ac02327ae68069278066b6e7e931bb9449c2603..84b690ecba6fcb9718a1008ee61383a84a381a46 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -26,17 +26,19 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { namespace test { -namespace ops = ::tensorflow::ops; class SegmentTest : public ::testing::Test { protected: - std::function MakeCandidateFn( + std::function MakeCandidateFn( const std::set& node_names) { - return [node_names](const tensorflow::Node* node) -> Status { + return [node_names](const Node* node) -> Status { if (node_names.find(node->name()) != node_names.end()) { return Status::OK(); } @@ -44,22 +46,21 @@ class SegmentTest : public ::testing::Test { }; } - std::function MakeInputEdgeCandidateFn( + std::function MakeInputEdgeCandidateFn( const std::set& node_names) { - return [node_names](const tensorflow::Edge* in_edge) -> bool { + return [node_names](const Edge* in_edge) -> bool { return node_names.find(in_edge->dst()->name()) != node_names.end(); }; } - std::function MakeOutputEdgeCandidateFn( + std::function MakeOutputEdgeCandidateFn( const std::set& node_names) { - return [node_names](const tensorflow::Edge* out_edge) -> bool { + return [node_names](const Edge* out_edge) -> bool { return node_names.find(out_edge->src()->name()) != node_names.end(); }; } - void RunTest(const tensorflow::Graph* graph, - const std::set& candidates, + void RunTest(const Graph* graph, const std::set& candidates, const std::set& input_candidates, const std::set& output_candidates, const std::vector>& expected_segments) { @@ -103,7 +104,7 @@ std::set operator-(const std::set& lhs, const string& rhs) { TEST_F(SegmentTest, Empty) { Scope s = Scope::NewRootScope(); - tensorflow::Graph g(OpRegistry::Global()); + Graph g(OpRegistry::Global()); TF_EXPECT_OK(s.ToGraph(&g)); // Expect no segments/subgraphs. RunTest(&g, {}, {}, {}, {}); @@ -126,7 +127,7 @@ TEST_F(SegmentTest, Simple) { auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); - tensorflow::Graph g(OpRegistry::Global()); + Graph g(OpRegistry::Global()); TF_EXPECT_OK(s.ToGraph(&g)); // All Add operations are candidates, and we expect all of them to be @@ -173,7 +174,7 @@ TEST_F(SegmentTest, AvoidCycle) { auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); - tensorflow::Graph g(OpRegistry::Global()); + Graph g(OpRegistry::Global()); TF_EXPECT_OK(s.ToGraph(&g)); // add2 is not a TRT candidate so there should be no segments generated. @@ -204,7 +205,7 @@ TEST_F(SegmentTest, Multiple) { auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); auto add4 = ops::Add(s.WithOpName("add4"), add2, add5); auto add6 = ops::Add(s.WithOpName("add6"), add5, add8); - tensorflow::Graph g(OpRegistry::Global()); + Graph g(OpRegistry::Global()); TF_EXPECT_OK(s.ToGraph(&g)); const std::set all_adds = {"add0", "add1", "add2", "add3", "add4", @@ -251,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) { auto add5 = ops::Add(s.WithOpName("add5"), add4, add4); auto add6 = ops::Add(s.WithOpName("add6"), add5, add5); auto add7 = ops::Add(s.WithOpName("add7"), add3, add6); - tensorflow::Graph g(OpRegistry::Global()); + Graph g(OpRegistry::Global()); TF_EXPECT_OK(s.ToGraph(&g)); // Make add2 not a TRT candidate, and we expect 2 segments. @@ -265,3 +266,6 @@ TEST_F(SegmentTest, BigIfElse) { } // namespace segment } // namespace tensorrt } // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h similarity index 92% rename from tensorflow/contrib/tensorrt/segment/union_find.h rename to tensorflow/compiler/tf2tensorrt/segment/union_find.h index 1c64ebbb0ae532a4776ab8963515d19fd3b23b4c..6458ae692fd7c922b5fc3bea2e55b613447dbde0 100644 --- a/tensorflow/contrib/tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ namespace tensorflow { namespace tensorrt { @@ -76,4 +76,4 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc similarity index 100% rename from tensorflow/contrib/tensorrt/tensorrt_test.cc rename to tensorflow/compiler/tf2tensorrt/tensorrt_test.cc diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc similarity index 53% rename from tensorflow/contrib/tensorrt/convert/utils.cc rename to tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index e7a1febb8c076891596741fe30721e7acca15a73..1b8ab1e2720e2eba3654a51beb972303fd55e029 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" +#endif namespace tensorflow { namespace tensorrt { @@ -34,35 +35,30 @@ bool IsGoogleTensorRTEnabled() { #endif } -Status GetPrecisionModeName(const int precision_mode, string* name) { - switch (precision_mode) { - case FP32MODE: - *name = "FP32"; - break; - case FP16MODE: - *name = "FP16"; - break; - case INT8MODE: - *name = "INT8"; - break; - default: - return tensorflow::errors::OutOfRange("Unknown precision mode"); - } - return Status::OK(); +void GetLinkedTensorRTVersion(int* major, int* minor, int* patch) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + *major = NV_TENSORRT_MAJOR; + *minor = NV_TENSORRT_MINOR; + *patch = NV_TENSORRT_PATCH; +#else + *major = 0; + *minor = 0; + *patch = 0; +#endif } -Status GetPrecisionMode(const string& name, int* precision_mode) { - if (name == "FP32") { - *precision_mode = FP32MODE; - } else if (name == "FP16") { - *precision_mode = FP16MODE; - } else if (name == "INT8") { - *precision_mode = INT8MODE; - } else { - return tensorflow::errors::InvalidArgument("Invalid precision mode name: ", - name); - } - return Status::OK(); +void GetLoadedTensorRTVersion(int* major, int* minor, int* patch) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + int ver = getInferLibVersion(); + *major = ver / 1000; + ver = ver - *major * 1000; + *minor = ver / 100; + *patch = ver - *minor * 100; +#else + *major = 0; + *minor = 0; + *patch = 0; +#endif } } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f52bb6f1badfa44f35878d788c85b998cb99b472 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ + +namespace tensorflow { +namespace tensorrt { + +bool IsGoogleTensorRTEnabled(); + +// Return compile time TensorRT library version information {Maj, Min, Patch}. +void GetLinkedTensorRTVersion(int* major, int* minor, int* patch); + +// Return runtime time TensorRT library version information {Maj, Min, Patch}. +void GetLoadedTensorRTVersion(int* major, int* minor, int* patch); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc similarity index 95% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index 7a2e93414aed56525eaeac876cdac20404bcf6ab..a18f758a5512141ef180844dd4fabe960cbed4f2 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/logging.h" @@ -72,7 +72,7 @@ void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, uint32_t flags) { if (size == 0) return nullptr; // WAR for allocator alignment requirement. Certain cuda API calls require GPU - // memory with alignemtn to cudaDeviceProp::textureAlignment. + // memory with alignment to cudaDeviceProp::textureAlignment. // See issue #20856 alignment = 512; assert((alignment & (alignment - 1)) == 0); // zero or a power of 2. @@ -94,7 +94,7 @@ void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, return mem; } -TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) +TRTDeviceAllocator::TRTDeviceAllocator(Allocator* allocator) : allocator_(allocator) { VLOG(1) << "Using " << allocator->Name() << " allocator from TensorFlow"; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h similarity index 89% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h index f857a9de055ee7668f0bf9bc97e030354505081b..8ec06d7456c28505fe45859e42d83cc569d90dc5 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ #include @@ -59,7 +59,7 @@ class TRTCudaAllocator : public TRTBaseAllocator { class TRTDeviceAllocator : public TRTBaseAllocator { // Allocator implementation wrapping TF device allocators. public: - TRTDeviceAllocator(tensorflow::Allocator* allocator); + TRTDeviceAllocator(Allocator* allocator); // TODO(aaroey): base class doesn't have a virtual destructor, work with // Nvidia to fix it. @@ -70,7 +70,7 @@ class TRTDeviceAllocator : public TRTBaseAllocator { void free(void* memory) override; private: - tensorflow::Allocator* allocator_; + Allocator* allocator_; // supporting alignment from allocation request requires a map to free; std::unordered_map mem_map_; @@ -81,4 +81,4 @@ class TRTDeviceAllocator : public TRTBaseAllocator { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc index beb1284208e4c10ffe1d36ef411cf08f11dbcb78..e457c64928e5df84c7e2726ba3621420f013dbc9 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc similarity index 94% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5..33a5c719ba9d750fc5ab173435512ef73ff3fce8 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include #include @@ -50,7 +50,7 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, const cudaStream_t stream) { - tensorflow::mutex_lock lock(cond_mtx_); + mutex_lock lock(cond_mtx_); // Wait while the queue is full or calibration is running. while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); @@ -87,7 +87,7 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, int num_bindings) { - tensorflow::mutex_lock lock(cond_mtx_); + mutex_lock lock(cond_mtx_); // Notify finish of last round of calibration. calib_running_ = false; cond_.notify_all(); @@ -111,7 +111,7 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, } void TRTInt8Calibrator::waitAndSetDone() { - tensorflow::mutex_lock lock(cond_mtx_); + mutex_lock lock(cond_mtx_); // Wait while the queue is full or calibration is running, so we don't miss // the last batch. while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); @@ -128,14 +128,14 @@ const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { } void TRTInt8Calibrator::setDone() { - tensorflow::mutex_lock lock(cond_mtx_); + mutex_lock lock(cond_mtx_); done_ = true; cond_.notify_all(); } void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, std::size_t length) { - calibration_table_ = string((const char*)ptr, length); + calibration_table_ = string(static_cast(ptr), length); VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr << " length=" << length; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h similarity index 86% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index 65466c9741989fda5f82fc27d813d026f35fe386..d34e244f6c7fe201915cb4b52808d3e0e3c57fa0 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ #include #include @@ -34,7 +34,12 @@ namespace tensorrt { // TRTs pull model for calibration. When TRT implements a means for // a push calibration This class should be updated accordingly +// IInt8EntropyCalibrator2 is prefferred for TRT 5.1+. +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { +#else struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { +#endif public: // Construct a calibrator for future calibration. TRTInt8Calibrator( @@ -73,10 +78,10 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { const int batch_size_; // mutex for condition_variable - tensorflow::mutex cond_mtx_; + mutex cond_mtx_; // condition variable to implement producer-consumer queue for calibration - tensorflow::condition_variable cond_; + condition_variable cond_; // Is calibration finished? bool done_; @@ -96,4 +101,4 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { #endif #endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc similarity index 90% rename from tensorflow/contrib/tensorrt/log/trt_logger.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index dda0dc9e712eb726800abfb6084f4f708d04825b..6bc842ed5ca7e03018157060a332338cdc926f14 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -26,6 +26,9 @@ namespace tensorrt { void Logger::log(Severity severity, const char* msg) { // Suppress info-level messages switch (severity) { +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + case Severity::kVERBOSE: +#endif case Severity::kINFO: { // Mark TRT info messages as debug! VLOG(2) << name_ << " " << msg; break; diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h similarity index 86% rename from tensorflow/contrib/tensorrt/log/trt_logger.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.h index 96ccacb791e40143c5c4d9d691bb353702f9a28b..22f4de970a80765b0e1e7e8816134d83aaec7c73 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ #include "tensorflow/core/platform/types.h" @@ -41,4 +41,4 @@ class Logger : public nvinfer1::ILogger { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h similarity index 87% rename from tensorflow/contrib/tensorrt/resources/trt_lru_cache.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index afd1b83e744907debc6df0c5acb219369ff89bdb..8ece326446d9f3cb20d5ea02406e71e6e346446e 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_LRU_CACHE_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_LRU_CACHE_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/errors.h" @@ -100,26 +100,24 @@ class LRUCache { } // Creates n free positions in cache - tensorflow::Status DiscardOld(size_t n = 0) { + Status DiscardOld(size_t n = 0) { if (n > capacity_) { - return tensorflow::errors::Internal( - "Insufficient capacity in cache (capacity = ", capacity_, - ", requested ", n, ")"); + return errors::Internal("Insufficient capacity in cache (capacity = ", + capacity_, ", requested ", n, ")"); } while (objects_.size() > (capacity_ - n)) { key_type discard_key = keys_.back(); keys_.pop_back(); objects_.erase(discard_key); } - return tensorflow::Status::OK(); + return Status::OK(); } }; // Define a hash function for vector because it is used as the key // for the engine cache. struct VectorTensorShapeHasher { - std::size_t operator()( - const std::vector& key) const { + std::size_t operator()(const std::vector& key) const { return std::hash()(TensorShapeUtils::ShapeListString(key)); } }; @@ -141,12 +139,12 @@ struct EngineContext { GUARDED_BY(mu); }; -class TRTEngineCacheResource : public tensorflow::ResourceBase { +class TRTEngineCacheResource : public ResourceBase { public: TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity) : cache_(capacity) { auto device = ctx->device(); - auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + auto alloc = device->GetAllocator(AllocatorAttributes()); if (!alloc) { LOG(ERROR) << "Can't find device allocator for gpu device " << device->name(); @@ -189,4 +187,4 @@ class TRTEngineCacheResource : public tensorflow::ResourceBase { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_LRU_CACHE_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_lru_cache_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/resources/trt_lru_cache_test.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc index a0959479ceebf3b05908cd8f6af896821b635672..0aa5eb8f7d4ad062c2d8622fa5aa55f823f80dd5 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_lru_cache_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_lru_cache.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc similarity index 92% rename from tensorflow/contrib/tensorrt/resources/trt_resources.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc index c19eb34dab8d25fc6b32dba2a7da667aafd8c205..534e59f06b7d8f6768d1fc58e6a96cfe692fa14f 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -48,8 +48,8 @@ Status TRTCalibrationResource::SerializeToString(string* serialized) { calibrator_->waitAndSetDone(); thr_->join(); *serialized = calibrator_->getCalibrationTableAsString(); - if (!serialized->size()) { - return tensorflow::errors::Unknown("Calibration table is empty."); + if (serialized->empty()) { + return errors::Unknown("Calibration table is empty."); } return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h similarity index 78% rename from tensorflow/contrib/tensorrt/resources/trt_resources.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_resources.h index d0a87f2c3136cb98e86f37c8d248d7e66c9d2ebe..abfed2c1816732a6e7d7ef396d1923edf0d90f32 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ #include #include @@ -23,10 +23,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -37,7 +37,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class SerializableResourceBase : public tensorflow::ResourceBase { +class SerializableResourceBase : public ResourceBase { public: virtual Status SerializeToString(string* serialized) = 0; }; @@ -60,7 +60,7 @@ class TRTCalibrationResource : public SerializableResourceBase { TrtUniquePtrType builder_; TrtUniquePtrType engine_; std::unique_ptr allocator_; - tensorflow::tensorrt::Logger logger_; + Logger logger_; // TODO(sami): Use threadpool threads! std::unique_ptr thr_; }; @@ -70,4 +70,4 @@ class TRTCalibrationResource : public SerializableResourceBase { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0366ec45fb75a21b98ebfc4bdaa903bfa908de7a..e1df032ba937f8c19b4a3cbfa16cddc08165a3a8 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test") package_group( name = "internal", @@ -24,7 +24,7 @@ package( ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") cc_library( name = "tf2xla_supported_ops_lib", @@ -60,6 +60,14 @@ xla_proto_library( ], ) +xla_py_proto_library( + name = "tf2xla_py", + has_services = False, + api_version = 2, + visibility = ["//visibility:public"], + deps = [":tf2xla_proto"], +) + xla_proto_library( name = "host_compute_metadata_proto", srcs = ["host_compute_metadata.proto"], @@ -76,7 +84,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":common", - ":dump_graph", ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", @@ -191,7 +198,6 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":dump_graph", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -204,6 +210,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", @@ -224,6 +231,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], alwayslink = 1, ) @@ -281,6 +289,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -315,11 +324,13 @@ tf_cc_test( ":tf2xla_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -436,22 +447,6 @@ tf_cc_test( ], ) -cc_library( - name = "dump_graph", - srcs = [ - "dump_graph.cc", - ], - hdrs = [ - "dump_graph.h", - ], - deps = [ - "//tensorflow/compiler/jit:flags", - "//tensorflow/core:framework", - "//tensorflow/core:graph", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "functionalize_control_flow_util", srcs = [ @@ -483,7 +478,6 @@ cc_library( ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -511,7 +505,6 @@ cc_library( ":functionalize_while", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -548,7 +541,6 @@ cc_library( ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -675,3 +667,25 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +tf_cuda_cc_test( + name = "fused_batchnorm_reserve_space_test", + size = "medium", + srcs = ["fused_batchnorm_reserve_space_test.cc"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/compiler/jit", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index a57095f91e43f6b31b58e5a5f36331241451b545..6aff436da4f613a399c006b922b8aba3ce65a2e5 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -20,15 +20,26 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { + +Status GetCompileTimeConstInputs(const Node* node, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime); + // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, + FunctionLibraryRuntime* flib_runtime, std::function edge_filter) { std::vector compile_time_const_nodes_impl; if (compile_time_const_nodes) { @@ -61,7 +72,18 @@ Status BackwardsConstAnalysis(const Graph& g, } for (const Edge* pred : node->in_edges()) { if (!pred->IsControlEdge() && edge_filter(*pred)) { - (*compile_time_const_nodes)[pred->src()->id()] = true; + // If the src node of the `pred` is an IdentityN do not mark it as a + // compile-time const. Only mark the corresponding input to the + // IdentityN node as a const. + // Note: XLA IdentityN op simply forwards its inputs so this is safe. + while (edge_filter(*pred) && + pred->src()->type_string() == "IdentityN") { + status = pred->src()->input_edge(pred->src_output(), &pred); + if (!status.ok()) return; + } + if (edge_filter(*pred)) { + (*compile_time_const_nodes)[pred->src()->id()] = true; + } } } return; @@ -69,17 +91,29 @@ Status BackwardsConstAnalysis(const Graph& g, // Mark any compile-time constant operator arguments as const. std::vector const_input_idxs; - status = XlaOpRegistry::CompileTimeConstantInputs( - node->def(), node->op_def(), &const_input_idxs); + status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime); if (!status.ok()) { return; } for (Edge const* edge : node->in_edges()) { - if (absl::c_binary_search(const_input_idxs, edge->dst_input()) && + if (!edge->IsControlEdge() && + absl::c_binary_search(const_input_idxs, edge->dst_input()) && edge_filter(*edge)) { - (*compile_time_const_nodes)[edge->src()->id()] = true; + // Do not mark IdentityN nodes as compile-time const. + // If the src node of the `pred` is an IdentityN do not mark it as a + // compile-time const. Only mark the corresponding input to the + // IdentityN node as a const. + // Note: XLA IdentityN op simply forwards its inputs so this is safe. + while (edge_filter(*edge) && + edge->src()->type_string() == "IdentityN") { + status = edge->src()->input_edge(edge->src_output(), &edge); + if (!status.ok()) return; + } + if (edge_filter(*edge)) { + (*compile_time_const_nodes)[edge->src()->id()] = true; + } } } }; @@ -91,4 +125,61 @@ Status BackwardsConstAnalysis(const Graph& g, return status; } +Status GetCompileTimeConstInputs(const Node* node, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { + if (node->type_string() != "While") { + return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(), + const_input_idxs); + } + // For While nodes, recurse into the body and cond graphs. + // TODO(b/124403063): Implement similar functionality for cond nodes and other + // functional ops. + NameAttrList cond_function; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "cond", &cond_function)); + NameAttrList body_function; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "body", &body_function)); + FunctionLibraryRuntime::Handle cond_handle; + FunctionLibraryRuntime::Handle body_handle; + TF_RETURN_IF_ERROR(flib_runtime->Instantiate( + cond_function.name(), AttrSlice(&cond_function.attr()), &cond_handle)); + TF_RETURN_IF_ERROR(flib_runtime->Instantiate( + body_function.name(), AttrSlice(&body_function.attr()), &body_handle)); + const FunctionBody* fcond = flib_runtime->GetFunctionBody(cond_handle); + const FunctionBody* fbody = flib_runtime->GetFunctionBody(body_handle); + TF_RET_CHECK(fcond); + TF_RET_CHECK(fbody); + int num_inputs = fbody->fdef.signature().input_arg_size(); + + // Stores which of the loop inputs are expected to be compile time constants. + std::vector compile_time_const_arg_indices(num_inputs); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fcond->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fbody->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + for (int i = 0; i < num_inputs; i++) { + if (compile_time_const_arg_indices[i]) { + // Check that this input is actually a loop invariant. + // NOTE(srbs): Ideally this should raise an error if the loop body + // requires the input at this index to be a compile time const but it is + // not a loop invariant. However, that causes problems because const + // analysis is performed for the entire graph (in the + // MarkForCompilationPass for example) and not just for the ops + // that will actually be run using XLA kernels. So we silently return here + // and let the error be raised during the actual compilation of the + // XLA graph. + Node* arg_i = fbody->arg_nodes[i]; + Node* ret_i = fbody->ret_nodes[i]; + const Node* ret_i_input_0; + TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0)); + if (ret_i_input_0->id() == arg_i->id()) { + const_input_idxs->push_back(i); + } + } + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 49b3c6d413c6b637fa825bf182be7cc36e49b6c8..1663cbff41c3e10ba586c60eca475b760dee4896 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -34,11 +34,13 @@ namespace tensorflow { // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. // // Only propagate const-ness along edges for which `edge_filter` returns true. -Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_arg_indices, - std::vector* compile_time_const_nodes, - std::function edge_filter = - [](const Edge& e) { return true; }); +Status BackwardsConstAnalysis( + const Graph& g, std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + FunctionLibraryRuntime* flib_runtime, + std::function edge_filter = [](const Edge& e) { + return true; + }); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 40c6d0e01701d9104a200d9ea27706a0a7c12146..ed5f004550f0cb57e1545436c90bb6a9e8c19652 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -44,8 +44,8 @@ TEST(ConstAnalysisTest, Basics) { std::vector const_args(4, false); std::vector const_nodes(root.graph()->num_node_ids(), false); - TF_ASSERT_OK( - BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes)); + TF_ASSERT_OK(BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes, + /*flib_runtime=*/nullptr)); // Arg 0 doesn't need to be constant since the graph only uses its shape. // Arg 1 must be constant because it flows to the shape argument of a Reshape. @@ -82,7 +82,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) { std::vector const_args(3, false); TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, - /*compile_time_const_nodes=*/nullptr)); + /*compile_time_const_nodes=*/nullptr, + /*flib_runtime=*/nullptr)); EXPECT_EQ(const_args, std::vector({true, true, false})); } @@ -103,7 +104,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) { std::vector const_args(2, false); TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, - /*compile_time_const_nodes=*/nullptr)); + /*compile_time_const_nodes=*/nullptr, + /*flib_runtime=*/nullptr)); EXPECT_EQ(const_args, std::vector({false, true})); } @@ -128,7 +130,8 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_0) { std::vector const_args(2, false); TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, - /*compile_time_const_nodes=*/nullptr)); + /*compile_time_const_nodes=*/nullptr, + /*flib_runtime=*/nullptr)); EXPECT_EQ(const_args, std::vector({false, false})); } @@ -152,7 +155,8 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_1) { std::vector const_args(1, false); TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, - /*compile_time_const_nodes=*/nullptr)); + /*compile_time_const_nodes=*/nullptr, + /*flib_runtime=*/nullptr)); EXPECT_EQ(const_args, std::vector({true})); } diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc deleted file mode 100644 index 64fdbbebc65bff4ed0b965fcdd534cc9696472b6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for -// debugging. - -#include "tensorflow/compiler/tf2xla/dump_graph.h" - -#include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/core/util/dump_graph.h" - -namespace tensorflow { -namespace dump_graph { - -string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return tensorflow::DumpGraphDefToFile( - name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); -} - -string DumpGraphToFile(const string& name, Graph const& graph, - const FunctionLibraryDefinition* flib_def) { - return tensorflow::DumpGraphToFile(name, graph, flib_def, - GetDumpGraphFlags()->tf_dump_graph_prefix); -} - -string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return tensorflow::DumpFunctionDefToFile( - name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); -} - -} // namespace dump_graph -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.h b/tensorflow/compiler/tf2xla/dump_graph.h deleted file mode 100644 index bbf01eb90dbd0478e873da785c3dcb6145096f17..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for -// debugging. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ -#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ - -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/graph/graph.h" - -namespace tensorflow { -namespace dump_graph { - -// Dumps 'graph_def' to a file, as a GraphDef text proto. Returns the file name -// chosen. -// -// Automatically picks a file name. Prefixes 'name' with the value of the -// --tf_dump_graph_prefix flag and suffixes it with ".pbtxt" to form a name. -// If a graph has already been dumped by this process with the same name, -// suffixes with "_n.pbtxt", where 'n' is a sequence number. -string DumpGraphDefToFile(const string& name, GraphDef const& graph_def); - -// Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph' -// and an optional function library 'flib_def'. Returns the file name chosen. -string DumpGraphToFile(const string& name, Graph const& graph, - const FunctionLibraryDefinition* flib_def = nullptr); - -// Similar to DumpGraphDefToFile, but dumps a function as a FunctionDef text -// proto. Returns the file name chosen. -string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef); - -} // namespace dump_graph -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 7ae96e1d484900e28e8c23c3bb2232401144ad82..6e093400e475ec0316ac072f4ddf1de5bdeec876 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" @@ -34,13 +33,53 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/dump_graph.h" using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { +bool AncestorNode::operator<(const AncestorNode& other) const { + return (output_tensor.node->id() < other.output_tensor.node->id()) || + (output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index < other.output_tensor.index) || + (output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index == other.output_tensor.index && + type < other.type); +} + +bool AncestorNode::operator==(const AncestorNode& other) const { + return output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index == other.output_tensor.index && type == other.type; +} + +size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const { + size_t h = std::hash()(ancestor.output_tensor.node->id()); + h = Hash64Combine(h, std::hash()(ancestor.output_tensor.index)); + return Hash64Combine(h, std::hash()(static_cast(ancestor.type))); +} + +typedef std::tuple + ClusterTuple; + +struct ClusterTupleLessThan { + bool operator()(const ClusterTuple& a, const ClusterTuple& b) const { + if (std::tie(std::get<0>(a), std::get<1>(a)) < + std::tie(std::get<0>(b), std::get<1>(b))) { + return true; + } else if (std::tie(std::get<0>(a), std::get<1>(a)) == + std::tie(std::get<0>(b), std::get<1>(b))) { + return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b)); + } else { + return false; + } + } +}; + // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { return absl::StrCat(tensor.node->name(), ":", tensor.index); @@ -145,10 +184,10 @@ size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { if (map.empty()) return 0; // Compute hash of the front element. auto it = map.begin(); - size_t h = hash()(*it); + size_t h = AncestorNode::Hash()(*it); for (++it; it != map.end(); ++it) { // Combine the has with the different elements in the map. - h = Hash64Combine(h, hash()(*it)); + h = Hash64Combine(h, AncestorNode::Hash()(*it)); } return h; } @@ -229,7 +268,17 @@ string StateMap::CondStateToString(StateMap::CondId id) const { } string StateMap::AncestorStateToString(const Node* node) const { - if (auto id = LookupAncestorId(node)) return NodesToString(*id); + if (auto id = LookupAncestorId(node)) { + return absl::StrCat( + "{", + absl::StrJoin(*id, ",", + [](string* output, const AncestorNode& ancestor) { + absl::StrAppend(output, + ancestor.output_tensor.node->name(), + ":", ancestor.output_tensor.index); + }), + "}"); + } return "{}"; } @@ -247,7 +296,9 @@ class Conditional { Status AddMerge(Node* m); // Constructs an If node from the merge nodes. - Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library); + Status BuildAndReplace( + Graph* graph, FunctionLibraryDefinition* library, + std::unordered_map* merge_to_replacement); private: // Extracts the then/else bodies: creates new graphs with the nodes @@ -262,10 +313,15 @@ class Conditional { Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); // Adds input edges to If node. - Status AddInputEdges(Graph* graph); + Status AddInputEdges( + Graph* graph, + const std::unordered_map& merge_to_replacement); // Adds output edges from If node. - Status AddOutputEdges(Graph* graph); + // Record new output tensor for all Merge nodes in 'merge_to_replacement'. + Status AddOutputEdges( + Graph* graph, + std::unordered_map* merge_to_replacement); // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); @@ -564,7 +620,32 @@ Status Conditional::ExtractBodies(Graph* graph) { stack.push_back(src); } } else if (e->IsControlEdge()) { - external_control_inputs_.push_back(src); + // Here we have a control flow edge between src and dst that are not + // in the same context. This is an external control dependency except + // for one case: where the only difference between CondId of e->src() + // and CondId of e->dst() is that e->src() has {PRED, kNeither} and + // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in + // gradients code for tf.cond(), where e->src() is a control pivot + // node for a branch and e->dst() is a data node in that branch. + bool is_external_control_input = true; + if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) { + std::vector diff; + std::set_symmetric_difference( + src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(), + std::back_inserter(diff), CondStateLess()); + if (diff.size() == 2 && diff[0].first == diff[1].first && + (diff[0].second == BranchType::kNeither || + diff[1].second == BranchType::kNeither)) { + auto src_branch = src_id->find(diff[0].first); + if (src_branch != src_id->end() && + src_branch->second == BranchType::kNeither) { + is_external_control_input = false; + } + } + } + if (is_external_control_input) { + external_control_inputs_.push_back(src); + } } else { // This shouldn't happen, this means we have an external data input // not entering via a switch node. Work around this by for @@ -654,7 +735,7 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] << "): " - << dump_graph::DumpGraphToFile( + << DumpGraphToFile( "functionalize_cond_body_" + branch_name[branch_index], *bodies_[branch_index], nullptr); @@ -705,9 +786,9 @@ Status Conditional::BuildIfNode(Graph* graph, } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), - predicate_.index, - predicate_.node->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index, + predicate_.node->output_type(predicate_.index))); // ... followed by the other inputs. builder.Input(inputs); @@ -720,12 +801,29 @@ Status Conditional::BuildIfNode(Graph* graph, return Status::OK(); } -Status Conditional::AddInputEdges(Graph* graph) { +Status Conditional::AddInputEdges( + Graph* graph, + const std::unordered_map& merge_to_replacement) { VLOG(2) << "AddInputEdges for " << if_node_->name(); int index = 0; // Add predicate input. - graph->AddEdge(const_cast(predicate_.node), predicate_.index, if_node_, - index++); + if (predicate_.node->IsMerge()) { + // If the predicate is a Merge node, we should not use Merge output as + // predicate. Instead, we should use the corresponding If output in + // 'merge_to_replacement'. Otherwise, this Conditional's If node is still + // connected to the predicate Merge node; and when we call + // DeleteReachableAndDeadNodes(), the predicate Merge node and this + // Conditional's If node will be removed. + auto iter = merge_to_replacement.find(predicate_.node); + if (iter == merge_to_replacement.end()) { + return errors::Internal("Cannot find replacement for Merge node ", + predicate_.node->name()); + } + graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++); + } else { + graph->AddEdge(const_cast(predicate_.node), predicate_.index, + if_node_, index++); + } // Add function body inputs. for (auto& arg : cond_arg_nodes_) { if (arg.src_output == Graph::kControlSlot) { @@ -740,7 +838,9 @@ Status Conditional::AddInputEdges(Graph* graph) { return Status::OK(); } -Status Conditional::AddOutputEdges(Graph* graph) { +Status Conditional::AddOutputEdges( + Graph* graph, + std::unordered_map* merge_to_replacement) { VLOG(2) << "AddOutputEdges for " << if_node_->name(); int i = 0; for (Node* node : merges_) { @@ -764,6 +864,10 @@ Status Conditional::AddOutputEdges(Graph* graph) { graph->AddEdge(if_node_, i, dst, dst_input); } } + + // Record corresponding output tensor in 'merge_to_replacement'. + (*merge_to_replacement)[node] = OutputTensor{if_node_, i}; + ++i; } for (Node* n : external_control_outputs_) { @@ -773,8 +877,9 @@ Status Conditional::AddOutputEdges(Graph* graph) { return Status::OK(); } -Status Conditional::BuildAndReplace(Graph* graph, - FunctionLibraryDefinition* library) { +Status Conditional::BuildAndReplace( + Graph* graph, FunctionLibraryDefinition* library, + std::unordered_map* merge_to_replacement) { VLOG(1) << "Build If and replace merge nodes " << NodesToString(this->merges_); if (replaced_) return Status::OK(); @@ -793,8 +898,8 @@ Status Conditional::BuildAndReplace(Graph* graph, } TF_RETURN_IF_ERROR(BuildIfNode(graph, library)); - TF_RETURN_IF_ERROR(AddInputEdges(graph)); - TF_RETURN_IF_ERROR(AddOutputEdges(graph)); + TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement)); + TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); // Check that the if_node doesn't feed into itself. @@ -913,10 +1018,18 @@ StatusOr FunctionalizeCond::JoinCondStatesNonMerge( both.insert(kv); } else { if (it->second != kv.second) { - return errors::InvalidArgument( - "Graph contains node with inputs predicated on incompatible " - "predicates: ", - DebugString(src), " and ", DebugString(dst)); + if (it->second == BranchType::kNeither) { + // BranchType for 'src' is kNeither. Use the BranchType in 'dst'. + it->second = kv.second; + } else if (kv.second == BranchType::kNeither) { + // BranchType for 'dst' is kNeither. Use the BranchType in 'src'. + // No need to change it->second. + } else { + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + } } } } @@ -936,6 +1049,10 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); if (state_map_.IsEmpty(dst)) return src; + if (state_map_.IsEmpty(src)) { + return errors::Internal("Merge node ", merge->name(), + " has input that's not in any CondContext."); + } if (state_map_.IsDead(src)) return src; if (state_map_.IsDead(dst)) return dst; @@ -981,7 +1098,17 @@ StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { if (id != nullptr) state = *id; OutputTensor predicate; TF_CHECK_OK(GetSwitchPredicate(*src, &predicate)); - if (!e->IsControlEdge()) { + if (e->IsControlEdge()) { + // In gradients of tf.cond(), in each branch, we have a NoOp node as + // control pivot. These NoOp nodes have control dependency from Switch + // node. If we don't record this into CondState, branches might have + // incorrect CondState (e.g. if the branch only has a Const data node). + // We set it to kNeither because there is no way to tell whether it's + // for true branch or false branch. This node's desendents might have + // other incoming edges with defined BranchType, and we correctly handle + // merging kNeither with other defined BranchType in StateAlongEdge(). + state[predicate] = BranchType::kNeither; + } else { state[predicate] = BranchType(e->src_output()); } return state_map_.GetCondId(state); @@ -1170,8 +1297,17 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { if (other_id != id && other_id != nullptr) { state.insert(other_id->begin(), other_id->end()); } - if (IsSwitch(src) || IsMerge(src)) { - state.insert(src); + if (IsMerge(src)) { + state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge}); + } else if (IsSwitch(src)) { + OutputTensor pred; + // For dead switch nodes, GetSwitchPredicate() will fail, and we use + // the switch node directly as ancestor. + if (GetSwitchPredicate(*src, &pred).ok()) { + state.insert({pred, AncestorNode::AncestorNodeType::kPred}); + } else { + state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch}); + } } return state_map_.GetAncestorId(state); }; @@ -1317,16 +1453,30 @@ Status FunctionalizeCond::FunctionalizeInternal() { // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Cluster merge nodes by CondId and AncestorId in order of nesting. - using ClusterPair = std::pair; + // Cluster merge nodes by (CondId, AncestorId, predicate) in order of + // nesting. (CondId, AncestorId) is not enough, e.g. + // pred1 = array_ops.placeholder(dtypes.bool, name='pred1') + // pred2 = array_ops.placeholder(dtypes.bool, name='pred2') + // cond1 = control_flow_ops.cond(pred1, ...) + // cond2 = control_flow_ops.cond(pred2, ...) + // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2) + // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2) + // cond3 and cond4 have the same (CondId, AncestorId), but they should not + // be merged into one "If" node (because they have different predicates). std::deque> merge_clusters; - std::map merge_cluster_index; + std::map merge_cluster_index; for (Node* merge : merge_order) { auto cond_id = state_map_.LookupCondId(merge); if (state_map_.IsDead(cond_id)) continue; - ClusterPair key = - std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto predicate = merge_to_predicate_.find(merge); + if (predicate == merge_to_predicate_.end()) { + return errors::Internal("Cannot find predicate for Merge node ", + merge->name()); + } + + ClusterTuple key = std::make_tuple( + cond_id, state_map_.LookupAncestorId(merge), predicate->second); auto idx = merge_cluster_index.find(key); if (idx == merge_cluster_index.end()) { merge_cluster_index[key] = merge_clusters.size(); @@ -1345,7 +1495,8 @@ Status FunctionalizeCond::FunctionalizeInternal() { Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_); for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); + TF_RETURN_IF_ERROR( + cond.BuildAndReplace(graph_, library_, &merge_to_replacement_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } @@ -1365,9 +1516,8 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " - << dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_cond_", name), *graph_, - library_); + << DumpGraphToFile(absl::StrCat("functionalize_cond_", name), + *graph_, library_); } void FunctionalizeCond::AddSwitchId(int switch_id) { diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 8525d7af61b4471e53a9ae16b081060bfd234c9c..d85800fb8ee65a354716bf6601c6bc40eca9a10d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,6 +43,33 @@ enum class BranchType { kNeither = 3, }; +// When we keep track of which switch/merge node's feed into a node, we record +// 1) predicate for non-dead switch node, +// 2) the switch node itself for dead switch node, +// 3) the merge node itself for merge node. +// Case 1) is an optimization. With this optimization, if there are nodes from +// different switch nodes but those switch nodes have the same predicate, the +// nodes will still have same AncestorState, and they will be clustered into a +// single "If". +struct AncestorNode { + enum class AncestorNodeType { + kPred = 0, + kSwitch = 1, + kMerge = 2, + }; + + OutputTensor output_tensor; + AncestorNodeType type; + + // Compare two AncestorNodes by (node id, index, type). + bool operator<(const AncestorNode& other) const; + bool operator==(const AncestorNode& other) const; + + struct Hash { + size_t operator()(const AncestorNode&) const; + }; +}; + // StateMap is responsible for mapping from each graph Node to // * a CondState, where each CondState is a map from predicate to branch (i,e., // what predicates have to hold or not hold). @@ -68,7 +95,7 @@ class StateMap { using CondId = const CondState*; // Keep track of which switch/merge node's feed into a node's values. - using AncestorState = std::set; + using AncestorState = std::set; // Every unique ID is mapped to a AncestorState. using AncestorId = const AncestorState*; @@ -232,6 +259,9 @@ class FunctionalizeCond { // Mapping from merge nodes to predicate. std::unordered_map merge_to_predicate_; + // Mapping from merge nodes to corresponding If node outputs. + std::unordered_map merge_to_replacement_; + FunctionLibraryDefinition* library_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index b0aabd63bbda784b3b7103a438ce025eea0cd93b..05fa1ee92dc172bd11cec9f99e3884996e00791f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -101,6 +101,17 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { TF_EXPECT_OK(t.status()); } +TEST_F(FunctionalizeCondTest, JoinCondStatesMergeWithInputNotInCondContext) { + Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* m = test::graph::Merge(graph_.get(), val, val); + + StateMap::CondState cond_state; + auto joined_or = JoinCondStatesMerge(m, /*src=*/nullptr, &cond_state); + EXPECT_FALSE(joined_or.ok()); +} + } // namespace } // namespace functionalize_cond } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 3dfd3f854c8646ebbf06d3378201d22e8741b7eb..9fe25dfe3e7d3cf6970851fb9a6b0c56c044da94 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" @@ -43,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -50,8 +50,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph, - library); + << DumpGraphToFile("functionalize_initial", *graph, library); // Functionalize and remove while loops from graph. TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library)); @@ -62,8 +61,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph, - library); + << DumpGraphToFile("functionalize_final", *graph, library); return Status::OK(); } @@ -200,13 +198,13 @@ Status FunctionalizeControlFlowForFunction( // Functionalize the function body. if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("functionalize_control_flow_before_fdef_", func_name), *g, fld); } TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, fld); } @@ -234,8 +232,8 @@ Status FunctionalizeControlFlowPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, - options.flib_def); + DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); } std::unique_ptr pflr( new ProcessFunctionLibraryRuntime( @@ -279,8 +277,8 @@ Status FunctionalizeControlFlowPass::Run( } if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, - options.flib_def); + DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index d87436a7b4ac37c74d0f0df921779c8716290013..517924bfc71d977bdf4222dbfbbfcb692f544f5b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -36,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -293,8 +293,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph, - library); + << DumpGraphToFile("functionalize_before", *graph, library); // Split loop-varying Enter nodes with multiple successors. If the same // Tensor is fed as input to multiple loop arguments, we may end up with a @@ -490,8 +489,8 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) - << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + << DumpGraphToFile("loop_condition", *cond_graph, library) + << " body: " << DumpGraphToFile("loop_body", *body_graph); static std::atomic sequence_num(0LL); int64 id = ++sequence_num; @@ -585,8 +584,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph, - library); + << DumpGraphToFile("functionalize_after", *graph, library); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4535ece374ceb801e450af98a21d5a4c5e8f2a29 --- /dev/null +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { +Status GetTestDevice(Session* session, string* test_device) { + std::vector devices; + TF_RETURN_IF_ERROR(session->ListDevices(&devices)); + + bool found_cpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) { + return device.device_type() == "CPU"; + }); + + bool found_gpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) { + return device.device_type() == "GPU"; + }); + + if (!found_gpu && !found_cpu) { + return errors::Internal("Expected at least one CPU or GPU!"); + } + + *test_device = found_gpu ? "GPU" : "CPU"; + VLOG(2) << "Using test device " << *test_device; + return Status::OK(); +} + +void FillZeros(Tensor* tensor) { + auto flat = tensor->flat(); + for (int i = 0; i < flat.size(); i++) { + flat.data()[i] = 0.0f; + } +} + +// This tests check that the implementation outputs from FusedBatchnorm +// training, reserve_space_{1|2}, are what we assume them to be in the TF/XLA +// lowering. +// +// If this test starts failing then it doesn't indicate that TF/cudnn have +// violated their contract, but it indicates that we need to update the TF/XLA +// lowering for FusedBatchnorm training to match the new implementation defined +// behavior. +TEST(FusedBatchnormReserveSpaceTest, Test) { + using ::tensorflow::ops::Const; + using ::tensorflow::ops::FusedBatchNorm; + + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions{})); + + string test_device; + TF_ASSERT_OK(GetTestDevice(session.get(), &test_device)); + + Scope root = tensorflow::Scope::NewRootScope(); + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + + Tensor scale_data(DT_FLOAT, TensorShape({10})); + FillZeros(&scale_data); + Output scale = + Const(root.WithOpName("scale"), Input::Initializer(scale_data)); + + Tensor offset_data(DT_FLOAT, TensorShape({10})); + FillZeros(&offset_data); + Output offset = + Const(root.WithOpName("offset"), Input::Initializer(offset_data)); + + Tensor mean_data(DT_FLOAT, TensorShape({0})); + Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({0})); + Output variance = + Const(root.WithOpName("variance"), Input::Initializer(variance_data)); + + string tf_device = absl::StrCat("/device:", test_device, ":0"); + string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); + + FusedBatchNorm fused_batch_norm_tf( + root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input, + scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true)); + FusedBatchNorm fused_batch_norm_xla( + root.WithOpName("fused_batch_norm_xla").WithDevice(xla_device), input, + scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true)); + + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + TF_ASSERT_OK(session->Create(graph)); + + Tensor input_data(DT_FLOAT, TensorShape({10, 10, 10, 10})); + auto flat_input = input_data.flat(); + for (int i = 0; i < flat_input.size(); i++) { + flat_input.data()[i] = (i - 5) / 1000.0f; + } + + std::vector results; + TF_ASSERT_OK(session->Run({{"input", input_data}}, + {fused_batch_norm_tf.reserve_space_1.name(), + fused_batch_norm_xla.reserve_space_1.name(), + fused_batch_norm_tf.reserve_space_2.name(), + fused_batch_norm_xla.reserve_space_2.name()}, + {}, &results)); + + test::ExpectClose(results[0], results[1], /*atol=*/1e-4); + test::ExpectClose(results[2], results[3], /*atol=*/1e-4); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 0c2bb0223905b22613a64ad54f07151f7f8590b2..e80b6f50ac37f71e7cb15289b471ccc1310b922b 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" @@ -34,7 +33,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -56,9 +58,9 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, auto client = ctx->compiler()->client(); std::vector arg_must_be_compile_time_constant(expressions.size()); - TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant, - /*compile_time_const_nodes=*/nullptr)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, &arg_must_be_compile_time_constant, + /*compile_time_const_nodes=*/nullptr, ctx->function_library())); args->resize(expressions.size()); for (int i = 0; i < args->size(); ++i) { @@ -87,8 +89,13 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, } break; case XlaExpression::Kind::kResource: + // TODO(b/126601755): This is a fairly common use case in TF 2.0 that + // we can hit when inlining is disabled or fails. return errors::Unimplemented( "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kTensorList: + return errors::Unimplemented( + "TensorList as function argument is not yet implemented."); case XlaExpression::Kind::kInvalid: return errors::InvalidArgument("Invalid function argument"); } @@ -121,6 +128,8 @@ Status GraphCompiler::Compile() { for (Node* n : topo_sorted_nodes) { OpKernel* op_kernel_raw = nullptr; + // The kernel is not actually run for functional ops, we just need it + // for metadata. Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); // Transfer ownership of the kernel to a local smart pointer. std::unique_ptr op_kernel(op_kernel_raw); @@ -154,7 +163,7 @@ Status GraphCompiler::Compile() { OpKernelContext op_context(¶ms, n->num_outputs()); VLOG(3) << "Translating " << params.op_kernel->name(); - if (IsFunctional(n)) { + if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); @@ -179,15 +188,37 @@ Status GraphCompiler::Compile() { return Status::OK(); } -bool GraphCompiler::IsFunctional(Node* n) { - return n->type_string() == FunctionLibraryDefinition::kGradientOp || - (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) != - nullptr); +namespace { + +Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, + const Node& node, NameAttrList* func) { + if (node.IsPartitionedCall()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value)); + if (!attr_value->has_func()) { + return errors::InvalidArgument( + "The attribute value for attribute 'f' in node ", node.DebugString(), + " does not have 'func' field set"); + } + *func = attr_value->func(); + return Status::OK(); + } + + if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) { + func->set_name(node.type_string()); + } else { + func->set_name(FunctionLibraryDefinition::kGradientOp); + } + *func->mutable_attr() = node.def().attr(); + return Status::OK(); } +} // namespace + Status GraphCompiler::CompileFunctionalNode(Node* n, OpKernelContext* op_context) { - TF_RET_CHECK(IsFunctional(n)); + TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)); // For functional nodes, compile them using compiler from the context and call // into the functions. XlaOpKernelContext xla_op_context(op_context); @@ -198,12 +229,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, XlaCompiler* compiler = xla_op_context.compiler(); NameAttrList func; - if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) { - func.set_name(n->def().op()); - } else { - func.set_name(FunctionLibraryDefinition::kGradientOp); - } - *func.mutable_attr() = n->def().attr(); + TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func)); std::vector expressions; @@ -224,7 +250,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); bool add_token_input_output = - HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName); + func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end(); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = false; @@ -244,8 +270,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } if (add_token_input_output) { std::vector token_input_nodes; - TF_RETURN_IF_ERROR( - GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes)); + TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()), + kXlaTokenInputNodesAttrName, + &token_input_nodes)); std::vector token_inputs; for (const string& node_name : token_input_nodes) { auto token_or = compiler->GetNodeToken(node_name); @@ -281,6 +308,7 @@ void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) { params->inputs = &tensor_inputs_; params->step_container = step_container_; params->resource_manager = device_->resource_manager(); + params->function_library = flib_; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index e9f02201cf6bed5495dff7dff76c5bafe7771516..eb02534e7fb42cb5c4f1df710debcafd76594c07 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -73,10 +73,6 @@ class GraphCompiler { // across multiple nodes visit. void PartiallySetupParams(OpKernelContext::Params* params); - // Tests if a node is a functional node. A functional node represents a - // defined computation and should be compiled using `compiler_`. - bool IsFunctional(Node* n); - // Compiles a functional node and writes result to OpkernelContext. A // functional node represents a defined computation and should be compiled // using `compiler_`. diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 47209d285f1a077fd80f779a406e6980892f1646..cf29778688887e9b787f4865b88db4438a1ac88d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -33,7 +33,9 @@ tf_kernel_library( "diag_op.cc", "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", + "einsum_op.cc", "elu_op.cc", + "empty_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", "fake_quantize_ops.cc", @@ -63,6 +65,7 @@ tf_kernel_library( "qr_op.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", + "random_ops_util.h", "reduce_window_op.cc", "reduction_ops.cc", "reduction_ops.h", @@ -88,6 +91,7 @@ tf_kernel_library( "sparse_to_dense_op.cc", "split_op.cc", "stack_ops.cc", + "stateful_random_ops.cc", "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", @@ -106,18 +110,24 @@ tf_kernel_library( "xla_pad_op.cc", "xla_reduce_op.cc", "xla_select_and_scatter_op.cc", + "xla_self_adjoint_eig_op.cc", + "xla_svd_op.cc", ], hdrs = [ "index_ops.h", "shape_util.h", ], + tags = ["optonly"], deps = [ + ":case_op", ":conv_op_helpers", ":if_op", + ":tensor_list_utils", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", + "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", @@ -133,7 +143,6 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", @@ -142,32 +151,46 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:quantize", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/compiler/xla/client/lib:svd", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:framework_bounds_check", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", + "//tensorflow/core:stateful_random_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:concat_lib", + "//tensorflow/core:training_ops_op_lib", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", - "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:list_kernels", - "//tensorflow/core/kernels:no_op", - "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:partitioned_function_ops", "//tensorflow/core/kernels:pooling_ops", "//tensorflow/core/kernels:random_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", + "//tensorflow/core/kernels:stateful_random_ops", "//tensorflow/core/kernels:training_ops", - "//tensorflow/core/kernels:transpose_op", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -212,28 +235,46 @@ cc_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:framework", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core:framework_bounds_check", "//tensorflow/core/kernels:conv_ops", - "//tensorflow/core/kernels:ops_util", "@com_google_absl//absl/types:span", ], ) +cc_library( + name = "tensor_list_utils", + srcs = ["tensor_list_utils.cc"], + hdrs = ["tensor_list_utils.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + tf_kernel_library( name = "while_op", srcs = ["while_op.cc"], hdrs = ["while_op.h"], deps = [ + ":tensor_list_utils", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -254,6 +295,23 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "case_op", + srcs = ["case_op.cc"], + hdrs = ["case_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # Kernels that have a dummy (no-op) implementation. tf_kernel_library( name = "xla_dummy_ops", @@ -287,9 +345,8 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", + "//tensorflow/core:framework_bounds_check", "//tensorflow/core:lib", - "//tensorflow/core/kernels:argmax_op", - "//tensorflow/core/kernels:bounds_check", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 41a453da80dec6b6f57a4d222e2c33ef6b786a10..b8d853bc36bf1682326c9728f9262522a6606a32 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace { @@ -30,9 +32,39 @@ class AddNOp : public XlaOpKernel { OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("AddN requires at least one argument")); - xla::XlaOp sum = ctx->Input(0); - for (int i = 1; i < ctx->num_inputs(); ++i) { - sum = xla::Add(sum, ctx->Input(i)); + XlaExpression::Kind kind = ctx->InputExpression(0).kind(); + xla::XlaOp sum; + switch (kind) { + case XlaExpression::Kind::kTensorList: { + OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &sum)); + TensorShape sum_shape; + OP_REQUIRES_OK(ctx, + GetTensorListBufferShape(ctx->Input(0), &sum_shape)); + for (int i = 1; i < ctx->num_inputs(); ++i) { + xla::XlaOp operand; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(i), &operand)); + // Check that the shapes match. + TensorShape operand_shape; + OP_REQUIRES_OK( + ctx, GetTensorListBufferShape(ctx->Input(i), &operand_shape)); + OP_REQUIRES( + ctx, sum_shape.dim_sizes() == operand_shape.dim_sizes(), + errors::InvalidArgument( + "TensorList arguments to AddN must all have the same ", + "shape.\n", "Expected: ", sum_shape.DebugString(), "\n", + "Found: ", operand_shape.DebugString())); + sum = xla::Add(sum, operand); + } + xla::XlaOp push_index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &push_index)); + OP_REQUIRES_OK(ctx, BuildTensorList(sum, push_index, &sum)); + break; + } + default: + sum = ctx->Input(0); + for (int i = 1; i < ctx->num_inputs(); ++i) { + sum = xla::Add(sum, ctx->Input(i)); + } } ctx->SetOutput(0, sum); @@ -42,7 +74,7 @@ class AddNOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(AddNOp); }; -REGISTER_XLA_OP(Name("AddN"), AddNOp); +REGISTER_XLA_OP(Name("AddN").AllowVariantTypes(), AddNOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 795ea09831e183a26fb3498b9bbaf9c3adaef9ed..5554d7a377d38554058aa731770ee10e400bc535 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -53,7 +53,11 @@ class XlaArgOp : public XlaOpKernel { const XlaExpression& arg = ctx->xla_context()->args()[index_]; OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, errors::InvalidArgument("Invalid/missing argument expression")); - ctx->SetOutputExpression(0, arg); + if (ctx->expected_output_dtype(0) == DT_VARIANT) { + ctx->SetTensorListOutput(0, arg.handle()); + } else { + ctx->SetOutputExpression(0, arg); + } } private: @@ -63,6 +67,8 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); +REGISTER_XLA_OP( + Name("_Arg").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 0e2f335f3354e3ae6008bdc0ac0b80683fe479c1..f1d78c87527eb5f818dcf92209feabe33653a625 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" @@ -34,6 +36,7 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } void Compile(XlaOpKernelContext* ctx) override { @@ -71,7 +74,18 @@ class FusedBatchNormOp : public XlaOpKernel { // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. ctx->SetOutput(3, xla::GetTupleElement(output, 1)); - ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + if (is_on_gpu_) { + // The last two outputs from the FusedBatchNorm training TensorFlow GPU + // op are implementation defined. For now we rely on the in-practice + // behavior of the op: + // output 3 is the mean + // output 4 is rsqrt(variance + epsilon) + xla::XlaOp variance = xla::GetTupleElement(output, 2); + ctx->SetOutput(4, xla::Rsqrt(xla::Add( + variance, xla::ScalarLike(variance, epsilon_)))); + } else { + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + } } else { xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), @@ -89,6 +103,7 @@ class FusedBatchNormOp : public XlaOpKernel { float epsilon_; TensorFormat data_format_; bool is_training_; + bool is_on_gpu_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); @@ -104,6 +119,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } void Compile(XlaOpKernelContext* ctx) override { @@ -130,6 +146,22 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp scale_backprop; xla::XlaOp offset_backprop; if (is_training_) { + if (is_on_gpu_) { + // The last two inputs to the FusedBatchNormGrad training TensorFlow GPU + // op are implementation defined. For now we rely on the in-practice + // behavior of the op: input 3 is the mean input 4 is rsqrt(variance + + // epsilon) + // + // The XLA op expects: + // input 3 is the mean + // input 4 is the variance + // + // so we adjust input 4 here. + xla::XlaOp one = xla::ScalarLike(var, 1.0f); + xla::XlaOp epsilon = xla::ScalarLike(var, epsilon_); + var = xla::Sub(one / (var * var), epsilon); + } + xla::XlaOp output = xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, epsilon_, feature_index); @@ -158,9 +190,8 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) - auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = xla::Pow( - xla::Add(var, xla::ConstantR0(b, epsilon_)), neg_half); + auto epsilon = XlaHelpers::FloatLiteral(b, scale_dtype, epsilon_); + auto scratch1 = xla::Rsqrt(xla::Add(var, epsilon)); // scratch2 = sum(y_backprop * (x - mean)) auto mul = @@ -187,6 +218,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { TensorFormat data_format_; float epsilon_; bool is_training_; + bool is_on_gpu_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index e7f369b761f36a717ea5fb536780af91a8955b1e..33bdf9aec3167b0277f3c1db18c9e247ed9bb5d1 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -48,8 +48,11 @@ class BiasOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_shape), errors::InvalidArgument("Biases must be 1D: ", bias_shape.DebugString())); - int feature_dim = (data_format_ == FORMAT_NHWC) ? input_shape.dims() - 1 - : input_shape.dims() - 3; + + // feature_dim is the channel (C) dimension of the data. + int feature_dim = (data_format_ == FORMAT_NHWC) + ? input_shape.dims() - 1 + : /*data_format == FORMAT_NCHW*/ 1; OP_REQUIRES( ctx, feature_dim >= 0, errors::InvalidArgument("Input tensor does not have enough dimensions " @@ -91,9 +94,10 @@ class BiasAddGradOp : public XlaOpKernel { errors::InvalidArgument("Input tensor must be at least 2D: ", out_backprop_shape.DebugString())); + // feature_dim is the channel (C) dimension of the data. int feature_dim = (data_format_ == FORMAT_NHWC) ? out_backprop_shape.dims() - 1 - : out_backprop_shape.dims() - 3; + : /*data_format == FORMAT_NCHW*/ 1; OP_REQUIRES( ctx, feature_dim >= 0, errors::InvalidArgument("Input tensor does not have enough dimensions " diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 5e9280c1fe692037b0a842a92ef5a8c28b854a54..f69b5dc022287825f05353349cddc6cd0f497a0e 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -77,7 +79,28 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(DivNoNan, DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -// Implementation of FloorDiv. Pseudo-code: +// Implementation of MulNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x * y; +// } +static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y)); + return result; +} +XLA_MAKE_BINARY(MulNoNan, + MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +// Implementation of FloorDiv. +// +// For floating-point values, simply returns floor(x / y). For integers, does: +// // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); // T abs_y = std::abs(y); @@ -88,6 +111,9 @@ XLA_MAKE_BINARY(DivNoNan, static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + if (DataTypeIsFloating(dtype)) { + return xla::Floor(xla::Div(x, y)); + } if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); } @@ -97,11 +123,7 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, auto abs_x = xla::Abs(x); auto abs_y = xla::Abs(y); auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one)); - auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); - if (DataTypeIsFloating(dtype)) { - result = xla::Floor(result); - } - return result; + return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); } XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); @@ -157,7 +179,7 @@ XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, - xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + xla::Mul((lhs * lhs) * lhs, xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), extend_dimensions)); XLA_MAKE_BINARY( @@ -165,12 +187,8 @@ XLA_MAKE_BINARY( xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return xla::Mul(x, x); -} - XLA_MAKE_BINARY(SquaredDifference, - Square(b, xla::Sub(lhs, rhs, extend_dimensions))); + xla::Square(xla::Sub(lhs, rhs, extend_dimensions))); XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); @@ -195,8 +213,8 @@ XLA_MAKE_BINARY(SoftplusGrad, // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, xla::Div(lhs, - Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Abs(rhs))))); + xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); XLA_MAKE_BINARY(TanhGrad, xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), @@ -204,6 +222,8 @@ XLA_MAKE_BINARY(TanhGrad, XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs)); + #undef XLA_MAKE_BINARY class ApproximateEqualOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..24623768f3897179575fe4cec6190a9a877a5202 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -0,0 +1,297 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/case_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } +} + +// TODO(b/35949885): There is duplication here with the handling of the +// while_op. Refactor the common code out/rework. +void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { + xla::XlaBuilder* b = ctx->builder(); + int num_branches = branches_.size(); + OP_REQUIRES(ctx, num_branches >= 1, + errors::InvalidArgument("Must provide at least one case branch")); + OP_REQUIRES(ctx, input_type(0) == DT_INT32, + errors::InvalidArgument( + "branch_index argument must be a int32 for XLA compilation")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), + errors::InvalidArgument( + "branch_index argument must be scalar for XLA compilation")); + + VLOG(1) << "Building Case: " << input_types_.size() << " inputs"; + + std::vector arguments(input_types_.size()); + int num_resource_args = 0; + for (int i = 0; i < input_types_.size(); ++i) { + XlaCompiler::Argument& arg = arguments[i]; + DataType type = ctx->input_type(i + 1); + + if (type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + + arg.type = resource->type(); + arg.shape = resource->shape(); + OP_REQUIRES(ctx, arg.initialized, + errors::Unimplemented("Uninitialized arguments: ", arg.name)); + arg.max_array_size = resource->max_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + arg.name = resource->name(); + VLOG(2) << "Resource " << resource->name() + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.HumanString() + << " initialized: " << arg.initialized; + + num_resource_args++; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input_types_[i]; + arg.shape = ctx->InputShape(i + 1); + VLOG(2) << "Arg type: " << DataTypeString(arg.type) + << " shape: " << arg.HumanString(); + } + } + + // Compile each branch of the conditional. + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.resolve_compile_time_constants = false; + options.return_updated_values_for_all_resources = true; + options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; + XlaCompiler* compiler = ctx->compiler(); + + std::vector branch_results(num_branches); + std::vector branch_results_p(num_branches); + for (int j = 0; j < num_branches; ++j) { + OP_REQUIRES_OK(ctx, + compiler->CompileFunction(options, branches_[j], arguments, + &branch_results[j])); + branch_results_p[j] = &branch_results[j]; + } + + bool has_tensor_array_gradients = false; + for (XlaCompiler::CompilationResult* result : branch_results_p) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + if (!resource->tensor_array_gradients().empty()) { + has_tensor_array_gradients = true; + } + } + } + + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + for (int j = 0; j < num_branches; ++j) { + branch_results[j] = {}; + OP_REQUIRES_OK(ctx, + compiler->CompileFunction(options, branches_[j], arguments, + &branch_results[j])); + } + } + + xla::Shape branch0_input_shape; + std::vector result_computations(num_branches); + for (int j = 0; j < num_branches; ++j) { + // Check that all branches have identical input shapes. + OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape branch_input_shape = branch_results[j].xla_input_shapes[0]; + if (j == 0) { + branch0_input_shape = branch_input_shape; + } + OP_REQUIRES(ctx, branch_input_shape.IsTuple(), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape), + errors::InvalidArgument( + "Input shapes of 0 and ", j, " branches do not match: ", + xla::ShapeUtil::HumanString(branch0_input_shape), " vs. ", + xla::ShapeUtil::HumanString(branch_input_shape))); + + // Check that all branches have identical output shapes. + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(branch_results[0].xla_output_shape, + branch_results[j].xla_output_shape), + errors::InvalidArgument( + "Output shapes of 0 and ", j, " branches do not match: ", + xla::ShapeUtil::HumanString(branch_results[0].xla_output_shape), + " vs. ", + xla::ShapeUtil::HumanString(branch_results[j].xla_output_shape))); + + if (j == 0) { + VLOG(2) << "Input shape: " + << xla::ShapeUtil::HumanString(branch0_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString( + branch_results[0].xla_output_shape); + } + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + OP_REQUIRES(ctx, + branch_results[0].resource_updates.size() == + branch_results[j].resource_updates.size(), + errors::FailedPrecondition( + "Different number of resources in 0 and ", j, " branch")); + for (int i = 0; i < branch_results[0].resource_updates.size(); ++i) { + const auto& lhs = branch_results[0].resource_updates[i]; + const auto& rhs = branch_results[j].resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && + lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + OP_REQUIRES(ctx, equal, + errors::FailedPrecondition("Mismatch in resource of 0 and ", + j, " branch for resource ", i)); + } + result_computations[j] = branch_results[j].computation.get(); + } + + // Prepare the input arg Tuple. + int num_inputs = branch_results[0].input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = branch_results[0].input_mapping[i] + 1; + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "case" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + auto input_tuple = xla::Tuple(b, inputs); + + xla::XlaOp outputs = + xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations), + std::vector(num_branches, input_tuple)); + // Sets non-variable outputs. + for (int i = 0; i < output_types_.size(); ++i) { + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; + } + } + ctx->SetOutput(i, output_handle); + } + if (has_token_input_output_) { + // Set token output for this "Case" op. Token output is the last output of + // XLA computation, which comes after all "normal" TF outputs and resource + // updates. For "Case" node, num of resource updates equals to number of + // resource args because we set `return_updated_values_for_all_resources` + // to true in XlaCompiler option. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } + + // Updates the values of any resource variables modified by the conditional + // bodies. + for (const XlaCompiler::CompilationResult& result : branch_results) { + for (int i = 0; i < result.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = result.resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + if (update.modified) { + int pos = static_cast(result.outputs.size()) + i; + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + xla::GetTupleElement(outputs, pos), b)); + } + VLOG(2) << "Case variable: pos: " << update.input_index + << " name: " << resource->name() + << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + } + } + VLOG(1) << "Done building Case"; +} + +REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ea14b18149cb5bc9162d42b384eb3a5e943ad8be --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional switch/case primitive. +// +// The outputs of the branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in branch bodies may read from and write to resource variables. +// Resource variables may be passed as arguments to the branch function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the branch bodies output. This ensures the branch bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaCaseOp : public XlaOpKernel { + public: + explicit XlaCaseOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp); + + std::vector branches_; + DataTypeVector input_types_; + DataTypeVector output_types_; + bool has_token_input_output_; + std::vector token_input_nodes_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index db58c2e651f0acf0dd4330575b5ce63fd0a1fb69..ca2152d6c103e05c06809d85d9529720ff112217 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -36,19 +36,6 @@ class CastOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); } - xla::PrimitiveType GetUnsignedIntTypeOfSameWidth(int64 src_bitwidth) { - switch (src_bitwidth) { - case 16: - return xla::U16; - case 32: - return xla::U32; - case 64: - return xla::U64; - default: - return xla::PRIMITIVE_TYPE_INVALID; - } - } - void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); xla::XlaOp input = ctx->Input(0); @@ -85,7 +72,7 @@ class CastOp : public XlaOpKernel { // source datatype. int64 mask = ~((1L << mantissa_difference) - 1); xla::PrimitiveType same_width_int = - GetUnsignedIntTypeOfSameWidth(src_bitwidth); + xla::primitive_util::UnsignedIntegralTypeForBitWidth(src_bitwidth); OP_REQUIRES(ctx, same_width_int != xla::PRIMITIVE_TYPE_INVALID, errors::Unimplemented("Unexpected type bitwidth")); input = xla::BitcastConvertType( diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index c2b4c28d1566f5429c5d8109db94af0c3762b131..a99c6ee4431852166eec0a71bb7ad74fd5c135d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -112,9 +113,12 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType type, XlaOpKernelContext* ctx) { xla::XlaBuilder* builder = ctx->builder(); - auto uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + // We want a number in (0, 1) rather than [0, 1) or (0, 1]: + // * log(-log(0)) is ∞. + // * log(-log(1)) is -∞. + auto uniforms = xla::RngUniform( + xla::MinPositiveNormalValue(builder, type), + xla::One(builder, uniform_shape.element_type()), uniform_shape); return xla::Log(-xla::Log(uniforms)); } @@ -143,9 +147,13 @@ class StatelessCategoricalOp : public CategoricalOp { if (uniform_shape.element_type() == xla::BF16) { uniform_shape.set_element_type(xla::F32); } + // We want a number in (0, 1) rather than [0, 1) or (0, 1]: + // * log(-log(0)) is ∞. + // * log(-log(1)) is -∞. auto uniforms = xla::StatelessRngUniform( - {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), - XlaHelpers::One(builder, DT_FLOAT)); + {seed0, seed1}, uniform_shape, + xla::MinPositiveNormalValue(builder, uniform_shape.element_type()), + xla::One(builder, uniform_shape.element_type())); return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec..e6b30a38e0379fc09af07af686f4c5f3a737ecda 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -24,7 +25,9 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); + ctx->SetOutput(0, + xla::Triangle(xla::Cholesky(ctx->Input(0), /*lower=*/true), + /*lower=*/true)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index cd7c7f4a82df7a65829787efcb1fd2f77870e945..09c97de13eb2ed951ca705cda89b7f293808cdf0 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -24,14 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 5b4f863f7418ecda0db502ce25fed2d0042bf3ca..6512ba25ce63a80b89f6148fce5444a7c0fee925 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -26,15 +26,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -203,7 +203,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, absl::Span dilations, const std::vector& strides, - Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, + absl::Span explicit_paddings) { TensorShape input_tensor_shape, filter_tensor_shape, out_backprop_tensor_shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); @@ -212,8 +213,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, - out_backprop_tensor_shape, dilations, strides, padding, - /*explicit_paddings=*/{}, data_format, dims); + out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings, + data_format, dims); } } // anonymous namespace @@ -227,10 +228,9 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); - // TODO(reedwm): Support explicit padding. if (attrs.padding == EXPLICIT) { - return errors::Unimplemented( - "XLA does not yet support Conv2D with explicit padding."); + TF_RETURN_IF_ERROR( + ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); } string data_format; @@ -303,6 +303,11 @@ xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, window_strides[i] = attrs.strides.at(dim); rhs_dilation[i] = attrs.dilations.at(dim); + if (attrs.padding == EXPLICIT) { + padding[i] = {attrs.explicit_paddings.at(dim * 2), + attrs.explicit_paddings.at(dim * 2 + 1)}; + } + int64 unused_output_size; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( input_shape.dimensions(dim), filter_shape.dimensions(i), @@ -337,7 +342,7 @@ xla::StatusOr MakeXlaBackpropInputConvOp( TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, - attrs.data_format, &dims)); + attrs.data_format, &dims, attrs.explicit_paddings)); // The input gradients are computed by a convolution of the output // gradients and the filter, with some appropriate padding. See the @@ -420,7 +425,7 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( type_string, attrs.num_spatial_dims, activations_shape, expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); + attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings)); // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] @@ -441,12 +446,6 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( std::vector window_strides(attrs.num_spatial_dims); std::vector ones(attrs.num_spatial_dims, 1); - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); @@ -475,12 +474,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. // The padded_in_rows should be such that when we convolve this with the // expanded_out_rows as a filter, we should get filter_rows back. - // + const int64 padded_in_size = dims.spatial_dims[i].expanded_output_size + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; @@ -501,6 +502,8 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // We apply negative padding in this case. const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + // + For the EXPLICIT padding, we pad the top/left side with the explicit + // padding and pad the bottom/right side with the remaining space. // + For the VALID padding, we don't pad anything on the top/left side // and pad the bottom/right side with the remaining space. // + For the SAME padding, we pad top/left side the same as bottom/right @@ -509,12 +512,12 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // In addition, if the padded input size is smaller than the input size, // we need to ignore some training elements of the input. We do this by // applying negative padding on the right/bottom. - const int64 pad_before = - attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - + const int64 pad_before = attrs.padding == Padding::EXPLICIT + ? attrs.explicit_paddings[2 * dim] + : attrs.padding == Padding::SAME + ? std::max(pad_total / 2, 0) + : 0; padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = attrs.dilations[dim]; } // Besides padding the input, we will also expand output_rows to diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 6e1b70a47850ae5c05939f8dfb7ec129c031df21..d893eca7f9ba07dded76eb215af4779080fa66b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -47,6 +47,7 @@ struct ConvOpAttrs { std::vector dilations; std::vector strides; Padding padding; + std::vector explicit_paddings; TensorFormat data_format; }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index eafdba876ae9e2c38694f065cf83bb3725b8460e..e74ada47517d1ee98cad5d8523872bbb4eab6e7e 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -25,15 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/conv_grad_ops.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index e96a1adce43c750314715107b4a1954d4a5b4e40..9fe91d16d77d601f94fe35b48cbe97452d7a6a72 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -30,11 +31,6 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Unsupported data format ", - ToString(data_format_), - "; expected formats NHWC or NCHW")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -42,19 +38,36 @@ class DepthToSpaceOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_tensor_shape = ctx->InputShape(0); - int input_rank = input_tensor_shape.dims(); + xla::XlaOp input = ctx->Input(0); + + TensorFormat data_format = data_format_; + // If the data is in a vectorized format, reformat it into a non-vectorized + // version first. We'll undo the transformation later. + if (data_format == FORMAT_NCHW_VECT_C) { + data_format = FORMAT_NCHW; + auto input_reshaped = NCHW_VECT_CToNCHW(input); + OP_REQUIRES_OK(ctx, input_reshaped.status()); + input = input_reshaped.ValueOrDie(); + } + + OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_))); + + xla::XlaBuilder* builder = input.builder(); + auto input_xla_shape = builder->GetShape(input); + OP_REQUIRES_OK(ctx, input_xla_shape.status()); + const std::vector& input_shape = + input_xla_shape.ValueOrDie().dimensions(); + int input_rank = input_shape.size(); + static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const absl::InlinedVector input_shape = - input_tensor_shape.dim_sizes(); - - xla::XlaOp input = ctx->Input(0); - int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); - int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); std::vector reshaped_shape; std::vector transpose_order; @@ -62,7 +75,7 @@ class DepthToSpaceOp : public XlaOpKernel { reshaped_shape.reserve(input_rank); transpose_order.reserve(input_rank); output_shape.reserve(input_rank); - if (data_format_ == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC) { reshaped_shape.push_back(input_shape[0]); for (int i = 0; i < num_spatial_dims; ++i) { reshaped_shape.push_back(input_shape[1 + i]); @@ -153,6 +166,14 @@ class DepthToSpaceOp : public XlaOpKernel { // xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); + // If this used to be a vectorized format turn it back now. + if (data_format != data_format_) { + DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); + auto output_reshaped = NCHWToNCHW_VECT_C(output); + OP_REQUIRES_OK(ctx, output_reshaped.status()); + output = output_reshaped.ValueOrDie(); + } + ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 6e6ba21daf5bf3eab5bfc15378e77b6dd253da7c..b119997cf39e210ed8e0ae730a08829e72b238b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b3334dc1de826d4946eb362223d4428858b23f0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +constexpr std::array kEinsumTypes = {{DT_BFLOAT16, DT_FLOAT}}; + +class EinsumOp : public XlaOpKernel { + public: + explicit EinsumOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("equation", &equation_)); + } + + ~EinsumOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp lhs = ctx->Input(0); + xla::XlaOp rhs = ctx->Input(1); + const TensorShape a_shape = ctx->InputShape(0); + const TensorShape b_shape = ctx->InputShape(1); + ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_)); + } + + private: + string equation_; + TF_DISALLOW_COPY_AND_ASSIGN(EinsumOp); +}; + +REGISTER_XLA_OP(Name("XlaEinsum").TypeConstraint("T", kEinsumTypes), EinsumOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 5fdb1d972c55efb876972d3f472b53a1f7cde1c2..87bb9d49c0c97181bac33da01ec7e0b10cf5d6fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..00d2ce7c12fdc96483612059d1c792c847df04f3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Empty Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class EmptyOp : public XlaOpKernel { + public: + explicit EmptyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // The output of this Op is a tensor of shape 'shape' with each + // element set to the default value of 'dtype'. If 'init' is false then + // the result values may be left undefined, though we don't do that here. + const TensorShape shape_shape = ctx->InputShape("shape"); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_shape), + errors::InvalidArgument("shape must be a vector of int32, got shape ", + shape_shape.DebugString())); + + std::vector shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("shape", &shape)); + + auto default_value = xla::Zero(ctx->builder(), type_); + auto result = xla::Broadcast(default_value, shape); + ctx->SetOutput(0, result); + } + + private: + DataType dtype_; + xla::PrimitiveType type_; + bool init_; +}; + +REGISTER_XLA_OP(Name("Empty").CompileTimeConstantInput("shape"), EmptyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 6df8b5367d2390e65995beb1583b225755e6ee9f..5ac288d8a346b0119892d804941608a286d7b721 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -21,14 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/conv_grad_ops.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -68,6 +67,13 @@ class GenericFftOp : public XlaOpKernel { } for (int i = 0; i < fft_rank_; i++) { int index = input_shape.dims() - fft_rank_ + i; + OP_REQUIRES( + ctx, + input_shape.dim_size(index) == 0 || + input_shape.dim_size(index) >= expected_sizes[i], + errors::InvalidArgument( + "Input dimension ", index, " must have length of at least ", + expected_sizes[i], " but got: ", input_shape.dim_size(index))); if (input_shape.dim_size(index) > expected_sizes[i]) { slice_sizes[index] = expected_sizes[i]; } else { diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index af1085d5b35077b7ebd144bfb2473485e3b3de6b..a4a786e8a0608f3deeee7a0bd0bca5d631d461ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/kernels/partitioned_function_ops.h" namespace tensorflow { namespace { @@ -107,6 +108,10 @@ class SymbolicGradientOp : public AsyncOpKernel { }; REGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp); +REGISTER_XLA_OP(Name("PartitionedCall").AllowResourceTypes(), + PartitionedCallOp); +REGISTER_XLA_OP(Name("StatefulPartitionedCall").AllowResourceTypes(), + PartitionedCallOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 41c31d0ed58fe9bc9bbde0bd58993c975f04fd60..6472045265e4d930a5da770a68f5c502192201ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -167,13 +167,13 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; - } OP_REQUIRES( - context, 0 <= axis && axis < params_dims, + context, -params_dims <= axis && axis < params_dims, errors::InvalidArgument("Expected axis in the range [", -params_dims, ", ", params_dims, "), but got ", axis)); + if (axis < 0) { + axis += params_dims; + } } DataType index_type = input_type(1); diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 19dd38c46ef154ea74bcbb6721dd04924702efcc..8b27e8e85a37bd5aa757b0cdd7e00e9fa3c0cf6e 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -38,9 +38,13 @@ class IdentityOp : public XlaOpKernel { // XLA_* devices also register a "real" Identity operator so we suppress the // dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(), - IdentityOp); -REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(), +REGISTER_XLA_OP( + Name("Identity").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + IdentityOp); +REGISTER_XLA_OP(Name("IdentityN") + .AllowResourceTypes() + .AllowVariantTypes() + .CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 954ae0b596f33243fad1374473c689adb580f6a4..aa5637e2669555da17af8bb05ab08beeba6a89c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -80,7 +80,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.name = resource->name(); VLOG(2) << "Resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.HumanString() << " initialized: " << arg.initialized; num_resource_args++; @@ -89,7 +89,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString(); + << " shape: " << arg.HumanString(); } } diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index b96d45316f626e678a64392a4315979eeeb6e83c..d19d48e5dd95962fe4a4e4026eaf6b06b7898564 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -134,14 +135,15 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Unlike the bilinear kernel, which is triangular, the nearest neighbor @@ -153,11 +155,12 @@ xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { // to the right (because an existing non TPU kernel // for nearest neighbor resize already chose to default to the right, // so we want to be consistent). -xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1, 0.0f); std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f); - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Kernels with more than 16 spatial elements are considered intense and the @@ -165,42 +168,66 @@ xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, absl::Span kernel_size, int64 channels, bool is_kernel_bilinear) { auto make_kernel_func = is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; - auto depthwise_kernel = xla::Broadcast( - xla::Zero(builder, xla::F32), - {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); + std::vector depthwise_kernel_sizes = { + (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}; + auto depthwise_kernel = + xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]), + depthwise_kernel_sizes, /*broadcast_dimensions=*/{1}); - return xla::Mul( - xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[1]), - /*broadcast_dimensions=*/{1}), - make_kernel_func(builder, kernel_size[0]), - /*broadcast_dimensions=*/{0}); + return xla::Mul(depthwise_kernel, + make_kernel_func(builder, type, kernel_size[0]), + /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, + xla::PrimitiveType type, absl::Span kernel_size, int64 channels, int64 dim, bool is_kernel_bilinear) { auto make_kernel_func = is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; - auto depthwise_kernel = - xla::Broadcast(xla::Zero(builder, xla::F32), - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); - return xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[dim]), - /*broadcast_dimensions=*/{dim}); + std::vector depthwise_kernel_sizes = { + dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}; + return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]), + depthwise_kernel_sizes, + /*broadcast_dimensions=*/{dim}); +} + +xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder, + const xla::XlaOp& input, + int32 spatial_dimensions_offset, + absl::Span in_size, + absl::Span out_size) { + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + auto broadcast_shape_or_status = builder->GetShape(input); + if (!broadcast_shape_or_status.ok()) { + return builder->ReportError(broadcast_shape_or_status.status()); + } + xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie(); + for (int32 i = 0; i < in_size.size(); ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + broadcast_shape.set_dimensions(spatial_dimensions_offset + i, + out_size[i]); + } + } + return xla::BroadcastInDim(input, broadcast_shape.dimensions(), + /*broadcast_dimensions=*/{0, 1, 2, 3}); } xla::XlaOp ResizeUsingDilationAndConvolution( - xla::XlaBuilder* builder, const xla::XlaOp& input, - const int num_spatial_dims, std::vector in_size, - std::vector out_size, const int64 channels, const bool align_corners, - bool is_kernel_bilinear) { + xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span out_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { // Picture for a 1x3 to 1x4 bilinear resize: // stride = 2, kernel size = 3 // Input: @@ -287,7 +314,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, channels, is_kernel_bilinear); output = xla::ConvGeneralDilated(input_data, kernel, dims.stride, @@ -299,7 +326,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); output = xla::ConvGeneralDilated( input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ @@ -308,7 +335,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ @@ -320,19 +347,14 @@ xla::XlaOp ResizeUsingDilationAndConvolution( // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && out_size[i] > 1) { - output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); - } - } - return output; + return BroadcastSpatialDimensions( + builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size); } xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( - xla::XlaBuilder* builder, const xla::XlaOp& grad, - const int num_spatial_dims, std::vector in_size, - std::vector grad_size, const int64 channels, + xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span grad_size, const int64 channels, const bool align_corners, bool is_kernel_bilinear) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); @@ -353,19 +375,14 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, channels, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a size == 1 // dimension to a size > 1 dimension. This has the effect of summing the // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = - xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), - /*broadcast_dimensions=*/{i}); - } - } + kernel = BroadcastSpatialDimensions( + builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size); output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, @@ -377,22 +394,22 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a // size == 1 dimension to a size > 1 dimension. This has the effect of // summing the gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { - kernel0 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + kernel0 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, {1}, + {grad_size[0]}); } if (in_size[1] == 1 && grad_size[1] > 1) { - kernel1 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + kernel1 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, + in_size, grad_size); } output = xla::ConvGeneralDilated( @@ -423,7 +440,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( } } if (pad_output) { - output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); + output = xla::Pad(output, xla::Zero(builder, type), padding); } return output; } @@ -458,6 +475,7 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, const int num_spatial_dims = 2; xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. @@ -475,8 +493,11 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } - // Output is always type float. - input = xla::ConvertElementType(input, xla::F32); + // Output is always type float if 'is_kernel_bilinear' is true. + if (is_kernel_bilinear) { + input = xla::ConvertElementType(input, xla::F32); + input_type = xla::F32; + } // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -504,19 +525,19 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels, - align_corners_, is_kernel_bilinear); + b, input, input_type, num_spatial_dims, in_size, next_out_size, + channels, align_corners_, is_kernel_bilinear); input = output; in_size = next_out_size; } else { output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels, + b, input, input_type, num_spatial_dims, in_size, out_size, channels, align_corners_, is_kernel_bilinear); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels, + b, input, input_type, num_spatial_dims, in_size, out_size, channels, align_corners_, is_kernel_bilinear); in_size = out_size; } @@ -631,19 +652,19 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels, - align_corners_, true); + b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size, + channels, align_corners_, true); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, align_corners_, true); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, align_corners_, true); in_size = grad_size; } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 978e9480eac5b522d1ee2d51a61841c6f1bbba0c..c1539f48d4f729510b2d930de91666a7c31f1ef0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -22,12 +22,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 30b993045c86c6d01f8eabe55986f132f8938643..e4bbdef6480104a1051acfc647644deb65c80171 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 47cf8c6675bc120653c2a5ab6d4b07376dc382ee..39d96e748b3a2a852c03c0dd53ec175f0c66a43a 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -25,9 +25,6 @@ limitations under the License. namespace tensorflow { EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*)); - float* input = static_cast(data[0]); int64 input_size = *static_cast(data[1]); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 93f029731c34e84000a3dc00df8af05654cccf2d..7f25d34c3ef82e5360fd2d7c1cd12dd8c6f40507 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 90c0ebefb24ec2c4378782e9b15d3f57c33032a4..5a6569c8954d1686dc9d7577a66feb720241ea13 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { namespace { @@ -31,7 +32,10 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, - /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); + /*lower=*/lower_, /*unit_diagonal=*/false, + /*transpose_a=*/ + adjoint_ ? xla::TriangularSolveOptions::ADJOINT + : xla::TriangularSolveOptions::NO_TRANSPOSE); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 65ab9da8d7ca0509a4a69c43727a0e6c0435908a..da50b75251beb2f97400cc7d2ffb5f4d05a3fb6e 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -13,12 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/no_op.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { +namespace { + +class NoOp : public OpKernel { + public: + explicit NoOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} + bool IsExpensive() override { return false; } +}; + +} // namespace + // XLA_* devices also register a "real" NoOp operator so we suppress the // dummy operator using CompilationOnly(). REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a9b519d8928cc2807831fd6b4f12e60b7d58ea55..6ca100a2f2bf90e1d61829aa45a44cbc97090ed1 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -24,13 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 06c6cc37ec90192486ba15010bfeb763a9ffb987..85223795aa8da93964efc0252eb34df3ebb6df3f 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -26,11 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/pooling_ops_common.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 01b047f732f0e9fb3b45b272e7886e2f8cf4fff4..d6c70d4af1c2e921b70b0869f0163c8481017c7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -279,9 +280,9 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp one = xla::One(b, xla_shape.element_type()); xla::XlaOp min_positive = - XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + xla::MinPositiveNormalValue(b, xla_shape.element_type()); auto uniform = xla::RngUniform(min_positive, one, xla_shape); ctx->SetOutput(0, TruncatedNormal(uniform)); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h new file mode 100644 index 0000000000000000000000000000000000000000..d107be6f13c48e26a4ba67fefa641c6ce811aa80 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise. +// It masks the last 16 bit. With normal rounding, values near "maxval" would be +// converted to "maxval" which is out of range ["minval", "maxval"). In +// addition, the distribution near the limit is not uniform. +xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e4046c795577983bff1a8053743bf4d3a258e583..1f417037284c87753b219ea5ce1d4edce0ce6336 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -37,10 +37,14 @@ class RetvalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const Tensor& input = ctx->op_kernel_context()->input(0); - OP_REQUIRES(ctx, input.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(input.dtype()), - " vs. expect ", DataTypeString(dtype_))); + // DT_VARIANT types represent Tensor Lists and are wrapped in a DT_UINT8 + // tensor so we skip the check here. + if (dtype_ != DT_VARIANT) { + OP_REQUIRES(ctx, input.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(input.dtype()), + " vs. expect ", DataTypeString(dtype_))); + } auto frame = ctx->call_frame(); if (frame) { // If 'frame' is non-null, this is an inner function call inside a JIT @@ -59,8 +63,9 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), - RetvalOp); +REGISTER_XLA_OP( + Name("_Retval").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index d7b38e86cc985d608116488f9e76756a8e904f9c..4d73469fb1858a252906bca190402ab8743e4cfb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -67,84 +67,59 @@ class ReverseSequenceOp : public XlaOpKernel { return; } - // Given the input - // - // 012345 - // 6789AB - // - // and sequence lens {2, 3} we: - // - // 1. Reverse and pad each row to get - // - // 543210XXXXXX - // BA9876XXXXXX - // - // 2. Gather out the suffix from each row to get - // - // 10XXXX - // 876XXX - // - // 3. Select from the input and the array created by (2) to get the result. - // - // 102345 - // 8769AB - const xla::PrimitiveType input_type = context->input_xla_type(0); const xla::PrimitiveType seq_lens_type = context->input_xla_type(1); const int64 max_seq_len = input_shape.dim_size(seq_dim_); - xla::XlaOp rev = xla::Rev(input, {seq_dim_}); - - auto padding_config = xla::MakeNoPaddingConfig(input_shape.dims()); - padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( - max_seq_len); - xla::XlaOp padded = - xla::Pad(rev, xla::Zero(builder, input_type), padding_config); - - // Form a start indices tensor with shape [2, batch_size]. For each batch - // entry we have a (batch offset, seq offset) pair. - xla::XlaOp start_indices = xla::ConcatInDim( + // Create [batch, sequence, 2] tensor that contains the indices where the + // real data belongs + xla::XlaOp back = xla::Sub(seq_lens, xla::ScalarLike(seq_lens, 1)); + xla::XlaOp batch_idx = xla::Iota( builder, - { - xla::Iota(builder, - xla::ShapeUtil::MakeShape(seq_lens_type, {1, batch_size}), - /*iota_dimension=*/1), - xla::Reshape(xla::ScalarLike(seq_lens, max_seq_len) - seq_lens, - {1, batch_size}), - }, - /*dimension=*/0); + xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), + /*iota_dimension=*/0); + xla::XlaOp forward_idx = xla::Iota( + builder, + xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), + /*iota_dimension=*/1); + xla::XlaOp reverse_idx = xla::Sub(back, forward_idx, {0}); + reverse_idx = xla::Select(xla::Lt(reverse_idx, xla::ZerosLike(reverse_idx)), + forward_idx, reverse_idx); + if (batch_dim_ > seq_dim_) { + // The output of the XLA gather op keeps indices dimensions in the same + // order as they appear in the input. If the batch_dim_ needs to be after + // the seq_dim_ in the output, it also needs to be that way in the input + // so we transpose. + batch_idx = xla::Transpose(batch_idx, {1, 0, 2}); + forward_idx = xla::Transpose(forward_idx, {1, 0, 2}); + reverse_idx = xla::Transpose(reverse_idx, {1, 0, 2}); + } + xla::XlaOp start_indices = + xla::ConcatInDim(builder, {batch_idx, reverse_idx}, + /*dimension=*/2); xla::GatherDimensionNumbers dnums; - // The first dimension of start_indices contains the batch/seq dim choice. - dnums.set_index_vector_dim(0); + dnums.set_index_vector_dim(2); + // The first and second element in the third dimension of reverse_idx are + // the batch_dim_ offset and the seq_dim_ offset respectively. dnums.add_start_index_map(batch_dim_); dnums.add_start_index_map(seq_dim_); - // All other dimensions other than the batch dim are offset dimensions. + // batch_dim_ and seq_dim_ are collapsed and the other dimensions are kept + // in the gather. for (int i = 0; i < input_shape.dims(); ++i) { - if (i != batch_dim_) { + if (i != batch_dim_ && i != seq_dim_) { dnums.add_offset_dims(i); + } else { + dnums.add_collapsed_slice_dims(i); } } - dnums.add_collapsed_slice_dims(batch_dim_); auto slice_sizes = input_shape.dim_sizes(); slice_sizes[batch_dim_] = 1; + slice_sizes[seq_dim_] = 1; - xla::XlaOp output = xla::Gather(padded, start_indices, dnums, slice_sizes); - - // Mask out elements after the sequence length, and copy the corresponding - // elements from the input. - xla::XlaOp iota = xla::Iota(builder, seq_lens_type, max_seq_len); - std::vector dims(input_shape.dims(), 1); - dims[batch_dim_] = batch_size; - auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); - - // Broadcast the mask up to the input shape. - mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0(builder, false), - input_shape.dim_sizes())); - - output = xla::Select(mask, output, input); - context->SetOutput(0, output); + context->SetOutput(0, + xla::Gather(input, start_indices, dnums, slice_sizes)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 4b9e1a578be2445091228953df7e5c5e82b42c28..8431724f438f67c07740212e1e31926777fef3ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -23,14 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index a95e7adacf194ba6eb33cbeb56abe1a5a2479337..a1c18bed3f94008af8038f32324c79aa5b2abded 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -110,10 +110,16 @@ class ScatterNdOp : public XlaOpKernel { auto updates = context->Input(1); auto result = XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + /*indices_are_vectors=*/true, /*combiner=*/Combine, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.ValueOrDie()); } + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } }; REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"), diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 9e4c57c9bf73369662274f6b783418e18ff860c2..aaf8c6075dd292e33e70683774a6c1bf374183e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index b1fa2915d59e4e5e2f2523e20e9a37898d087117..7a620d2a6518f8686ef570b33aac971d1dccb6c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -157,9 +157,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const float step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } @@ -171,9 +173,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const double step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 85b0367f73cf31b95e2cd2297e9c1476cfac9d50..280b68383c28d1b9d88f7b2ac0f8fab47244c05d 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -104,7 +104,7 @@ class SizeOp : public XlaOpKernel { for (int64 i = 0; i < rank; ++i) { size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); } - size = xla::ConvertElementType(size, xla::S32); + size = xla::ConvertElementType(size, ctx->output_xla_type(0)); ctx->SetOutput(0, size); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 76ea5f525598f511f295eb5a30f3cf603fbf57aa..b18e3f965c427aec456ce2b188dad79485df23cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/framework/bounds_check.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 88da64e5a217a0c026106f03cb26958f6738446c..1be651da4704d5be1ce0a33312b6a67158a60285 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 3293c13b21bc4825c83f494b7f2d48a9b3000f9e..96863d6d1bae77dedfd02fe6469e53b311b4269a 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -30,11 +31,6 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Unsupported data format ", - ToString(data_format_), - "; expected formats NHWC or NCHW")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -42,19 +38,36 @@ class SpaceToDepthOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_tensor_shape = ctx->InputShape(0); - int input_rank = input_tensor_shape.dims(); + xla::XlaOp input = ctx->Input(0); + + TensorFormat data_format = data_format_; + // If the data is in a vectorized format, reformat it into a non-vectorized + // version first. We'll undo the transformation later. + if (data_format == FORMAT_NCHW_VECT_C) { + data_format = FORMAT_NCHW; + auto input_reshaped = NCHW_VECT_CToNCHW(input); + OP_REQUIRES_OK(ctx, input_reshaped.status()); + input = input_reshaped.ValueOrDie(); + } + + OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_))); + + xla::XlaBuilder* builder = input.builder(); + auto input_xla_shape = builder->GetShape(input); + OP_REQUIRES_OK(ctx, input_xla_shape.status()); + const std::vector& input_shape = + input_xla_shape.ValueOrDie().dimensions(); + int input_rank = input_shape.size(); + static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const absl::InlinedVector input_shape = - input_tensor_shape.dim_sizes(); - - xla::XlaOp input = ctx->Input(0); - int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); - int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); std::vector reshaped_shape; std::vector transpose_order; @@ -62,7 +75,7 @@ class SpaceToDepthOp : public XlaOpKernel { reshaped_shape.reserve(input_rank); transpose_order.reserve(input_rank); output_shape.reserve(input_rank); - if (data_format_ == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC) { int64 block_elems = 1; for (int i = 0; i < num_spatial_dims; ++i) { OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, @@ -157,6 +170,14 @@ class SpaceToDepthOp : public XlaOpKernel { // xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); + // If this used to be a vectorized format turn it back now. + if (data_format != data_format_) { + DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); + auto output_reshaped = NCHWToNCHW_VECT_C(output); + OP_REQUIRES_OK(ctx, output_reshaped.status()); + output = output_reshaped.ValueOrDie(); + } + ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d0c5231e843aefa68490e29475ee96bd92859aac..a93d137e96519837ae289f08ff4d32960970aad9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -24,14 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1d68835e1271b9e83d98eff6c3973d2a2593e5e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -0,0 +1,362 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/stateful_random_ops.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +std::pair GetInputsFromCounter( + xla::XlaOp counter, const int64 size) { + auto builder = counter.builder(); + auto input_u64 = Iota(builder, xla::U64, size); + input_u64 = input_u64 + counter; + counter = counter + xla::ConstantR0(builder, size); + return std::make_pair(xla::Uint64ToUint32s(input_u64), counter); +} + +// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too +// wastefully, only able to generate 2^32*2 int32 numbers for each key, while +// the real capacity is 2^64*2. Counter-space efficiency is important for +// stateful ops, hence the following 2 new functions. +std::pair StatefulRngUniformU32( + xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { + auto builder = key.builder(); + const int64 size = xla::ShapeUtil::ElementsIn(shape); + const int64 half_size = xla::CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + auto inputs_counter = GetInputsFromCounter(counter, half_size); + auto inputs = inputs_counter.first; + counter = inputs_counter.second; + auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key)); + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + auto result = ConcatInDim(builder, outputs, 0); + return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), + counter); +} + +std::pair StatefulRngUniformU64( + xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { + const int64 size = xla::ShapeUtil::ElementsIn(shape); + auto inputs_counter = GetInputsFromCounter(counter, size); + auto inputs = inputs_counter.first; + counter = inputs_counter.second; + auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + auto result = Uint32sToUint64(outputs); + return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), + counter); +} + +std::pair StatefulRngUniform(xla::XlaOp key, + xla::XlaOp counter, + const xla::Shape& shape, + xla::XlaOp minval, + xla::XlaOp maxval) { + auto builder = key.builder(); + xla::PrimitiveType type = shape.element_type(); + switch (type) { + case xla::F32: { + auto bits_counter = StatefulRngUniformU32(key, counter, shape); + auto bits = bits_counter.first; + counter = bits_counter.second; + return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval), + counter); + } + case xla::U32: // fall through + case xla::S32: { + auto bits_counter = StatefulRngUniformU32(key, counter, shape); + auto bits = bits_counter.first; + counter = bits_counter.second; + return std::make_pair( + xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32), + counter); + } + case xla::U64: // fall through + case xla::S64: { + auto bits_counter = StatefulRngUniformU64(key, counter, shape); + auto bits = bits_counter.first; + counter = bits_counter.second; + return std::make_pair( + xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64), + counter); + } + default: + return std::make_pair(builder->ReportError(xla::Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by " + "StatefulRngUniform.")), + counter); + } +} + +template +std::pair map_first(std::function f, std::pair p) { + return std::make_pair(f(p.first), p.second); +} + +std::pair StatefulRngUniformFullInt( + xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { + xla::PrimitiveType type = shape.element_type(); + switch (type) { + case xla::U32: + return StatefulRngUniformU32(key, counter, shape); + case xla::S32: { + // Needs explicit function type because of type-inference failure. + std::function f = [](xla::XlaOp x) { + return BitcastConvertType(x, xla::S32); + }; + return map_first(f, StatefulRngUniformU32(key, counter, shape)); + } + case xla::U64: + return StatefulRngUniformU64(key, counter, shape); + case xla::S64: { + std::function f = [](xla::XlaOp x) { + return BitcastConvertType(x, xla::S64); + }; + return map_first(f, StatefulRngUniformU64(key, counter, shape)); + } + default: + auto builder = key.builder(); + return std::make_pair( + builder->ReportError(xla::Unimplemented( + "Types other than U32, S32, U64 and S64 are not implemented by " + "StatefulRngUniformFullInt; got: %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + counter); + } +} + +template +ListB Map(F f, ListA const& list_a) { + ListB list_b; + for (auto a : list_a) { + list_b.push_back(f(a)); + } + return list_b; +} + +xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, + absl::Span scalars) { + return ConcatInDim( + builder, + Map>( + [](xla::XlaOp x) { return xla::Reshape(x, {1}); }, scalars), + 0); +} + +using sampler_return_type = xla::StatusOr>; + +// A helper function containing the common part of several kernels below. +// Precondition: 'algorithm' and 'shape' are compile-time constants. +Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, + int alg_input_idx, int shape_input_idx, + std::function const& + sample_with_threefry) { + auto alg_shape = ctx->InputShape(alg_input_idx); + if (alg_shape.dims() != 0) { + return errors::InvalidArgument("algorithm must be of shape [], not ", + alg_shape.DebugString()); + } + xla::Literal alg_literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); + auto alg = alg_literal.Get({}); + + if (alg == RNG_ALG_THREEFRY) { + xla::XlaOp var; + TensorShape var_shape; + TF_RETURN_IF_ERROR(ctx->ReadVariableInput( + state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var)); + if (var_shape.dims() != 1) { + return errors::InvalidArgument( + "RNG state must have one and only one dimension, not ", + var_shape.dims()); + } + auto state_size = var_shape.dim_size(0); + if (state_size < THREEFRY_MIN_STATE_SIZE) { + return errors::InvalidArgument( + "For the ThreeFry algorithm, the size of state" + " must be at least ", + THREEFRY_MIN_STATE_SIZE, "; got ", state_size); + } + TensorShape shape; + TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape)); + + static constexpr int COUNTER_SIZE = 1; + auto counter = BitcastConvertType( + xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64); + auto key = BitcastConvertType( + xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}), + {}), + xla::U64); + + auto status_or_value = sample_with_threefry(counter, key, shape); + if (!status_or_value.ok()) { + return status_or_value.status(); + } + auto output_counter = status_or_value.ConsumeValueOrDie(); + auto output = output_counter.first; + counter = output_counter.second; + ctx->SetOutput(0, output); + auto builder = ctx->builder(); + var = ConcatScalars(builder, {counter, key}); + xla::PrimitiveType state_element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); + var = BitcastConvertType(var, state_element_type); + TF_RETURN_IF_ERROR( + ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); + return Status::OK(); + } else { + return errors::InvalidArgument("Unsupported algorithm id: ", alg); + } +} + +class StatefulStandardNormalOp : public XlaOpKernel { + public: + explicit StatefulStandardNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + auto sample_with_threefry = + // Needs explicit lambda return type because it fails to be inferred. + [builder, this](xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto uniform_counter = StatefulRngUniform( + key, counter, xla_shape, + xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), + xla::ConstantR0(builder, 1.0)); + auto uniform = uniform_counter.first; + counter = uniform_counter.second; + // Convert uniform distribution to normal distribution by computing + // sqrt(2) * erfinv(x) + auto normal = + xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); + normal = MaybeConvertF32ToBF16(normal, dtype_); + return {{normal, counter}}; + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp); +}; + +// TODO(wangpeng): Support plain float16 and float64 to get rid of the +// `TypeConstraint`. +REGISTER_XLA_OP(Name("StatefulStandardNormalV2") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), + StatefulStandardNormalOp); + +class StatefulUniformIntOp : public XlaOpKernel { + public: + explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp minval = ctx->Input(3); + xla::XlaOp maxval = ctx->Input(4); + auto sample_with_threefry = [minval, maxval, this]( + xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + return StatefulRngUniform(key, counter, xla_shape, minval, maxval); + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp); +}; + +REGISTER_XLA_OP(Name("StatefulUniformInt") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", + {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}), + StatefulUniformIntOp); + +class StatefulUniformFullIntOp : public XlaOpKernel { + public: + explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto sample_with_threefry = [this]( + xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + return StatefulRngUniformFullInt(key, counter, xla_shape); + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp); +}; + +REGISTER_XLA_OP(Name("StatefulUniformFullInt") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", + {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}), + StatefulUniformFullIntOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 50653d7b3973b73d580cdeec5d71943b575d7cc9..e143a711730720c0566f079e00965d876c869a99 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,12 +32,8 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { -namespace { xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { - // Mask the last 16 bit. With normal rounding, values near "maxval" would be - // converted to "maxval" which is out of range ["minval", "maxval"). In - // addition, the distribution near the limit is not uniform. if (dtype == DT_BFLOAT16) { xla::XlaBuilder* builder = input.builder(); auto output = xla::BitcastConvertType(input, xla::U32) & @@ -48,6 +45,26 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { } } +xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) { + // Convert uniform distribution to normal distribution by computing + // sqrt(2) * erfinv(x) + return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); +} + +// A wrapper of xla::StatelessRngUniform. Returns an op that produces random +// values with uniform distribution in the range [minval, maxval) for the given +// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and +// S64 are implemented. +xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype, + xla::XlaOp seed, xla::XlaOp minval, + xla::XlaOp maxval) { + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval); +} + +namespace { + class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -69,12 +86,8 @@ class StatelessRandomUniformOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - - auto uniform = xla::StatelessRngUniform( - {seed0, seed1}, xla_shape, xla::ConstantR0(builder, 0.0), + xla::XlaOp uniform = StatelessRandomUniformImpl( + xla_shape, dtype_, seed, xla::ConstantR0(builder, 0.0), xla::ConstantR0(builder, 1.0)); uniform = MaybeConvertF32ToBF16(uniform, dtype_); ctx->SetOutput(0, uniform); @@ -123,12 +136,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); - - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - - auto uniform = - xla::StatelessRngUniform({seed0, seed1}, xla_shape, minval, maxval); + xla::XlaOp uniform = + StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval); ctx->SetOutput(0, uniform); } @@ -164,18 +173,11 @@ class StatelessRandomNormalOp : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - - auto uniform = xla::StatelessRngUniform( - {seed0, seed1}, xla_shape, + xla::XlaOp uniform = StatelessRandomUniformImpl( + xla_shape, dtype_, seed, xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), xla::ConstantR0(builder, 1.0)); - // Convert uniform distribution to normal distribution by computing - // sqrt(2) * erfinv(x) - auto normal = - xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); + xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform); normal = MaybeConvertF32ToBF16(normal, dtype_); ctx->SetOutput(0, normal); } @@ -211,18 +213,15 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - auto uniform = xla::StatelessRngUniform( - {seed0, seed1}, xla_shape, - xla::ConstantR0(builder, std::numeric_limits::min()), - xla::ConstantR0(builder, 1.0)); - auto output = TruncatedNormal(uniform); - output = MaybeConvertF32ToBF16(output, dtype_); - ctx->SetOutput(0, output); + xla::XlaOp uniform = StatelessRandomUniformImpl( + xla_shape, dtype_, seed, + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); + xla::XlaOp truncated_normal = TruncatedNormal(uniform); + truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); + ctx->SetOutput(0, truncated_normal); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 2273b592466431f59abcc43fcac4c37eecd53bff..9da1504bff12b54c9ae10cb5c2fa00214642b551 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" @@ -291,7 +291,7 @@ class StridedSliceAssignOp : public XlaOpKernel { absl::InlinedVector slice_begin; absl::InlinedVector slice_dims; for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 + // TODO(b/121179231): implement strides != 1 OP_REQUIRES( ctx, strides[i] == 1 || strides[i] == -1, errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 22e64789a3a0db104834baa8a2a914d2cd1742d1..b98b98ce50af2cb811297989899b06d33296bf13 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -27,14 +27,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 5dda62fd492a2e3bf9caf502c04f8ba8674a6510..9bc565e5d638bca682fadfc93044adb550c12893 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -18,36 +18,81 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { + namespace { -Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, - TensorShape* tensor_list_shape) { - auto shape_or_status = builder->GetShape(op); - if (!shape_or_status.ok()) { - return shape_or_status.status(); +class TensorListLengthOp : public XlaOpKernel { + public: + explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index)); + ctx->SetOutput(0, index); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp); +}; + +REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); + +// Creates an empty list with size (leading_dim, *element_shape) if +// element_shape is known at compile time. Otherwise creates one with size +// (leading_dim, 0) which gets initialized later in `GetInitializedList`. +Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index, + int64 leading_dim, DataType dtype, xla::XlaOp* list) { + TensorShape list_shape; + list_shape.AddDim(leading_dim); + xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); + TF_ASSIGN_OR_RETURN( + bool is_element_shape_compile_time_const, + element_shape_handle.builder()->IsConstant(element_shape_handle)); + PartialTensorShape partial_element_shape; + if (is_element_shape_compile_time_const) { + TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape( + element_shape_index, &partial_element_shape)); + } + if (is_element_shape_compile_time_const && + partial_element_shape.IsFullyDefined()) { + TensorShape element_shape; + partial_element_shape.AsTensorShape(&element_shape); + list_shape.AppendShape(element_shape); + } else { + // If element_shape is not a compile time constant or if it is not fully + // defined we will have to wait for the first write call to fully allocate + // the array. + // TODO(srbs): We are using element_shape of [0] as a proxy to denote an + // uninitialized list. A better implementation may be to represent the + // list as a 3-tuple containining an explicit "initialized" flag. However, + // we would still need to create a dummy tensor for the first tuple + // element. + list_shape.AddDim(0); } - xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(shape.IsTuple()); - return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), - tensor_list_shape); + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + return Status::OK(); } class TensorListReserveOp : public XlaOpKernel { @@ -57,19 +102,18 @@ class TensorListReserveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); - TensorShape tensor_shape; - tensor_shape.AddDim(num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &buffer)); - xla::XlaBuilder* b = ctx->builder(); - ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, BuildTensorList( + buffer, xla::ConstantR0(ctx->builder(), num_elements), + &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -90,8 +134,6 @@ class EmptyTensorListOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); OP_REQUIRES( @@ -99,14 +141,15 @@ class EmptyTensorListOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires a fixed tensor list " "size. Set the max number of elements.")); - TensorShape tensor_shape; - tensor_shape.AddDim(max_num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, + CreateZerosList(ctx, 0, max_num_elements, dtype_, &buffer)); - xla::XlaBuilder* b = ctx->builder(); - ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, BuildTensorList(buffer, xla::ConstantR0(ctx->builder(), 0), + &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -130,7 +173,7 @@ class TensorListElementShapeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); shape.RemoveDim(0); switch (shape_type_) { @@ -160,6 +203,185 @@ class TensorListElementShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); +class TensorListGetItemOp : public XlaOpKernel { + public: + explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); + + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &buffer)); + xla::XlaOp index = ctx->Input(1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + xla::XlaOp read = xla::DynamicSlice(buffer, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp); + +class TensorListStackOp : public XlaOpKernel { + public: + explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer)); + ctx->SetOutput(0, buffer); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp); +}; + +REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); + +class TensorListFromTensorOp : public XlaOpKernel { + public: + explicit TensorListFromTensorOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + PartialTensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsPartialShape(1, &element_shape)); + + const TensorShape tensor_shape = ctx->InputShape(0); + // Ensure that tensor_shape is compatible with element_shape. + PartialTensorShape unused; + OP_REQUIRES_OK( + ctx, + element_shape.MergeWith( + PartialTensorShape( + absl::Span(tensor_shape.dim_sizes()).subspan(1)), + &unused)); + OP_REQUIRES(ctx, tensor_shape.dims() > 0, + errors::InvalidArgument("Input value must be at least a " + "vector but received shape: ", + tensor_shape.DebugString())); + const int num_elements = tensor_shape.dim_size(0); + + xla::XlaBuilder* b = ctx->builder(); + const xla::XlaOp tensor = ctx->Input(0); + + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, BuildTensorList(tensor, xla::ConstantR0(b, num_elements), + &output_list)); + ctx->SetTensorListOutput(0, output_list); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp); +}; + +REGISTER_XLA_OP( + Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), + TensorListFromTensorOp); + +// Returns the 0'th element of `tuple` containing the list tensor if it has been +// initialized already else creates one lazily. This allows lazy initialization +// of the list on the first call to SetItem or PushBack. +Status GetInitializedList(const xla::XlaOp& input_list, + const TensorShape& element_shape, DataType dtype, + xla::XlaOp* output_list_buffer) { + bool is_already_initialized; + TF_RETURN_IF_ERROR( + IsTensorListInitialized(input_list, &is_already_initialized)); + TensorShape input_list_shape; + TF_RETURN_IF_ERROR(GetTensorListBufferShape(input_list, &input_list_shape)); + TensorShape input_list_element_shape = input_list_shape; + input_list_element_shape.RemoveDim(0); + + if (is_already_initialized) { + TF_RET_CHECK(element_shape == input_list_element_shape); + TF_RETURN_IF_ERROR(GetTensorListBuffer(input_list, output_list_buffer)); + return Status::OK(); + } + + int64 leading_dim = input_list_shape.dim_size(0); + TensorShape output_list_shape = element_shape; + output_list_shape.InsertDim(0, leading_dim); + + xla::XlaOp output_list; + TF_RETURN_IF_ERROR( + InitializeTensorList(input_list, output_list_shape, &output_list)); + TF_RETURN_IF_ERROR(GetTensorListBuffer(output_list, output_list_buffer)); + return Status::OK(); +} + +class TensorListSetItemOp : public XlaOpKernel { + public: + explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp tl = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(2); + + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetInitializedList(tl, elem_shape, dtype_, &buffer)); + xla::XlaOp push_index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tl, &push_index)); + + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + xla::XlaOp output_list; + OP_REQUIRES_OK(ctx, BuildTensorList(xla::DynamicUpdateSlice(buffer, update, + start_indices), + push_index, &output_list)); + ctx->SetTensorListOutput(0, output_list); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp); + class TensorListPushBackOp : public XlaOpKernel { public: explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -168,11 +390,15 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp list = ctx->Input(0); + xla::XlaOp list_tuple = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(list, 0); - xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, + GetInitializedList(list_tuple, elem_shape, dtype_, &buffer)); + + xla::XlaOp index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list_tuple, &index)); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -184,11 +410,12 @@ class TensorListPushBackOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. - ctx->SetOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), - index + xla::ConstantR0(b, 1)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, + BuildTensorList(xla::DynamicUpdateSlice(buffer, update, start_indices), + index + xla::ConstantR0(b, 1), &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -210,10 +437,12 @@ class TensorListPopBackOp : public XlaOpKernel { xla::XlaOp state = ctx->Input(0); TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); - xla::XlaOp ta = xla::GetTupleElement(state, 0); - xla::XlaOp index = xla::GetTupleElement(state, 1); + xla::XlaOp ta; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &ta)); + xla::XlaOp index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(state, &index)); index = index - xla::ConstantR0(b, 1); @@ -224,13 +453,13 @@ class TensorListPopBackOp : public XlaOpKernel { auto slice_shape = shape.dim_sizes(); slice_shape[0] = 1LL; - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + xla::XlaOp output_list; + OP_REQUIRES_OK(ctx, BuildTensorList(ta, index, &output_list)); + ctx->SetTensorListOutput(0, output_list); ctx->SetOutput(1, xla::Reshape(read, value_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa6ee2ac35e3584ce5580d06fc02c5fb97f54edd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +bool IsTensorListInput(XlaOpKernelContext* ctx, int index) { + return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList; +} + +Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, + xla::XlaOp* output_list) { + TF_RET_CHECK(buffer.builder()); + *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); + return Status::OK(); +} + +Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) { + TF_RET_CHECK(op.builder()); + *buffer = xla::GetTupleElement(op, 0); + return Status::OK(); +} + +Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) { + TF_RET_CHECK(op.builder()); + *push_index = xla::GetTupleElement(op, 1); + return Status::OK(); +} + +Status GetTensorListBufferShape(const xla::XlaOp& op, + TensorShape* buffer_shape) { + TF_RET_CHECK(op.builder()); + TensorShape shape; + TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape, + op.builder()->GetShape(op)); + return GetTensorListBufferShape(list_tuple_shape, buffer_shape); +} + +Status GetTensorListBufferShape(const xla::Shape& list_shape, + TensorShape* buffer_shape) { + TF_RET_CHECK(list_shape.IsTuple()); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape)); + return Status::OK(); +} + +Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) { + TensorShape list_shape; + TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape)); + *is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0); + return Status::OK(); +} + +Status InitializeTensorList(const xla::XlaOp& uninitialized_list, + const TensorShape& buffer_shape, + xla::XlaOp* output_list) { + TensorShape input_buffer_shape; + TF_RETURN_IF_ERROR( + GetTensorListBufferShape(uninitialized_list, &input_buffer_shape)); + if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) { + return errors::InvalidArgument( + "Number of elements in input list does not match buffer size. ", + "input list size: ", input_buffer_shape.dim_size(0), + "buffer size: ", buffer_shape.dim_size(0)); + } + xla::XlaBuilder* builder = uninitialized_list.builder(); + xla::XlaOp input_buffer; + TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer)); + TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape, + builder->GetShape(input_buffer)); + auto new_buffer = xla::Broadcast( + xla::ConstantLiteral(builder, xla::LiteralUtil::Zero( + input_buffer_xla_shape.element_type())), + buffer_shape.dim_sizes()); + xla::XlaOp push_index; + TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index)); + return BuildTensorList(new_buffer, push_index, output_list); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..937af6f8d77499248b46069822cac291aae0d60b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ + +// TensorList utilities. +// +// Tensor lists are represented as tuple consisting of a pre-allocated buffer +// consisting of the tensors (and where dim 0 is the list index), along with a +// scalar telling us the next index to push a value at. + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +// Whether the input expression at `index` corresponds to a TensorList. +bool IsTensorListInput(XlaOpKernelContext* ctx, int index); + +// Builds a TensorList from its constituents, `buffer` and `push_index`. +Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, + xla::XlaOp* output_list); + +// Returns the buffer for the TensorList. +Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer); + +// Returns the push_index for the TensorList. +Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index); + +// Returns the shape of the TensorList buffer. +Status GetTensorListBufferShape(const xla::XlaOp& op, + TensorShape* buffer_shape); + +// Inputs the TensorList shape and returns the buffer shape. +Status GetTensorListBufferShape(const xla::Shape& list_shape, + TensorShape* buffer_shape); + +// Returns whether the TensorList has been initialized. +// +// A TensorList is considered initialized if its element_shape is completely +// known. +Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized); + +// Inputs an uninitialized list and a buffer_shape and returns an initialized +// list. The initialized list uses the dtype and push index of the uninitialized +// list and is filled with zeros. +Status InitializeTensorList(const xla::XlaOp& uninitialized_list, + const TensorShape& buffer_shape, + xla::XlaOp* output_list); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index ee3bdf3394e37c757f31724e73e95417becaa534..22cfd16008899c1ad3c73453bec34a0b0d2e8c78 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 26d4214099d1d07c1b2e275d783654d9cd948e28..247db8d5d172b04e414b1ff0e53f12b533f36944 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { @@ -856,15 +855,12 @@ class ResourceApplyAdadelta : public XlaOpKernel { xla::XlaOp grad = ctx->Input(6); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - accum = rho * accum + (one - rho) * xla::Pow(grad, two); - xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * - xla::Pow(accum + epsilon, neg_half) * grad; - accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + accum = rho * accum + (one - rho) * xla::Square(grad); + xla::XlaOp update = + xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Square(update); var = var - update * lr; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c9b324a243e4cc3ec64daa3ca0d285336a0d0154..65569576d4146dfdd1464251c9b42a7621d3ebc8 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -18,15 +18,15 @@ limitations under the License. // handles all transposes, while Eigen needs a restricted DoTranspose // helper. -#include "tensorflow/core/kernels/transpose_op.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -128,29 +128,46 @@ class InvertPermutationOp : public XlaOpKernel { errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); - std::vector perm; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); - - int size = perm.size(); + auto e = ctx->InputExpression(0); + auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client()); + OP_REQUIRES_OK(ctx, tensor_or_status.status()); + // If the input is a constant, we also want the output to be a constant. + // Some models rely on the result of InvertPermutation being a constant. + // TODO(b/32495713): Remove this when we can check whether Scatter is + // constant. Right now, we always assume it is non-constant because we don't + // check the embedded computation. + if (tensor_or_status.ValueOrDie().has_value()) { + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); + + int size = perm.size(); + + std::vector output(size); + std::fill_n(output.data(), size, -1); + for (int i = 0; i < size; ++i) { + const int64 d = perm[i]; + OP_REQUIRES(ctx, FastBoundsCheck(d, size), + errors::InvalidArgument(d, " is not between 0 and ", size)); + OP_REQUIRES(ctx, output[d] == -1, + errors::InvalidArgument(d, " is duplicated in the input.")); + output[d] = i; + } - std::vector output(size); - std::fill_n(output.data(), size, -1); - for (int i = 0; i < size; ++i) { - const int64 d = perm[i]; - OP_REQUIRES(ctx, FastBoundsCheck(d, size), - errors::InvalidArgument(d, " is not between 0 and ", size)); - OP_REQUIRES(ctx, output[d] == -1, - errors::InvalidArgument(d, " is duplicated in the input.")); - output[d] = i; + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); + } else { + auto indices = ctx->Input(0); + int size = ctx->InputShape(0).num_elements(); + auto iota = xla::Iota(ctx->builder(), xla::S32, size); + auto result = XlaScatter(iota, iota, indices, + /*indices_are_vectors=*/false, /*combiner=*/{}, + ctx->builder()); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } - - ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } }; -REGISTER_XLA_OP(Name("InvertPermutation") - .TypeConstraint("T", DT_INT32) - .CompileTimeConstantInput("x"), +REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a0ea6422d732b00fc1b8cf855d9c9ad603b87c82..7c4176eb839f85e6d68565d22e04f982354a7282 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -65,11 +65,8 @@ XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); -XLAJIT_MAKE_UNARY( - IsInf, - xla::Eq(xla::Abs(x), - xla::ScalarLike(x, std::numeric_limits::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); +XLAJIT_MAKE_UNARY(IsInf, xla::IsInf(x)); +XLAJIT_MAKE_UNARY(IsNan, xla::IsNan(x)); // Return 1/x XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); @@ -92,8 +89,9 @@ xla::XlaOp Sigmoid(xla::XlaOp x) { } XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); -// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); +// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, + xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x))); XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) @@ -116,82 +114,10 @@ XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); XLAJIT_MAKE_UNARY(Real, xla::Real(x)); XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); - -#undef XLAJIT_MAKE_UNARY - -// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial -// is used outside of this range. -class ErfOp : public XlaOpKernel { - public: - explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp x = ctx->Input(0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); - auto y = - xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x)); - ctx->SetOutput(0, y); - } -}; -REGISTER_XLA_OP(Name("Erf"), ErfOp); - -class ErfcOp : public XlaOpKernel { - public: - explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp x = ctx->Input(0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); - auto y = - xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x)); - ctx->SetOutput(0, y); - } -}; -REGISTER_XLA_OP(Name("Erfc"), ErfcOp); - -class LgammaOp : public XlaOpKernel { - public: - explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - // Calculate lgamma using the Lanczos approximation - // (https://en.wikipedia.org/wiki/Lanczos_approximation). - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp input = ctx->Input(0); - xla::PrimitiveType input_type = ctx->input_xla_type(0); - - if (input_type == xla::F16 || input_type == xla::BF16) { - // The approximation works better with at least 32-bits of accuracy. - xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); - xla::XlaOp result_f32 = xla::Lgamma(input_f32); - xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); - ctx->SetOutput(0, result_x16); - } else { - xla::XlaOp result = xla::Lgamma(input); - ctx->SetOutput(0, result); - } - } -}; // namespace -REGISTER_XLA_OP(Name("Lgamma"), LgammaOp); - -class DigammaOp : public XlaOpKernel { - public: - explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - // Calculate lgamma using the Lanczos approximation - // (https://en.wikipedia.org/wiki/Lanczos_approximation). - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp input = ctx->Input(0); - xla::PrimitiveType input_type = ctx->input_xla_type(0); - - if (input_type == xla::F16 || input_type == xla::BF16) { - // The approximation works better with at least 32-bits of accuracy. - xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); - xla::XlaOp result_f32 = xla::Digamma(input_f32); - xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); - ctx->SetOutput(0, result_x16); - } else { - xla::XlaOp result = xla::Digamma(input); - ctx->SetOutput(0, result); - } - } -}; // namespace -REGISTER_XLA_OP(Name("Digamma"), DigammaOp); +XLAJIT_MAKE_UNARY(Erf, xla::Erf(x)); +XLAJIT_MAKE_UNARY(Erfc, xla::Erfc(x)); +XLAJIT_MAKE_UNARY(Lgamma, xla::Lgamma(x)); +XLAJIT_MAKE_UNARY(Digamma, xla::Digamma(x)); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 8671632976023fded04c26a9780c1a67638b0916..2d95f2f30a86f3a9c95e528858c53ab48d7a02e8 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -24,13 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 941b04363f8386a7bdbe8c91ea34c9754592a52d..885031ca0b8a57731f8020937307bd37624d41f1 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -25,21 +27,27 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +const char kPropagateCompileTimeConsts[] = "_xla_propagate_compile_time_consts"; + namespace { // Builds XlaCompiler argument descriptions `args` from `ctx`. Status MakeXlaCompilerArgumentsFromInputs( XlaOpKernelContext* ctx, std::vector* args, - bool* has_uninitialized_vars, bool* has_tensor_arrays) { + bool* has_uninitialized_vars, bool* has_tensor_arrays, + bool* has_uninitialized_tensor_lists) { VLOG(2) << "Num inputs " << ctx->num_inputs(); args->resize(ctx->num_inputs()); *has_uninitialized_vars = false; *has_tensor_arrays = false; + *has_uninitialized_tensor_lists = false; for (int i = 0; i < ctx->num_inputs(); ++i) { VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); @@ -70,18 +78,177 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.ShapeHumanString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; - arg.type = ctx->input_type(i); - arg.shape = ctx->InputShape(i); + arg.type = type; + TF_ASSIGN_OR_RETURN(arg.shape, ctx->builder()->GetShape(ctx->Input(i))); + if (IsTensorListInput(ctx, i)) { + // arg.initialized == false means that the element_shape of the list + // was not available at the time of building the list so an empty list + // was created instead. If so, the body function of While is run once + // to infer the shape of the list before actually building the While op. + TF_RETURN_IF_ERROR( + IsTensorListInitialized(ctx->Input(i), &arg.initialized)); + if (!arg.initialized) { + *has_uninitialized_tensor_lists = true; + } + } } } return Status::OK(); } +// Populates loop invariant indices to true in `loop_invariants`. +void GetLoopInvariants(XlaOpKernelContext* ctx, + const NameAttrList& body_name_attr, + std::vector* const loop_invariants) { + const FunctionBody* body; + OP_REQUIRES_OK(ctx, ctx->compiler()->FindFunctionBody(body_name_attr, &body)); + for (int i = 0; i < body->ret_nodes.size(); i++) { + const Node* arg = body->arg_nodes[i]; + const Node* ret = body->ret_nodes[i]; + const Node* ret_input_0; + OP_REQUIRES_OK(ctx, ret->input_node(0, &ret_input_0)); + (*loop_invariants)[i] = ret_input_0->id() == arg->id(); + } +} + +// Converts entries in `args` which are loop invariants and have compile +// time constant inputs to constants so that they can be propagated in the loop +// body. +Status ConvertLoopInvariantsToConst( + XlaOpKernelContext* ctx, const NameAttrList& body_name_attr, + std::vector* args, + std::vector* compile_time_const_arg_indices, + int* num_compile_time_const_args, xla::Client* client) { + std::vector loop_invariants(ctx->num_inputs()); + GetLoopInvariants(ctx, body_name_attr, &loop_invariants); + for (int i = 0; i < ctx->num_inputs(); i++) { + XlaCompiler::Argument& arg = (*args)[i]; + const XlaExpression& expression = ctx->InputExpression(i); + // If this is a loop invariant and the input tensor is a compile time + // constant build a kConstant type argument. + if (arg.kind != XlaCompiler::Argument::kResource && loop_invariants[i]) { + // NOTE: We can not simple check that this is Kind::kConstant because + // this could be the output of a MetadataOnly op e.g. Size. + xla::StatusOr> maybe_constant = + expression.ResolveConstant(client); + if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = expression.dtype(); + arg.constant_value = std::move(maybe_constant.ValueOrDie().value()); + arg.shape = expression.GetShape().ValueOrDie(); + compile_time_const_arg_indices->at(i) = true; + (*num_compile_time_const_args)++; + } + } + } + return Status::OK(); +} + +Status VerifyBodyInputAndOutputShapeMatch( + XlaOpKernelContext* ctx, + const std::vector& compile_time_const_arg_indices, + const XlaCompiler::CompilationResult& body, bool has_token_input_output) { + xla::Shape body_input_shape = body.xla_input_shapes[0]; + xla::Shape body_output_shape; + body_output_shape.set_element_type(xla::TUPLE); + for (int i = 0; i < ctx->num_outputs(); i++) { + if (!compile_time_const_arg_indices[i]) { + *(body_output_shape.add_tuple_shapes()) = + body.xla_output_shape.tuple_shapes(i); + } + } + // If `body` has a token output, append its shape to `body_output_shape`. + if (has_token_input_output) { + *(body_output_shape.add_tuple_shapes()) = + body.xla_output_shape.tuple_shapes(ctx->num_inputs()); + } + if (!xla::ShapeUtil::Compatible(body_input_shape, body_output_shape)) { + return errors::InvalidArgument( + "Input and output shapes of loop body do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(body_output_shape)); + } + return Status::OK(); +} + +xla::StatusOr BuildWrappedCond( + XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& cond) { + xla::Shape cond_input_shape = cond.xla_input_shapes[0]; + std::unique_ptr cb = + ctx->builder()->CreateSubBuilder("cond_wrapper"); + auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs"); + auto outputs = xla::Call(cb.get(), *cond.computation, {inputs}); + xla::GetTupleElement(outputs, 0); + return cb->Build(); +} + +xla::StatusOr BuildWrappedBody( + XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& body, + const std::vector& compile_time_const_arg_indices, + int num_compile_time_const_args, bool has_token_input_output) { + if (num_compile_time_const_args <= 0) { + return xla::XlaComputation(body.computation->proto()); + } + xla::XlaComputation body_wrapper; + std::unique_ptr cb = + ctx->builder()->CreateSubBuilder("body_wrapper"); + xla::Shape body_input_shape = body.xla_input_shapes[0]; + auto inputs = xla::Parameter(cb.get(), 0, body_input_shape, "inputs"); + // Call the original body function which has mismatched inputs and outputs + // and strip the compile time consts from the list of outputs. While requires + // the inputs and outputs of its body function to match. + auto outputs = xla::Call(cb.get(), *body.computation, {inputs}); + std::vector non_compile_time_const_outputs; + for (int i = 0; i < compile_time_const_arg_indices.size(); i++) { + if (!compile_time_const_arg_indices[i]) { + non_compile_time_const_outputs.push_back( + xla::GetTupleElement(outputs, i)); + } + } + // If `body` has a token output, append it to + // `non_compile_time_const_outputs`. + if (has_token_input_output) { + non_compile_time_const_outputs.push_back( + xla::GetTupleElement(outputs, ctx->num_outputs())); + } + xla::Tuple(cb.get(), non_compile_time_const_outputs); + return cb->Build(); +} + +xla::XlaOp BuildWhile(XlaOpKernelContext* ctx, + const xla::XlaComputation& wrapped_cond, + const xla::XlaComputation& wrapped_body, + const xla::XlaOp& initial_values, + const std::vector& input_mapping, + const std::vector& compile_time_const_arg_indices, + int num_compile_time_const_args, + bool has_token_input_output) { + xla::XlaOp while_result = + xla::While(wrapped_cond, wrapped_body, initial_values); + std::vector padded_while_outputs(ctx->num_outputs()); + int while_result_index = 0; + for (int i = 0; i < ctx->num_inputs(); i++) { + if (!compile_time_const_arg_indices[i]) { + padded_while_outputs[input_mapping[while_result_index]] = + xla::GetTupleElement(while_result, while_result_index); + while_result_index++; + } else { + padded_while_outputs[i] = ctx->Input(i); + } + } + // If `body` has a token output, append it to `padded_while_outputs`. + if (has_token_input_output) { + padded_while_outputs.push_back(xla::GetTupleElement( + while_result, ctx->num_inputs() - num_compile_time_const_args)); + } + return xla::Tuple(ctx->builder(), padded_while_outputs); +} + } // anonymous namespace XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -95,6 +262,10 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } else { has_token_input_output_ = !token_input_nodes_.empty(); } + if (ctx->HasAttr(kPropagateCompileTimeConsts)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts, + &propagate_compile_time_consts_)); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -103,13 +274,33 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector arguments; bool has_uninitialized_vars; bool has_tensor_arrays; - OP_REQUIRES_OK( - ctx, MakeXlaCompilerArgumentsFromInputs( - ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); + bool has_uninitialized_tensor_lists; + OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs( + ctx, &arguments, &has_uninitialized_vars, + &has_tensor_arrays, &has_uninitialized_tensor_lists)); xla::XlaBuilder* builder = ctx->builder(); XlaCompiler* compiler = ctx->compiler(); + // Indices of loop vars which satisfy the following conditions: + // 1. They are loop invariants. + // 2. The op inputs at these indices are compile time constants. + // + // These compile time consts do not appear as _Args in the cond/body functions + // and are replaced by kConstant nodes instead. As as result, the compiled + // body function does not have matching input and output shape. We fix this + // by rewriting the body computation (see body_wrapper below) to output + // just the non compile-time-const values and later pad up the while output + // with the const args. + std::vector compile_time_const_arg_indices(ctx->num_inputs()); + int num_compile_time_const_args = 0; + if (propagate_compile_time_consts_) { + OP_REQUIRES_OK(ctx, ConvertLoopInvariantsToConst( + ctx, body_name_attr_, &arguments, + &compile_time_const_arg_indices, + &num_compile_time_const_args, compiler->client())); + } + VLOG(1) << "Compiling body"; // All resource that are inputs to the loop's body must also be @@ -145,10 +336,13 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Hence we can use the output shapes and TensorArray gradients of each // resource as the "true" shapes. // 2) again with the "correct" resource information determined by (1). - if (has_uninitialized_vars || has_tensor_arrays) { + if (has_uninitialized_vars || has_tensor_arrays || + has_uninitialized_tensor_lists) { VLOG(2) << "Recompiling loop body: has_uninitialized_vars: " << has_uninitialized_vars - << " has_tensor_arrays: " << has_tensor_arrays; + << " has_tensor_arrays: " << has_tensor_arrays + << " has_uninitialized_tensor_lists: " + << has_uninitialized_tensor_lists; // Initializes any uninitialized resource with zero values of the // shape determined by the first compilation. for (int i = 0; i < body.resource_updates.size(); ++i) { @@ -185,6 +379,23 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { arg.tensor_array_gradients.insert(gradient.first); } } + + // Set the shape of any uninitialized TensorLists to the shape determined by + // the first compilation. Note that, unlike resources, we do not initialize + // the input list with zeros here, that is done later. + xla::Shape body_output_shape = body.xla_output_shape; + OP_REQUIRES(ctx, body_output_shape.IsTuple(), + errors::FailedPrecondition( + "xla_output_shape of while body must be a tuple.")); + for (int i = 0; i < arguments.size(); i++) { + XlaCompiler::Argument& arg = arguments[i]; + if (arg.initialized || !IsTensorListInput(ctx, i)) { + continue; + } + arg.shape = body_output_shape.tuple_shapes(i); + arg.initialized = true; + } + // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; body = {}; @@ -225,12 +436,13 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { "Input shapes of loop body and condition do not match: ", xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(cond_input_shape))); - OP_REQUIRES( - ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape), - errors::InvalidArgument( - "Input and output shapes of loop body do not match: ", - xla::ShapeUtil::HumanString(body_input_shape), " vs. ", - xla::ShapeUtil::HumanString(body.xla_output_shape))); + + // Check that the shape of the body outputs excluding the compile time const + // args (which are pruned from the body outputs in body_wapper) matches the + // shape of the inputs. + OP_REQUIRES_OK(ctx, VerifyBodyInputAndOutputShapeMatch( + ctx, compile_time_const_arg_indices, body, + has_token_input_output_)); xla::Shape expected_cond_output_shape_without_side_effect = xla::ShapeUtil::MakeTupleShape( @@ -267,8 +479,28 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); + } else if (IsTensorListInput(ctx, input_num)) { + // If the list received as input is uninitialized but its shape was + // inferred in the first compilation pass we create a new list filled + // with zeros and used that as the input to the while op. + TensorShape input_list_shape; + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(input_num), + &input_list_shape)); + TensorShape body_arg_shape; + OP_REQUIRES_OK(ctx, + GetTensorListBufferShape(body_input_shape.tuple_shapes(i), + &body_arg_shape)); + // Shape of the input list may differ from the shape of the body/cond + // input if the list's shape was inferred after the first compilation and + // the body/cond was recompiled with the updated shape of the list. + if (input_list_shape != body_arg_shape) { + OP_REQUIRES_OK(ctx, InitializeTensorList(ctx->Input(input_num), + body_arg_shape, &inputs[i])); + } else { + inputs[i] = ctx->Input(input_num); + } } else { - inputs[i] = ctx->Input(i); + inputs[i] = ctx->Input(input_num); } } @@ -277,26 +509,28 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building while loop"; // Wraps the condition in a computation that unpacks the output tuple. - xla::XlaComputation cond_wrapper; - { - std::unique_ptr cb = - builder->CreateSubBuilder("cond_wrapper"); - auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs"); - auto outputs = xla::Call(cb.get(), *cond.computation, {inputs}); - xla::GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); - OP_REQUIRES_OK(ctx, result.status()); - cond_wrapper = std::move(result.ValueOrDie()); - } - - xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); + xla::StatusOr cond_result = BuildWrappedCond(ctx, cond); + OP_REQUIRES_OK(ctx, cond_result.status()); + xla::XlaComputation wrapped_cond = std::move(cond_result.ValueOrDie()); + + // Remove compile time const args from the list of body outputs. + xla::StatusOr body_result = + BuildWrappedBody(ctx, body, compile_time_const_arg_indices, + num_compile_time_const_args, has_token_input_output_); + OP_REQUIRES_OK(ctx, body_result.status()); + xla::XlaComputation wrapped_body = std::move(body_result.ValueOrDie()); + + // Builds the While op and pads its output with the compile time const args. + xla::XlaOp while_result = + BuildWhile(ctx, wrapped_cond, wrapped_body, init, body.input_mapping, + compile_time_const_arg_indices, num_compile_time_const_args, + has_token_input_output_); // Sets non-variable outputs and determine when resource variables start. int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - ctx->SetOutput(body.input_mapping[i], - xla::GetTupleElement(while_result, i)); + ctx->SetOutput(i, xla::GetTupleElement(while_result, i)); ++resource_index; } else { break; @@ -341,8 +575,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building while loop"; } -REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); -REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); -REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); +REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index aeeff40e68f8b778628b9e85bd9b4ddcb73883a5..16ec8d0e520b5a282318f8e5225bcec65818e3e8 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -21,6 +21,8 @@ limitations under the License. namespace tensorflow { +extern const char kPropagateCompileTimeConsts[]; + // This TensorFlow op provides a functional iteration primitive. // // The inputs and outputs of the loop body must agree on the number, types, and @@ -58,6 +60,10 @@ class XlaWhileOp : public XlaOpKernel { NameAttrList body_name_attr_; bool has_token_input_output_; std::vector token_input_nodes_; + // Whether to propagate compile time consts into the loop body. + // This is not supported by default now since it may cause HBM memory + // overheads. + bool propagate_compile_time_consts_ = false; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index b20adc592a0d3d2129c897218ddbfc891b4cd40a..0b5b66ae52f9b9e8251813a1af2fc35eff4a42ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -35,10 +35,9 @@ class XlaConvOp : public XlaOpKernel { string precision_config_attr; OP_REQUIRES_OK( context, context->GetAttr("precision_config", &precision_config_attr)); - OP_REQUIRES( - context, - precision_config_.ParsePartialFromString(precision_config_attr), - errors::InvalidArgument("Error parsing convolution dimension numbers")); + OP_REQUIRES(context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing precison config.")); } void Compile(XlaOpKernelContext* context) override { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..233ac8e7b455403f8ee65b95b1403ecefdb92c6b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/core/lib/core/bits.h" + +namespace tensorflow { +namespace { + +class XlaSelfAdjointEigOp : public XlaOpKernel { + public: + explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = + xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } + + private: + bool lower_; + int32 max_iter_; + float epsilon_; +}; + +class SelfAdjointEigV2Op : public XlaOpKernel { + public: + explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(n); + auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } +}; + +REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes), + XlaSelfAdjointEigOp); +REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes), + SelfAdjointEigV2Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..de627b8d1b32b1af0ef834daab9550a80f1c3fa0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" + +namespace tensorflow { +namespace { + +class XlaSvdOp : public XlaOpKernel { + public: + explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + string precision_config_attr; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES(ctx, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing precison config.")); + if (precision_config_.operand_precision_size() == 0) { + precision_config_.mutable_operand_precision()->Add( + xla::PrecisionConfig::HIGHEST); + } + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = xla::SVD(ctx->Input(0), max_iter_, epsilon_, + precision_config_.operand_precision(0)); + ctx->SetOutput(0, result.d); + ctx->SetOutput(1, result.u); + ctx->SetOutput(2, result.v); + } + + private: + int32 max_iter_; + float epsilon_; + xla::PrecisionConfig precision_config_; +}; + +class SvdOp : public XlaOpKernel { + public: + explicit SvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("compute_uv", &compute_uv_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); + } + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int m = input_shape.dim_size(input_shape.dims() - 2); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(std::max(m, n)); + auto result = xla::SVD(ctx->Input(0), max_iter, 1e-6); + ctx->SetOutput(0, result.d); + if (compute_uv_) { + int p = std::min(m, n); + if (!full_matrices_) { + if (p < m) { + result.u = xla::SliceInMinorDims(result.u, {0, 0}, {m, p}); + } + if (p < n) { + result.v = xla::SliceInMinorDims(result.v, {0, 0}, {n, p}); + } + } + ctx->SetOutput(1, result.u); + ctx->SetOutput(2, result.v); + } + } + + private: + bool compute_uv_; + bool full_matrices_; +}; + +REGISTER_XLA_OP(Name("XlaSvd").TypeConstraint("T", kFloatTypes), XlaSvdOp); +REGISTER_XLA_OP(Name("Svd").TypeConstraint("T", kFloatTypes), SvdOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b..f9ce50be6e343144ab0691872381fa6c5aa45f6c 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -82,3 +82,15 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "data_format", + srcs = ["data_format.cc"], + hdrs = ["data_format.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc new file mode 100644 index 0000000000000000000000000000000000000000..0253bcdc5f922a970e24782aba67fbd1a907e091 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/data_format.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +xla::StatusOr Contract(xla::XlaOp input, int64 dim) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + + if (input_shape.dimensions().back() != 4) { + return errors::InvalidArgument("Expected last dimension to be 4; got ", + input_shape.dimensions().back()); + } + + // Transpose the input so C is directly followed by VECT_C. + std::vector permutation; + for (int64 i = 0; i != input_shape.rank() - 1; ++i) { + permutation.push_back(i); + if (i == dim) { + permutation.push_back(input_shape.rank() - 1); + } + } + + // Now merge the adjacent dimensions with a reshape. + std::vector contracted_shape(input_shape.dimensions().begin(), + input_shape.dimensions().end() - 1); + contracted_shape[dim] *= 4; + + return xla::Reshape(xla::Transpose(input, permutation), contracted_shape); +} + +xla::StatusOr Expand(xla::XlaOp input, int64 dim) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + + if (input_shape.dimensions(dim) % 4 != 0) { + return errors::InvalidArgument( + "Expected vectorized dimension to be evenly divisible by 4; got ", + input_shape.dimensions(dim)); + } + + // Split the `dim` into two dimensions with a reshape. The size of the new + // dimension is always 4. + std::vector expanded_shape(input_shape.dimensions()); + expanded_shape[dim] /= 4; + expanded_shape.insert(expanded_shape.begin() + dim, 4); + + // Move the newly created dimension to the end with a transpose. + std::vector permutation; + for (int64 i = 0; i != expanded_shape.size(); ++i) { + permutation.push_back(i); + if (i == dim) { + ++i; + } + } + permutation.push_back(dim + 1); + + return xla::Transpose(xla::Reshape(input, expanded_shape), permutation); +} + +} // namespace + +xla::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input) { + return Contract(input, 1); +} + +xla::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input) { + return Expand(input, 1); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/data_format.h b/tensorflow/compiler/tf2xla/lib/data_format.h new file mode 100644 index 0000000000000000000000000000000000000000..839723b0ea87ffff67b3d79e362759a42b8f88fe --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/data_format.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// Reformat from NCHW_VECT_C to NCHW. +// +// Prerequisites: the last dimension of the input must be of size 4. +xla::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input); + +// Reformat from NCHW to NCHW_VECT_C. +// +// Prerequisites: the vectorized dimension `C` must be a multiple of 4. +xla::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 7140b6a1227a53290c3747892a55886a7f48513b..cb6e0fbef4b5e0a43514a7e42135d5a1c5cc9f17 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -17,6 +17,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index af641131ed76a8d6a7291c360302fa17c94af014..b82c04bd5a8795314465f67ceb3485648eeaf453 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -56,6 +58,75 @@ lhs_output: the broadcasted LHS tensor rhs_output: the broadcasted RHS tensor )doc"); +REGISTER_OP("XlaSelfAdjointEig") + .Input("a: T") + .Attr("lower: bool") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Output("w: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +i=0...N-1. + +a: the input tensor. + +lower: a boolean specifies whether the calculation is done with the lower + triangular part or the upper triangular part. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly logN sweeps are needed in practice (Ref: Golub & + van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +w: The eigenvalues in ascending order, each repeated according to its + multiplicity. +v: The column v[..., :, i] is the normalized eigenvector corresponding to the + eigenvalue w[..., i]. +)doc"); + +REGISTER_OP("XlaSvd") + .Input("a: T") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Attr("precision_config: string") + .Output("s: T") + .Output("u: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in +tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]). + +a: the input tensor. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly log(min (M, N)) sweeps are needed in practice + (Ref: Golub & van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +precision_config: a serialized xla::PrecisionConfig proto. + +s: Singular values. The values are sorted in reverse order of magnitude, so + s[..., 0] is the largest value, s[..., 1] is the second largest, etc. +u: Left singular vectors. +v: Right singular vectors. +)doc"); + REGISTER_OP("XlaConv") .Input("lhs: T") .Input("rhs: T") @@ -437,5 +508,86 @@ transpose_output: Boolean to determine if output is transposed. transpose_output is faster when input is large and rank of input is higher than 1. )doc"); +REGISTER_OP("XlaEinsum") + .Input("a: T") + .Input("b: T") + .Output("product: T") + .Attr("equation: string") + .Attr("T: {bfloat16, float}") + .SetShapeFn([](shape_inference::InferenceContext* context) { + shape_inference::ShapeHandle input_a = context->input(0); + shape_inference::ShapeHandle input_b = context->input(1); + + int64 rank_a, rank_b; + if (context->RankKnown(input_a)) { + rank_a = context->Rank(input_a); + } else { + return errors::InvalidArgument("input 0's rank is unknown."); + } + if (context->RankKnown(input_b)) { + rank_b = context->Rank(input_b); + } else { + return errors::InvalidArgument("input 1's rank is unknown."); + } + string equation; + TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation)); + + std::map left_map; + std::map right_map; + std::vector dims; + + std::vector equation_split = absl::StrSplit(equation, "->"); + + if (equation_split.size() != 2) { + return errors::InvalidArgument("Expected one \"->\" in equation. Got: ", + equation); + } + + std::vector lhs_rhs_split = + absl::StrSplit(equation_split[0], ','); + if (lhs_rhs_split.size() != 2) { + return errors::InvalidArgument("Expected one \",\" in equation. Got: ", + equation); + } + + if (rank_a != lhs_rhs_split[0].size()) { + return errors::InvalidArgument(absl::StrCat( + "Expected equation[0] with size: ", rank_a, " Got '", + lhs_rhs_split[0], "'", " with size: ", lhs_rhs_split[0].size())); + } + + if (rank_b != lhs_rhs_split[1].size()) { + return errors::InvalidArgument(absl::StrCat( + "Expected equation[1] with size: ", rank_b, " Got '", + lhs_rhs_split[1], "'", " with size: ", lhs_rhs_split[1].size())); + } + + for (const char& c : lhs_rhs_split[0]) { + left_map[c] = context->Dim(input_a, left_map.size()); + } + for (const char& c : lhs_rhs_split[1]) { + right_map[c] = context->Dim(input_b, right_map.size()); + } + + for (const char& c : equation_split[1]) { + if (left_map.count(c)) { + dims.push_back(left_map[c]); + } else if (right_map.count(c)) { + dims.push_back(right_map[c]); + } else { + return errors::InvalidArgument("Invalid equation: ", equation); + } + } + + context->set_output(0, context->MakeShape(dims)); + return Status::OK(); + }) + .Doc(R"doc( +An op which supports basic einsum op with 2 inputs and 1 output. + +This op has better TPU performnce since it doesn't have explicitly reshape and +transpose operations as tf.einsum does. +)doc"); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 345193c936a885e5a9e468979c4b73b5b0c9e5c2..8732ee04d7577f83ff427548e271a90fa827c5f4 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,8 +291,40 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) +def self_adjoint_eig(a, lower, max_iter, epsilon): + return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) + + +def svd(a, max_iter, epsilon, precision_config=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) + + dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +einsum = gen_xla_ops.xla_einsum + + +@ops.RegisterGradient('XlaEinsum') +def _einsum_grad(op, grad): + equation = op.get_attr('equation') + inputs, output = equation.split('->') + left, right = inputs.split(',') + + return [ + gen_xla_ops.xla_einsum( + grad, + op.inputs[1], + equation='{},{}->{}'.format(output, right, left), + name=None), + gen_xla_ops.xla_einsum( + grad, + op.inputs[0], + equation='{},{}->{}'.format(output, left, right), + name=None) + ] # TODO(phawkins): generalize tf.pad to support interior padding, and then remove # the XLA-specific pad operator. diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index c20d6a5fd1f3bd7dad30cb3359d13ed4609a2250..29ebf46e4bf72c650d3768acd997de1fe3c1322c 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -82,6 +82,9 @@ CreateResourceOpInfoMap() { add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("StatefulStandardNormalV2" , kReadWrite, kVariable); + add("StatefulUniformFullInt" , kReadWrite, kVariable); + add("StatefulUniformInt" , kReadWrite, kVariable); add("VarIsInitializedOp" , kRead, kVariable); add("VariableShape" , kRead, kVariable); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9fac16a9700419b189bf5393c2b8bd7d76c6c1cc..fb44ae0391a3e02b0584436d6e95dcac3777b320 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -45,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -164,12 +164,10 @@ Status RewriteAndPruneGraph( std::unordered_set retval_nodes; TF_RETURN_IF_ERROR( AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); - VLOG(2) << "Post rewrite: " - << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph); + VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); PruneForReverseReachability(graph, retval_nodes); FixupSourceAndSinkEdges(graph); - VLOG(2) << "Post prune: " - << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph); + VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. std::set missing_feeds, missing_fetches; for (const tf2xla::Feed& feed : config.feed()) { @@ -243,7 +241,9 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TensorShape shape; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); + arg.shape = shape; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } @@ -252,7 +252,8 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, +Status ConvertGraphToXla(std::unique_ptr graph, + const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { @@ -262,6 +263,29 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); + std::vector xla_aliases; + // Populate arguments with resource variables from the config. The variables + // get turned into inputs and outputs. + int64 input_num = xla_args.size(); + int64 output_num = config.fetch_size(); + for (const tf2xla::Variable& variable : config.variable()) { + XlaCompiler::Argument arg; + arg.type = variable.type(); + arg.kind = XlaCompiler::Argument::kResource; + arg.shape = variable.shape(); + arg.name = variable.node_name(); + arg.resource_kind = XlaResource::kVariable; + arg.initialized = true; + xla_args.push_back(std::move(arg)); + + // We want to alias the input and output of the variable, so the updates are + // carried out in-place. + xla_aliases.push_back({/*output_index=*/{output_num}, + /*param_number=*/input_num, /*param_index=*/{}}); + ++input_num; + ++output_num; + } + // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; @@ -274,7 +298,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), "tfcompile", std::move(graph), - xla_args, &result)); + xla_args, xla_aliases, &result)); *computation = std::move(*result.computation); int num_const_results = 0; @@ -359,7 +383,8 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def, xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); + TF_RETURN_IF_ERROR( + ConvertGraphToXla(std::move(graph), config, client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.proto b/tensorflow/compiler/tf2xla/tf2xla.proto index 18c9089f5fa0e9792a4763d9bfac4c4e826eb5b2..5627af7452b99da594c1c214d0b556d8d70544d5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.proto +++ b/tensorflow/compiler/tf2xla/tf2xla.proto @@ -39,6 +39,15 @@ message Fetch { string name = 2; // Optional name for generated code. }; +// Variable represents a resource variable with the given name, shape and type. +message Variable { + string node_name = 1; + string name = + 2; // Optional name for generated code. If empty, node_name will be used. + TensorShapeProto shape = 3; + DataType type = 4; +} + // Config represents configuration information for tf2xla conversion. message Config { // Each feed is a positional input argument for the generated computation. @@ -47,4 +56,6 @@ message Config { // Each fetch is a positional output argument for the generated computation. // The order of each entry matches the order of each output argument. repeated Fetch fetch = 2; + // Each variable is a named input and output of the generated computation. + repeated Variable variable = 3; }; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 18d87727c500619bf386be7d8c7085724f44aba3..e5e4bf8bcf43e80d0579672bb172643654e38343 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include #include #include #include @@ -113,7 +114,7 @@ Status ReplaceArgUsageWithConstNode( // Collect all _Arg nodes. std::unordered_map arg_nodes; for (Node* n : g->op_nodes()) { - if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + if (n->IsArg()) { int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); arg_nodes[index] = n; @@ -122,7 +123,12 @@ Status ReplaceArgUsageWithConstNode( for (const auto& iter : const_input_index_to_node) { int arg_index = iter.first; - Node* const_node = g->CopyNode(iter.second); + NodeDef const_def = iter.second->def(); + const_def.set_name(g->NewName(const_def.name())); + Status s; + Node* const_node = g->AddNode(const_def, &s); + TF_RETURN_IF_ERROR(s); + Node* arg_node = arg_nodes[arg_index]; // Collect all usages of the _Arg node. @@ -265,6 +271,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node, } // Check if i-th retval's input comes from i-th arg directly. + // For resource variable input of While nodes, TF2XLA convention is to place + // them at the end of all inputs (after all data inputs), and *not* return + // them. So number of While node inputs might be larger than number of its + // outputs. + if (i >= body_func->signature().output_arg_size()) { + continue; + } const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i); auto output_arg_input = body_func->ret().find(output_arg.name()); if (output_arg_input == body_func->ret().end()) { @@ -543,7 +556,9 @@ uint32 GetXLARandomSeed() { // after an overflow. When seeded with zero, some XLA backends // can return all zeros instead of random numbers. static std::atomic counter(InitialRandomSeed()); - return counter.fetch_add(2); + uint32 seed = counter.fetch_add(2); + std::srand(seed); + return std::rand() | 1; } // TODO(b/77601805): add tests for associated function related stuff. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 202e929315cacd4d6cdfc69d50639d8a427ec6c2..28b4744470e7d28863b5f7275f829b9bd59641e1 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -329,5 +331,90 @@ TEST(CachedFunctionHandles, Basic) { TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); } +TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); +} + +TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("arg1"), DT_BOOL, 1); + auto duplicate_name = ops::NoOp(scope.WithOpName("duplicate_name")); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = + ops::Const(scope.WithOpName("duplicate_name"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), false, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); + + // Check that in rewritten body function, the NoOp node still has name + // "duplicate_name", and the copied Const node has name "duplicate_name/_0". + auto node_name_index = graph.BuildNodeNameIndex(); + Node* while_node = node_name_index["while"]; + ASSERT_NE(while_node, nullptr); + TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn)); + const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name()); + ASSERT_NE(rewritten_body_fn, nullptr); + std::unordered_map nodes; + for (const NodeDef& node_def : rewritten_body_fn->node_def()) { + nodes[node_def.name()] = node_def; + } + auto noop_def = nodes.find("duplicate_name"); + ASSERT_NE(noop_def, nodes.end()); + EXPECT_EQ(noop_def->second.op(), "NoOp"); + auto const_def = nodes.find("duplicate_name/_0"); + ASSERT_NE(const_def, nodes.end()); + EXPECT_EQ(const_def->second.op(), "Const"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index ddb284966eeb97cc7c9d3ed77fb313e567975e59..f98d07d196ea8551f1a5b53fa2e88e7bc43639de 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -42,7 +42,7 @@ class XlaCompilationAllocator : public Allocator { void* AllocateRaw(size_t alignment, size_t num_bytes) override { // Regardless of the size requested, always allocates an XlaExpression. - // Respects the aligment request because there is alignment checking even + // Respects the alignment request because there is alignment checking even // for Tensors whose data is never accessed. void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment); XlaExpression* expression = reinterpret_cast(p); @@ -60,8 +60,6 @@ class XlaCompilationAllocator : public Allocator { // buffers, so they get ids to track. bool ShouldAllocateEmptyTensors() override { return true; } - void GetStats(AllocatorStats* stats) override { stats->Clear(); } - private: // Don't run any constructors or destructors for complex objects, // since there is no backing store for the tensor to run them diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 15fd2656862e43532e33066414c5eac73593a5c0..86a25177d271c49d185de4f37a892297ba8b22c9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" @@ -31,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -42,9 +42,12 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -57,7 +60,11 @@ Status CheckSignature(const DataTypeVector& types, " elements while function has ", types.size()); } for (int i = 0; i < types.size(); ++i) { - if (types[i] != args[i].type && types[i] != DT_RESOURCE) { + // Don't perform type checks on resource variables and tensor + // lists (DT_VARIANT) as we have to trick the type system in order to + // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor. + if (types[i] != args[i].type && types[i] != DT_RESOURCE && + types[i] != DT_VARIANT) { return errors::Internal( "Argument ", i, " has declared type ", DataTypeString(args[i].type), " but function parameter has type ", DataTypeString(types[i])); @@ -85,14 +92,14 @@ ComputeArgAndRetvalCores(const Graph& graph) { std::map arg_cores; std::map retval_cores; for (const Node* n : graph.nodes()) { - if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + if (n->IsArg()) { TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); if (core < 0) continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0) << "Negative _Arg index"; arg_cores[index] = core; - } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + } else if (n->IsRetval()) { TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); if (core < 0) continue; int index; @@ -178,9 +185,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of which retvals have layout to update. The first element is - // the output index, second element is the new layout. - std::vector> retval_to_update_layout; + // Keeps track of the layout of each retval. If a retval is not in this list, + // a descending layout is used. The first element is the output index, second + // element is the new layout. + std::vector> retval_index_and_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -192,6 +200,8 @@ Status BuildComputation( output.shape = output.constant_value.shape(); break; + case XlaExpression::Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case XlaExpression::Kind::kXlaOp: { output.is_constant = false; TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); @@ -207,7 +217,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + retval_index_and_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); @@ -280,6 +290,11 @@ Status BuildComputation( // Ensures the correct sharding is applied to the output. handle = identity_op(handle); + // Set layout of the retval to device representation layout. + if (resource->representation_shape().has_value()) { + retval_index_and_layout.emplace_back( + elems.size(), resource->representation_shape()->layout()); + } elems.push_back(handle); } } @@ -309,15 +324,15 @@ Status BuildComputation( computation->GetProgramShape()); *output_shape = program_shape.result(); // Update the output layout to the layout of retval. - for (auto& update : retval_to_update_layout) { + for (auto& index_and_layout : retval_index_and_layout) { if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = update.second; + *output_shape->mutable_layout() = index_and_layout.second; continue; } - xla::Shape* output_sub_shape = - xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); - *output_sub_shape->mutable_layout() = update.second; + xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( + output_shape, {index_and_layout.first}); + *output_sub_shape->mutable_layout() = index_and_layout.second; } return Status::OK(); } @@ -333,8 +348,21 @@ bool XlaCompiler::Argument::operator==( other.tensor_array_gradients)) { return false; } - if (shape != other.shape) { - return false; + if (absl::holds_alternative(shape)) { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (!xla::Shape::Equal()(absl::get(shape), + absl::get(other.shape))) { + return false; + } + } else { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (absl::get(shape) != absl::get(other.shape)) { + return false; + } } if (constant_value.shape() != other.constant_value.shape()) { return false; @@ -348,7 +376,7 @@ string XlaCompiler::Argument::HumanString() const { common = absl::StrCat(" name=", name); } absl::StrAppend(&common, " type=", DataTypeString(type), - " shape=", shape.DebugString()); + " shape=", ShapeHumanString()); switch (kind) { case kInvalid: return "invalid"; @@ -375,6 +403,23 @@ string XlaCompiler::Argument::HumanString() const { } } +std::vector XlaCompiler::Argument::DimensionSizes() const { + if (absl::holds_alternative(shape)) { + return xla::InlinedVectorToVector( + absl::get(shape).dim_sizes()); + } else { + return absl::get(shape).dimensions(); + } +} + +string XlaCompiler::Argument::ShapeHumanString() const { + if (absl::holds_alternative(shape)) { + return absl::get(shape).DebugString(); + } else { + return absl::get(shape).DebugString(); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -536,29 +581,27 @@ Status XlaCompiler::CompileFunction( // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == - FunctionLibraryDefinition::kArgOp) { + if (n->IsArg()) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == - FunctionLibraryDefinition::kRetOp) { + if (n->IsRetval()) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " - << dump_graph::DumpGraphToFile( + << DumpGraphToFile( absl::StrCat("xla_compile_function_", function_id), *graph); } VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( - CompileGraph(options, function_id, std::move(graph), args, result)); + CompileGraph(options, function_id, std::move(graph), args, {}, result)); VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result; @@ -574,11 +617,22 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { if (is_entry_computation) { - TF_ASSIGN_OR_RETURN( - *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); + TensorShape shape; + if (absl::holds_alternative(arg.shape)) { + shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(absl::get(arg.shape), &shape)); + } + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn(shape, arg.type)); } else { - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); + if (absl::holds_alternative(arg.shape)) { + *xla_shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + arg.type, absl::get(arg.shape), xla_shape)); + } } return Status::OK(); } @@ -587,8 +641,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( - arg.shape, arg.type)); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn( + absl::get(arg.shape), arg.type)); return Status::OK(); } @@ -597,9 +653,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); if (!arg.tensor_array_gradients.empty()) { @@ -614,9 +671,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); @@ -646,14 +704,15 @@ Status XlaCompiler::BuildArguments( bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, std::vector* input_shapes, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); // Argument numbers of arguments and resources that are to be passed to the - // XLA computation as runtime parameters. - input_mapping->clear(); - input_mapping->reserve(args.size()); + // XLA computation as runtime parameters. `input_to_args[a] = b` means that + // the a'th XLA input corresponds to the b'th original arg indexes. + input_to_args->clear(); + input_to_args->reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -663,24 +722,25 @@ Status XlaCompiler::BuildArguments( switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource = context->AddResource(absl::make_unique( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::XlaOp(), + arg.resource_kind, i, arg.name, arg.type, + absl::get(arg.shape), xla::XlaOp(), /*max_array_size=*/arg.max_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, /*tensor_array_multiple_writes_aggregate=*/true)); arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { - input_mapping->push_back(i); + input_to_args->push_back(i); } break; } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { - input_mapping->push_back(i); + input_to_args->push_back(i); break; } case XlaCompiler::Argument::kConstant: @@ -692,15 +752,23 @@ Status XlaCompiler::BuildArguments( } } - if (input_mapping->empty()) { + if (input_to_args->empty()) { return Status::OK(); } - std::vector arg_shapes(input_mapping->size()); - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds + // to the d'th XLA input. Note that the value -1 corresponds to constants, or + // other args that don't correspond to an input. + std::vector arg_to_inputs(args.size(), -1); + for (int i = 0; i < input_to_args->size(); i++) { + arg_to_inputs[input_to_args->at(i)] = i; + } + + std::vector arg_shapes(input_to_args->size()); + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { // Computes the shapes of non-constant arguments. TF_RETURN_IF_ERROR(XLAShapeForArgument( - args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); + args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -717,13 +785,13 @@ Status XlaCompiler::BuildArguments( builder->SetOpMetadata(arg_metadata); // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_to_args->size()); if (use_tuple_arg) { xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : *input_mapping) { + for (int64 parameter : *input_to_args) { auto it = arg_cores.find(parameter); const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = @@ -735,7 +803,19 @@ Status XlaCompiler::BuildArguments( } else { tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, + /*target_param_num=*/0, /*target_param_index=*/{i}, + dim_and_arg_num.first)); + } + } + + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -744,7 +824,7 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -753,6 +833,17 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], absl::StrCat("arg", i)); } + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/dynamic_size_param_index, {}, + /*target_param_num=*/i, /*target_param_index=*/{}, + dim_and_arg_num.first)); + } + } } builder->ClearOpMetadata(); @@ -760,12 +851,12 @@ Status XlaCompiler::BuildArguments( // Fill in the handles in non-constant arguments, and reshape parameters // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << input_mapping->at(i); - XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; + << " name: " << arg.name << " TF arg " << input_to_args->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -782,7 +873,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression = XlaExpression::XlaOp( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); + xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); } else { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } @@ -843,7 +934,8 @@ Status XlaCompiler::CompileSingleOp( } FixupSourceAndSinkEdges(graph.get()); - return CompileGraph(options, node_def.name(), std::move(graph), args, result); + return CompileGraph(options, node_def.name(), std::move(graph), args, {}, + result); } namespace { @@ -862,6 +954,28 @@ Status ValidateFunctionDef(const FunctionDef* fdef, return Status::OK(); } +// If node is PartitionedCall or StatefulPartitionedCall, returns the +// name from the "f" attr, else returns node.def().op(). +// Returned pointer points to the internal string either in node's attributes +// or in its NodeDef. This pointer is valid as long as the node has not been +// modified. +Status GetPotentialFunctionName(const Node& node, const string** name) { + if (node.IsPartitionedCall()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value)); + if (!attr_value->has_func()) { + return errors::InvalidArgument( + "The attribute value for attribute 'f' in node ", node.DebugString(), + " does not have 'func' field set"); + } + *name = &attr_value->func().name(); + return Status::OK(); + } + *name = &node.type_string(); + return Status::OK(); +} + // Check that the graph doesn't have any invalid nodes (e.g. incompatible with // given device_type, invalid data type, missing attributes...) Status ValidateGraph(const Graph* graph, @@ -881,7 +995,9 @@ Status ValidateGraph(const Graph* graph, if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } - const FunctionDef* fdef = flib_def.Find(node->def().op()); + const string* function_name; + TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name)); + const FunctionDef* fdef = flib_def.Find(*function_name); Status s; if (fdef) { s = ValidateFunctionDef(fdef, flib_def); @@ -926,20 +1042,19 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder, } // namespace -Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, - string const& name, - std::unique_ptr graph, - absl::Span args, - CompilationResult* result) { +Status XlaCompiler::CompileGraph( + const XlaCompiler::CompileOptions& options, string const& name, + std::unique_ptr graph, absl::Span args, + absl::Span user_aliases, + CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( graph.get(), options_.flib_def, local_flib_def_.get())); if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " - << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph, - flib_runtime_->GetFunctionLibraryDefinition()); + << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. @@ -978,6 +1093,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + // Propagate any aliases given to us by the user. + for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) { + builder.SetUpAlias(alias.output_index, alias.param_number, + alias.param_index); + } + PushNodeTokenMapping(); // Use std::set instead of std::unordered_set to ensure determinism. std::set output_node_token_inputs; @@ -1023,8 +1144,17 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->outputs.resize(context->retvals().size()); std::vector retvals = context->retvals(); if (options.resolve_compile_time_constants) { - TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( - client(), absl::Span(retvals))); + Status status = ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals)); + + // If the HloEvaluator has not implemented an expression, just evaluate it + // at runtime. + if (status.code() == error::UNIMPLEMENTED) { + ConvertConstantsToExpressions(&builder, + absl::Span(retvals)); + } else { + TF_RETURN_IF_ERROR(status); + } } else { ConvertConstantsToExpressions(&builder, absl::Span(retvals)); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 0d801b73a8c2651305328384377751254ecaa41d..0b0908e9d6913f2664e4d976611b1218be44ff2b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" @@ -124,7 +125,8 @@ class XlaCompiler { DataType type = DT_INVALID; // The shape of the argument. For: - // * a parameter: the shape of the parameter. + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. // * a constant: ignored; the shape given by constant_value is used // instead. // * an uninitialized resource: ignored. We don't yet know the shape of an @@ -133,7 +135,7 @@ class XlaCompiler { // * an initialized TensorArray or Stack resource: the shape of an entry in // the TensorArray/Stack. Note this is the size of a single entry, not the // XLA data structure that represents the complete stack/array. - TensorShape shape; + absl::variant shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -157,10 +159,20 @@ class XlaCompiler { // as `tensor_array_gradients`. std::set tensor_array_gradients; + // dynamic dims to arg number map. Empty if no dynamic shapes. + std::map dynamic_dim_to_arg_num_map; + bool is_pad_arg = false; + bool operator==(const Argument& other) const; // Returns a human-readable summary of the argument. string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector DimensionSizes() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -327,10 +339,11 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - Status CompileGraph(const CompileOptions& options, string const& name, - std::unique_ptr graph, - absl::Span args, - CompilationResult* result); + Status CompileGraph( + const CompileOptions& options, string const& name, + std::unique_ptr graph, absl::Span args, + absl::Span user_aliases, + CompilationResult* result); // Compiles a single Op, given by `node_def`, into an // xla::XlaComputation. Similar to CompileFunction but takes a single Op as @@ -404,11 +417,11 @@ class XlaCompiler { Status SetNodeToken(const string& node_name, const xla::XlaOp& op); xla::StatusOr GetNodeToken(const string& node_name); - private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, const FunctionBody** fbody); + private: // Returns the optimized graph object in this function body. std::unique_ptr GetGraph(const FunctionBody* fbody); @@ -420,7 +433,7 @@ class XlaCompiler { XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 492010f7317d32a8a620147cd2cd9356d4f13fde..1818d4290324aa398f8f90ff11725dc48948b621 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -175,9 +175,9 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) { std::unique_ptr graph(new Graph(OpRegistry::Global())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), - /*args=*/{}, &result)); + TF_ASSERT_OK(compiler.CompileGraph( + XlaCompiler::CompileOptions(), "add", std::move(graph), + /*args=*/{}, /*user_aliases=*/{}, &result)); TF_ASSERT_OK(client_->Execute(*result.computation, {}).status()); } @@ -207,7 +207,8 @@ TEST_F(XlaCompilerTest, Simple) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); // Tests that the generated computation works. xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); @@ -258,7 +259,7 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { compile_options.always_return_tuple = false; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, &result)); + args, /*user_aliases=*/{}, &result)); // Tests that the generated computation works. xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); @@ -277,6 +278,99 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } +// Tests that the compiler can correctly propagate the layout assigned by +// shape_representation_fn_ to return types. +TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + options.shape_representation_fn = + [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, + /*user_aliases=*/{}, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + +// The layout of resource variable shouldn't change after transpose +TEST_F(XlaCompilerTest, TransposeVariables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto transposed_read = ops::Transpose(scope, read, {1, 0}); + auto reshape = ops::Reshape(scope, transposed_read, {2, 3}); + auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose", + std::move(graph), args, + /*user_aliases=*/{}, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) { @@ -319,7 +413,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) { compile_options.always_return_tuple = false; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, &result)); + args, /*user_aliases=*/{}, &result)); EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1)); } @@ -349,9 +443,9 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { XlaCompiler compiler(DefaultOptions()); XlaCompiler::CompilationResult result; - Status status = - compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", - std::move(graph), args, &result); + Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(), + "reshape", std::move(graph), args, + /*user_aliases=*/{}, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.error_message(), "depends on a parameter")) @@ -395,7 +489,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { compile_options.resolve_compile_time_constants = true; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", - std::move(graph_copy), args, &result)); + std::move(graph_copy), args, + /*user_aliases=*/{}, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_TRUE(result.outputs[0].is_constant); @@ -428,7 +523,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { compile_options.resolve_compile_time_constants = false; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", - std::move(graph_copy), args, &result)); + std::move(graph_copy), args, + /*user_aliases=*/{}, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[0].is_constant); @@ -514,7 +610,8 @@ TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { compile_options.resolve_compile_time_constants = true; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_TRUE(result.outputs[0].is_constant); @@ -556,7 +653,8 @@ TEST_F(XlaCompilerTest, ResourceManager) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); EXPECT_EQ(1, resource->Get()); @@ -592,7 +690,8 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { XlaCompiler compiler(options); TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", - std::move(graph), args, &results[i])); + std::move(graph), args, + /*user_aliases=*/{}, &results[i])); } for (int64 i = 1; i < test_count; ++i) { @@ -658,7 +757,8 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); ASSERT_EQ(1, result.resource_updates.size()); const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; @@ -717,7 +817,8 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); EXPECT_EQ(0, result.resource_updates.size()); } @@ -749,7 +850,8 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); EXPECT_EQ(1, result.resource_updates.size()); } @@ -824,7 +926,8 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); } // Tests CompileFunction with a local function lookup failing, fails with @@ -907,7 +1010,8 @@ TEST_F(XlaCompilerTest, Variables) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); RunAndCheckVariablesComputation(client_, result); } @@ -942,7 +1046,7 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) { auto compile_options = XlaCompiler::CompileOptions(); compile_options.always_return_tuple = false; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph), - args, &result)); + args, /*user_aliases=*/{}, &result)); EXPECT_TRUE(xla::ShapeUtil::Equal( result.xla_output_shape, xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); @@ -978,7 +1082,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); xla::Shape result_shape = xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); @@ -1008,7 +1113,8 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); // Tests that the generated computation works. xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); @@ -1058,7 +1164,8 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + std::move(graph), args, + /*user_aliases=*/{}, &result)); RunAndCheckVariablesComputation(client_, result); } @@ -1109,7 +1216,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, &result)); + args, /*user_aliases=*/{}, &result)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, client_->GetComputationShape(*result.computation)); @@ -1179,7 +1286,7 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, &result)); + args, /*user_aliases=*/{}, &result)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, client_->GetComputationShape(*result.computation)); @@ -1260,7 +1367,8 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { std::vector args; XlaCompiler::CompilationResult result; status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", - std::move(graph), args, &result); + std::move(graph), args, /*user_aliases=*/{}, + &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); @@ -1285,7 +1393,8 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { XlaCompiler::CompilationResult result; XlaCompiler compiler(DefaultOptions()); status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", - std::move(graph), args, &result); + std::move(graph), args, /*user_aliases=*/{}, + &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains(status.error_message(), "is not in the list of allowed values")) @@ -1311,7 +1420,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result)); + std::move(graph_copy), args, + /*user_aliases=*/{}, &result)); } } @@ -1360,7 +1470,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), - args, &result)); + args, /*user_aliases=*/{}, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 1); EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); @@ -1378,7 +1488,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), - args, &result)); + args, /*user_aliases=*/{}, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 2); EXPECT_TRUE(result.xla_input_shapes[1].IsToken()); EXPECT_TRUE(result.xla_output_shape.IsTuple()); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 6139bf3cea0790c2697130a993e92be96c81848b..3f787fd86c9f7366a7728dcf146a3797ba672bc3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -76,7 +76,7 @@ XlaResource* XlaContext::AddResource(std::unique_ptr resource) { } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { - return LookupOrCreate(type, &max_func_, [this, type] { + return LookupOrCreate(type, &max_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; xla::XlaBuilder b("max<" + type_string + ">"); @@ -92,7 +92,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { - return LookupOrCreate(type, &min_func_, [this, type] { + return LookupOrCreate(type, &min_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; xla::XlaBuilder b("min<" + type_string + ">"); @@ -108,7 +108,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { - return LookupOrCreate(type, &add_func_, [this, type] { + return LookupOrCreate(type, &add_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; xla::XlaBuilder b("add<" + type_string + ">"); @@ -124,7 +124,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { - return LookupOrCreate(type, &mul_func_, [this, type] { + return LookupOrCreate(type, &mul_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; xla::XlaBuilder b("mul<" + type_string + ">"); diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index ca0309166b7c73d1a5a818091e2a30fa112a4de4..3d228c92adcbe3d093a4fe70d157e57ab3e80c80 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -46,6 +46,14 @@ XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { return e; } +XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) { + XlaExpression e; + e.kind_ = Kind::kTensorList; + e.dtype_ = DT_VARIANT; + e.handle_ = tensor_list; + return e; +} + XlaExpression XlaExpression::Resource(XlaResource* resource) { XlaExpression e; e.kind_ = Kind::kResource; @@ -64,6 +72,8 @@ string XlaExpression::HumanString() const { return "xla_op"; case Kind::kResource: return "resource"; + case Kind::kTensorList: + return "tensor_list"; } } @@ -76,6 +86,8 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { HostTensorToBorrowingLiteral(constant_value_, &literal)); return xla::ConstantLiteral(builder, literal); } + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kXlaOp: if (builder != handle_.builder()) { return errors::InvalidArgument( @@ -96,7 +108,10 @@ xla::StatusOr> XlaExpression::ResolveConstant( return {constant_value()}; case Kind::kXlaOp: break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kResource: + TF_FALLTHROUGH_INTENDED; case Kind::kInvalid: return errors::InvalidArgument( "ResolveConstant called on XlaExpression: ", HumanString()); @@ -134,6 +149,8 @@ xla::StatusOr XlaExpression::GetShape() const { TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); return shape; } + case Kind::kTensorList: + return TensorShape({}); case Kind::kResource: return TensorShape({}); case Kind::kInvalid: diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index bed6761d362a98d344003c1edea342e68c31ef07..ac0232d8924cf2c9e35ad3f0772a3a2adc18af87 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -32,11 +32,16 @@ namespace tensorflow { // * a constant tensor. // * an xla::XlaOp, representing a symbolic XLA value. // * a resource, e.g., a variable, represented as an XlaResource pointer. +// * a tensor list, represented by a tuple of tensors and the list length. // // Constant tensors are mostly an optimization to avoid passing large constants // to XLA, but are also sometimes used to represent tensors that have no XLA // representation, for example, DT_STRING tensors. A canonical use case might be // an error message string. +// +// Tensor lists are very similar to xla::XlaOp, however they require some +// specific logic around shape management since the tuples are not supported by +// TensorFlow. class XlaExpression { public: enum class Kind { @@ -44,6 +49,7 @@ class XlaExpression { kConstant, kXlaOp, kResource, + kTensorList, }; XlaExpression(); @@ -62,6 +68,9 @@ class XlaExpression { // be derived from the XLA type. static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + // Builds a tensor list expression. + static XlaExpression TensorList(xla::XlaOp tensor_list); + // Builds a resource expression. static XlaExpression Resource(XlaResource* resource); @@ -100,7 +109,8 @@ class XlaExpression { DataType dtype_ = DT_INVALID; - // The XLA handle of the expression's computation, if kind_ == kXlaOp. + // The XLA handle of the expression's computation, if kind_ == kXlaOp or + // a tuple expression if kind_ == kTensorList. xla::XlaOp handle_; // The value of the constant, if kind_ == kConstant. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 04a5d934064a9083a41cc210b48df65bbc862fff..7bb1ad27467a5b281626de4203169e575288f9ee 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -81,61 +81,27 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return Status::OK(); } -template -static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { - Tensor linspace(DataTypeToEnum::v(), shape); - auto linspace_flat = linspace.flat(); - for (int64 i = 0; i < depth; ++i) { - linspace_flat(i) = i; - } - return linspace; -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, const xla::XlaOp& off_value, xla::XlaOp* one_hot) { - const int indices_dims = indices_shape.dims(); - const int output_dims = indices_dims + 1; - - TensorShape output_shape = indices_shape; - output_shape.InsertDim(axis, depth); - - // Build a Tensor populated with values 0, 1, 2, ... depth. - std::vector linspace_dims(output_dims, 1); - linspace_dims[axis] = depth; - TensorShape linspace_shape(linspace_dims); - Tensor linspace; - switch (index_type) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(index_type)); - } - - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = xla::Eq( - indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); + + TensorShape output_shape = indices_shape; + output_shape.InsertDim(axis, depth); + xla::Shape iota_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); // Selects the user-provided off_value and on_value values. - *one_hot = xla::Select(one_hot_bool, - xla::Broadcast(on_value, output_shape.dim_sizes()), - xla::Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select( + xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 58bd173e61aa3263fae4b494914707833c7a624f..ee11f3a3de658c7e5108605122b84fbc3e1cd963 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -93,7 +93,7 @@ TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { } DataType XlaOpKernelContext::input_type(int index) const { - return context_->input(index).dtype(); + return context_->input_dtype(index); } DataType XlaOpKernelContext::InputType(absl::string_view name) { @@ -229,7 +229,8 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { if (literal.shape().rank() != 1) { - return errors::InvalidArgument("value is not 1D"); + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); } int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { @@ -318,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } +Status XlaOpKernelContext::ConstantInputAsPartialShape( + int index, PartialTensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + // If `literal` is a scalar it's value must be -1. + if (literal.shape().rank() == 0) { + int64 shape_val; + TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); + if (shape_val != -1) { + return errors::InvalidArgument( + "Cannot convert value to PartialTensorShape: ", shape_val); + } + *shape = PartialTensorShape(); // Shape with unknown rank. + return Status::OK(); + } + std::vector dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = PartialTensorShape(dims); + return Status::OK(); +} + Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { @@ -353,8 +375,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); if (!variable->initialized()) { - return errors::InvalidArgument("Read of uninitialized variable ", - variable->name()); + return errors::FailedPrecondition("Read of uninitialized variable ", + variable->name()); } if (variable->type() != type) { return errors::InvalidArgument( @@ -446,6 +468,16 @@ void XlaOpKernelContext::SetOutputExpression(int index, } } +xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { SetOutputExpression( index, @@ -456,6 +488,11 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { SetOutputExpression(index, XlaExpression::Constant(constant)); } +void XlaOpKernelContext::SetTensorListOutput(int index, + const xla::XlaOp& handle) { + SetOutputExpression(index, XlaExpression::TensorList(handle)); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { SetOutputExpression(index, XlaExpression::Resource(resource)); } @@ -497,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, handle = xla::Reshape(handle, xla::AsInt64Slice(representation_shape.dimensions())); } + variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 1858844bc05a6e12abbf07af83cad816590ddd03..cc2d5e8de3eb020ba41dfed7d730b48cd0534b4c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -138,6 +138,10 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); + // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 + // into a PartialTensorShape. + Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); + // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. @@ -155,6 +159,11 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } + // Returns the type of output `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType output_xla_type(int index); + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. @@ -168,6 +177,9 @@ class XlaOpKernelContext { // Returns an XlaExpression describing the value of 'index'. void SetOutputExpression(int index, const XlaExpression& expression); + // Sets output `index` to the Tensor List `handle`. + void SetTensorListOutput(int index, const xla::XlaOp& handle); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 14237df69081016817fbd1a5332f22996e7f264d..1106c027c034938b5d454a4b26e2925f65100a3d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -73,6 +73,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } + if (x.allow_variant_types != y.allow_variant_types) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible allow_variant_types settings."; + return false; + } if (!x.has_device_whitelist && !y.has_device_whitelist) { LOG(WARNING) << "Duplicate registrations of " << x.name << "with no device whitelists."; @@ -132,6 +137,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; static void* registration_init = [®istry]() { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit; mutex_lock lock(registry.mutex_); if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { @@ -142,7 +148,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; cpu_global_jit ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; - registration.compile_resource_ops = false; + registration.compile_all_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = @@ -150,7 +156,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; - registration.compile_resource_ops = false; + registration.compile_all_resource_ops = false; } return nullptr; }(); @@ -289,6 +295,9 @@ void XlaOpRegistry::RegisterCompilationKernels() { if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); } + if (op_registration->allow_variant_types) { + allowed_values->add_type(DT_VARIANT); + } // Don't build KernelDefs that have unsatisfiable type constraints. if (allowed_values->type().empty()) { unsatisfiable_type_constraint = true; @@ -485,6 +494,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() { + registration_->allow_variant_types = true; + return *this; +} + XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, DataType allowed) { std::set& types = diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index ce3b6b298c6dc5a08e7b794bbab3a28575967d28..bf4d2e1a9ddb2f89889d2ec15f4e685fa39b51cf 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,10 +51,10 @@ constexpr std::array kNumericTypes = { {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { +constexpr std::array kCpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL}}; + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; constexpr std::array kGpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, @@ -89,7 +89,7 @@ class XlaOpRegistry { AutoclusteringPolicy autoclustering_policy; // Enable compilation of operators that use DT_RESOURCE types? - bool compile_resource_ops = false; + bool compile_all_resource_ops = false; }; // Registers an XLA backend. `compilation_device_name` is the name of the @@ -212,6 +212,10 @@ class XlaOpRegistry { // allow DT_RESOURCE. bool allow_resource_types = false; + // Should we allow variant types for type attributes? Used by While to + // allow TensorList which is of type DT_VARIANT. + bool allow_variant_types = false; + // Mapping from attribute name to a list of supported types. std::unordered_map> type_constraints; @@ -233,9 +237,9 @@ class XlaOpRegistry { // Returns true if registrations x and y can both be added to the registry. // This is always the case if they refer to different ops. If they refer to - // the same op name, they must: have the same values for compilation_only and - // allow_resource_types; use a device_whitelist; and their - // whitelists must not intersect. + // the same op name, they must: have the same values for compilation_only, + // allow_resource_types and allow_variant_types; use a device_whitelist; and + // their whitelists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); static Status CompileTimeConstantInputs(const NodeDef& node_def, @@ -293,6 +297,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Allow DT_VARIANT types for type parameters. + XlaOpRegistrationBuilder& AllowVariantTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstantInput( absl::string_view input_name); diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 736588bb8b89ba756cdce77eeebff8d1fcf4774c..ab3a5bdd9bc580c16d65d35c3be3ba8204511f83 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -86,6 +86,12 @@ class XlaResource { // variables have new values that need to be written back. const xla::XlaOp& initial_value() const { return initial_value_; } + // An xla shape that indicates how this resource variable is represented on + // device. + const absl::optional& representation_shape() const { + return representation_shape_; + } + // A variable is initialized if it has a value. bool initialized() const { return value_.valid(); } @@ -100,6 +106,11 @@ class XlaResource { // Sets the current value of the resource to an all-zero value. Status SetZeroValue(xla::XlaBuilder* builder); + // Sets the representational shape of the resource on device. + void SetRepresentationShape(const xla::Shape& shape) { + representation_shape_ = absl::make_optional(shape); + } + // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator @@ -160,6 +171,10 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; + // An xla shape that indicates how this resource variable is represented on + // device. + absl::optional representation_shape_; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 636e5ef721f58c009566c10a653d09a7667619c0..ea3378085b94af466972128f7c18fd6b8b0984bc 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -57,6 +57,24 @@ xla_proto_library( ], ) +cc_library( + name = "comparison_util", + srcs = [ + "comparison_util.cc", + ], + hdrs = [ + "comparison_util.h", + ], + visibility = [":friends"], + deps = [ + ":statusor", + ":types", + ":util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "execution_options_util", srcs = [ @@ -150,8 +168,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", "//tensorflow/stream_executor/lib", ], ) @@ -194,7 +210,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", ], ) @@ -830,10 +846,10 @@ cc_library( deps = [ ":parse_flags_from_env", + ":status", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 27c075e8f13f6777af4e837501d97a33034313f5..d5ade8f626205b0c60dadd2216c9a49d1112832b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -118,6 +118,7 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", @@ -212,6 +213,7 @@ cc_library( ":padding", ":sharding_builder", ":xla_computation", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -246,6 +248,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 1f594e551af381d7537e947892cbf7e0b5b3b861..f2d124d099b09af0612c6b23d899467c170edcd6 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -58,6 +58,12 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } +ExecutableBuildOptions& ExecutableBuildOptions::set_num_replicas( + int num_replicas) { + num_replicas_ = num_replicas; + return *this; +} + string ExecutableBuildOptions::ToString() const { string result_layout = "nullopt"; if (result_layout_set_) { @@ -65,8 +71,8 @@ string ExecutableBuildOptions::ToString() const { } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " - "generate_hlo_graph=%s}", - device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); + "num_replicas=%d}", + device_ordinal_, result_layout, num_replicas_); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index a58090253bfac7779e4b61bc7231a0f0d945cc00..1d85fb34304b95d1fccdb0b0d6a7a65e739fae18 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -67,12 +67,18 @@ class ExecutableBuildOptions { // debugging. string ToString() const; + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int num_replicas() const { return num_replicas_; } + ExecutableBuildOptions& set_num_replicas(int num_replicas); + private: int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; + int num_replicas_ = 1; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 27f9c61848dfb0a2c6739081f2893713412b975a..1ddd3c2a4550605dd78a9d0079bea58f7d3a6582 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -49,48 +49,6 @@ xla_test( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":math", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:loops", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:triangular_solve", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "cholesky_test", - srcs = ["cholesky_test.cc"], - tags = ["optonly"], - deps = [ - ":arithmetic", - ":cholesky", - ":matrix", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "comparators", srcs = ["comparators.cc"], @@ -184,6 +142,7 @@ cc_library( srcs = ["math.cc"], hdrs = ["math.h"], deps = [ + ":arithmetic", ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -195,8 +154,10 @@ xla_test( name = "math_test", srcs = ["math_test.cc"], deps = [ + ":constants", ":math", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -213,13 +174,18 @@ cc_library( deps = [ ":arithmetic", ":constants", + ":slicing", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -230,12 +196,16 @@ xla_test( deps = [ ":matrix", ":slicing", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", ], ) @@ -274,7 +244,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", "@com_google_absl//absl/base", ], ) @@ -326,12 +295,7 @@ cc_library( srcs = ["slicing.cc"], hdrs = ["slicing.h"], deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", ], @@ -345,7 +309,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -357,6 +320,7 @@ cc_library( srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ + ":comparators", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -372,7 +336,6 @@ xla_test( ":sorting", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -433,51 +396,105 @@ cc_library( ) cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], deps = [ - "//tensorflow/compiler/xla:literal", + ":arithmetic", + ":comparators", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = [ - "enable_for_xla_interpreter", - "noasan", # sometimes times out, http://b/78650012 + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", ], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], deps = [ + ":arithmetic", + ":constants", + ":matrix", + ":self_adjoint_eig", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "svd", + srcs = ["svd.cc"], + hdrs = ["svd.h"], + deps = [ + ":arithmetic", + ":comparators", + ":constants", + ":loops", ":math", ":matrix", - ":triangular_solve", + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "svd_test", + srcs = ["svd_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], + deps = [ + ":arithmetic", + ":constants", + ":matrix", + ":slicing", + ":svd", "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/client/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h deleted file mode 100644 index 0bae26837c0f14dd0cfab82cf426becc787ec11c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/cholesky.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Computes the Cholesky decompositions of a batch of symmetric positive -// definite matrices. -// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the -// two minor dimensions equal. -// The algorithm implements a blocked Cholesky decomposition; `block_size` is -// the block size to use. -// TODO(phawkins): check for negative values on the diagonal and return an -// error, instead of silently yielding NaNs. -// TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky( - xla::XlaOp a, int64 block_size = 256, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 1ada7b4a964ccf7ca400b937abbe425bef083468..6bd56a8df0a5d0417f747a158664ed0daa8a7b40 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -80,6 +80,24 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { } } +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + std::numeric_limits::min()); + case BF16: + return ConstantR0(builder, bfloat16::min_positive_normal()); + case F32: + return ConstantR0(builder, std::numeric_limits::min()); + case F64: + return ConstantR0(builder, std::numeric_limits::min()); + default: + return builder->ReportError( + InvalidArgument("Invalid type for MinPositiveNormalValue (%s).", + PrimitiveType_Name(type))); + } +} + XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { return ConstantLiteral(builder, LiteralUtil::MaxValue(type)); } @@ -100,4 +118,28 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { } } +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + switch (type) { + case F16: + return ConstantR0( + builder, Eigen::NumTraits::quiet_NaN()); + case BF16: + return ConstantR0( + builder, bfloat16(std::numeric_limits::quiet_NaN())); + case F32: + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + case F64: + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + default: + return InvalidArgument( + "Operand to NanValue was %s, but must be a real-valued " + "floating-point type.", + PrimitiveType_Name(type)); + } + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index a38282e8dbd7c8ac247a9c16f1e756c6e23a1360..47b8f1b44ffa12b2b15be0e865d693a709962e6e 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -90,6 +90,27 @@ XlaOp ScalarLike(XlaOp prototype, T value) { }); } +// Returns an array or scalar containing copies of `value` cast to the same +// run-type type as `prototype` and broadcast to the same dimensions as +// `prototype`. +// +// If `prototype` is not a scalar or array, returns an error. +template +XlaOp FullLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { + return Broadcast(ScalarLike(prototype, value), shape.dimensions()); + } else { + return InvalidArgument( + "Prototype shape for BroadcastConstantLike must be a scalar or " + "array, but was %s", + shape.ToString()); + } + }); +} + // Returns a scalar with value '0' of 'type'. XlaOp Zero(XlaBuilder* builder, PrimitiveType type); @@ -114,6 +135,9 @@ XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); // point type, this is equal to -MaxFiniteValue(). XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); +// Returns the minimum positive normal value for floating-point type `type`. +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); + // Returns the maximum representable finite or infinite value for 'type'. // Returns 'inf' for floating-point types. XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); @@ -121,6 +145,9 @@ XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); // Returns the maximum representable finite value for 'type'. XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); +// Returns a nan for the given type. Only valid for real-valued fp types. +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index f4320f65c1f76d4d4c384110b39d6606773aaf01..180175b7495b32250af8ae77c8c7fba804703885 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -155,5 +155,12 @@ XLA_TEST_F(ConstantsTest, MaxValueF32) { {}); } +XLA_TEST_F(ConstantsTest, NanValueF32) { + XlaBuilder builder(TestName()); + NanValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::quiet_NaN(), + {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 3d0e3a2b93fe7347597be7f0fb6ee8147948a3e5..20d3c0fc549d9cbb14c8d8e271ff386a06b5ecab 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -13,8 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// This macro is required to make MSVC defines math constants in math.h +#define _USE_MATH_DEFINES +#include + #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -22,51 +27,112 @@ limitations under the License. namespace xla { -XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } +// Returns operation(operand), except if `operand` is one of the types in +// upcast_types, in which case first converts it to F32, and then converts the +// result down to the original type. +static XlaOp DoWithUpcastToF32(XlaOp operand, + absl::Span upcast_types, + const std::function& operation) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + PrimitiveType elem_ty = shape.element_type(); + bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty); + + if (needs_upcast) { + operand = ConvertElementType(operand, F32); + } + XlaOp result = operation(operand); + if (needs_upcast) { + result = ConvertElementType(result, elem_ty); + } + return result; + }); +} -XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } +// TODO(jlebar): Use this function in more places in this file to restrict the +// domain of other functions. +static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) { + auto& b = *operand.builder(); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + auto elem_ty = shape.element_type(); + if (!primitive_util::IsFloatingPointType(elem_ty)) { + return InvalidArgument( + "Operands to %s must be real-valued floating-point, but got %s", + op_name, PrimitiveType_Name(elem_ty)); + } + return Status::OK(); +} -XlaOp Square(XlaOp operand) { return operand * operand; } +XlaOp IsPosInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + // Note that this is only correct for floating-point types. If we wanted it + // to be correct for all types, we'd need to Gt(MaxFiniteValue). + return Eq(operand, MaxValue(&b, shape.element_type())); + }); +} -XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } +XlaOp IsNegInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + // Note that this is only correct for floating-point types. If we wanted it + // to be correct for all types, we'd need to Lt(MinFiniteValue). + return Eq(operand, MinValue(&b, shape.element_type())); + }); +} -namespace { +XlaOp IsInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand)); + return IsPosInf(Abs(operand)); + }); +} -// Polynomials for computing erf/erfc. Originally from cephes. -// Note we use float for compatibility across devices, at the cost of some -// precision for 64 bit computations. -// -// Coefficients are in descending order. -std::array kErfcPCoefficient = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; -std::array kErfcQCoefficient = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; -std::array kErfcRCoefficient = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; -std::array kErfcSCoefficient = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; -std::array kErfTCoefficient = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; -std::array kErfUCoefficient = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; -} // namespace +XlaOp IsNan(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand)); + return Ne(operand, operand); + }); +} + +XlaOp IsNegZero(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + + // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0 + // (sign bit on, all other bits off). + switch (shape.element_type()) { + case F64: + return Eq(BitcastConvertType(operand, U64), + ConstantR0WithType(&b, U64, uint64{1} << 63)); + case F32: + return Eq(BitcastConvertType(operand, U32), + ConstantR0WithType(&b, U32, uint32{1} << 31)); + case F16: + case BF16: + // Not all XLA backends handle U16 well, so we convert to F32/U32. + // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for + // backends that *do* support it. + return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32), + ConstantR0WithType(&b, U32, uint32{1} << 31)); + default: + LOG(FATAL) << "Expected real fp type."; + } + }); +} + +XlaOp Square(XlaOp operand) { return operand * operand; } + +XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. @@ -78,27 +144,97 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { return poly; } -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x) { +// Computes an approximation of the error function complement (1 - erf(x)). +// +// Precondition: abs(x) >= 1. Otherwise, use ErfImpl. +// +// This follows Cephes's f32 implementation of erfc, and so it may have errors +// for double precision. +// +// See also these alternate implementations of erf and erfc: +// +// https://stackoverflow.com/questions/35148198 +// https://stackoverflow.com/questions/35966695 +// +static XlaOp ErfcImpl(XlaOp x) { + // Coefficients for erfc(f32), from Cephes. + // + // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2 + static std::array kErfcPCoefficient{ + +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, + -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, + +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, + }; + // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14 + static std::array kErfcRCoefficient{ + -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, + +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, + -2.820767439740514E-1, +5.641895067754075E-1, + }; + XlaOp abs_x = Abs(x); XlaOp z = Exp(-x * x); + XlaOp q = ScalarLike(x, 1) / abs_x; + XlaOp y = q * q; + XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)), + EvaluatePolynomial(y, kErfcPCoefficient), + EvaluatePolynomial(y, kErfcRCoefficient)); + y = z * q * p; + return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y); +} - XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); - XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); - XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); - XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); - - XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); +// Compute a polynomial approximation of the error function. +// +// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. +// +// This follows Cephes's f32 implementation of erf, so it may have errors for +// double precision. +static XlaOp ErfImpl(XlaOp x) { + // Coefficients for by erf(f32), from Cephes. + // + // erf(x) = x P(x^2), 0 < x < 1 + static std::array kErfTCoefficient{ + +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, + -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, + +1.128379165726710E+0, + }; + + return x * EvaluatePolynomial(x * x, kErfTCoefficient); +} - return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +XlaOp Erfc(XlaOp x) { + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x)); + + // erfc(x) = + // erfc_impl(x) if x > 1 + // 1 - erf_impl(x) otherwise + // + // Erf(c)Impl don't have enough precision when run with bf16 intermediates + // (not surprising!), so upcast to f32 in this case. + return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) { + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x), + ScalarLike(x, 1) - ErfImpl(x)); + }); + }); } -// Compute a polynomial approximation of the error function. XlaOp Erf(XlaOp x) { - XlaOp z = x * x; - XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); - XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); - return x * pt / pu; + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x)); + // erf(x) = + // erf_impl(x) if x < 1 + // 1 - erfc_impl(x) otherwise + // + // Erf(c)Impl don't have enough precision when run with bf16 intermediates + // (not surprising!), so upcast to f32 in this case. + return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) { + return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x), + ScalarLike(x, 1) - ErfcImpl(x)); + }); + }); } // Approximation for the inverse error function from @@ -114,36 +250,40 @@ XlaOp Erf(XlaOp x) { // } // return p*x XlaOp ErfInv(XlaOp x) { - XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; - auto one = ScalarLike(x, 1.0); - auto w = -Log((one - x) * (one + x)); - - auto lt = Lt(w, ScalarLike(x, 5.0)); - auto coefficient = [&](int i) { - return Select(lt, - Broadcast(ScalarLike(x, w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = coefficient(i) + p * w; - } - return p * x; + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, FullLike(x, w_less_than_5_constants[i]), + FullLike(x, w_greater_than_5_constants[i])); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + + // Result modulo edge cases. + XlaOp result = p * x; + + // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is + // indeterminate, and can give nan or -/+inf.) + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x)); + return Select(Eq(Abs(x), ScalarLike(x, 1)), + x * MaxValue(&b, shape.element_type()), result); }); } @@ -171,49 +311,108 @@ static constexpr std::array kLanczosCoefficients = { // t(z) = z + kLanczosGamma + 1/2 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) XlaOp Lgamma(XlaOp input) { - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - XlaOp log_pi = ScalarLike(input, std::log(M_PI)); - XlaOp log_sqrt_two_pi = ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); - - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // gamma(x) = pi / sin(pi * x) * gamma(1 - x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp x = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - x = x + lanczos_coefficient / (z + index + one); - } + auto do_it = [](XlaOp input) { + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + XlaOp log_pi = ScalarLike(input, std::log(M_PI)); + XlaOp log_sqrt_two_pi = + ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); + + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp x = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + x = x + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - - // If z = a + 0j, the analytic continuation of log reduces to taking the - // absolute value of the real part. - // Re(log(z)) = Re(log|z| + arg(z)j) - // = log|a| - XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; - XlaOp result = Select(need_to_reflect, reflection, log_y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + // Compute the final result (modulo reflection). t(z) may be large, and we + // need to be careful not to overflow to infinity in the first term of + // + // (z + 1/2) * log(t(z)) - t(z). + // + // Therefore we compute this as + // + // (z + 1/2 - t(z) / log(t(z))) * log(t(z)). + // + XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x); + + // Compute the reflected value, used when x < 0.5: + // + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // (The abs is because lgamma is the log of the absolute value of the gamma + // function.) + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the + // sin(pi * x) term. The slope is large, so precision is particularly + // important. + // + // Because abs(sin(pi * x)) has period 1, we can equivalently use + // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This + // is more numerically accurate: It doesn't overflow to inf like pi * x can, + // and if x is an integer, it evaluates to 0 exactly, which is significant + // because we then take the log of this value, and log(0) is inf. + // + // We don't have a frac(x) primitive in XLA and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for + // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close + // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain + // [0, 1] is symmetric across the line Y=0.5. + // + XlaOp abs_input = Abs(input); + XlaOp abs_frac_input = abs_input - Floor(abs_input); + // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve + // precision of pi * abs_frac_input for values of abs_frac_input close to 1. + XlaOp reduced_frac_input = + Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)), + ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input); + XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input)); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + XlaOp reflection = + Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y, + -reflection_denom); + XlaOp result = Select(need_to_reflect, reflection, log_y); + + // lgamma(+/-inf) = +inf. + XlaOp inf_bcast = FullLike(input, std::numeric_limits::infinity()); + return Select(IsInf(input), inf_bcast, result); + }; + + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input)); + // F16 and BF16 don't provide sufficient precision for intermediate results + // here (although it's better than you might expect!), so do the + // computations in F32. + return DoWithUpcastToF32(input, {BF16, F16}, do_it); + }); } // Compute the Digamma function using Lanczos' approximation from "A Precision @@ -224,69 +423,101 @@ XlaOp Lgamma(XlaOp input) { // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) XlaOp Digamma(XlaOp input) { - XlaOp zero = ScalarLike(input, 0); - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - - XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp num = zero; - XlaOp denom = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); - denom = denom + lanczos_coefficient / (z + index + one); - } + auto do_it = [](XlaOp input) { + XlaOp zero = ScalarLike(input, 0); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + + XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp num = zero; + XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp y = log_t + num / denom - lanczos_gamma / t; - XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); - XlaOp result = Select(need_to_reflect, reflection, y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp y = log_t + num / denom - lanczos_gamma / t; + + // We need to be careful how we compute cot(pi * input) below: For + // near-integral values of `input`, pi * input can lose precision. + // + // Input is already known to be less than 0.5 (otherwise we don't have to + // reflect). We shift values smaller than -0.5 into the range [-.5, .5] to + // increase precision of pi * input and the resulting cotangent. + XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5))); + XlaOp reflection = + y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input); + XlaOp real_result = Select(need_to_reflect, reflection, y); + + // Digamma has poles at negative integers and zero; return nan for those. + return Select(And(Le(input, zero), Eq(input, Floor(input))), + FullLike(input, std::numeric_limits::quiet_NaN()), + real_result); + }; + + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); + return DoWithUpcastToF32(input, {BF16, F16}, do_it); + }); } // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = ScalarLike(x, 0.5); - auto one = ScalarLike(x, 1.0); - auto two = ScalarLike(x, 2.0); - - auto round_val = Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * Floor(half * x); - auto is_odd = Eq(nearest_even_int, one); - return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs (What does it even mean to round a complex + // number? Do you round each component equally? In that case, you should + // just ask for that explicitly.) + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x)); + + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); + + auto round_val = Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); + }); } // Trigonometric functions. -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// pi if x == -1 XlaOp Acos(XlaOp x) { - return ScalarLike(x, 2.0) * - Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) @@ -330,4 +561,82 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) { }); } +XlaOp NextAfter(XlaOp from, XlaOp to) { + auto builder = from.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from)); + int bitwidth = primitive_util::BitWidth(shape.element_type()); + auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth); + auto from_as_int = BitcastConvertType(from, int_type); + auto to_as_int = BitcastConvertType(to, int_type); + + // The result is NaN if either "from" or "to" are NaN. + auto from_is_nan = Ne(from, from); + auto to_is_nan = Ne(to, to); + auto nan_input = Or(from_is_nan, to_is_nan); + auto result_for_nan = + Broadcast(ScalarLike(from, std::numeric_limits::quiet_NaN()), + shape.dimensions()); + result_for_nan = BitcastConvertType(result_for_nan, int_type); + + // The sign bit is the MSB. + const int64 sign_mask = int64{1} << (bitwidth - 1); + // Discard the sign bit to make the result non-negative. + auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask)); + auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask)); + + // When both "from" and "to" are equal, the result is "to". + // N.B. It would not make a difference if we chose the result to be "from". + auto from_and_to_are_equal = Eq(from_as_int, to_as_int); + auto result_for_equal = to_as_int; + + // When both "from" and "to" are both 0, the result is "to". This ensures we + // get a zero signed like "to". + auto from_is_zero = Eq(from_abs, ZerosLike(from_abs)); + auto to_is_zero = Eq(to_abs, ZerosLike(to_abs)); + auto result_for_both_zero = to_as_int; + + auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask)); + auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask)); + + // If from == 0 && to != 0, we need to return the smallest subnormal number + // signed like "to". + auto result_for_from_zero_to_non_zero = + Or(to_sign, ScalarLike(from_as_int, 1)); + + // If the sign of "from" and "to" disagree: + // - we need to make the magnitude of "from" smaller so that it is closer to + // zero. + // + // Otherwise the signs agree: + // - "from" with a magnitude larger than "to" means we need to make the + // magnitude smaller. + // - "from" with a magnitude smaller than "to" means we need to make the + // magnitude larger. + // - "from" with the same magnitude and sign as "to" has already been + // handled. + auto signs_disagree = Ne(from_sign, to_sign); + auto from_magnitude_larger_than_to = Gt(from_abs, to_abs); + auto result_has_smaller_magnitude = + Or(from_magnitude_larger_than_to, signs_disagree); + auto magnitude_adjustment = + Select(result_has_smaller_magnitude, + Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()), + Broadcast(ScalarLike(from_as_int, 1), shape.dimensions())); + auto result = Add(from_as_int, magnitude_adjustment); + // Handle from == ±0. + result = Select(from_is_zero, + Select(to_is_zero, result_for_both_zero, + result_for_from_zero_to_non_zero), + result); + // Handle from == to. + result = Select(from_and_to_are_equal, result_for_equal, result); + // Handle isnan(from) || isnan(to). + result = Select(nan_input, result_for_nan, result); + + // Cast back to the original type. + return BitcastConvertType(result, shape.element_type()); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 17612bf9fdc0f1eabb338671c93c025c5b268872..71a3acedcec0a8e65561d4139baeaf532ec8bf46 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -20,11 +20,22 @@ limitations under the License. namespace xla { -// Computes the square root of 'operand'. -XlaOp Sqrt(XlaOp operand); - -// Computes the reciprocal of the square root of 'operand'. -XlaOp Rsqrt(XlaOp operand); +// Determines whether operand is +/-inf or nan. +// +// Raises an error if called on integral or complex values. +XlaOp IsPosInf(XlaOp operand); +XlaOp IsNegInf(XlaOp operand); +XlaOp IsInf(XlaOp operand); +XlaOp IsNan(XlaOp operand); + +// Determines whether operand is equal to -0. +// +// Raises an error for integral or complex values. +XlaOp IsNegZero(XlaOp operand); + +// Returns the next number after 'from' in the direction of 'to' the same way +// std::nextafter(from, to) would. +XlaOp NextAfter(XlaOp from, XlaOp to); // Computes the square of 'operand'. XlaOp Square(XlaOp operand); @@ -32,7 +43,7 @@ XlaOp Square(XlaOp operand); // Computes the reciprocal of 'operand'. XlaOp Reciprocal(XlaOp operand); -// Evaluates a polynomial given coefficients and `x`. +// Evaluates a polynomial given coefficients and 'x'. // N.B. Coefficients should be supplied in decreasing order. XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); @@ -86,7 +97,7 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); -// Applies a complex conjugation operation if `a` is complex and `conjugate` +// Applies a complex conjugation operation if 'a' is complex and 'conjugate' // is true, otherwise returns its argument. xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index ae2ea225d1aadd7b3a794eabeca866c498f34760..50613ce50255b8e211f6e64afbe0add290dfc647 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -30,6 +32,185 @@ class MathTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001}; }; +// Write TYPED_TESTs within the class definition so that we don't have to litter +// "this->" everywhere. +template +class MathTypedTest : public MathTest { + public: + void TestLogEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}}), &b)); + ComputeAndCompareR1(&b, + {-std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}, + {}, error_spec_); + } + + void TestLog1pEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log1p(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}, T{-1.0}}), &b)); + ComputeAndCompareR1( + &b, {T{0.0}, T{-0.0}, -std::numeric_limits::infinity()}, {}, + error_spec_); + } + + void TestIsInfOrNan() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + auto x = + ConstantR1(&b, { + T{0}, + T{100}, + T{-1000}, + T{std::numeric_limits::max()}, + T{std::numeric_limits::lowest()}, + T{std::numeric_limits::infinity()}, + T{-std::numeric_limits::infinity()}, + T{std::numeric_limits::quiet_NaN()}, + T{std::numeric_limits::signaling_NaN()}, + }); + Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); + + auto expected = LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1( + {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, true, true, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, true, false, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, false, true, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, false, false, true, true})); + ComputeAndCompareLiteral(&b, expected, {}); + } + + void TestIsNegZero() { + SetFastMathDisabled(true); + XlaBuilder b(TestName()); + T inf(std::numeric_limits::infinity()); + T nan(std::numeric_limits::quiet_NaN()); + IsNegZero(AddParam( + LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), + &b)); + + ComputeAndCompareLiteral( + &b, + LiteralUtil::CreateR1( + {true, false, false, false, false, false, false}), + {}, error_spec_); + } + + // sqrt(x) == pow(x, 0.5) except that + // + // pow(-inf, 0.5) == inf, while + // sqrt(-inf) == nan. + // + // Check that none of our backends are incorrectly assuming that sqrt(x) == + // pow(x, 0.5) without checking this edge case. + // + // For good measure, we also check pow with an exponent other than 0.5. + void TestSqrtPowInequivalence() { + SetFastMathDisabled(true); + + // Tests disable constant folding by default, but this test needs it + // enabled, otherwise we don't tickle the bug we're trying to catch. + // Specifically, without constant folding, the constants we pass to Pow + // below are hidden behind a reshape that's never folded away! + mutable_debug_options()->clear_xla_disable_hlo_passes(); + + const T inf(std::numeric_limits::infinity()); + const T nan(std::numeric_limits::quiet_NaN()); + + XlaBuilder b(TestName()); + auto x = AddParam(LiteralUtil::CreateR1({-inf}), &b); + ConcatInDim( + &b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))}, + 0); + std::vector expected = {nan, inf, inf}; + ComputeAndCompareR1(&b, expected, {}, error_spec_); + } + + void TestErfEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + auto x = AddParam(LiteralUtil::CreateR1({T{-1}, T{1}, T{0}}), &b); + ErfInv(x); + + const T inf(std::numeric_limits::infinity()); + std::vector expected = {-inf, inf, T{0}}; + + ComputeAndCompareR1(&b, expected, {}, error_spec_); + } +}; + +// TODO(b/123355973): Add bfloat16 to TestTypes once it's working. +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif + +TYPED_TEST_CASE(MathTypedTest, TestTypes); + +XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); } +XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); } +XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) { + this->TestSqrtPowInequivalence(); +} +XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfEdgeCases(); } + +// Check that certain ops only support real, floating-point inputs. +// +// TODO(jlebar): Expand this test to cover more ops. +XLA_TEST_F(MathTest, RealFpOnlyOps) { + for (int64 i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) { + auto ty = static_cast(i); + SCOPED_TRACE(PrimitiveType_Name(ty)); + Shape shape; + if (primitive_util::IsArrayType(ty)) { + shape = ShapeUtil::MakeShape(ty, {42}); + } else if (ty == PrimitiveType::TUPLE) { + shape = ShapeUtil::MakeTupleShape({}); + } else if (ty == PrimitiveType::OPAQUE) { + shape = ShapeUtil::MakeOpaqueShape(); + } else if (ty == PrimitiveType::TOKEN) { + shape = ShapeUtil::MakeTokenShape(); + } else { + continue; + } + + for (const auto& test : + std::vector, string>>({ + {IsFinite, "is_finite"}, + {IsInf, "is_inf"}, + {IsPosInf, "is_pos_inf"}, + {IsNegInf, "is_neg_inf"}, + {IsNan, "is_nan"}, + {Erf, "erf"}, + {Erfc, "erfc"}, + {Lgamma, "lgamma"}, + {Digamma, "digamma"}, + {RoundToEven, "round_to_even"}, + })) { + SCOPED_TRACE(test.second); + XlaBuilder b(TestName()); + XlaOp p = Parameter(&b, 0, shape, "p0"); + test.first(p); + + EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty)); + } + } +} + XLA_TEST_F(MathTest, SqrtF32) { XlaBuilder builder(TestName()); Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); @@ -106,6 +287,29 @@ XLA_TEST_F(MathTest, Lgamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +XLA_TEST_F(MathTest, LgammaF16) { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + + // These seemingly arbitrary inputs came from debugging the lgamma + // implementation against a test which tried all possible f16 values. + auto x = ConstantR1(&b, { + half(-7360.0), + half(-4066.0), + half(-5.9605e-08), + }); + Lgamma(x); + std::vector expected = { + std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + half(16.64), + }; + ComputeAndCompareR1(&b, expected, {}, ErrorSpec{0.1}); +} +#endif + XLA_TEST_F(MathTest, Digamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, @@ -148,5 +352,40 @@ XLA_TEST_F(MathTest, RoundToEven) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, ErfRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erf(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, ErfcRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erfc(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, LgammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Lgamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, DigammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Digamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, RoundToEvenRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + RoundToEven(x); + EXPECT_FALSE(b.Build().status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index dcec2139e47fc86d81a8877b4dccc43eb2b7207f..a055a8e625c680cf5232896c95cd35b78cb172bc 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -15,17 +15,26 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -37,7 +46,7 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, return ConvertElementType(indicator, type); } -XlaOp GetMatrixDiagonal(XlaOp x) { +XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); @@ -45,10 +54,13 @@ XlaOp GetMatrixDiagonal(XlaOp x) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); + + auto offset = ConstantR0WithType(builder, S32, k); + absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + offset; auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); auto mask = Broadcast(indicator, major_dims); @@ -58,9 +70,21 @@ XlaOp GetMatrixDiagonal(XlaOp x) { primitive_util::IsIntegralType(shape.element_type()) ? CreateScalarOrComputation(shape.element_type(), builder) : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + // k == 0, we can save one slice op. + if (k == 0) { + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + } else if (k > 0) { + auto result = Reduce(Select(mask, x, Zeros(builder, shape)), + ScalarLike(x, 0), reducer, {n_dims - 2}); + return SliceInMinorDims(result, {std::min(k, n)}, + {std::min(m + k, n)}); + } else { + auto result = Reduce(Select(mask, x, Zeros(builder, shape)), + ScalarLike(x, 0), reducer, {n_dims - 1}); + return SliceInMinorDims(result, {std::min(-k, m)}, + {std::min(m, n - k)}); + } }); } @@ -91,77 +115,224 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } -XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config) { + for (auto dim : output_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(y_config, dim)) { + if (absl::c_count(output_config, dim) > 1) { + return InvalidArgument("Einsum has repeated output dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has output dimension without corresponding input dimension."); + } + for (auto dim : x_config) { + if (absl::c_linear_search(y_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(x_config, dim) > 1) { + return InvalidArgument("Einsum has repeated lhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has lhs dimension without corresponding rhs or output " + "dimension."); + } + for (auto dim : y_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(y_config, dim) > 1) { + return InvalidArgument("Einsum has repeated rhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has rhs dimension without corresponding lhs or output " + "dimension."); + } + return Status::OK(); +} + +xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, + absl::Span output_config, + xla::PrecisionConfig::Precision precision) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + TF_RETURN_IF_ERROR( + ValidateEinsumNumericDimensions(x_config, y_config, output_config)); + const int64 x_rank = x_config.size(); + const int64 y_rank = y_config.size(); + const int64 output_rank = output_config.size(); + absl::flat_hash_set x_map; + absl::flat_hash_set y_map; + absl::flat_hash_set output_map; + + auto find = [&](const absl::flat_hash_set& map, int64 d) { + return map.count(d) != 0; + }; - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (x_shape.rank() != y_shape.rank()) { - return InvalidArgument( - "Arguments to BatchDot have different ranks: %s vs. %s", - ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + auto insert = [&](absl::flat_hash_set& map, char d) { + CHECK(!find(map, d)); + map.insert(d); + }; + + for (auto d : x_config) { + insert(x_map, d); } - const int ndims = x_shape.rank(); - if (ndims < 2) { - return InvalidArgument( - "Arguments to BatchDot must have rank >= 2: got %d", ndims); + + for (auto d : y_config) { + insert(y_map, d); } - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return InvalidArgument( - "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", - i, ShapeUtil::HumanString(x_shape), - ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); + for (auto d : output_config) { + insert(output_map, d); } - int x_inner_dim = ndims - 1; - int y_inner_dim = ndims - 2; - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return InvalidArgument( - "Dimensions %d and %d of arguments to BatchDot must be equal: " - "shapes %s vs %s", - x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), - ShapeUtil::HumanString(y_shape)); + DotDimensionNumbers dnums; + std::vector lhs_outer_dims; + auto is_batch_dim = [&](int64 d) { + return find(x_map, d) && find(y_map, d) && find(output_map, d); + }; + auto is_contracting = [&](int64 d) { + return find(x_map, d) && find(y_map, d); + }; + auto rhs_dimension_number = [&](int64 d) { + return absl::c_find(y_config, d) - y_config.begin(); + }; + for (int64 i = 0; i < x_rank; ++i) { + auto dim_name = x_config[i]; + if (is_batch_dim(dim_name)) { + dnums.add_lhs_batch_dimensions(i); + dnums.add_rhs_batch_dimensions(rhs_dimension_number(dim_name)); + } else if (is_contracting(dim_name)) { + dnums.add_lhs_contracting_dimensions(i); + dnums.add_rhs_contracting_dimensions(rhs_dimension_number(dim_name)); + } else { + lhs_outer_dims.push_back(i); + } } - // Check for zero lhs/rhs dim size. - if (ShapeUtil::IsZeroElementArray(x_shape) || - ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + std::vector rhs_outer_dims; + for (int64 i = 0; i < y_rank; ++i) { + auto dim_name = y_config[i]; + if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) { + rhs_outer_dims.push_back(i); } - int x_outer_dim = ndims - 2; - int y_outer_dim = ndims - 1; - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return Broadcast( - ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), - dimensions); + } + + auto output_dimension_number = [&](char d) { + return absl::c_find(output_config, d) - output_config.begin(); + }; + + std::vector output_dims; + output_dims.reserve(output_rank); + for (auto d : dnums.lhs_batch_dimensions()) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : lhs_outer_dims) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : rhs_outer_dims) { + output_dims.push_back(output_dimension_number(y_config[d])); + } + + std::vector transpose_dims(output_rank); + for (int64 i = 0; i < output_rank; ++i) { + transpose_dims[output_dims[i]] = i; } PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); + return Transpose(DotGeneral(x, y, dnums, &precision_proto), transpose_dims); + }); +} + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + const int ndims = x_shape.rank(); + batch_dimension_numbers.reserve(ndims - 2); + for (int i = 0; i < ndims - 2; ++i) { + batch_dimension_numbers.push_back(i); + } + std::vector x_config = batch_dimension_numbers; + x_config.push_back(ndims - 2); + x_config.push_back(ndims); + std::vector y_config = batch_dimension_numbers; + y_config.push_back(ndims); + y_config.push_back(ndims - 1); + std::vector output_config = batch_dimension_numbers; + output_config.push_back(ndims - 2); + output_config.push_back(ndims - 1); + return Einsum(x, x_config, y, y_config, output_config, precision); + }); +} + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config) { + std::array, 3> einsum_config_numeric; + std::vector main_split = + absl::StrSplit(einsum_config, ','); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + if (main_split.size() != 2) { + return InvalidArgument("Expected one \",\" in einsum_config."); + } + + auto maybe_invalid_character = [](char d) { + if (absl::ascii_isalpha(d)) { + return Status::OK(); + } + if (d == '.') { + return InvalidArgument("Unsupported \"...\" or \".\" in einsum config."); } + return InvalidArgument("Unexpected character in einsum config."); + }; + + auto& x_config = einsum_config_numeric[0]; + x_config.reserve(main_split[0].size()); + for (auto d : main_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + x_config.push_back(static_cast(d)); + } + std::vector y_output_split = + absl::StrSplit(main_split[1], "->"); + if (y_output_split.size() != 2) { + return InvalidArgument("Expected one \"->\" in einsum_config."); + } + auto& y_config = einsum_config_numeric[1]; + y_config.reserve(y_output_split[0].size()); + for (auto d : y_output_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + y_config.push_back(static_cast(d)); + } + auto& output_config = einsum_config_numeric[2]; + output_config.reserve(y_output_split[1].size()); + for (auto d : y_output_split[1]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + output_config.push_back(static_cast(d)); + } + return einsum_config_numeric; +} - return DotGeneral(x, y, dot_dnums, &precision_proto); +XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto einsum_config_numeric, + ParseEinsumString(einsum_config)); + return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1], + einsum_config_numeric[2], precision); }); } @@ -181,4 +352,5 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { return transpose ? TransposeInMinorDims(x) : x; } + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 916cd83748e7028c474065b86bf02d85166d2c9c..60c41ec45a086726086dac7227fc432a9c62d0c8 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -26,10 +30,15 @@ namespace xla { // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); -// Get the diagonals of the last two dimensions. If 'x' has shape -// [..., M, N], then the output has shape [..., min(M, N)], containing the -// diagonal elements (i.e., with indices [..., i, i]). -XlaOp GetMatrixDiagonal(XlaOp x); +// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the +// main diagonal, and k<0 for diagonals below the main diagonal. +// +// If 'x' has shape [..., M, N] +// If k >= 0: then the output has shape [..., min(M, N - k)], containing the +// diagonal elements (i.e., with indices [..., i, i + k]). +// If k < 0: then the output has shape [..., min(M + k, N)], containing the +// diagonal elements (i.e., with indices [..., i - k, i]). +XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); // Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal // and false above that diagonal. @@ -65,6 +74,40 @@ xla::XlaOp BatchDot( xla::XlaOp x, xla::XlaOp y, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); +// Parse an einsum string into dimension numbers: +// "ab,cb->ac" +// becomes: +// {{0, 1},{2, 1},{0, 2}} +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config); + +// Determine if each dimension label is in at least two inputs. +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config); + +// Supports two operand einsum notation like "ab,cb->ac". +xla::XlaOp Einsum( + xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac" +// becomes: +// x_config = {0, 1} +// y_config = {2, 1} +// output_config = {0, 2} +xla::XlaOp Einsum( + xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, absl::Span output_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 0593a7517ac125ca8dc5395cee76f6bc23232cd3..a93fc2ccb92912a10b9b6c2192b81cd73566f2a0 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -15,13 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { @@ -51,13 +53,24 @@ void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); - - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - GetMatrixDiagonal(a); - Array2D expected({{0, 5, 10}, {12, 17, 22}}); - - ComputeAndCompareR2(&builder, expected, {a_data.get()}); + std::map> k_and_expected = { + {0, {{0, 5, 10}, {12, 17, 22}}}, + {1, {{1, 6, 11}, {13, 18, 23}}}, + {2, {{2, 7}, {14, 19}}}, + {3, {{3}, {15}}}, + {4, {{}, {}}}, + {-1, {{4, 9}, {16, 21}}}, + {-2, {{8}, {20}}}, + {-3, {{}, {}}}, + {-4, {{}, {}}}, + }; + for (const auto& kv : k_and_expected) { + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a, kv.first); + + ComputeAndCompareR2(&builder, kv.second, {a_data.get()}); + } } XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } @@ -101,5 +114,78 @@ XLA_TEST_F(MatrixTest, RowBatchDot) { ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); } + +XLA_TEST_F(MatrixTest, Einsum) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + Einsum(l_index, row, "abc,adc->abd"); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +XLA_TEST_F(MatrixTest, ParseEinsumString) { + auto to_vec = [](absl::string_view s) { + std::vector v; + v.reserve(s.size()); + for (auto c : s) { + v.push_back(int64{c}); + } + return v; + }; + + auto to_string = [&](absl::string_view x, absl::string_view y, + absl::string_view o) { + return absl::StrCat(x, ",", y, "->", o); + }; + + std::vector> good_test_cases = {{"ab", "bc", "ac"}, + {"Bab", "Bbc", "Bac"}, + {"ab", "cd", "dcba"}, + {"abc", "abd", "cbd"}}; + for (auto test_case : good_test_cases) { + auto parse_result_or_status = + ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2])); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(parse_result[i], to_vec(test_case[i])); + } + EXPECT_TRUE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } + + std::vector einsum_strings_that_fail_parsing = { + "", "a", "ab->ba", "ab,bc,cd->ad", "a...b,bc->a...c"}; + for (auto test_case : einsum_strings_that_fail_parsing) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_FALSE(parse_result_or_status.status().ok()); + } + + std::vector einsum_strings_that_fail_numeric_validation = { + "a,b->c", "ab,bc->acd", "abz,bc->ac", "ab,bcz->ac"}; + for (auto test_case : einsum_strings_that_fail_numeric_validation) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + EXPECT_FALSE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 85b9e1827dcef5ed907d893277deb5a52f8f30e9..63b3b07ddc2a64aad4c3b14853958f2bcfa08b59 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/lib/prng.h" + #include #include "absl/base/casts.h" @@ -30,11 +32,8 @@ XlaOp RotateLeftU32(XlaOp v, int distance) { ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); } -using ThreeFry2x32State = std::array; +} // namespace -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { XlaBuilder* builder = input[0].builder(); key[0] = BitcastConvertType(key[0], U32); @@ -127,15 +126,28 @@ XlaOp StatelessRngUniformU32(std::array key, const Shape& shape) { return Reshape(result, AsInt64Slice(shape.dimensions())); } +ThreeFry2x32State Uint64ToUint32s(XlaOp u64) { + auto builder = u64.builder(); + auto const32 = ConstantR0WithType(builder, U64, 32); + auto fst = ConvertElementType(u64, U32); + auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32); + return {fst, snd}; +} + +XlaOp Uint32sToUint64(ThreeFry2x32State u32s) { + auto builder = u32s[0].builder(); + return ConvertElementType(u32s[0], U64) | + ShiftLeft(ConvertElementType(u32s[1], U64), + ConstantR0WithType(builder, U64, 32)); +} + XlaOp StatelessRngUniformU64(std::array key, const Shape& shape) { XlaBuilder* builder = key[0].builder(); const int64 size = ShapeUtil::ElementsIn(shape); ThreeFry2x32State inputs = GetInputs(size, builder); ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); // low 32 bit: outputs[0], high 32 bit: outputs[1] - auto result = ConvertElementType(outputs[0], U64) | - ShiftLeft(ConvertElementType(outputs[1], U64), - ConstantR0WithType(builder, U64, 32)); + auto result = Uint32sToUint64(outputs); return Reshape(result, AsInt64Slice(shape.dimensions())); } @@ -161,10 +173,6 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, PrimitiveType type, PrimitiveType unsigned_type) { XlaBuilder* builder = bits.builder(); - // TODO(b/72573764): Generate real uniform integer distribution. - // The following algorithm is the same one that TF uses right now, but it's - // uniform only when maxval - minval is a divisor of the range that bits is - // generated from. auto range = BitcastConvertType(maxval, unsigned_type) - BitcastConvertType(minval, unsigned_type); auto dist = Rem(bits, range); @@ -175,8 +183,6 @@ XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, BitcastConvertType(dist - dist_div_2, type); } -} // namespace - XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, XlaOp minval, XlaOp maxval) { XlaBuilder* builder = seeds[0].builder(); diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index 2603818de26888566a533334e49b039b126db66e..7b0b4c2439e538f004c8b9d6e5eb2553e485ee72 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -23,12 +23,38 @@ limitations under the License. namespace xla { +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +using ThreeFry2x32State = std::array; +ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key); + // Returns a tensor containing 'shape' random values uniformly distributed in // the range [minval, maxval). Requires 2 32-bit integer seeds. // Currently only 'shape's of type F32, S32 and S64 are implemented. XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, XlaOp minval, XlaOp maxval); +// Converts a 32-bit (signed or unsigned) integer random number `bits` into a +// float32 in the range [minval, maxval). +XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval); + +// Converts an integer random number 'bits' of type 'type' to a random number +// in the range [minval, maxval), of the same type. 'unsigned_type' is the +// unsigned version of 'type' (could be the same) with the same bit width. +// The algorithm is the same one that TF uses right now, but it's +// uniform only when maxval - minval is a divisor of the range that bits is +// generated from. +// TODO(b/72573764): Generate real uniform integer distribution. +XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, + PrimitiveType type, PrimitiveType unsigned_type); + +// The following 2 functions, for converting between one uint64 and two uint32s, +// use the contract "lower 32 bits for the first uint32, higher 32 bits for the +// second". +ThreeFry2x32State Uint64ToUint32s(XlaOp u64); +XlaOp Uint32sToUint64(ThreeFry2x32State u32s); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c0680b883acdfd93290558fe324e049d458b799 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -0,0 +1,466 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Jacobi rotation (also known as Givens rotation): +// G = [[ c, s], +// [-s, c]] +// matmul(G_T, G) = I +struct JacobiRotation { + XlaOp c; // cosine. + XlaOp s; // sine. +}; + +// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix. +struct JacobiUpdate { + XlaOp v; + XlaOp w; +}; + +struct FrobeniusNorms { + XlaOp off_diagonal_norm; + XlaOp total_norm; +}; + +// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, +// it computes a rotation matrix G = [[c, s], [-s, c]], such that +// G_T * A[[p, q], [p, q]] * G +// is diagonalized. +// +// def sym_schur2x2(A, p, q): +// if np.abs(A[p, q]) > 1e-6: +// tau = (A[q, q] - A[p, p]) / (2 * A[p, q]) +// if tau >= 0: +// t = 1.0 / (tau + np.sqrt(1 + tau ** 2)) +// else: +// t = -1.0 / (-tau + np.sqrt(1 + tau ** 2)) +// c = 1.0 / np.sqrt(1.0 + t ** 2) +// s = t * c +// else: +// c = 1.0 +// s = 0.0 +// return c, s +StatusOr SymmetricShurDecomposition2x2(XlaOp a, XlaOp p, + XlaOp q, XlaOp tol) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + + auto zero = ScalarLike(a, 0.0); + auto one = ScalarLike(a, 1.0); + auto two = ScalarLike(a, 2.0); + + auto pqs = DynamicSliceInMinorDims(a, {p, q}, {1, 1}); + + auto ps = DynamicSliceInMinorDims(a, {p, p}, {1, 1}); + auto qs = DynamicSliceInMinorDims(a, {q, q}, {1, 1}); + + auto tau = (qs - ps) / (pqs * two); + auto t_pos = one / (tau + Sqrt(one + Square(tau))); + auto t_neg = -one / (-tau + Sqrt(one + Square(tau))); + auto t = Select(Ge(tau, zero), t_pos, t_neg); + + auto c_temp = Rsqrt(one + Square(t)); + auto s_temp = t * c_temp; + + auto c = Select(Ge(Abs(pqs), tol), c_temp, ZerosLike(c_temp) + one); + auto s = Select(Ge(Abs(pqs), tol), s_temp, ZerosLike(s_temp)); + // Renormalize c and s to compensate for low precision arithmetic, this step + // is redundant if high precision float is used, like float64. + auto rnorm = Rsqrt(Square(c) + Square(s)); + + JacobiRotation schur; + + schur.c = c * rnorm; + schur.s = s * rnorm; + + return schur; +} + +StatusOr Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q, + XlaOp tol, int64 n) { + XlaBuilder* builder = jacobi_update.w.builder(); + TF_ASSIGN_OR_RETURN(JacobiRotation schur, SymmetricShurDecomposition2x2( + jacobi_update.w, p, q, tol)); + + TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(jacobi_update.w)); + const std::vector batch_dims(w_shape.dimensions().begin(), + w_shape.dimensions().end() - 2); + const int64 num_dims = w_shape.rank(); + + auto zero = ScalarLike(p, 0); + + XlaOp c = schur.c; + XlaOp s = schur.s; + + auto slice_p = DynamicSliceInMinorDims(jacobi_update.w, {p, zero}, {1, n}); + auto slice_q = DynamicSliceInMinorDims(jacobi_update.w, {q, zero}, {1, n}); + + auto slice_p_new = c * slice_p - s * slice_q; + auto slice_q_new = s * slice_p + c * slice_q; + + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {p, zero}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {q, zero}); + + slice_p = DynamicSliceInMinorDims(jacobi_update.w, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(jacobi_update.w, {zero, q}, {n, 1}); + + slice_p_new = c * slice_p - s * slice_q; + slice_q_new = s * slice_p + c * slice_q; + + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {zero, p}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {zero, q}); + + // Zero out a_{pq} explicitly. + std::vector pq_dims(batch_dims.begin(), batch_dims.end()); + pq_dims.push_back(1); + pq_dims.push_back(1); + auto pq_zero = ScalarLike(jacobi_update.w, 0.0); + auto pq_zeros = Broadcast(pq_zero, pq_dims); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {p, q}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {q, p}); + + slice_p = DynamicSliceInMinorDims(jacobi_update.v, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(jacobi_update.v, {zero, q}, {n, 1}); + + std::vector broadcast_dims(batch_dims.size()); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims.push_back(num_dims - 1); + + // Renormalize the p-th and q-th columns. This step is redundant if high + // precision floats are used, like 64-bit float. But for 32-bit float, it + // becomes necessary. This step will not increase the overall complexity. + slice_p_new = c * slice_p - s * slice_q; + slice_p_new = Mul( + slice_p_new, + Rsqrt(Reduce(Square(slice_p_new), pq_zero, + CreateScalarAddComputation(w_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + slice_q_new = s * slice_p + c * slice_q; + slice_q_new = Mul( + slice_q_new, + Rsqrt(Reduce(Square(slice_q_new), pq_zero, + CreateScalarAddComputation(w_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + jacobi_update.v = + DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_p_new, {zero, p}); + jacobi_update.v = + DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q}); + + return jacobi_update; +} + +StatusOr ComputeFrobeniusNorms(XlaOp w) { + XlaBuilder* builder = w.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); + const int64 num_dims = shape.rank(); + auto frobenius_norm = + Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2, num_dims - 1})); + auto diag = GetMatrixDiagonal(w); + auto diag_square = + Reduce(Square(diag), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2}); + + FrobeniusNorms frobenius_norms; + + frobenius_norms.off_diagonal_norm = + Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); + frobenius_norms.total_norm = frobenius_norm; + + return frobenius_norms; +} + +StatusOr> WhileLoopFn( + absl::Span initial_values, // + int matrix_dimension, // + int max_sweep_updates, // + PrimitiveType index_type, // + absl::string_view name, // + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto k = values[0]; + auto max_sweeps = ScalarLike(k, max_sweep_updates); + auto sweep_update_cond = Gt(max_sweeps, k); + + auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie(); + auto tol = norms.total_norm * values[3]; + auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), + xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); + + return And(sweep_update_cond, tol_cond); + }; + + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + auto while_cond_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_cond_builder) -> StatusOr { + auto p = values_inner[0]; + return Lt(p, ScalarLike(p, matrix_dimension - 1)); + }; + + auto while_body_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_body_builder) -> StatusOr> { + auto while_cond_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_cond_builder) -> StatusOr { + auto q = values_innermost[1]; + return Lt(q, ScalarLike(q, matrix_dimension)); + }; + auto while_body_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_body_builder) + -> StatusOr> { + auto p = values_innermost[0]; + auto q = values_innermost[1]; + + JacobiUpdate jacobi_update; + jacobi_update.v = values_innermost[2]; + jacobi_update.w = values_innermost[3]; + + auto tol = values_innermost[4]; + + TF_ASSIGN_OR_RETURN(jacobi_update, + Update(jacobi_update, p, q, tol, matrix_dimension)); + + std::vector updated_values_innermost; + updated_values_innermost.reserve(values_innermost.size()); + + updated_values_innermost.push_back(p); + updated_values_innermost.push_back(q + ScalarLike(q, 1)); + updated_values_innermost.push_back(jacobi_update.v); + updated_values_innermost.push_back(jacobi_update.w); + updated_values_innermost.push_back(tol); + + return updated_values_innermost; + }; + + std::vector values_innermost(5); + auto p = values_inner[0]; + auto q = p + ScalarLike(p, 1); + values_innermost[0] = p; // index p. + values_innermost[1] = q; // index q. + values_innermost[2] = values_inner[1]; // v. + values_innermost[3] = values_inner[2]; // w. + values_innermost[4] = values_inner[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_innermost, + WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost, + values_innermost, absl::StrCat(name, "-Innermost"), + inner_body_builder)); + + std::vector updated_values_inner; + updated_values_inner.reserve(values_inner.size()); + + updated_values_inner.push_back(p + ScalarLike(p, 1)); + updated_values_inner.push_back(values_innermost[2]); + updated_values_inner.push_back(values_innermost[3]); + updated_values_inner.push_back(values_innermost[4]); + return updated_values_inner; + }; + // Indexes. + XlaOp k = values[0]; + + std::vector values_inner(4); + values_inner[0] = ScalarLike(k, 0); // index p. + values_inner[1] = values[1]; // v. + values_inner[2] = values[2]; // w. + values_inner[3] = values[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_inner, + WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner, + absl::StrCat(name, "-Inner"), body_builder)); + + std::vector updated_values; + updated_values.reserve(values_inner.size()); + + updated_values.push_back(k + ScalarLike(k, 1)); + updated_values.push_back(values_inner[1]); + updated_values.push_back(values_inner[2]); + updated_values.push_back(values_inner[3]); + + return updated_values; + }; + std::vector values; + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + initial_values, name, builder)); + + return values; +} + +StatusOr SortByEigenvalues(SelfAdjointEigResult result) { + XlaBuilder* builder = result.v.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v)); + const int64 num_dims = shape.rank(); + auto dimensions = shape.dimensions(); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + result.w = BroadcastInDim(result.w, dimensions, broadcast_dims); + + XlaOp sort_result = + Sort({result.w, result.v}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); + result.v = GetTupleElement(sort_result, 1); + return result; +} + +} // namespace + +// This is the cyclic Jacobi iteration. Please note that the eigenvalues are +// possibly not ordered. +// +// def jacobi(A): +// n, _ = A.shape +// V = np.eye(n) +// frobenius_norm = np.linalg.norm(A) +// diag_norm = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) +// while off_diag_norm > 1e-6 * frobenius_norm: +// for p in range(n - 1): +// for q in range(p + 1, n): +// c, s = sym_schur2x2(A, p, q) +// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), +// A[[p, q], :]) +// A[:, [p, q]] = np.matmul(A[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// V[:, [p, q]] = np.matmul(V[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// frobenius_norm = np.linalg.norm(A) +// diag_norm = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt( +// frobenius_norm + diag_norm) +// +// return A, V +// +// TODO(kuny): Implement parallel order Jacobi. +// +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, + float epsilon) { + XlaBuilder* builder = a.builder(); + auto return_error = [&](const Status& status) { + SelfAdjointEigResult result; + result.v = builder->ReportError(status); + result.w = builder->ReportError(status); + return result; + }; + auto shape_with_status = builder->GetShape(a); + if (!shape_with_status.status().ok()) { + return return_error(shape_with_status.status()); + } + Shape a_shape = shape_with_status.ValueOrDie(); + const int64 num_dims = a_shape.rank(); + if (num_dims < 2) { + return return_error(InvalidArgument( + "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", + a_shape.ToString())); + } + PrimitiveType type = a_shape.element_type(); + if (!primitive_util::IsFloatingPointType(type)) { + return return_error(InvalidArgument( + "Type of the input matrix must be float: got %s.", a_shape.ToString())); + } + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + + if (m != n) { + return return_error(InvalidArgument( + "Arguments to Eigen decomposition must be square matrices: got shape " + "(%d, %d).", + m, n)); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + auto tol = ScalarLike(a, epsilon); + + auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); + auto w_init = Triangle(a, lower); + w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init; + + auto output_with_status = WhileLoopFn( + { + Zero(builder, S32), // k + v_init, // v + w_init, // w + tol, // + }, // + n, // + max_iter, // + S32, // + "CyclicJacobi", // + builder); + if (!output_with_status.status().ok()) { + return return_error(output_with_status.status()); + } + + auto output = output_with_status.ValueOrDie(); + + SelfAdjointEigResult result; + result.v = output[1]; + result.w = GetMatrixDiagonal(output[2]); + + return SortByEigenvalues(result).ValueOrDie(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h new file mode 100644 index 0000000000000000000000000000000000000000..2a089891d6a2d80c0c265a3310539b4f1c5db4d5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The eigenvalue decomposition of a symmetric matrix, the original matrix is +// recovered by v * w * v_t. +struct SelfAdjointEigResult { + // The i-th column is the normalized eigenvector corresponding to the + // eigenvalue w[i]. Will return a matrix object if a is a matrix object. + XlaOp v; + // The eigenvalues in ascending order, each repeated according to its + // multiplicity. + XlaOp w; +}; + +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64 max_iter = 100, float epsilon = 1e-6); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..99bec8a9ab5396c3414e4b79b42169099b66ac1a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -0,0 +1,315 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +class SelfAdjointEigTest : public ClientLibraryTestBase { + protected: + void SetUp() override { + ClientLibraryTestBase::SetUp(); + batch_3d_4x4_ = Array3D{ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 100, 6}, + {12, 48, 6, 62}, + }, + }; + matrix2d_8x8_ = Array2D{ + {14., 123., 49., 112., 115., 173., 182., 125.}, + {123., 14., 60., 118., 150., 130., 91., 72.}, + {49., 60., 138., 111., 106., 101., 115., 142.}, + {112., 118., 111., 142., 91., 130., 25., 61.}, + {115., 150., 106., 91., 116., 121., 128., 85.}, + {173., 130., 101., 130., 121., 70., 151., 132.}, + {182., 91., 115., 25., 128., 151., 66., 92.}, + {125., 72., 142., 61., 85., 132., 92., 156.}, + }; + low_rank_4x4_ = Array2D{ + // x = [[1, 2, 3, 4], [1, -1, 1, -1]] + // matmul(x.T, x) + {2, 1, 4, 3}, + {1, 5, 5, 9}, + {4, 5, 10, 11}, + {3, 9, 11, 17}, + }; + } + void TearDown() override { ClientLibraryTestBase::TearDown(); } + + Array3D GetUnitMatrix3D(const Array3D& matrix) { + Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); + for (int i = 0; i < matrix.n1(); ++i) { + for (int j = 0; j < matrix.n2(); ++j) { + result({i, j, j}) = 1.0; + } + } + return result; + } + + Array3D ExtractTriangularMatrix(const Array3D& matrix, + bool lower) { + Array3D result(matrix); + for (int i = 0; i < result.n1(); ++i) { + for (int j = 0; j < result.n2(); ++j) { + if (lower) { + for (int k = j + 1; k < result.n3(); ++k) { + result({i, j, k}) = 0.0; + } + } else { + for (int k = 0; k < j; ++k) { + result({i, j, k}) = 0.0; + } + } + } + } + return result; + } + + XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { + Shape shape = builder->GetShape(result.v).ValueOrDie(); + std::vector out_dims = shape.dimensions(); + std::vector broadcast_dims(shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + + broadcast_dims[shape.rank() - 2] = shape.rank() - 1; + auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims)); + return BatchDot(vw, TransposeInMinorDims(result.v), + PrecisionConfig::HIGHEST); + } + + XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { + Shape shape = builder->GetShape(m1).ValueOrDie(); + int64 size = 1; + for (auto d : shape.dimensions()) { + size *= d; + } + return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), + CreateScalarAddComputation(F32, builder)) / + ConstantR0WithType(builder, F32, size); + } + + Array2D GenerateRandomSymmetricMatrix(int size) { + Array2D result{size, size, 0.0}; + // TODO(b/128001705): This seed should not be needed but makes the test + // avoid inputs which trigger numerical instability. + result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */); + for (int i = 0; i < size; ++i) { + for (int j = 0; j < i; ++j) { + result({j, i}) = result({i, j}); + } + } + return result; + } + + Array3D batch_3d_4x4_; + Array2D matrix2d_8x8_; + Array2D low_rank_4x4_; + Array2D wrong_type_4x4_; +}; + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter( + ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter( + ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); + auto result = SelfAdjointEig(a, false); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, GetUnitMatrix3D(batch_3d_4x4_), + {a_data.get()}, ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR2Parameter(low_rank_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR2(&builder, low_rank_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { + XlaBuilder builder(TestName()); + + // This is computed by numpy.linalg.eigh with float32. + std::vector expected{-182.69205, -116.86245, -105.74489, -9.545369, + 37.81711, 104.732285, 120.29153, 868.00385}; + + XlaOp a; + auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + Add(result.w, ZerosLike(result.w)); + + ComputeAndCompareR1(&builder, expected, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { + XlaBuilder builder(TestName()); + + float expected_vals = 1e-3; + + XlaOp a; + auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 + GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), + BatchDot(TransposeInMinorDims(result.v), result.v), + &builder); + + ComputeAndCompareR0(&builder, expected_vals, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR2Parameter(wrong_type_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + EXPECT_FALSE(result.v.valid()); + EXPECT_FALSE(result.w.valid()); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { + XlaBuilder builder(TestName()); + int size = 8; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { + XlaBuilder builder(TestName()); + int size = 16; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { + XlaBuilder builder(TestName()); + int size = 32; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { + XlaBuilder builder(TestName()); + int size = 256; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { + XlaBuilder builder(TestName()); + int size = 512; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 77145ba7d4c72435450d3e33d57b2507eb84d2fc..0878cbeaf9ae1d85051ea3b5844f5837286c7dc2 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -134,4 +134,51 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, }); } +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + ShapeUtil::AppendMajorDimension(1, &index_shape); + std::vector to_concat; + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + to_concat.reserve(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + if (i == dim) { + to_concat.push_back(Reshape(index, index_shape.dimensions())); + } else { + to_concat.push_back(Iota(builder, index_shape, i)); + } + } + XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank()); + std::vector slice_sizes(input_shape.rank(), 1); + GatherDimensionNumbers gather_dnums; + gather_dnums.set_index_vector_dim(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + gather_dnums.add_collapsed_slice_dims(i); + gather_dnums.add_start_index_map(i); + } + return Gather(input, gather_indices, gather_dnums, slice_sizes); + }); +} + +XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + std::vector slice_sizes = input_shape.dimensions(); + slice_sizes[dim] = 1; + GatherDimensionNumbers gather_dnums; + for (int64 i = 0; i < input_shape.rank(); ++i) { + if (i != dim) { + gather_dnums.add_offset_dims(i); + } + } + gather_dnums.set_index_vector_dim(index_shape.rank()); + gather_dnums.add_collapsed_slice_dims(dim); + gather_dnums.add_start_index_map(dim); + return Gather(input, index, gather_dnums, slice_sizes); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890..bb6191df7c442f23a63f0d0b80c9b534c31e30fc 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -43,6 +43,28 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts); +// Gathers values along an axis specified by dim. +// +// For a 3-D tensor the output is specified by: +// +// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +// +// If `input` is an n-dimensional tensor with size +// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size +// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as +// `index`. +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim); + +// Returns a new tensor which indexes the input tensor along dimension dim using +// the entries in index. +// +// The returned tensor has the same number of dimensions as the original tensor +// (input). The dimth dimension has the same size as the length of index; other +// dimensions have the same size as in the original tensor. +XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 8d362119e01006555db0f82d02626175936e1d05..408a82ca3c6eeeae7edac8511769ec9c0d5a5f44 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -102,5 +102,56 @@ XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } +XLA_TEST_F(SlicingTest, TorchGather) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 0, "input", &builder, &input); + auto index_data = + CreateR2Parameter({{0, 0}, {1, 0}}, 1, "index", &builder, &index); + TorchGather(input, index, 1); + + ComputeAndCompareR2(&builder, {{1, 1}, {4, 3}}, + {input_data.get(), index_data.get()}); +} + +XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{0.1427, 0.0231, -0.5414, -1.0009}, + {-0.4664, 0.2647, -0.1228, -1.1068}, + {-1.1734, -0.6571, 0.7230, -0.6004}}, + 0, "input", &builder, &input); + auto index_data = + CreateR1Parameter({0, 2}, 1, "index", &builder, &index); + TorchIndexSelect(input, index, 0); + + ComputeAndCompareR2( + &builder, + {{0.1427, 0.0231, -0.5414, -1.0009}, {-1.1734, -0.6571, 0.7230, -0.6004}}, + {input_data.get(), index_data.get()}); +} + +XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{0.1427, 0.0231, -0.5414, -1.0009}, + {-0.4664, 0.2647, -0.1228, -1.1068}, + {-1.1734, -0.6571, 0.7230, -0.6004}}, + 0, "input", &builder, &input); + auto index_data = + CreateR1Parameter({0, 2}, 1, "index", &builder, &index); + TorchIndexSelect(input, index, 1); + + ComputeAndCompareR2( + &builder, {{0.1427, -0.5414}, {-0.4664, -0.1228}, {-1.1734, 0.7230}}, + {input_data.get(), index_data.get()}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index e8553a08bb014e790822a14e128686b60b8d6b7c..ddc39f4d874cd3613a763b969091e7e65ff1c783 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -30,7 +31,13 @@ XlaOp TopK(XlaOp input, int64 k) { ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - XlaOp sort_result = Sort(Neg(input), {iota_s32}); + // TODO(b/122298745): Get rid of Neg() and use CreateScalarGtComputation + // once the TPU backend supports the comparison computations. + XlaOp sort_result = + Sort({Neg(input), iota_s32}, + CreateScalarLtComputation({input_shape.element_type(), S32}, + iota_s32.builder()), + last_dim, /*is_stable=*/true); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/svd.cc b/tensorflow/compiler/xla/client/lib/svd.cc new file mode 100644 index 0000000000000000000000000000000000000000..dd4547dbab6e49a502a0d0e9afa67b509fe4d1e6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/svd.cc @@ -0,0 +1,884 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/client/lib/svd.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Given a matrix A, define H, +// H = A * (I - beta * v_T * v) if v is a row vector, or +// H = (I - beta * v * v_T) if v is column vector. +// A * H or H * A zeros out trailing part of some row or column of A. +// +// [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H +// = [x0, ..., x_{k-1}, xnorm, 0, ..., 0] +// +// Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}]) +struct HouseHolderResult { + XlaOp v; + XlaOp beta; + XlaOp a; +}; + +// Jacobi rotation (also known as Givens rotation): +// G = [[ c, s], +// [-s, c]] +// matmul(G_T, G) = I +struct JacobiRotation { + XlaOp c; // cosine. + XlaOp s; // sine. +}; + +// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix. +struct JacobiUpdate { + XlaOp v; + XlaOp w; +}; + +// OneSidedJacobiRotation holds the left and right Jacobi rotations. Refer to +// GetOneSidedJacobiRotation for the effect of applying OneSidedJacobiRotation +// to a matrix. +struct OneSidedJacobiRotation { + JacobiRotation rot_l; + JacobiRotation rot_r; +}; + +struct FrobeniusNorms { + XlaOp off_diagonal_norm; + XlaOp total_norm; +}; + +// Householder reflection on the trailing elements of a vector. +// +// H = I - beta * [1, v]' * [1, v] +// +// H * x = [..., xnorm, 0, ..., 0] +// ..., j, j + 1, ..., n +// +// def house(x, j, eps): +// sigma = np.linalg.norm(x[(j + 1):]) +// v = np.zeros_like(x) +// v[(j + 1):] = x[(j + 1):] +// if sigma < eps: +// beta = 0 +// else: +// mu = sigma * np.sqrt((x[j]/sigma)**2 + 1) +// if x[j] <= 0: +// v[j] = x[j] - mu +// else: +// v[j] = -sigma / (x[j] + mu) * sigma +// beta = 2 / ((sigma / v[j])**2 + 1) +// v = v / v[j] +// v[j] = 1 +// return v, beta +// +// Householder reflection on the trailing elements of a row of a matrix. After +// applying it on the matrix, all elements in [i, (j+1):] become zeros, i.e., +// +// H = I - beta * [1, v]' * [1, v], then, +// +// A[i, j:] * H = [sigma, 0, 0, ..., 0] +// +StatusOr HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int64 num_dims = a_shape.rank(); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + XlaOp zero = ScalarLike(i, 0); + XlaOp x = DynamicSliceInMinorDims(a, {i, zero}, {1, n}); + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int k = 0; k < num_batch_dims; ++k) { + batch_dims[k] = ShapeUtil::GetDimension(a_shape, k); + } + + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()), + num_dims - 1); + auto zeros = ZerosLike(x); + auto v = Select(Gt(idx, j), x, zeros); + + auto one = ScalarLike(v, 1.0); + + auto sigma = + Sqrt(Reduce(Square(v), ScalarLike(v, 0.0), + CreateScalarAddComputation(x_shape.element_type(), builder), + {num_dims - 1})); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + auto x_0j = DynamicSliceInMinorDims(x, {zero, j}, {1, 1}); + auto mu = Mul(sigma, Sqrt(Square(Div(x_0j, sigma, broadcast_dims)) + one), + broadcast_dims); + + auto v_0j = Select( + Le(x_0j, ScalarLike(x_0j, 0.0)), Sub(x_0j, mu), + -Mul(sigma, Div(sigma, Add(x_0j, mu), broadcast_dims), broadcast_dims)); + + auto beta = Div(ScalarLike(v_0j, 2.0), + (Square(Div(sigma, v_0j, broadcast_dims)) + one)); + + v = Select( + BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v, + v / v_0j); + v = Select(Eq(idx, j), zeros + one, v); + + beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps), + ZerosLike(beta), beta); + + HouseHolderResult result; + result.v = v; + result.beta = beta; + result.a = + Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision), + v, precision))); + + return result; +} + +// Householder reflection on the trailing elements of a col of a matrix. After +// applying it on the matrix, all elements in [(i+1):, j] become zeros, i.e., +// +// H = I - beta * [1; v] * [1; v]', then, +// +// H * A[i:, j] = [xnorm, 0, 0, ..., 0] +// +StatusOr HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int64 num_dims = a_shape.rank(); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + XlaOp zero = ScalarLike(i, 0); + XlaOp x = DynamicSliceInMinorDims(a, {zero, j}, {m, 1}); + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int k = 0; k < num_batch_dims; ++k) { + batch_dims[k] = ShapeUtil::GetDimension(a_shape, k); + } + + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()), + num_dims - 2); + auto zeros = ZerosLike(x); + auto v = Select(Gt(idx, i), x, zeros); + + auto one = ScalarLike(v, 1.0); + + auto sigma = + Sqrt(Reduce(Square(v), ScalarLike(v, 0.0), + CreateScalarAddComputation(x_shape.element_type(), builder), + {num_dims - 2})); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + auto x_0i = DynamicSliceInMinorDims(x, {i, zero}, {1, 1}); + auto mu = Mul(sigma, Sqrt(Square(Div(x_0i, sigma, broadcast_dims)) + one), + broadcast_dims); + + auto v_0i = Select( + Le(x_0i, ScalarLike(x_0i, 0.0)), Sub(x_0i, mu), + -Mul(sigma, Div(sigma, Add(x_0i, mu), broadcast_dims), broadcast_dims)); + + auto beta = Div(ScalarLike(v_0i, 2.0), + (Square(Div(sigma, v_0i, broadcast_dims)) + one)); + + v = Select( + BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v, + v / v_0i); + v = Select(Eq(idx, i), zeros + one, v); + + beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps), + ZerosLike(beta), beta); + + HouseHolderResult result; + result.v = v; + result.beta = beta; + result.a = Sub( + a, Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision), + precision))); + + return result; +} + +// Apply column and row householder reflections for bidiagonalization. +// +// def house_bidiag(A): +// xz, yz = A.shape +// LL = np.eye(xz) +// RR = np.eye(yz) +// for i in range(yz - 1): +// v, beta = house_col(A, i, i, 1e-8) +// L = np.eye(xz) - beta * np.outer(v, v) +// LL = np.matmul(LL, L) +// A = np.matmul(L, A) +// if i < yz - 2: +// v, beta = house_row(A, i, i + 1, 1e-8) +// R = np.eye(yz) - beta * np.outer(v, v) +// RR = np.matmul(RR, R) +// A = np.matmul(A, R) +// return LL, A, RR +// +StatusOr HouseHolderBidiagonalization( + XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int64 num_dims = a_shape.rank(); + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + XlaOp u_init = Broadcast( + IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims); + XlaOp v_init = Broadcast( + IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims); + + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto i = values[0]; + return Lt(i, ScalarLike(i, n - 2)); + }; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + auto i = values[0]; + auto one = ScalarLike(i, 1); + + auto u = values[1]; + auto v = values[2]; + auto a = values[3]; + auto eps = values[4]; + + TF_ASSIGN_OR_RETURN(HouseHolderResult house_col, + HouseCol(a, i, i, eps, precision)); + u = Sub(u, Mul(house_col.beta, + BatchDot(BatchDot(u, house_col.v, precision), + TransposeInMinorDims(house_col.v), precision))); + a = house_col.a; + + TF_ASSIGN_OR_RETURN(HouseHolderResult house_row, + HouseRow(a, i, i + one, eps, precision)); + v = Sub( + v, + Mul(house_row.beta, + BatchDot(BatchDot(v, TransposeInMinorDims(house_row.v), precision), + house_row.v, precision))); + a = house_row.a; + + std::vector updated_values; + updated_values.reserve(values.size()); + + updated_values.push_back(i + one); + updated_values.push_back(u); + updated_values.push_back(v); + updated_values.push_back(a); + updated_values.push_back(eps); + return updated_values; + }; + + std::vector values(5); + values[0] = Zero(builder, S32); + values[1] = u_init; + values[2] = v_init; + values[3] = a; + values[4] = eps; + + TF_ASSIGN_OR_RETURN(values, + WhileLoopHelper(while_cond_fn, while_body_fn, values, + "HouseHolderBidiagonalization", builder)); + + for (int k = 2; k > 0; --k) { + if (n - k >= 0) { + XlaOp index = ScalarLike(values[0], n - k); + TF_ASSIGN_OR_RETURN(HouseHolderResult house_col, + HouseCol(values[3], index, index, eps, precision)); + values[1] = + Sub(values[1], + Mul(house_col.beta, + BatchDot(BatchDot(values[1], house_col.v, precision), + TransposeInMinorDims(house_col.v), precision))); + values[3] = house_col.a; + } + } + + SVDResult result; + result.u = values[1]; + result.v = values[2]; + result.d = values[3]; + return result; +} + +// MakeJacobi computes a rotation matrix G = [[c, s], [-s, c]], such that +// G_T * [[ps, pqs], [pqs, qs]] * G +// is diagonalized. +// +// def make_jacobi(ps, qs, pqs, eps): +// if np.abs(a_pq) > eps: +// tau = (a_qq - a_pp) / (2 * a_pq) +// if tau >= 0: +// t = 1.0 / (tau + np.sqrt(1 + tau ** 2)) +// else: +// t = -1.0 / (-tau + np.sqrt(1 + tau ** 2)) +// c = 1.0 / np.sqrt(1.0 + t ** 2) +// s = t * c +// else: +// c = 1.0 +// s = 0.0 +// return c, s +// +StatusOr MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) { + auto zero = ScalarLike(ps, 0.0); + auto one = ScalarLike(ps, 1.0); + auto two = ScalarLike(ps, 2.0); + + auto tau = (qs - ps) / (pqs * two); + auto t_pos = one / (tau + Sqrt(one + Square(tau))); + auto t_neg = -one / (-tau + Sqrt(one + Square(tau))); + auto t = Select(Ge(tau, zero), t_pos, t_neg); + + auto c_temp = Rsqrt(one + Square(t)); + auto s_temp = t * c_temp; + + auto c = Select(Ge(Abs(pqs), eps), c_temp, ZerosLike(c_temp) + one); + auto s = Select(Ge(Abs(pqs), eps), s_temp, ZerosLike(s_temp)); + // Renormalize c and s to compensate for low precision arithmetic, this step + // is redundant if high precision float is used, like float64. + auto rnorm = Rsqrt(Square(c) + Square(s)); + + JacobiRotation rot; + + rot.c = c * rnorm; + rot.s = s * rnorm; + + return rot; +} + +// One sided Jacobi rotations. For a matrix, +// [a_pp, a_pq] +// [a_qp, a_qq] +// After applying Jacobi rotations on both sides, the matrix is diagonalized. +// [b_pp, 0] +// [0, b_qq] +// +// def jacobi_rot(a, p, q, eps): +// t = a[p, p] + a[q, q] +// d = a[q, p] - a[p, q] +// +// if np.abs(d) < eps: +// s = 0.0 +// c = 1.0 +// else: +// u = t / d +// tmp = np.sqrt(1.0 + u**2) +// s = -1.0 / tmp +// c = u / tmp +// +// rot = np.array([[c, s], [-s, c]]) +// m_tmp = rot.T @ a[[p, q], [p, q]] +// c_r, s_r = make_jacobi(m_tmp[0, 0], m_tmp[1, 1], m_tmp[0, 1]) +// rot_r = np.array([[c_r, s_r], [-s_r, c_r]]) +// rot_l = rot @ rot_r +// return rot_l, rot_r +// +StatusOr GetOneSidedJacobiRotation(XlaOp a, XlaOp p, + XlaOp q, XlaOp eps) { + XlaOp a_pp = DynamicSliceInMinorDims(a, {p, p}, {1, 1}); + XlaOp a_pq = DynamicSliceInMinorDims(a, {p, q}, {1, 1}); + XlaOp a_qp = DynamicSliceInMinorDims(a, {q, p}, {1, 1}); + XlaOp a_qq = DynamicSliceInMinorDims(a, {q, q}, {1, 1}); + + XlaOp one = ScalarLike(a, 1.0); + + XlaOp t = a_pp + a_qq; + XlaOp d = a_qp - a_pq; + + XlaOp u = Div(t, d); + XlaOp tmp = Rsqrt(one + Square(u)); + + JacobiRotation rot; + + XlaOp zeros = ZerosLike(tmp); + XlaOp ones = zeros + one; + + rot.s = Select(Lt(Abs(d), eps), zeros, -tmp); + rot.c = Select(Lt(Abs(d), eps), ones, Mul(u, tmp)); + + XlaOp a_pp_new = rot.c * a_pp - rot.s * a_qp; + XlaOp a_pq_new = rot.c * a_pq - rot.s * a_qq; + XlaOp a_qq_new = rot.s * a_pq + rot.c * a_qq; + + OneSidedJacobiRotation rots; + TF_ASSIGN_OR_RETURN(rots.rot_r, + MakeJacobi(a_pp_new, a_qq_new, a_pq_new, eps)); + + rots.rot_l.c = rot.c * rots.rot_r.c - rot.s * rots.rot_r.s; + rots.rot_l.s = rot.s * rots.rot_r.c + rot.c * rots.rot_r.s; + + return rots; +} + +// Apply one-sided Jacobi on elements at indices pp, pq, qp, qq. +StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q, + XlaOp eps) { + XlaOp u = svd_result.u; + XlaOp v = svd_result.v; + XlaOp d = svd_result.d; + XlaBuilder* builder = d.builder(); + TF_ASSIGN_OR_RETURN(Shape d_shape, builder->GetShape(d)); + const int64 num_dims = d_shape.rank(); + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(d_shape, i); + } + const int64 m = ShapeUtil::GetDimension(d_shape, -2); + const int64 n = ShapeUtil::GetDimension(d_shape, -1); + + TF_ASSIGN_OR_RETURN(OneSidedJacobiRotation onesided_jacobi, + GetOneSidedJacobiRotation(d, p, q, eps)); + + auto zero = ScalarLike(p, 0); + + // Zero out a_{pq} explicitly. + std::vector pq_dims(batch_dims.begin(), batch_dims.end()); + pq_dims.push_back(1); + pq_dims.push_back(1); + auto pq_zero = ScalarLike(d, 0.0); + auto pq_zeros = Broadcast(pq_zero, pq_dims); + + std::vector broadcast_dims(batch_dims.size()); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims.push_back(num_dims - 1); + + // Apply Jacobi Rotation on the left. + auto slice_p = DynamicSliceInMinorDims(d, {p, zero}, {1, n}); + auto slice_q = DynamicSliceInMinorDims(d, {q, zero}, {1, n}); + auto slice_p_new = + onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q; + auto slice_q_new = + onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q; + d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {p, zero}); + d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {q, zero}); + + // Apply Jacobi Rotation on the right. + slice_p = DynamicSliceInMinorDims(d, {zero, p}, {m, 1}); + slice_q = DynamicSliceInMinorDims(d, {zero, q}, {m, 1}); + slice_p_new = + onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q; + slice_q_new = + onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q; + d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {zero, p}); + d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {zero, q}); + + d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {p, q}); + d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {q, p}); + + // Apply left Jacobi Rotation on U. + slice_p = DynamicSliceInMinorDims(u, {zero, p}, {m, 1}); + slice_q = DynamicSliceInMinorDims(u, {zero, q}, {m, 1}); + slice_p_new = + onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q; + + slice_p_new = Mul( + slice_p_new, + Rsqrt(Reduce(Square(slice_p_new), pq_zero, + CreateScalarAddComputation(d_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + slice_q_new = + onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q; + + slice_q_new = Mul( + slice_q_new, + Rsqrt(Reduce(Square(slice_q_new), pq_zero, + CreateScalarAddComputation(d_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + u = DynamicUpdateSliceInMinorDims(u, slice_p_new, {zero, p}); + u = DynamicUpdateSliceInMinorDims(u, slice_q_new, {zero, q}); + + // Apply right Jacobi Rotation on V. + slice_p = DynamicSliceInMinorDims(v, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(v, {zero, q}, {n, 1}); + slice_p_new = + onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q; + + slice_p_new = Mul( + slice_p_new, + Rsqrt(Reduce(Square(slice_p_new), pq_zero, + CreateScalarAddComputation(d_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + slice_q_new = + onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q; + + slice_q_new = Mul( + slice_q_new, + Rsqrt(Reduce(Square(slice_q_new), pq_zero, + CreateScalarAddComputation(d_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + v = DynamicUpdateSliceInMinorDims(v, slice_p_new, {zero, p}); + v = DynamicUpdateSliceInMinorDims(v, slice_q_new, {zero, q}); + + svd_result.d = d; + svd_result.u = u; + svd_result.v = v; + + return svd_result; +} + +StatusOr ComputeFrobeniusNorms(XlaOp w) { + XlaBuilder* builder = w.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); + const int64 num_dims = shape.rank(); + auto frobenius_norm = + Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2, num_dims - 1})); + auto diag = GetMatrixDiagonal(w); + auto diag_square = + Reduce(Square(diag), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2}); + + FrobeniusNorms frobenius_norms; + + frobenius_norms.off_diagonal_norm = + Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); + frobenius_norms.total_norm = frobenius_norm; + + return frobenius_norms; +} + +// Main boby of One-sided Jacobi Method. +StatusOr> WhileLoopFn( + absl::Span initial_values, // + int matrix_dimension, // + int max_sweep_updates, // + absl::string_view name, // + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto k = values[0]; + auto max_sweeps = ScalarLike(k, max_sweep_updates); + auto sweep_update_cond = Gt(max_sweeps, k); + + auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie(); + auto tol = norms.total_norm * values[4]; + auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), + xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); + + return And(sweep_update_cond, tol_cond); + }; + + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + auto while_cond_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_cond_builder) -> StatusOr { + auto p = values_inner[0]; + return Lt(p, ScalarLike(p, matrix_dimension - 1)); + }; + + auto while_body_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_body_builder) -> StatusOr> { + auto while_cond_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_cond_builder) -> StatusOr { + auto q = values_innermost[1]; + return Lt(q, ScalarLike(q, matrix_dimension)); + }; + auto while_body_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_body_builder) + -> StatusOr> { + auto p = values_innermost[0]; + auto q = values_innermost[1]; + + SVDResult onesided_jacobi_update; + onesided_jacobi_update.u = values_innermost[2]; + onesided_jacobi_update.v = values_innermost[3]; + onesided_jacobi_update.d = values_innermost[4]; + + auto eps = values_innermost[5]; + + TF_ASSIGN_OR_RETURN( + onesided_jacobi_update, + OneSidedJacobiUpdate(onesided_jacobi_update, p, q, eps)); + + std::vector updated_values_innermost; + updated_values_innermost.reserve(values_innermost.size()); + + updated_values_innermost.push_back(p); + updated_values_innermost.push_back(q + ScalarLike(q, 1)); + updated_values_innermost.push_back(onesided_jacobi_update.u); + updated_values_innermost.push_back(onesided_jacobi_update.v); + updated_values_innermost.push_back(onesided_jacobi_update.d); + updated_values_innermost.push_back(eps); + + return updated_values_innermost; + }; + + std::vector values_innermost(6); + auto p = values_inner[0]; + auto q = p + ScalarLike(p, 1); + values_innermost[0] = p; // index p. + values_innermost[1] = q; // index q. + values_innermost[2] = values_inner[1]; // u. + values_innermost[3] = values_inner[2]; // v. + values_innermost[4] = values_inner[3]; // d. + values_innermost[5] = values_inner[4]; // eps. + TF_ASSIGN_OR_RETURN( + values_innermost, + WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost, + values_innermost, absl::StrCat(name, "-Innermost"), + inner_body_builder)); + + std::vector updated_values_inner; + updated_values_inner.reserve(values_inner.size()); + + updated_values_inner.push_back(p + ScalarLike(p, 1)); + updated_values_inner.push_back(values_innermost[2]); + updated_values_inner.push_back(values_innermost[3]); + updated_values_inner.push_back(values_innermost[4]); + updated_values_inner.push_back(values_innermost[5]); + return updated_values_inner; + }; + // Indexes. + XlaOp k = values[0]; + + std::vector values_inner(5); + values_inner[0] = ScalarLike(k, 0); // index p. + values_inner[1] = values[1]; // u. + values_inner[2] = values[2]; // v. + values_inner[3] = values[3]; // d. + values_inner[4] = values[4]; // eps. + TF_ASSIGN_OR_RETURN( + values_inner, + WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner, + absl::StrCat(name, "-Inner"), body_builder)); + + std::vector updated_values; + updated_values.reserve(values_inner.size()); + + updated_values.push_back(k + ScalarLike(k, 1)); + updated_values.push_back(values_inner[1]); + updated_values.push_back(values_inner[2]); + updated_values.push_back(values_inner[3]); + updated_values.push_back(values_inner[4]); + + return updated_values; + }; + std::vector values; + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + initial_values, name, builder)); + + return values; +} + +// Sort singular values in decending order, and make sure they are non-negative +// by flipping the signs of negative diagonal values and transferring the signs +// to V. And for numeric stability, renormalize U and V. +StatusOr SortBySingularValuesAndPostProcessing(SVDResult result) { + XlaBuilder* builder = result.d.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d)); + const int64 num_dims = shape.rank(); + auto dimensions = shape.dimensions(); + const int64 m = ShapeUtil::GetDimension(shape, -2); + const int64 n = ShapeUtil::GetDimension(shape, -1); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + + auto d = GetMatrixDiagonal(result.d); + + auto zeros = ZerosLike(d); + auto one = ScalarLike(d, 1.0); + + // Make all the singular values to be non-negative by transferring the signs + // to V. + auto sign = Select(Ge(d, zeros), zeros + one, zeros - one); + d = Select(Ge(d, zeros), d, -d); + result.v = Mul(result.v, sign, broadcast_dims); + + d = BroadcastInDim(d, dimensions, broadcast_dims); + + // As m >= n, only first m columns vectors are needed to be permuted, and the + // rest of m - n vectors are appended after the sorting is done. + XlaOp sort_u_result = + Sort({-d, SliceInMinorDims(result.u, {0, 0}, {m, n})}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + + // TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed. + XlaOp sort_v_result = + Sort({SliceInMinorDims(-d, {0, 0}, {n, n}), result.v}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + // Make sure all the signular values are non-negative. + result.d = Max(-GetMatrixDiagonal(GetTupleElement(sort_v_result, 0)), + ScalarLike(d, 0.0)); + + result.v = GetTupleElement(sort_v_result, 1); + result.v = Mul( + result.v, + Rsqrt(Reduce(Square(result.v), ScalarLike(d, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + // Append the rest of m - n vectors. + result.u = ConcatInDim(builder, + {GetTupleElement(sort_u_result, 1), + SliceInMinorDims(result.u, {0, n}, {m, m})}, + num_dims - 1); + result.u = Mul( + result.u, + Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + return result; +} + +} // namespace + +// def jacobi_svd(A): +// U, D, V = house_bidiag(A) +// m, n = D.shape +// iter, max_iter = 0, 100 +// frobenius_norm = np.linalg.norm(D) +// diag_norm = np.linalg.norm(np.diag(D)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) +// while off_diag_norm > 1e-6 * frobenius_norm and iter < max_iter: +// iter += 1 +// for p in range(m - 1): +// for q in range(p + 1, n): +// rot_l, rot_r = jacobi_rot(D[p][p], D[p][q], D[q][p], D[q][q]) +// D[[p, q], :] = np.matmul(rot_l.T, D[[p, q], :]) +// D[:, [p, q]] = np.matmul(D[:, [p, q]], rot_r) +// U[:, [p, q]] = np.matmul(U[:, [p, q]], rot_l) +// V[:, [p, q]] = np.matmul(V[:, [p, q]], rot_r) +// frobenius_norm = np.linalg.norm(D) +// diag_norm = np.linalg.norm(np.diag(D)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) +// +// return U, np.diag(D), V +// +SVDResult SVD(XlaOp a, int64 max_iter, float epsilon, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + auto return_error = [&](const Status& status) { + SVDResult result; + result.u = builder->ReportError(status); + result.v = builder->ReportError(status); + result.d = builder->ReportError(status); + return result; + }; + auto shape_with_status = builder->GetShape(a); + if (!shape_with_status.status().ok()) { + return return_error(shape_with_status.status()); + } + Shape a_shape = shape_with_status.ValueOrDie(); + const int64 num_dims = a_shape.rank(); + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + int64 m = ShapeUtil::GetDimension(a_shape, -2); + int64 n = ShapeUtil::GetDimension(a_shape, -1); + bool maybe_transpose = m < n; + + if (maybe_transpose) { + a = TransposeInMinorDims(a); + std::swap(m, n); + } + + auto eps = ScalarLike(a, epsilon); + + SVDResult svd_result = + HouseHolderBidiagonalization(a, eps, precision).ValueOrDie(); + + auto output_with_status = WhileLoopFn( + { + Zero(builder, S32), // k + svd_result.u, // u + svd_result.v, // v + svd_result.d, // d + eps, // epsilon + }, // + n, // + max_iter, // + "CyclicOneSidedJacobi", // + builder); + if (!output_with_status.status().ok()) { + return return_error(output_with_status.status()); + } + + auto output = output_with_status.ValueOrDie(); + + svd_result.u = output[1]; + svd_result.v = output[2]; + svd_result.d = output[3]; + svd_result = SortBySingularValuesAndPostProcessing(svd_result).ValueOrDie(); + if (maybe_transpose) { + std::swap(svd_result.u, svd_result.v); + } + return svd_result; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/svd.h b/tensorflow/compiler/xla/client/lib/svd.h new file mode 100644 index 0000000000000000000000000000000000000000..1a29539d9fc635d085bd7ccc48e693fe72307213 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/svd.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SVD_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SVD_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The singular value decomposition of a given matrix A[..., M, N], the original +// matrix is recovered by u * diag(d) * v_t, where the first dims(A) - 2 +// dimensions are batch dimensions. +struct SVDResult { + // The columns of U are the left-singular vectors, e.g., + // U[..., :, :]_T * U[..., :, :] = I. + XlaOp u; + // Vector(s) with the singular values, within each vector sorted in descending + // order. The first dims(D) - 1 dimensions have the same size as the batch + // dimensions of A. And U[..., :, i] * D[..., i] = A[..., :, :] * V[..., :, + // i]. + XlaOp d; + // The columns of V are the right-singular vectors. e.g., + // V[..., :, :]_T * U[..., :, :] = I. + XlaOp v; +}; + +// TODO(kuny): Add a bool flag that supports SVD with economy (reduced) +// representation, which is more memory efficient, especially in the case of +// tall-skinny matrices. +SVDResult SVD(XlaOp a, int64 max_iter = 100, float epsilon = 1e-6, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SVD_H_ diff --git a/tensorflow/compiler/xla/client/lib/svd_test.cc b/tensorflow/compiler/xla/client/lib/svd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3c6ae93d8107b83b2884b851fdb44b3fe77a3c4 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/svd_test.cc @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/svd.h" +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +class SVDTest : public ClientLibraryTestBase { + protected: + void SetUp() override { + ClientLibraryTestBase::SetUp(); + batch_3d_4x5_ = Array3D{ + { + {4, 6, 8, 10, 1}, + {6, 45, 54, 63, 1}, + {8, 54, 146, 166, 1}, + {10, 63, 166, 310, 1}, + }, + { + {16, 24, 8, 12, 6}, + {24, 61, 82, 48, 5}, + {8, 82, 100, 6, 4}, + {12, 48, 6, 62, 3}, + }, + }; + } + void TearDown() override { ClientLibraryTestBase::TearDown(); } + + Array3D GetUnitMatrix3D(int32 batch_dim, int32 mat_dim) { + Array3D result(batch_dim, mat_dim, mat_dim, 0.0); + for (int i = 0; i < batch_dim; ++i) { + for (int j = 0; j < mat_dim; ++j) { + result({i, j, j}) = 1.0; + } + } + return result; + } + + XlaOp ComputeMatmulUDVT(SVDResult result, XlaBuilder* builder) { + Shape u_shape = builder->GetShape(result.u).ValueOrDie(); + Shape v_shape = builder->GetShape(result.v).ValueOrDie(); + + int64 m = ShapeUtil::GetDimension(u_shape, -1); + int64 n = ShapeUtil::GetDimension(v_shape, -1); + + auto v = result.v; + auto u = result.u; + auto d = result.d; + + auto zero = Zero(builder, S32); + if (m > n) { + u = DynamicSliceInMinorDims(u, {zero, zero}, {m, n}); + } else if (m < n) { + v = DynamicSliceInMinorDims(v, {zero, zero}, {n, m}); + } + + int num_dims = u_shape.rank(); + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + return BatchDot(Mul(u, d, broadcast_dims), TransposeInMinorDims(v), + PrecisionConfig::HIGHEST); + } + + Array3D ExtractTriangularMatrix(const Array3D& matrix, + bool lower) { + Array3D result(matrix); + for (int i = 0; i < result.n1(); ++i) { + for (int j = 0; j < result.n2(); ++j) { + if (lower) { + for (int k = j + 1; k < result.n3(); ++k) { + result({i, j, k}) = 0.0; + } + } else { + for (int k = 0; k < j; ++k) { + result({i, j, k}) = 0.0; + } + } + } + } + return result; + } + + XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { + Shape shape = builder->GetShape(m1).ValueOrDie(); + int64 size = 1; + for (auto d : shape.dimensions()) { + size *= d; + } + return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), + CreateScalarAddComputation(F32, builder)) / + ConstantR0WithType(builder, F32, size); + } + + Array2D GenerateRandomMatrix(int xsize, int ysize) { + Array2D result{xsize, ysize, 0.0}; + result.FillRandom(10 /* stddev */, 2 /* mean */); + return result; + } + + Array3D batch_3d_4x5_; +}; + +XLA_TEST_F(SVDTest, Simple2D) { + XlaBuilder builder(TestName()); + + Array2D simple_2d_4x4_ = Array2D{ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }; + XlaOp a; + auto a_data = CreateR2Parameter(simple_2d_4x4_, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + ComputeMatmulUDVT(result, &builder); + + ComputeAndCompareR2(&builder, simple_2d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x5_, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-8); + ComputeMatmulUDVT(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x5_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Test_Orthogonality_U) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x5_, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-8); + ComputeMatmulUDVT(result, &builder); + BatchDot(result.u, TransposeInMinorDims(result.u)); + + ComputeAndCompareR3(&builder, GetUnitMatrix3D(2, 4), {a_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(SVDTest, Test_Orthogonality_V) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x5_, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-8); + BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, GetUnitMatrix3D(2, 5), {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) { + XlaBuilder builder(TestName()); + + auto singular_values = Array2D{ + {431.05153007, 49.88334164, 20.94464584, 3.24845468}, + {179.73128591, 68.05162245, 21.77679503, 13.94319712}, + }; + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x5_, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-8); + Add(result.d, ZerosLike(result.d)); + + ComputeAndCompareR2(&builder, singular_values, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) { + XlaBuilder builder(TestName()); + Array2D a_val = GenerateRandomMatrix(512, 128); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) { + XlaBuilder builder(TestName()); + Array2D a_val = GenerateRandomMatrix(128, 256); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) { + XlaBuilder builder(TestName()); + Array2D a_val = GenerateRandomMatrix(256, 128); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x512) { + XlaBuilder builder(TestName()); + Array2D a_val = GenerateRandomMatrix(128, 512); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) { + XlaBuilder builder(TestName()); + Array2D a_val = GenerateRandomMatrix(512, 256); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) { + XlaBuilder builder(TestName()); + Array2D a_val = GenerateRandomMatrix(512, 512); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SVD(a, 100, 1e-6); + GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h deleted file mode 100644 index 50a3b30ebd1c15eb6d2ace4e351cb41f21db7093..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Solves systems of linear equations with lower or upper triangular coefficient -// matrices by forward- or back-substitution. Broadcasting along leading -// dimensions, this routine solves one of the matrix systems -// `op(a) * x = b`, or `x * op(a) = b`, -// for the variable `x` given `a` and `b`, where `op(a)` is either -// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. -// That is, the innermost matrices in the output satisfy a scalar system -// depending on the value of the value of (left_side, transpose_a, conjugate_a) -// according to: -// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`, -// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`, -// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`, -// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`, -// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`, -// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`, -// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`, -// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`, -// where * denotes complex conjugation and where the index `k` is summed over. -// -// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form -// square matrices. If lower is true (false), then the strictly upper (lower) -// triangular part of each innermost matrix in `a` is assumed to be zero and is -// not accessed. -// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a -// tensor of shape `[..., K, M]`. -// `left_side` is a boolean, indicating whether to solve a system of the form -// op(a) * x = b (true) or x * op(a) = b (false). -// `lower` is a boolean, indicating whether the argument `a` is lower-triangular -// (true) or upper-triangular (false). -// `transpose_a` is a boolean indicating whether the matrix `a` is transposed. -// `conjugate_a` is a boolean indicating whether the entries of `a` are complex -// conjugated (independently of whether they are transposed), so that when both -// transpose_a and conjugate_a are true the effect is a Hermitian adjoint. -// -// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no -// blocking is used. -XlaOp TriangularSolve( - XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 128, - PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 48b5f94538f453785194bc434a91ee0a10c020c2..ae7d3d9f9920385476aecb2ce7c2a02b76adc93f 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/stream_pool.h" @@ -185,7 +186,7 @@ StatusOr LocalExecutable::ExecuteAndDump( executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); - TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); + DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot()); return std::move(result); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ddb36680e8b185b053368baffa6f1d5cac50dc07..2dd8c130e26e75a7c2690d6848d4b081234eea37 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -67,10 +67,10 @@ class LocalExecutable { const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to - // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + // invoke it, and the result. Enabled by flag: --xla_dump_hlo_snapshots. // - // The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS - // environment variable. + // The given ServiceExecutableRunOptions override any values from the + // XLA_FLAGS environment variable. StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const absl::Span arguments); @@ -114,7 +114,7 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // The given ExecutableBuildOptions overrides any values from XLA_FLAGS // environment variable. StatusOr> Compile( const XlaComputation& computation, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 20298d175dc83abc4fe16212300eb587385ae583..2f574366694a61ee18876596df23eaacec1b1129 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -267,8 +267,8 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, for (int64 index : target_param_index) { param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index); } - param_shape_ptr->set_dynamic_dimension(target_dim_num, - /*is_dynamic=*/true); + // TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA + // backend can handle dynamic dimensions. *instr.mutable_shape() = param_shape.ToProto(); } } @@ -299,46 +299,51 @@ XlaComputation XlaBuilder::BuildAndNoteError() { return build_status.ConsumeValueOrDie(); } -StatusOr XlaBuilder::Build() { +Status XlaBuilder::GetCurrentStatus() const { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } - return Build(instructions_.back().id()); + return Status::OK(); +} + +StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { + TF_RETURN_IF_ERROR(GetCurrentStatus()); + return Build(instructions_.back().id(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(XlaOp root) { +StatusOr XlaBuilder::Build(XlaOp root, + bool remove_dynamic_dimensions) { if (root.builder_ != this) { return InvalidArgument("Given root operation is not in this computation."); } - return Build(root.handle()); + return Build(root.handle(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(int64 root_id) { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } +StatusOr XlaBuilder::Build(int64 root_id, + bool remove_dynamic_dimensions) { + TF_RETURN_IF_ERROR(GetCurrentStatus()); // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove // all dynamic dimensions before building xla program until we have support in // the backend. - std::function remove_dynamic_dimension = - [&](ShapeProto* shape) { - if (shape->tuple_shapes_size() != 0) { - for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) { - remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); + if (remove_dynamic_dimensions) { + std::function remove_dynamic_dimension = + [&](ShapeProto* shape) { + if (shape->tuple_shapes_size() != 0) { + for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) { + remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); + } } - } - for (int64 i = 0; i < shape->dimensions_size(); ++i) { - shape->set_is_dynamic_dimension(i, false); - } - }; + for (int64 i = 0; i < shape->dimensions_size(); ++i) { + shape->set_is_dynamic_dimension(i, false); + } + }; - for (auto& instruction : instructions_) { - remove_dynamic_dimension(instruction.mutable_shape()); + for (auto& instruction : instructions_) { + remove_dynamic_dimension(instruction.mutable_shape()); + } } HloComputationProto entry; @@ -475,7 +480,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { } XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::optional direction) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -484,6 +490,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); + if (binop == HloOpcode::kCompare) { + if (!direction.has_value()) { + return InvalidArgument( + "kCompare expects a ComparisonDirection, but none provided."); + } + instr.set_comparison_direction(ComparisonDirectionToString(*direction)); + } else if (direction.has_value()) { + return InvalidArgument( + "A comparison direction is provided for a non-compare opcode: %s.", + HloOpcodeString(binop)); + } const int64 lhs_rank = lhs_shape.rank(); const int64 rhs_rank = rhs_shape.rank(); @@ -497,16 +514,19 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : shape.dimensions()) { - to_size.push_back(size); + std::vector to_size_is_dynamic; + for (int i = 0; i < shape.rank(); i++) { + to_size.push_back(shape.dimensions(i)); + to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i)); } for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) { int64 to_dim = broadcast_dimensions[from_dim]; to_size[to_dim] = from_shape.dimensions(from_dim); + to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim); } - const Shape& broadcasted_shape = - ShapeUtil::MakeShape(from_shape.element_type(), to_size); + const Shape& broadcasted_shape = ShapeUtil::MakeShape( + from_shape.element_type(), to_size, to_size_is_dynamic); TF_ASSIGN_OR_RETURN( XlaOp broadcasted_operand, InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); @@ -566,16 +586,6 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, }); } -XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -665,8 +675,17 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - const auto& output_shape = + auto output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + for (int i = 0; i < broadcast_dimensions.size(); i++) { + if (broadcast_dimensions[i] < 0 || + broadcast_dimensions[i] > out_dim_size.size()) { + return InvalidArgument("Broadcast dimension %lld is out of bound", + broadcast_dimensions[i]); + } + output_shape.set_dynamic_dimension(broadcast_dimensions[i], + operand_shape.is_dynamic_dimension(i)); + } TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( operand_shape, output_shape, broadcast_dimensions) @@ -988,36 +1007,6 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { }); } -XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1038,6 +1027,18 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + // If one operand is a scalar, just multiply the two operands. + if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { + if (dimension_numbers.rhs_batch_dimensions_size() != 0 || + dimension_numbers.lhs_batch_dimensions_size() != 0 || + dimension_numbers.rhs_contracting_dimensions_size() != 0 || + dimension_numbers.lhs_contracting_dimensions_size() != 0) { + return InvalidArgument( + "Dots with scalar operands must have no contracting or batch " + "dimensions"); + } + return xla::Mul(lhs, rhs); + } TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); @@ -1533,147 +1534,6 @@ XlaOp XlaBuilder::CustomCall( }); } -XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); -} - -XlaOp XlaBuilder::Conj(const XlaOp& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Not(const XlaOp& operand) { - return UnaryOp(HloOpcode::kNot, operand); -} - -XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, - broadcast_dimensions); -} - -XlaOp XlaBuilder::ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, - broadcast_dimensions); -} - -XlaOp XlaBuilder::Abs(const XlaOp& operand) { - return UnaryOp(HloOpcode::kAbs, operand); -} - -XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); -} - -XlaOp XlaBuilder::Exp(const XlaOp& operand) { - return UnaryOp(HloOpcode::kExp, operand); -} - -XlaOp XlaBuilder::Expm1(const XlaOp& operand) { - return UnaryOp(HloOpcode::kExpm1, operand); -} - -XlaOp XlaBuilder::Floor(const XlaOp& operand) { - return UnaryOp(HloOpcode::kFloor, operand); -} - -XlaOp XlaBuilder::Ceil(const XlaOp& operand) { - return UnaryOp(HloOpcode::kCeil, operand); -} - -XlaOp XlaBuilder::Round(const XlaOp& operand) { - return UnaryOp(HloOpcode::kRoundNearestAfz, operand); -} - -XlaOp XlaBuilder::Log(const XlaOp& operand) { - return UnaryOp(HloOpcode::kLog, operand); -} - -XlaOp XlaBuilder::Log1p(const XlaOp& operand) { - return UnaryOp(HloOpcode::kLog1p, operand); -} - -XlaOp XlaBuilder::Sign(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSign, operand); -} - -XlaOp XlaBuilder::Clz(const XlaOp& operand) { - return UnaryOp(HloOpcode::kClz, operand); -} - -XlaOp XlaBuilder::Cos(const XlaOp& operand) { - return UnaryOp(HloOpcode::kCos, operand); -} - -XlaOp XlaBuilder::Sin(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSin, operand); -} - -XlaOp XlaBuilder::Tanh(const XlaOp& operand) { - return UnaryOp(HloOpcode::kTanh, operand); -} - -XlaOp XlaBuilder::Real(const XlaOp& operand) { - return UnaryOp(HloOpcode::kReal, operand); -} - -XlaOp XlaBuilder::Imag(const XlaOp& operand) { - return UnaryOp(HloOpcode::kImag, operand); -} - -XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { - return UnaryOp(HloOpcode::kIsFinite, operand); -} - XlaOp XlaBuilder::Transpose(const XlaOp& operand, absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1704,36 +1564,146 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } +namespace { +// Switch from a floating point value to a integer value in such a way that when +// using the integer value to compare, we get the same result for normal values, +// and -Nan is treated as the smallest value, and Nan is treated as the largest +// value. +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? numeric_limits::max() - x : x; +// then y is ordered as an int32 such that finite values have the obvious order, +// -0 is ordered before 0, and -NaN and NaN appear at the beginning and end of +// the ordering. +// Note that in order to avoid -x to overflow, we calculate +// numeric_limits::max() - x as unsigned, and then convert back to +// signed. +XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, + int64 bit_width) { + PrimitiveType signed_type; + PrimitiveType unsigned_type; + XlaOp max_value; + switch (bit_width) { + case 16: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S16; + unsigned_type = U16; + break; + case 32: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S32; + unsigned_type = U32; + break; + case 64: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S64; + unsigned_type = U64; + break; + default: + return value.builder()->ReportError( + InvalidArgument("Invalid bit width %lld for Comparator floating " + "point parameter.", + bit_width)); + } + auto signed_value = BitcastConvertType(value, signed_type); + auto unsigned_value = BitcastConvertType(value, unsigned_type); + auto flipped_value = + BitcastConvertType(Sub(max_value, unsigned_value), signed_type); + auto is_negative = + Lt(signed_value, + ConstantLiteral(value.builder(), LiteralUtil::Zero(signed_type))); + return Select(is_negative, flipped_value, signed_value); +} +} // namespace + XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + std::vector operands{keys}; + for (const XlaOp& value : values) { + operands.push_back(value); + } + // Build the default less-than comparator (copied from lib/comparators.cc). + // TODO(b/122298745): Remove the deprecated API method so that this code + // duplication can be deleted. + auto b = this->CreateSubBuilder("comparator"); + std::vector operand_types; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + + int64 parameter_count = 0; + XlaOp first_lhs_param; + XlaOp first_rhs_param; + + for (auto operand_type : operand_types) { + auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); + auto lhs_param = + b->Parameter(parameter_count * 2, scalar_shape, + absl::StrCat("p.", parameter_count, ".lhs")); + auto rhs_param = + b->Parameter(parameter_count * 2 + 1, scalar_shape, + absl::StrCat("p.", parameter_count, ".rhs")); + if (parameter_count == 0) { + first_lhs_param = lhs_param; + first_rhs_param = rhs_param; + } + ++parameter_count; + } + if (primitive_util::IsFloatingPointType(operand_types[0])) { + PrimitiveType compare_type = operand_types[0]; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + first_lhs_param = b->ConvertElementType(first_lhs_param, F32); + first_rhs_param = b->ConvertElementType(first_rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + first_lhs_param = + BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); + first_rhs_param = + BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); + } + Lt(first_lhs_param, first_rhs_param); + + TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); + return Sort(operands, comparator, dimension, /*is_stable=*/false); + }); +} + +XlaOp XlaBuilder::Sort(absl::Span operands, + const XlaComputation& comparator, int64 dimension, + bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_is_stable(is_stable); std::vector operand_shape_ptrs; - TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - operand_shape_ptrs.push_back(&keys_shape); - TF_ASSIGN_OR_RETURN(std::vector values_shapes, - GetOperandShapes(values)); - absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), + TF_ASSIGN_OR_RETURN(std::vector operand_shapes, + GetOperandShapes(operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { - TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(operands[0])); dimension = keys_shape.rank() - 1; } instr.add_dimensions(dimension); - std::vector operands{keys}; - operands.insert(operands.end(), values.begin(), values.end()); + AddCalledComputation(comparator, &instr); return AddInstruction(std::move(instr), HloOpcode::kSort, operands); }); } -XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1759,10 +1729,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::Neg(const XlaOp& operand) { - return UnaryOp(HloOpcode::kNegate, operand); -} - XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return TernaryOp(HloOpcode::kClamp, min, operand, max); @@ -1926,32 +1892,46 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { + // The index of true_computation must be 0 and that of false computation + // must be 1. + return Conditional(predicate, {&true_computation, &false_computation}, + {true_operand, false_operand}); +} + +XlaOp XlaBuilder::Conditional( + const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); - TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape, - GetShape(true_operand)); - TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape, - true_computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape, - GetShape(false_operand)); - TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, - false_computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferConditionalShape( - predicate_shape, true_operand_shape, false_operand_shape, - true_computation_shape, false_computation_shape)); + TF_ASSIGN_OR_RETURN(const Shape& branch_index_shape, + GetShape(branch_index)); + std::vector branch_operand_shapes(branch_operands.size()); + std::vector branch_computation_shapes( + branch_computations.size()); + for (int j = 0; j < branch_operands.size(); ++j) { + TF_ASSIGN_OR_RETURN(branch_operand_shapes[j], + GetShape(branch_operands[j])); + TF_ASSIGN_OR_RETURN(branch_computation_shapes[j], + branch_computations[j]->GetProgramShape()); + } + TF_ASSIGN_OR_RETURN(const Shape shape, + ShapeInference::InferConditionalShape( + branch_index_shape, branch_computation_shapes, + branch_operand_shapes)); *instr.mutable_shape() = shape.ToProto(); - // The index of true_computation must be 0 and that of false computation - // must be 1. - AddCalledComputation(true_computation, &instr); - AddCalledComputation(false_computation, &instr); + for (const XlaComputation* branch_computation : branch_computations) { + AddCalledComputation(*branch_computation, &instr); + } + std::vector operands(1, branch_index); + for (const XlaOp branch_operand : branch_operands) { + operands.emplace_back(branch_operand); + } return AddInstruction(std::move(instr), HloOpcode::kConditional, - {predicate, true_operand, false_operand}); + absl::MakeSpan(operands)); }); } @@ -2143,8 +2123,8 @@ XlaOp XlaBuilder::CrossReplicaSum( TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); auto b = CreateSubBuilder("sum"); - b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), - b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); return CrossReplicaSum(operand, computation, replica_groups, /*channel_id=*/absl::nullopt); @@ -2253,6 +2233,14 @@ XlaOp XlaBuilder::CollectivePermute( }); } +XlaOp XlaBuilder::ReplicaId() { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {}); + }); +} + XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -2932,32 +2920,39 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt); } -XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe); } -XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); +} + +XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction) { + return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs, + broadcast_dimensions, direction); } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, @@ -3031,6 +3026,44 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type, return operand.builder()->Fft(operand, fft_type, fft_length); } +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(const Shape& b_shape, builder->GetShape(b)); + xla::TriangularSolveOptions& options = + *instr.mutable_triangular_solve_options(); + options.set_left_side(left_side); + options.set_lower(lower); + options.set_unit_diagonal(unit_diagonal); + options.set_transpose_a(transpose_a); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape( + a_shape, b_shape, options)); + *instr.mutable_shape() = shape.ToProto(); + + return builder->AddInstruction(std::move(instr), + HloOpcode::kTriangularSolve, {a, b}); + }); +} + +XlaOp Cholesky(XlaOp a, bool lower) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a)); + xla::CholeskyOptions& options = *instr.mutable_cholesky_options(); + options.set_lower(lower); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferCholeskyShape(a_shape)); + *instr.mutable_shape() = shape.ToProto(); + + return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a}); + }); +} + XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { return builder->Infeed(shape, config); } @@ -3060,78 +3093,96 @@ XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, operand_shapes_with_layout); } -XlaOp Complex(const XlaOp& real, const XlaOp& imag, +XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return real.builder()->Complex(real, imag, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs, + broadcast_dimensions); } -XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } +XlaOp Conj(const XlaOp& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs, + broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs, + broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs, + broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs, + broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs, + broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs, + broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs, + broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->And(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs, + broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs, + broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs, + broadcast_dimensions); } -XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } +XlaOp Not(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kNot, operand); +} XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, + broadcast_dimensions); } XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); } XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, @@ -3203,6 +3254,8 @@ XlaOp CollectivePermute( return operand.builder()->CollectivePermute(operand, source_target_pairs); } +XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); } + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, Padding padding, @@ -3224,48 +3277,73 @@ XlaOp SelectAndScatterWithGeneralPadding( init_value, scatter); } -XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } +XlaOp Abs(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kAbs, operand); +} -XlaOp Atan2(const XlaOp& y, const XlaOp& x, +XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return y.builder()->Atan2(y, x, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs, + broadcast_dimensions); } -XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); } - -XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); } - -XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); } - -XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); } - -XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); } - -XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); } - -XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); } - -XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); } - -XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); } - -XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); } - -XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); } - -XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); } - -XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } - -XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } +XlaOp Exp(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kExp, operand); +} +XlaOp Expm1(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand); +} +XlaOp Floor(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kFloor, operand); +} +XlaOp Ceil(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kCeil, operand); +} +XlaOp Round(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} +XlaOp Log(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kLog, operand); +} +XlaOp Log1p(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand); +} +XlaOp Sign(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSign, operand); +} +XlaOp Clz(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kClz, operand); +} +XlaOp Cos(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kCos, operand); +} +XlaOp Sin(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSin, operand); +} +XlaOp Tanh(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kTanh, operand); +} +XlaOp Real(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kReal, operand); +} +XlaOp Imag(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kImag, operand); +} +XlaOp Sqrt(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand); +} +XlaOp Rsqrt(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand); +} XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs, + broadcast_dimensions); } XlaOp IsFinite(const XlaOp& operand) { - return operand.builder()->IsFinite(operand); + return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand); } XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { @@ -3276,7 +3354,9 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } +XlaOp Neg(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kNegate, operand); +} XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); @@ -3290,6 +3370,12 @@ XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { return keys.builder()->Sort(keys, values, dimension); } +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension, bool is_stable) { + return operands[0].builder()->Sort(operands, comparator, dimension, + is_stable); +} + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } @@ -3322,6 +3408,13 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, false_computation); } +XlaOp Conditional(const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands) { + return branch_index.builder()->Conditional(branch_index, branch_computations, + branch_operands); +} + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits) { return operand.builder()->ReducePrecision(operand, exponent_bits, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 8908d172fa89632ead48f954de12066af12411c7..80f93a8b6de98e124bf9fed3969ffcec7c4a95c4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" @@ -56,6 +57,9 @@ class XlaOp { } ~XlaOp() = default; + XlaOp(const XlaOp& other) = default; + XlaOp& operator=(const XlaOp& other) = default; + // Precondition: !IsUninitialized(). // // It's very common to do foo.builder()->bar(). Without this precondition, if @@ -197,11 +201,19 @@ class XlaBuilder { // status. Note that all ops that have been enqueued will be moved to the // computation being returned. The root of the computation will be the last // added operation. - StatusOr Build(); + // + // `remove_dynamic_dimensions` tells the builder whether to remove the + // dyanmic dimensions information in all ops. + // + // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the + // dynamic dimensions information when XLA backend can handle dynamic + // dimensions. + StatusOr Build(bool remove_dynamic_dimensions = true); // Overload of Build which specifies a particular root instruction for the // computation. - StatusOr Build(XlaOp root); + StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = true); // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. @@ -227,6 +239,10 @@ class XlaBuilder { // See also set_die_immediately_on_error(). Status first_error() const { return first_error_; } + // Returns the current status of the builder, complete with the stack trace + // information. + Status GetCurrentStatus() const; + // Returns the shape of the given op. StatusOr GetShape(const XlaOp& op) const; @@ -269,6 +285,10 @@ class XlaBuilder { // and its real dynamic size is represented by `dynamic_param_index` in // parameter `dynamic_param_num`. // + // Note that this should be called before the dynamic parameters are used to + // create other operations, otherwise created operations won't have the + // dynamic dimensions information. + // // TODO(b/119520625): Remove this API once we have more dynamic shape infra // ready. Status SetDynamicBinding(int64 dynamic_size_param_num, @@ -284,16 +304,19 @@ class XlaBuilder { input_output_aliases_.push_back({output_index, param_number, param_index}); } - private: // Describes an input/output alias as inserted by the SetUpAlias() API. struct InputOutputAlias { + // Specifies the index of the aliased buffer in the result tuple. ShapeIndex output_index; + // Specifies the parameter containing the buffer to be aliased. int64 param_number; + // Specifies the index of the aliased buffer in the parameter ShapeIndex param_index; }; + private: // Build helper which takes the id of the root operation.. - StatusOr Build(int64 root_id); + StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); // Description for the methods below can be found in the corresponding public // functions section in this file. @@ -303,38 +326,6 @@ class XlaBuilder { XlaOp ConstantLiteral(const LiteralSlice& literal); - template - XlaOp ConstantR0(NativeT value); - template - XlaOp ConstantR1(absl::Span values); - XlaOp ConstantR1(const tensorflow::core::Bitmap& values); - template - XlaOp ConstantR2( - std::initializer_list> values); - template - XlaOp ConstantFromArrayWithLayout(const Array& values, - const Layout& layout); - template - XlaOp ConstantFromArray(const Array& values); - template - XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout); - template - XlaOp ConstantR2FromArray2D(const Array2D& values); - template - XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout); - template - XlaOp ConstantR3FromArray3D(const Array3D& values); - template - XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout); - template - XlaOp ConstantR4FromArray4D(const Array4D& values); - - template - XlaOp ConstantR1(int64 length, NativeT value); - XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); @@ -382,24 +373,6 @@ class XlaBuilder { XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); @@ -464,50 +437,6 @@ class XlaBuilder { const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); - XlaOp Complex(const XlaOp& real, const XlaOp& imag, - absl::Span broadcast_dimensions = {}); - - XlaOp Conj(const XlaOp& operand); - - XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Not(const XlaOp& operand); - - XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); @@ -550,6 +479,8 @@ class XlaBuilder { const XlaOp& operand, const std::vector>& source_target_pairs); + XlaOp ReplicaId(); + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -564,44 +495,6 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - XlaOp Abs(const XlaOp& operand); - - XlaOp Atan2(const XlaOp& y, const XlaOp& x, - absl::Span broadcast_dimensions = {}); - - XlaOp Exp(const XlaOp& operand); - - XlaOp Expm1(const XlaOp& operand); - - XlaOp Floor(const XlaOp& operand); - - XlaOp Ceil(const XlaOp& operand); - - XlaOp Round(const XlaOp& operand); - - XlaOp Log(const XlaOp& operand); - - XlaOp Log1p(const XlaOp& operand); - - XlaOp Sign(const XlaOp& operand); - - XlaOp Clz(const XlaOp& operand); - - XlaOp Cos(const XlaOp& operand); - - XlaOp Sin(const XlaOp& operand); - - XlaOp Tanh(const XlaOp& operand); - - XlaOp Real(const XlaOp& operand); - - XlaOp Imag(const XlaOp& operand); - - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp IsFinite(const XlaOp& operand); - XlaOp Iota(const Shape& shape, int64 iota_dimension); XlaOp Iota(PrimitiveType type, int64 size); @@ -612,14 +505,15 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - XlaOp Neg(const XlaOp& operand); - XlaOp Transpose(const XlaOp& operand, absl::Span permutation); XlaOp Rev(const XlaOp& operand, absl::Span dimensions); + ABSL_DEPRECATED("Use form with comparator computation instead") XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); + XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension = -1, bool is_stable = false); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -639,6 +533,10 @@ class XlaBuilder { const XlaOp& false_operand, const XlaComputation& false_computation); + XlaOp Conditional(const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); @@ -699,9 +597,11 @@ class XlaBuilder { // Internal helper method that does the building for an arbitrary binary op. // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. + // when the operation is between tensors of different ranks. The direction is + // only used if opcode is kCompare. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::optional direction = absl::nullopt); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, @@ -811,48 +711,6 @@ class XlaBuilder { const Shape& shape, const string& name); friend XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); - template - friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); - template - friend XlaOp ConstantR1(XlaBuilder* builder, - absl::Span values); - friend XlaOp ConstantR1(XlaBuilder* builder, - const tensorflow::core::Bitmap& values); - template - friend XlaOp ConstantR2( - XlaBuilder* builder, - std::initializer_list> values); - template - friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout); - template - friend XlaOp ConstantFromArray(XlaBuilder* builder, - const Array& values); - template - friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout); - template - friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values); - template - friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout); - template - friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values); - template - friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout); - template - friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values); - - template - friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); @@ -912,6 +770,9 @@ class XlaBuilder { absl::Span broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); + friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, @@ -950,6 +811,10 @@ class XlaBuilder { const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); + friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + friend XlaOp Cholesky(XlaOp a, bool lower); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, @@ -1028,6 +893,7 @@ class XlaBuilder { friend XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); + friend XlaOp ReplicaId(XlaBuilder* builder); friend XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -1058,6 +924,8 @@ class XlaBuilder { friend XlaOp Tanh(const XlaOp& operand); friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); + friend XlaOp Sqrt(const XlaOp& operand); + friend XlaOp Rsqrt(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); @@ -1074,6 +942,9 @@ class XlaBuilder { friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); friend XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension); + friend XlaOp Sort(absl::Span operands, + const XlaComputation& comparator, int64 dimension, + bool is_stable); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1088,6 +959,10 @@ class XlaBuilder { const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation); + friend XlaOp Conditional( + const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands); friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, @@ -1410,6 +1285,11 @@ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Enqueues a comparison instruction onto the computation. +XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); + // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); @@ -1469,6 +1349,45 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves for x in one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// +// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If `lower` is true (false), then the strictly upper +// (lower) triangular part of each innermost matrix in `a` is assumed to be +// zero and is not accessed. +// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a +// tensor of shape `[..., K, M]`. +// * `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// * `lower` is a boolean, indicating whether the argument `a` is +// lower-triangular (true) or upper-triangular (false). +// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be +// 1 and not accessed. +// * `transpose_a` indicates which function `op` we use to transform the tensor +// `a`: the identity function, transpose(a), or conjugate(transpose(a)) +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + +// Computes the Cholesky decompositions of a batch of symmetric (Hermitian) +// positive definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// If `lower` is true, the data from the lower triangle is used; if false, the +// upper triangle is used. The input data in the other triangle of the input +// does not affect the output. Returns the output in the same lower/uppper +// triangle. The data returned in the other output triangle is arbitrary and +// implementation-defined. +// +// The value returned if `a` is not Hermitian positive definite is +// implementation-defined. +XlaOp Cholesky(XlaOp a, bool lower); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1568,9 +1487,33 @@ XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Overload to call And with 3 or more operands. We need the following somewhat +// convoluted overload set to disambiguate with the overload that takes the +// `broadcast_dimensions` optional param. +inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { + return And(op1, And(op2, op3)); +} +template +XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, + const XlaOpTs&... operands) { + return And(op1, And(op2, And(op3, operands...))); +} + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Overload to call Or with 3 or more operands. As with `And`, we need the +// following complicated overload set to handle the default arg in the `Or` +// overload above. +inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { + return Or(op1, Or(op2, op3)); +} +template +XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, + const XlaOpTs&... operands) { + return Or(op1, Or(op2, Or(op3, operands...))); +} + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -1663,6 +1606,9 @@ XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); +// Enqueues an operation that returns the replica ID. +XlaOp ReplicaId(XlaBuilder* builder); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1730,14 +1676,24 @@ XlaOp Real(const XlaOp& operand); // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); +// Enqueues a sqrt computation onto the computation. +XlaOp Sqrt(const XlaOp& operand); + +// Enqueues a rsqrt computation onto the computation. +XlaOp Rsqrt(const XlaOp& operand); + // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); -// Enqueues an operator that tests if the operand's values are finite, i.e., -// not Inf or NaN. Defined only for floating-point types. Returns an array of -// booleans with the same shape where entries are true iff the corresponding -// entry was NaN. +// Enqueues an operator that tests if the operand's values are finite, i.e., not +// +/-Inf or NaN. Returns an array of booleans with the same shape where +// entries are true iff the corresponding entry was not infinite or NaN. +// +// Defined only for real-valued (i.e. not complex) floating-point types; raises +// an error for other types. +// +// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. XlaOp IsFinite(const XlaOp& operand); // Enqueues an iota operation onto the computation. @@ -1773,7 +1729,7 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // of keys, in ascending order. // * If the keys have higher rank, the keys are sorted along the provided // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension -// value of 0 will indepenently sort every column, and a dimension value of 1 +// value of 0 will independently sort every column, and a dimension value of 1 // will independently sort each row. If no dimension number is provided, then // the last dimension is chosen by default. // @@ -1783,9 +1739,39 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and tensors with their // corresponding values as the other elements. +ABSL_DEPRECATED("Use form with comparator computation instead") XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); +// Enqueues a sort instruction onto the computation, using 'comparator' for +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. +// If only one operand is provided: +// * If the operand is a rank-1 tensor (an array), the result is a sorted array. +// The resulting sorting order has the property that for all index positions +// i, j with i < j, either +// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or +// comparator(value[i], value[j]) = true. +// * If the operand has higher rank, the operand is sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix), a dimension value +// of 0 will independently sort every column, and a dimension value of 1 will +// independently sort each row. If no dimension number is provided, then the +// last dimension is chosen by default. For the dimension which is sorted, the +// same sorting order applies as in the rank-1 case. +// +// If more than one operand is provided: +// * All operands must be tensors with the same dimensions. The element types of +// the tensors may be different. +// * The result is a tuple that consists of the operands in sorted order (along +// the provided dimension, as above). The same permutation as implied by the +// comparison computation is applied to all operand tensors. When comparing +// two index positions, 'comparator' is called with 2 * n scalar parameters, +// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at +// two index positions. +// Default comparator computations can be found in lib/comparators.h +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension = -1, bool is_stable = false); + // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -1812,6 +1798,15 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaOp& false_operand, const XlaComputation& false_computation); +// Enqueues either a predicated (if/else) or indexed (switch/case/default) +// conditional node onto the computation. N >= 1 branch_computations and +// branch_operands are matched by index. branch_index selects the branch that +// will be executed. Out of range branch_index uses the N-1'th +// branch_computation as default. +XlaOp Conditional(const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + // Enqueues a ReducePrecision node onto the computation. XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); @@ -1924,81 +1919,6 @@ XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); // Implementation details below this point. // -template -XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(LiteralUtil::CreateR0(value)); -} - -template -XlaOp XlaBuilder::ConstantR1(absl::Span values) { - return ConstantLiteral(LiteralUtil::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(LiteralUtil::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(LiteralUtil::CreateR2(values)); -} - -template -XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, - const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(LiteralUtil::CreateFromArray(values)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { - return ConstantFromArray(values); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { - return ConstantFromArray(values); -} - // Free function template implementations. template diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index feee8187c7db846b04bf763c28476f0d71f3dd30..12656a89943d6cef213a714df5e29289e7b1ada3 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -39,7 +40,8 @@ using ::testing::HasSubstr; class XlaBuilderTest : public ::testing::Test { protected: StatusOr> BuildHloModule(XlaBuilder* b) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(/*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -50,7 +52,8 @@ class XlaBuilderTest : public ::testing::Test { // Overload which explicitly specifies the root instruction. StatusOr> BuildHloModule(XlaBuilder* b, XlaOp root) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(root, /*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -132,6 +135,38 @@ TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { op::ShiftRightLogical(op::Constant(), op::Constant())); } +TEST_F(XlaBuilderTest, VariadicAnd) { + XlaBuilder b(TestName()); + Shape s = ShapeUtil::MakeShape(PRED, {}); + And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), + Parameter(&b, 2, s, "p2")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + // Don't specify in the test whether And(x, y, z) is right- or + // left-associative; accept either one. + EXPECT_THAT( + module->entry_computation()->root_instruction(), + ::testing::AnyOf(op::And(op::Parameter(0), + op::And(op::Parameter(1), op::Parameter(2))), + op::And(op::And(op::Parameter(0), op::Parameter(1)), + op::Parameter(2)))); +} + +TEST_F(XlaBuilderTest, VariadicOr) { + XlaBuilder b(TestName()); + Shape s = ShapeUtil::MakeShape(PRED, {}); + Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), + Parameter(&b, 2, s, "p2")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + // Don't specify in the test whether Or(x, y, z) is right- or + // left-associative; accept either one. + EXPECT_THAT( + module->entry_computation()->root_instruction(), + ::testing::AnyOf( + op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))), + op::Or(op::Or(op::Parameter(0), op::Parameter(1)), + op::Parameter(2)))); +} + TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { XlaBuilder b(TestName()); ConstantR0(&b, 1) >> ConstantR0(&b, 2); @@ -447,10 +482,9 @@ TEST_F(XlaBuilderTest, ProtoMatches) { } TEST_F(XlaBuilderTest, DynamicParameter) { - std::vector computations; - XlaBuilder b("builder"); + XlaBuilder b(TestName()); Shape tuple_param_shape = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1, @@ -463,9 +497,450 @@ TEST_F(XlaBuilderTest, DynamicParameter) { ->parameter_instruction(0) ->shape() .tuple_shapes(1); - // TODO(b/121223198): The dynamic dimension should be set once we enable - // dynamic dimensions in xla builder. - EXPECT_FALSE(param_shape.is_dynamic_dimension(0)); + EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicUnary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Neg(gte); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}, {true}), + ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), + ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4}, + /*broadcast_dimensions=*/{1, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {10}, {true}), + ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15] + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {10}, {true}), + ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + + Select(gte0, gte1, gte1); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicPad) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pad_val = ConstantR0(&b, -1); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + PaddingConfig padding_config; + for (int i = 0; i < 2; i++) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + } + Pad(gte, pad_val, padding_config); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicConvolution) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}), + ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto input = GetTupleElement(p0, 0); + auto filter = GetTupleElement(p0, 1); + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {true, false, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicDot) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}), + ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + + auto lhs = GetTupleElement(p0, 0); + auto rhs = GetTupleElement(p0, 1); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dnums); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduce) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + Reduce(gte, init, sum, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduceWindow) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), + ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + XlaBuilder bge(TestName()); + Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build()); + + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto source = GetTupleElement(p0, 1); + SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source, + init, sum); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReshape) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}, + {false, false, true, true, false}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/2)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/3)); + auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6] + Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_TRUE(result_shape.is_dynamic_dimension(3)); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {false, true, false, true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelect) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), + ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/1)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Select(pred, gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_FALSE(result_shape.is_dynamic_dimension(2)); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), + ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6] + auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] + Select(pred, gte0, gte1); + Status status = BuildHloModule(&b).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Operands to select must be the same shape; " + "got f32[4,<=5,6] and f32[4,5,<=6]")); +} + +TEST_F(XlaBuilderTest, DynamicTranspose) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Transpose(gte, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true})) + << result_shape; } TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..de34ad678e799dad8f8404ac92e5f4830d85368a --- /dev/null +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/comparison_util.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +std::string ComparisonDirectionToString(ComparisonDirection direction) { + switch (direction) { + case ComparisonDirection::kEq: + return "EQ"; + case ComparisonDirection::kNe: + return "NE"; + case ComparisonDirection::kGe: + return "GE"; + case ComparisonDirection::kGt: + return "GT"; + case ComparisonDirection::kLe: + return "LE"; + case ComparisonDirection::kLt: + return "LT"; + } +} + +StatusOr StringToComparisonDirection( + absl::string_view direction_name) { + static auto* direction_map = + new absl::flat_hash_map({ + {"EQ", ComparisonDirection::kEq}, + {"NE", ComparisonDirection::kNe}, + {"GE", ComparisonDirection::kGe}, + {"GT", ComparisonDirection::kGt}, + {"LE", ComparisonDirection::kLe}, + {"LT", ComparisonDirection::kLt}, + }); + auto it = direction_map->find(direction_name); + if (it == direction_map->end()) { + return InvalidArgument("Unknown comparison direction: %s", direction_name); + } + return it->second; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8b150c3cfadfa4ad22459c7f1f7514c73f393ded --- /dev/null +++ b/tensorflow/compiler/xla/comparison_util.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ + +#include "absl/base/macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Represents different comparison operations. +enum class ComparisonDirection : uint8 { + kEq, + kNe, + kGe, + kGt, + kLe, + kLt, +}; + +string ComparisonDirectionToString(ComparisonDirection direction); + +StatusOr StringToComparisonDirection( + absl::string_view direction_name); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index a9a91648ac377987e7f226116e11c9c697ace103..2437bf04b0f931c577c5d122df8fa99b2d49c08c 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -33,7 +33,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_eliminate_hlo_implicit_broadcast(true); - opts.set_xla_hlo_dump_as_html(false); + opts.set_xla_dump_hlo_as_html(false); #ifdef INTEL_MKL opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL @@ -84,6 +84,14 @@ static void AllocateFlags() { }; }; + auto string_setter_for = + [](void (DebugOptions::*member_setter)(const string& value)) { + return [member_setter](const string& value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector disabled_passes = @@ -114,56 +122,26 @@ static void AllocateFlags() { }; flag_objects = new std::vector({ - tensorflow::Flag( - "xla_generate_hlo_graph", - flag_values->mutable_xla_generate_hlo_graph(), - "HLO modules matching this regex will be dumped to a .dot file " - "throughout various stages in compilation."), - tensorflow::Flag( - "xla_hlo_graph_addresses", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), - "With xla_generate_hlo_graph, show addresses of HLO ops in " - "graph dump."), - tensorflow::Flag( - "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), - "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag( - "xla_hlo_dump_as_graphdef", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), - flag_values->xla_hlo_dump_as_graphdef(), - "Dump HLO graphs as TensorFlow GraphDefs."), - tensorflow::Flag("xla_hlo_dump_as_html", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), - flag_values->xla_hlo_dump_as_html(), - "Dump HLO graphs as an HTML (DOT rendered into SVG " - "inlined in HTML)."), - tensorflow::Flag( - "xla_hlo_graph_sharding_color", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), - "Assign colors based on sharding assignments when generating the " - "HLO graphs."), - tensorflow::Flag( - "xla_hlo_tfgraph_device_scopes", - bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), - flag_values->xla_hlo_tfgraph_device_scopes(), - "When generating TensorFlow HLO graphs, if the HLO instructions " - "are assigned to a specific device, prefix the name scope with " - "\"devX\" with X being the device ordinal."), - tensorflow::Flag( - "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO)."), - tensorflow::Flag( - "xla_generate_hlo_text_to", - flag_values->mutable_xla_generate_hlo_text_to(), - "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( "xla_cpu_enable_fast_math", bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), flag_values->xla_cpu_enable_fast_math(), "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_cpu_fast_math_honor_nans", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), + flag_values->xla_cpu_fast_math_honor_nans(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce NaNs. Ignored when " + "xla_cpu_enable_fast_math is false."), + tensorflow::Flag( + "xla_cpu_fast_math_honor_infs", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), + flag_values->xla_cpu_fast_math_honor_infs(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce infinites. Ignored when " + "xla_cpu_enable_fast_math is false."), tensorflow::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), @@ -222,9 +200,6 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), flag_values->xla_embed_ir_in_executable(), "Embed the compiler IR as a string in the executable."), - tensorflow::Flag( - "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), - "Dump the compiler IR into this directory as individual files."), tensorflow::Flag( "xla_eliminate_hlo_implicit_broadcast", bool_setter_for( @@ -259,20 +234,6 @@ static void AllocateFlags() { int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), flag_values->xla_gpu_max_kernel_unroll_factor(), "Specify the maximum kernel unroll factor for the GPU backend."), - tensorflow::Flag( - "xla_dump_optimized_hlo_proto_to", - flag_values->mutable_xla_dump_optimized_hlo_proto_to(), - "Dump Hlo after all hlo passes are executed as proto binary into " - "this directory."), - tensorflow::Flag( - "xla_dump_unoptimized_hlo_proto_to", - flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(), - "Dump HLO before any hlo passes are executed as proto binary into " - "this directory."), - tensorflow::Flag("xla_dump_per_pass_hlo_proto_to", - flag_values->mutable_xla_dump_per_pass_hlo_proto_to(), - "Dump HLO after each pass as an HloProto in binary file " - "format into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), @@ -295,14 +256,6 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_profile), flag_values->xla_hlo_profile(), "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag("xla_dump_computations_to", - flag_values->mutable_xla_dump_computations_to(), - "Dump computations that XLA executes into the provided " - "directory path"), - tensorflow::Flag("xla_dump_executions_to", - flag_values->mutable_xla_dump_executions_to(), - "Dump parameters and results of computations that XLA " - "executes into the provided directory path"), tensorflow::Flag("xla_backend_extra_options", setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; " @@ -355,6 +308,79 @@ static void AllocateFlags() { &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), flag_values->xla_gpu_disable_ptxas_optimizations(), "In XLA:GPU run ptxas in -O0 (default is -O3)."), + + tensorflow::Flag( + "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), + flag_values->xla_dump_to(), + "Directory into which debugging data is written. If not specified " + "but another dumping flag is passed, data will be written to stdout. " + " To explicitly write to stdout, set this to \"-\". The values " + "\"sponge\" and \"test_undeclared_outputs_dir\" have a special " + "meaning: They cause us to dump into the directory specified by the " + "environment variable TEST_UNDECLARED_OUTPUTS_DIR."), + tensorflow::Flag( + "xla_dump_hlo_as_text", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), + flag_values->xla_dump_hlo_as_text(), + "Dumps HLO modules as text before and after optimizations. Results " + "are written to the --xla_dump_to dir, or, if no dir is specified, " + "to stdout."), + tensorflow::Flag( + "xla_dump_hlo_as_proto", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), + flag_values->xla_dump_hlo_as_proto(), + "Dumps HLO modules as HloProtos to the directory specified by " + "--xla_dump_to."), + tensorflow::Flag( + "xla_dump_hlo_as_dot", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), + flag_values->xla_dump_hlo_as_dot(), + "Dumps HLO modules rendered as dot files to the directory " + "specified by --xla_dump_to."), + tensorflow::Flag("xla_dump_hlo_as_html", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html), + flag_values->xla_dump_hlo_as_html(), + "Dumps HLO modules rendered as HTML files to the " + "directory specified by --xla_dump_to."), + tensorflow::Flag( + "xla_dump_hlo_as_url", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), + flag_values->xla_dump_hlo_as_url(), + "Tries to dump HLO modules rendered as URLs to stdout (and also to " + "the directory specified by --xla_dump_to). This is not implemented " + "by default; you need to add a plugin which calls " + "RegisterGraphToURLRenderer()."), + tensorflow::Flag( + "xla_dump_hlo_snapshots", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), + flag_values->xla_dump_hlo_snapshots(), + "Every time an HLO module is run, dumps an HloSnapshot to the " + "directory specified by --xla_dump_to."), + tensorflow::Flag( + "xla_dump_hlo_module_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), + flag_values->xla_dump_hlo_module_re(), + "Limits dumping only to modules which match this regular expression. " + " Default is to dump all modules."), + tensorflow::Flag( + "xla_dump_hlo_pass_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), + flag_values->xla_dump_hlo_pass_re(), + "If specified, dumps HLO before and after optimization passes which " + "match this regular expression, in addition to dumping at the very " + "beginning and end of compilation."), + tensorflow::Flag( + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " + "the address in memory of each HloInstruction object."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h index a1463aa15941b9c265db94e2eb3cc176fab6695b..4359f3b7deb8e585494cb2a9c7115eac6a312c8e 100644 --- a/tensorflow/compiler/xla/error_spec.h +++ b/tensorflow/compiler/xla/error_spec.h @@ -30,6 +30,19 @@ struct ErrorSpec { // In effect, this allows the tested operation to produce incorrect results // for inputs outside its mathematical domain. bool relaxed_nans; + + // If this is true, then we treat each +/-inf in the actual result as + // equivalent to our choice of either +/-inf or the min/max floating-point + // value. + // + // If the expected result is +/-inf, the actual result must still be +/-inf. + // + // In effect, this allows the tested operation to overflow, so long as it's + // overflowing on "large" values. + // + // (We could have a symmetric more_infs_ok flag if necessary; right now it + // appears not to be.) + bool fewer_infs_ok = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f..230f3b202a4b531c381665471c3856c3feba5a3a 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { } ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( - DeviceAssignment* device_assignment) { + const DeviceAssignment* device_assignment) { device_assignment_ = device_assignment; return *this; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index ba3217f31b55bd1428f67da6154a46c8bc304053..1e744953bd3be58afba5b81c0e2a8ba26665f9c4 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Pulls in the ::stream_executor -> ::xla::se namespace alias. -#include "tensorflow/compiler/xla/types.h" - // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -28,12 +25,6 @@ class Stream; class Platform; } // namespace stream_executor -namespace tensorflow { -namespace thread { -class ThreadPool; -} // namespace thread -} // namespace tensorflow - namespace Eigen { struct ThreadPoolDevice; } // namespace Eigen @@ -83,7 +74,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_device_assignment( - DeviceAssignment* device_assignment); + const DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; ExecutableRunOptions& set_rng_seed(int rng_seed); @@ -92,7 +83,7 @@ class ExecutableRunOptions { private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; - DeviceAssignment* device_assignment_ = nullptr; + const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 267701e9c0e42a21d2cda6238520f6a9692e7e76..d756cd74c98b98a6fda099690d966562bd694e2c 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -25,6 +25,8 @@ upper_tabs: path: /xla/operation_semantics - title: Shapes and layout path: /xla/shapes + - title: Tiled layout + path: /xla/tiled_layout - title: Using AOT compilation path: /xla/tfcompile - heading: Tutorials diff --git a/tensorflow/compiler/xla/g3doc/_project.yaml b/tensorflow/compiler/xla/g3doc/_project.yaml index 33d8bdb27a664d9e282d1d65c007ebf5838b196a..1cacee703dca30f9c4af6a4964839bb9fa4b0140 100644 --- a/tensorflow/compiler/xla/g3doc/_project.yaml +++ b/tensorflow/compiler/xla/g3doc/_project.yaml @@ -8,3 +8,4 @@ use_site_branding: true hide_from_products_list: true content_license: cc3-apache2 buganizer_id: 171704 +include: /_project_included.yaml diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index 85fa16ccc7f48a3dce840564e79097c9e136767f..d7ce5ee1ba6abbe60507f99a82ef2839a9f8a213 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -144,7 +144,8 @@ Execute the python script to train the model with XLA and turn on a debugging feature of XLA via an environmental variable that outputs the XLA graph. ```shell -XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py +XLA_FLAGS="--xla_hlo_profile --xla_dump_to=/tmp/foo --xla_dump_hlo_as_text" +python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline @@ -153,28 +154,10 @@ should look similar to the picture below with one long bar labeled `XlaLaunch`. -To understand what is happening in `XlaLaunch`, look at the console output for -statements similar to the following: +To understand what is happening in `XlaLaunch`, look at the console output. Each +XLA cluster that's launched will have a corresponding profile (from +`--xla_hlo_profile`) showing how long each HLO took to run. -```shell -computation cluster_0[_XlaCompiledKernel=true,_XlaNumConstantArgs=1].v82 [CPU: -pipeline start, before inline]: /tmp/hlo_graph_0.dot - -``` - -The console statements point to the location of `hlo_graph_xx.dot` files that -contain information about the graph created by XLA. The process that XLA takes -to fuse Ops is visible by starting at `hlo_graph_0.dot` and viewing each diagram -in succession. - -To Render the .dot file into a png, install -[GraphViz](https://www.graphviz.org/download/) and run: - -```shell -dot -Tpng hlo_graph_80.dot -o hlo_graph_80.png -``` - -The result will look like the following: -
- -
+`/tmp/foo` will contain the HLO before and after optimizations for each HLO +module that's run. You can read this as-is, or you can visualize it using +`tensorflow/compiler/xla/tools:interactive_graphviz`. diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index c5f9377f98868cdf6d5c711cf80ede5d41fd8305..7d718c5301018980522e0d09835da0c6c40239c5 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -322,6 +322,37 @@ Invokes a computation with the given arguments. The arity and types of the `args` must match the parameters of the `computation`. It is allowed to have no `args`. +## Cholesky + +See also +[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Computes the +[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition) +of a batch of symmetric (Hermitian) positive definite matrices. + + `Cholesky(a, lower)` + +Arguments | Type | Semantics +--------- | ------- | ----------------------------------------------------- +`a` | `XlaOp` | a rank > 2 array of a complex or floating-point type. +`lower` | `bool` | whether to use the upper or lower triangle of `a`. + +If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l +. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such +that $$ a = u^T . u $$. + +Input data is read only from the lower/upper triangle of `a`, depending on the +value of `lower`. Values from the other triangle are ignored. Output data is +returned in the same triangle; the values in the other triangle are +implementation-defined and may be anything. + +If the rank of `a` is greater than 2, `a` is treated as a batch of matrices, +where all except the minor 2 dimensions are batch dimensions. + +If `a` is not symmetric (Hermitian) positive definite, the result is +implementation-defined. + ## Clamp See also @@ -510,25 +541,49 @@ See also false_computation)` Arguments | Type | Semantics -------------------- | ---------------- | --------------------------------- +------------------- | ---------------- | -------------------------------------- `pred` | `XlaOp` | Scalar of type `PRED` -`true_operand` | `XlaOp` | Argument of type `T_0` -`true_computation` | `XlaComputation` | XlaComputation of type `T_0 -> S` -`false_operand` | `XlaOp` | Argument of type `T_1` -`false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S` +`true_operand` | `XlaOp` | Argument of type $$ T_0 $$ +`true_computation` | `XlaComputation` | XlaComputation of type $$ T_0 \to S$$ +`false_operand` | `XlaOp` | Argument of type $$ T_1 $$ +`false_computation` | `XlaComputation` | XlaComputation of type $$ T_1 \to S $$ Executes `true_computation` if `pred` is `true`, `false_computation` if `pred` is `false`, and returns the result. -The `true_computation` must take in a single argument of type `T_0` and will be -invoked with `true_operand` which must be of the same type. The -`false_computation` must take in a single argument of type `T_1` and will be +The `true_computation` must take in a single argument of type $$ T_0 $$ and will +be invoked with `true_operand` which must be of the same type. The +`false_computation` must take in a single argument of type $$ T_1 $$ and will be invoked with `false_operand` which must be of the same type. The type of the returned value of `true_computation` and `false_computation` must be the same. Note that only one of `true_computation` and `false_computation` will be executed depending on the value of `pred`. + `Conditional(branch_index, branch_computations, branch_operands)` + +| Arguments | Type | Semantics | +| --------------------- | --------------------- | ---------------------------- | +| `branch_index` | `XlaOp` | Scalar of type `PRED` or | +: : : `S32` : +| `branch_computations` | sequence of N | XlaComputations of type $$ | +: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., : +: : : T_{N-1} \to S $$ : +| `branch_operands` | sequence of N `XlaOp` | Arguments of type $$ T_0 , | +: : : T_1 , ..., T_{N-1} $$ : + +Executes `branch_computations[branch_index]`, and returns the result. If +`branch_index` is a `PRED`, then the `true` branch is in position 0 and the +`false` branch is in position 1. If `branch_index` is an `S32` which is < 0 +or >= N, then `branch_computations[N-1]` is executed as the default branch. + +Each `branch_computations[b]` must take in a single argument of type `T_b` and +will be invoked with `branch_operands[b]` which must be of the same type. The +type of the returned value of each `branch_computations[b]` must be the same. + +Note that only one of the `branch_computations` will be executed depending on +the value of `branch_index`. + ## Conv (convolution) See also @@ -1186,7 +1241,7 @@ if and only if the corresponding input element is finite. `Sign(operand)` Element-wise sign operation `x -> sgn(x)` where -$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$ +$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$ using the comparison operator of the element type of `operand`. @@ -1608,15 +1663,18 @@ Applies a reduction function to one or more arrays in parallel. `Reduce(operands..., init_values..., computation, dimensions)` -Arguments | Type | Semantics -------------- | --------------------- | --------------------------------------- -`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. -`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. -`computation` | `XlaComputation` | computation of type - : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)` -`dimensions` | `int64` array | unordered array of dimensions to reduce +| Arguments | Type | Semantics | +| ------------- | --------------------- | ------------------------------------ | +| `operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. | +| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. | +| `computation` | `XlaComputation` | computation of type `T_0, ..., T_N, | +: : : T_0, ..., T_N ->` `Collate(T_0, ..., : +: : : T_N)` : +| `dimensions` | `int64` array | unordered array of dimensions to | +: : : reduce : Where: + * N is required to be greater or equal to 1. * All input arrays must have the same dimensions. * If `N = 1`, `Collate(T)` is `T`. @@ -1626,10 +1684,10 @@ The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type `T_i`, the dimensions of which are described below. This operation reduces one or more dimensions of each input array into scalars. -The rank of each returned array is `rank(operand) - len(dimensions)`. -`init_value` is the initial value used for every reduction and may be inserted +The rank of each returned array is `rank(operand) - len(dimensions)`. The +initial value used for every reduction is `init_value`, and it may be inserted anywhere during computation by the back-end. In most cases, `init_value` is an -identity of the reduction function (for example, 0 for addition). The applied +identity of the reduction function (for example, `0` for addition). The applied `computation` is always passed the `init_value` on the left-hand side. The evaluation order of the reduction function is arbitrary and may be @@ -1640,10 +1698,10 @@ Some reduction functions like addition are not strictly associative for floats. However, if the range of the data is limited, floating-point addition is close enough to being associative for most practical uses. It is possible to conceive of some completely non-associative reductions, however, and these will produce -incorrect or unpredictable results in XLA reductions. +incorrect or unpredictable results in XLA. As an example, when reducing across one dimension in a single 1D array with -values [10, 11, 12, 13], with reduction function `f` (this is `computation`) +values `[10, 11, 12, 13]`, with reduction function `f` (this is `computation`) then that could be computed as `f(10, f(11, f(12, f(init_value, 13)))` @@ -1722,16 +1780,27 @@ preserved in the output, but some dimensions may get assigned new numbers (since the rank changes). We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces -the 1D array `| 20 28 36 |`. +the 1D array `[20, 28, 36]`. Reducing the 3D array over all its dimensions produces the scalar `84`. +### Variadic Reduce + When `N > 1`, reduce function application is slightly more complex, as it is -applied simultaneously to all inputs. For example, consider the following -reduction function, which can be used to compute the max and the argmax of a a -1-D array in parallel: +applied simultaneously to all inputs. The operands are supplied to the +computation in the following order: -``` +* Running reduced value for the first operand +* ... +* Running reduced value for the N'th operand +* Input value for the first operand +* ... +* Input value for the N'th operand + +For example, consider the following reduction function, which can be used to +compute the max and the argmax of a 1-D array in parallel: + +```python f: (Float, Int, Float, Int) -> Float, Int f(max, argmax, value, index): if value >= argmax: @@ -1743,6 +1812,7 @@ f(max, argmax, value, index): For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values `I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only input dimension is equivalent to the following recursive application: + ``` f_0 = f(I_V, I_K, V_0, K_0) f_1 = f(f_0.first, f_0.second, V_1, K_1) @@ -1873,6 +1943,20 @@ non-deterministic. Therefore, the reduction function should not be overly sensitive to reassociation. See the discussion about associativity in the context of [`Reduce`](#reduce) for more details. +## ReplicaId + +See also +[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the unique ID (U32 scalar) of the replica. + + `ReplicaId()` + +The unique ID of each replica is an unsigned integer in the interval `[0, N)`, +where `N` is the number of replicas. Since all the replicas are running the same +program, a `ReplicaId()` call in the program will return a different value on +each replica. + ## Reshape See also @@ -2425,6 +2509,46 @@ Permutes the operand dimensions with the given permutation, so This is the same as Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)). +## TriangularSolve + +See also +[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Solves systems of linear equations with lower or upper triangular coefficient +matrices by forward- or back-substitution. Broadcasting along leading +dimensions, this routine solves one of the matrix systems `op(a) * x = +b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is +either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`. + + `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` + +| Arguments | Type | Semantics | +| --------------- | ----------- | -------------------------------------------- | +| `a` | `XlaOp` | a rank > 2 array of a complex or | +: : : floating-point type with shape `[..., M, : +: : : M]`. : +| `b` | `XlaOp` | a rank > 2 array of the same type with shape | +: : : `[..., M, K]` if `left_side` is true, `[..., : +: : : K, M]` otherwise. : +| `left_side` | `bool` | indicates whether to solve a system of the | +: : : form `op(a) * x = b` (`true`) or `x * : +: : : op(a) = b` (`false`). : +| `lower` | `bool` | whether to use the upper or lower triangle | +: : : of `a`. : +| `unit_diagonal` | `bool` | if `true`, the diagonal elements of `a` are | +: : : assumed to be `1` and not accessed. : +| `transpose_a` | `Transpose` | whether to use `a` as is, transpose it or | +: : : take its conjugate transpose. : + +Input data is read only from the lower/upper triangle of `a`, depending on the +value of `lower`. Values from the other triangle are ignored. Output data is +returned in the same triangle; the values in the other triangle are +implementation-defined and may be anything. + +If the rank of `a` and `b` are greater than 2, they are treated as batches of +matrices, where all except the minor 2 dimensions are batch dimensions. `a` and +`b` must have equal batch dimensions. + ## Tuple See also diff --git a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md b/tensorflow/compiler/xla/g3doc/tiled_layout.md similarity index 96% rename from tensorflow/compiler/xla/g3doc/layout_with_tiling.md rename to tensorflow/compiler/xla/g3doc/tiled_layout.md index 5e990851af7495ebd4417e44f1d955fcc14dadf1..21e88ceab6208cdf940826d769fd93713044d5a0 100644 --- a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md +++ b/tensorflow/compiler/xla/g3doc/tiled_layout.md @@ -1,9 +1,7 @@ # Tiled layout -*Note: This doc describes how tiled layout is intended to work. Tiling is being -implemented, but this is an early effort and it is currently not even guaranteed -to get an Unimplemented error if one tries to use tiling - it may be just -silently ignored.* +Caution: Tiled layout is *pre-release* and this describes how it's intended to +work. Errors may be silently ignored.
![](images/xla_array_layout_figure1.png) diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 7e22a32e545e4155545ffcfb9582187eadec3a82..eebd8245abe759b71b3fe732943761325ea04b81 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index e3b5fcd5274881cec31ecf906e3461685f82a1f4..000c4fdc40519214fa9fa721a8987b77b534442b 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -30,7 +30,19 @@ TileProto Tile::ToProto() const { } string Tile::ToString() const { - return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); + std::vector elements; + for (auto dim : dimensions()) { + if (dim >= 0) { + elements.push_back(std::to_string(dim)); + } else { + if (dim == kCombineDimension) { + elements.push_back("*"); + } else { + elements.push_back(absl::StrCat("Invalid value ", dim)); + } + } + } + return absl::StrCat("(", absl::StrJoin(elements, ","), ")"); } /* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { @@ -64,23 +76,43 @@ LayoutProto Layout::ToProto() const { } string Layout::ToString() const { - // TODO(b/119839262): Emit tiles in string. if (format() == SPARSE) { + CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled."; return absl::StrCat("sparse{", max_sparse_elements(), "}"); } else if (format() == DENSE) { - return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + string colon_string = tiles().empty() ? "" : "T"; + for (Tile tile : tiles()) { + absl::StrAppend(&colon_string, tile.ToString()); + } + if (element_size_in_bits() != 0) { + absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")"); + } + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), + colon_string.empty() ? "" : ":", colon_string, "}"); } else { CHECK_EQ(format(), INVALID_FORMAT); return "invalid{}"; } } +bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { + if (lhs.format() != rhs.format() || + lhs.minor_to_major() != rhs.minor_to_major() || + lhs.max_sparse_elements() != rhs.max_sparse_elements()) { + return false; + } + if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { + return false; + } + if (!ignore_element_size_ && + lhs.element_size_in_bits() != rhs.element_size_in_bits()) { + return false; + } + return true; +} + bool Layout::operator==(const Layout& other) const { - return (other.format() == format() && - other.minor_to_major() == minor_to_major() && - other.element_size_in_bits() == element_size_in_bits() && - other.max_sparse_elements() == max_sparse_elements() && - other.tiles() == tiles()); + return Equal()(*this, other); } std::ostream& operator<<(std::ostream& out, const Tile& tile) { diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index 313368c39e4c976fc481941eb17325101f2ba69a..acc449b781b503142b24ed7229e3559230bb1599 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -55,6 +55,20 @@ class Tile { // Returns the dimensions of the tile. const std::vector& dimensions() const { return dimensions_; } + Tile& add_dimensions(int64 value) { + dimensions_.push_back(value); + return *this; + } + + Tile& clear_dimensions() { + dimensions_.clear(); + return *this; + } + + // This dimension size means the corresponding dimension in the shape is + // combined with the next minor dimension before tiling is applied. + static constexpr int64 kCombineDimension = std::numeric_limits::min(); + private: // The bounds of the tile. std::vector dimensions_; @@ -71,10 +85,12 @@ class Layout { // Constructs a dense tiled layout with the given minor-to-major order and // tiles. - Layout(absl::Span minor_to_major, absl::Span tiles) + Layout(absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits = 0) : format_(DENSE), minor_to_major_(minor_to_major.begin(), minor_to_major.end()), - tiles_(tiles.begin(), tiles.end()) {} + tiles_(tiles.begin(), tiles.end()), + element_size_in_bits_(element_size_in_bits) {} // Construct a shape from a LayoutProto. static Layout CreateFromProto(const LayoutProto& proto); @@ -85,6 +101,37 @@ class Layout { // Returns a human-readable string that represents this layout. string ToString() const; + // Equal is a configurable functor to check the equality of two layouts. + // + // Examples: + // + // - Comparing two layouts ignoring their difference in tiles: + // Equal().IgnoreTiles()(layout1, layout2); + // + // - Comparing two layouts ignoring their difference in tiles and element + // size: + // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); + class Equal { + public: + Equal() = default; + + bool operator()(const Layout& lhs, const Layout& rhs); + + Equal& IgnoreTiles() { + ignore_tiles_ = true; + return *this; + } + + Equal& IgnoreElementSize() { + ignore_element_size_ = true; + return *this; + } + + private: + bool ignore_tiles_ = false; + bool ignore_element_size_ = false; + }; + bool operator==(const Layout& other) const; bool operator!=(const Layout& other) const { return !(*this == other); } @@ -159,7 +206,7 @@ class Layout { element_size_in_bits_ = 0; } - public: + private: // The format of this layout. Format format_ = INVALID_FORMAT; @@ -172,11 +219,11 @@ class Layout { // memory. This field must be zero unless the format is SPARSE. int64 max_sparse_elements_ = 0; - // The number of bits used to store an individual array element. - int64 element_size_in_bits_ = 0; - // The tiles used in tiling-based layout. std::vector tiles_; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; }; std::ostream& operator<<(std::ostream& out, const Tile& Tile); diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc index fb6abd3f6523b978e72b21ec082ae06973e86243..f5d71c553ed2e0cfd5d5945144dd476557582b5f 100644 --- a/tensorflow/compiler/xla/layout_test.cc +++ b/tensorflow/compiler/xla/layout_test.cc @@ -38,10 +38,13 @@ TEST_F(LayoutTest, ToString) { "sparse{123}"); EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), - "{3,2,1,0}"); + "{3,2,1,0:T(42,123)(4,5)}"); EXPECT_EQ( Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), - "{1,0}"); + "{1,0:T(2,55)E(42)}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({-2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0:T(Invalid value -2,55)E(42)}"); } TEST_F(LayoutTest, StreamOut) { @@ -84,6 +87,15 @@ TEST_F(LayoutTest, Equality) { Layout().set_format(SPARSE).set_max_sparse_elements(42)); EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), Layout().set_format(SPARSE).set_max_sparse_elements(24)); + + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2}))); + EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}))); + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {}, 32), Layout({0, 1, 2}, {}, 1))); + EXPECT_TRUE(Layout::Equal().IgnoreElementSize()(Layout({0, 1, 2}, {}, 32), + Layout({0, 1, 2}, {}, 1))); } TEST_F(LayoutTest, LayoutToFromProto) { diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 2fe9b56c6bdffb931726f60ab75081361b43ebb4..62314118ca9713a04cb4e3cf6ad261b966d85f15 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -54,12 +54,24 @@ void SetDefaultLayoutToContainer(std::vector* minor_to_major) { } // namespace /* static */ Layout LayoutUtil::MakeLayout( - absl::Span minor_to_major) { + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); } + for (Tile tile : tiles) { + for (int64 dim : tile.dimensions()) { + if (dim < 0 && dim != Tile::kCombineDimension) { + LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if " + "it's negative. Value is " + << dim; + } + } + *layout.add_tiles() = tile; + } + layout.set_element_size_in_bits(element_size_in_bits); return layout; } @@ -235,6 +247,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } dimensions_in_layout[dim] = true; } + } else { + if (layout.tiles_size() != 0) { + return InvalidArgument("Only dense layouts can be tiled."); + } } return Status::OK(); diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 609dba67bcdbcb11be0906b7d87a52a17ba0dfbd..9997aef465daa48ee77050e03d97cde0ea2425cc 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,7 +36,9 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(absl::Span minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major, + absl::Span tiles = {}, + int64 element_size_in_bits = 0); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4cc94c270cd64eb19761cc1044861c7d185b7888..12da214063676717aa075e66aa54974f4cc2b31b 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,6 +317,81 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, HumanStringWithTiling) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2}); + Tile* tile; + + // No tiling. + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), "f32[2,3,4]{0,1,2}"); + + // 2D tile. + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(512); + tile->add_dimensions(1024); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "f32[2,3,4]{0,1,2:T(512,1024)}"); + + // 1D tile. + shape.mutable_layout()->clear_tiles(); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(512); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "f32[2,3,4]{0,1,2:T(512)}"); + + // 2 tiles. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 4}, {1, 2, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(16); + tile->add_dimensions(256); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(1); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[2,3,4]{1,2,0:T(16,256)(2,1)}"); + + // PRED with element size of 8 bits. + shape = ShapeUtil::MakeShapeWithLayout(PRED, {8, 8, 8}, {0, 2, 1}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(8); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:T(8,128)}"); + + // PRED with element size of 32 bits. + shape.mutable_layout()->clear_tiles(); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(8); + tile->add_dimensions(128); + shape.mutable_layout()->set_element_size_in_bits(32); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:T(8,128)E(32)}"); + + // No tile. PRED with element size of 32 bits. + shape.mutable_layout()->clear_tiles(); + shape.mutable_layout()->set_element_size_in_bits(32); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:E(32)}"); + + // Tile with negative dimension size for combining dimensions. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 1004}, {2, 1, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[2,3,1004]{2,1,0:T(2,*,128)}"); + + // Tile with two negative dimensions. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {8, 2, 3, 1004}, {3, 2, 1, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[8,2,3,1004]{3,2,1,0:T(2,*,*,128)}"); +} + TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8600e8752cfbe072407391559d210d0b49bea511..5cd738d0f7769ceac7eb3bdbc5abd3196d9cf99c 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -44,7 +44,6 @@ namespace xla { namespace { using absl::StrCat; -using absl::StrFormat; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; @@ -1628,26 +1627,20 @@ bool LiteralBase::IsAllFloat(float value) const { return true; } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue( - piece.data(), static_cast(value)); - default: - return false; - } - }; - if (!piece_is_all()) { - return false; + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; } - return true; }); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 041151fda1280d6ae7b35d5857ca79788d4f7203..c418be895d6c3faa6a85ca2c73c6f42b0a021104 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -963,6 +963,10 @@ void MutableLiteralBase::AppendSparseElement( CHECK(LayoutUtil::IsSparseArray(subshape)); int64 rank = subshape.rank(); CHECK_EQ(multi_index.size(), rank); + for (int64 i = 0; i < rank; ++i) { + CHECK_GE(multi_index[i], 0); + CHECK_LT(multi_index[i], subshape.dimensions(i)); + } int64 last_element = p.sparse_indices()->index_count(); CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); p.sparse_indices()->Append(multi_index); diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 69efa06d39a7f13e10004bec4470bb0937e73afd..9b3de75dd4e9d495778af86fb8fc07909ab4ba81 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -200,32 +200,26 @@ int64 RecursiveElementCount(const Shape& shape) { } } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +// Returns whether the given value is infinity. template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); - } +bool IsInf(NativeT val) { + return std::isinf(val); } template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); +bool IsInf(half val) { + return std::isinf(static_cast(val)); } -// Returns whether the given value is infinity. +// Returns whether the given value is nan. template -bool IsInf(NativeT val) { - return std::isinf(val); +float IsNan(NativeT value) { + return std::isnan(value); } template <> -bool IsInf(half val) { - return std::isinf(static_cast(val)); +float IsNan(half value) { + return IsNan(static_cast(value)); } // Converts the given floating-point value to a string. @@ -376,21 +370,39 @@ class NearComparator { // the given literal_index and keeps track of various mismatch statistics. template void CompareValues(T expected, T actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); + } else if (IsNan(expected) || IsNan(actual)) { + if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) || + (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is + // used for sorting a std::set of the top mismatchs, and a nan value + // here will result in undefined behavior because nan's do not satisfy + // the strict weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = 0; + rel_error = 0; + } + } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) { + // `fewer_infs_ok` gives us the option of comparing as though `actual` + // were float_max/min rather than inf. + T actual_finite = actual > T{0} ? std::numeric_limits::max() + : std::numeric_limits::lowest(); + abs_error = FpAbsoluteValue(actual_finite - expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. @@ -410,8 +422,7 @@ class NearComparator { } const bool is_abs_mismatch = abs_error > error_.abs; const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + const bool is_mismatch = is_abs_mismatch && is_rel_mismatch; // Update the error of the relative bucket only if the *absolute* error // bound is exceeded and vice versa. @@ -725,7 +736,7 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback, const ShapeIndex& shape_index) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); @@ -766,30 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { + bool use_detailed_message = detailed_message.value_or( + ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { case BF16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F32: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C128: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " @@ -880,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback) { VLOG(1) << "Expected literal:"; XLA_VLOG_LINES(1, expected.ToString()); diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 9e5bf7c1d062ef0f25d07a80d6ded8106df5dacc..23fff3fa348f1652eaec344da4c40ccf3ad1079a 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -55,9 +55,10 @@ using MiscompareCallback = // being compared. // // If detailed_message is true, then the error message in the assertion result -// will contain a more detailed breakdown of mismatches. +// will contain a more detailed breakdown of mismatches. By default, we display +// a detailed message only for "large" inputs. Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback); // Calling ToString on a literal with over 100 million elements takes around diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index ad1699a1ae65180d56617b069d8b2e1d7d81c38c..bad65ac32018fafcc7634b989f1b4b0867aa5c0d 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/metric_table_report.h" -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" @@ -249,7 +249,7 @@ string MetricTableReport::MetricString(double metric) { string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. - while (!sp1.empty() && !isdigit(sp1[0])) { + while (!sp1.empty() && !absl::ascii_isdigit(sp1[0])) { output.push_back(sp1[0]); sp1.remove_prefix(1); } diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 5b568888d14f21c1330556d017eafba6c8dd2228..e1e22f784172b5f3850f0bc510322dfad9e7f1bb 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { -static const char kWS[] = " \t\r\n"; // whitespace +static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed // from data gleaned from the environment. @@ -104,7 +105,8 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { // Set e to the index just past the end of the flag. size_t e = b; while (e != flag_str.size() && isascii(flag_str[e]) && - (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { + (strchr("-_", flag_str[e]) != nullptr || + absl::ascii_isalnum(flag_str[e]))) { e++; } if (e != flag_str.size() && flag_str[e] == '=' && @@ -184,6 +186,14 @@ bool ParseFlagsFromEnvAndDieIfUnknown( tensorflow::mutex_lock lock(env_argv_mu); auto* env_argv = &EnvArgvs()[string(envvar)]; SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized + + if (VLOG_IS_ON(1)) { + VLOG(1) << "For env var " << envvar << " found arguments:"; + for (int i = 0; i < env_argv->argc; i++) { + VLOG(1) << " argv[" << i << "] = " << env_argv->argv[i]; + } + } + bool result = tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 3386d2e09758192a32d981f94533b534b7399766..1eedddf72c1d393cb1b88e589881e24de02ad802 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -97,6 +97,21 @@ int BitWidth(PrimitiveType type) { } } +xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { + switch (src_bitwidth) { + case 8: + return xla::U8; + case 16: + return xla::U16; + case 32: + return xla::U32; + case 64: + return xla::U64; + default: + return xla::PRIMITIVE_TYPE_INVALID; + } +} + PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index d32505335daa429c459b948d20f387713ac2a1d7..295d353003276b4c1731f7d6a378fd1ae0288d3c 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -151,6 +151,8 @@ bool IsArrayType(PrimitiveType primitive_type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth); + // Returns the real, imag component type underlying the given complex type. // LOG(FATAL)'s if complex_type is not complex. PrimitiveType ComplexComponentType(PrimitiveType complex_type); diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index ac342bf40fbc0052acbb09a346b9d062561ed06b..e476015f94ffdd5225cf75fc845b1e8ba2067ce8 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -38,42 +38,14 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } -namespace { - -std::pair>*> -GetDirectoryExpanders() { - static auto* mutex = new tensorflow::mutex; - static auto* singleton = new std::vector>; - return {mutex, singleton}; -} - -// Runs all the directory expanders over x and returns the result. -string Expand(string x) { - auto pair = GetDirectoryExpanders(); - tensorflow::mutex_lock lock(*pair.first); - for (const auto& f : *pair.second) { - x = f(x); - } - return x; -} - -} // namespace - Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - string expanded_dir = Expand(directory); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); + const string path = tensorflow::io::JoinPath(directory, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } -void RegisterDirectoryExpander(const std::function& expander) { - auto pair = GetDirectoryExpanders(); - tensorflow::mutex_lock lock(*pair.first); - pair.second->push_back(expander); -} - } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index f22fc8b8499dd4a5329276040331a2ed9e89bea9..e20a7e95a63e15fb375f6f6da5b3a75843bc5396 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ +#include "absl/time/time.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ddffafa9017a565f01c3214360a958e6840e9148..f84e87d0c0a324a2194f6fe527358c6183e287be 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,8 +3,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") py_library( name = "xla_client", @@ -13,8 +13,6 @@ py_library( visibility = ["//visibility:public"], deps = [ ":pywrap_xla", - "//tensorflow/compiler/xla:xla_data_proto_py", - "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -33,6 +31,7 @@ py_test( deps = [ ":custom_call_for_test", ":xla_client", + "//tensorflow/compiler/xla:xla_data_proto_py", "//tensorflow/python:platform_test", ], ) @@ -59,10 +58,6 @@ cc_library( srcs = ["local_computation_builder.cc"], hdrs = ["local_computation_builder.h"], deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -74,18 +69,42 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:svd", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/core:lib", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "xrt", + srcs = ["xrt.cc"], + hdrs = ["xrt.h"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -93,11 +112,19 @@ cc_library( tf_py_wrap_cc( name = "pywrap_xla", - srcs = ["xla.i"], + srcs = [ + "xla.i", + ], swig_includes = [ "local_computation_builder.i", + "xla_data.i", "//tensorflow/python:platform/base.i", ], + version_script = select({ + "//tensorflow:macos": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), deps = [ ":local_computation_builder", ":numpy_bridge", @@ -105,7 +132,29 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service:gpu_plugin", - ]), + ] + xla_python_default_plugins(), +) + +tf_py_wrap_cc( + name = "pywrap_xrt", + srcs = [ + "xrt.i", + ], + swig_includes = [ + "xla_data.i", + "//tensorflow/python:platform/base.i", + ], + version_script = select({ + "//tensorflow:macos": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), + visibility = ["//visibility:public"], + deps = [ + ":numpy_bridge", + ":xrt", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index a1bb9d6b67e32552f6fb5d6523b5ba04b590f808..ffbfa7aaffe756ecb5580955822f3c06f1fb80f5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,29 +20,23 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/xla/client/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -51,72 +45,6 @@ limitations under the License. namespace xla { namespace swig { -// TODO(b/118641336): Factor out XRT parts into a small c++ library of their -// own. - -// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of -// device handles instead of needing to set the number of replicas at XLA -// service initialization time. -tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); -int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; -LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; - -string* GetPlatformNameString() { - static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = - new string("Host"); - return platform_name_string; -} - -Status InitializeReplicaCount(int replica_count) { - if (replica_count < 1) { - return InvalidArgument("Replica count must be >= 1; got %d.", - replica_count); - } - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the replica count to %d, but a local XLA service was " - "previously created with a replica count of %d.", - replica_count, g_replica_count); - } - g_replica_count = replica_count; - return Status::OK(); -} - -Status InitializePlatformName(const string& platform_name) { - string* g_platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the platform name to %s, but a local XLA service was " - "previously created with a platform name of %s.", - platform_name, *g_platform_name); - } - TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); - *g_platform_name = platform_name; - return Status::OK(); -} - -int GetReplicaCount() { - tensorflow::mutex_lock lock(g_local_client_mutex); - return g_replica_count; -} - -StatusOr GetOrCreateLocalClient() { - string* platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return g_local_client; - } - LocalClientOptions options; - options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); - options.set_number_of_replicas(g_replica_count); - TF_ASSIGN_OR_RETURN(g_local_client, - ClientLibrary::GetOrCreateLocalClient(options)); - CHECK(g_local_client != nullptr); - return g_local_client; -} - Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; if (!PyCapsule_IsValid(capsule, name)) { @@ -131,62 +59,66 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { return Status::OK(); } -Status TransferToInfeedLocal(const Literal& literal) { - VLOG(1) << "Infeeding literal without replica number; shape: " - << literal.shape(); - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); -} +LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {} -Status TransferToInfeedLocalReplica(const Literal& literal, - int replica_number) { - VLOG(1) << "Infeeding shape " << literal.shape() - << " to replica number: " << replica_number; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferToInfeedLocal(literal, device_ordinal); +/* static */ StatusOr LocalClient::Get( + const string& platform_name) { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(platform_name)); + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("Platform %s has no visible devices.", + platform_name); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + CHECK(client != nullptr); + return LocalClient(client); } -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number) { - VLOG(1) << "Outfeeding literal from replica number: " << replica_number - << " shape: " << shape; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferFromOutfeedLocal(shape, device_ordinal); +// Returns the number of devices known to the XLA client. +int LocalClient::DeviceCount() const { return client_->device_count(); } + +Status LocalClient::TransferToInfeed(const Literal& literal, + int device_ordinal) { + VLOG(1) << "Infeeding literal to device " << device_ordinal + << "; shape: " << literal.shape(); + return client_->TransferToInfeed(literal, device_ordinal); } -static StatusOr ToBuffer(LocalClient* client, - int device_ordinal, - const Literal& arg) { - return client->LiteralToShapedBuffer(arg, device_ordinal, - client->backend().memory_allocator()); +StatusOr LocalClient::TransferFromOutfeed(const Shape& shape, + int device_ordinal) { + VLOG(1) << "Outfeeding literal from device " << device_ordinal + << "; shape: " << shape; + return client_->TransferFromOutfeed(&shape, device_ordinal); } /* static */ StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " - << replica_number << "/" << device_ordinal; + const LocalClient& client, int device_ordinal) { + VLOG(1) << "Creating shaped buffer from literal on device ordinal: " + << device_ordinal; + auto literal_to_buffer = [&](const Literal& arg) { + return client.client()->LiteralToShapedBuffer( + arg, device_ordinal, client.client()->backend().memory_allocator()); + }; + StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, device_ordinal, relaid); + return literal_to_buffer(relaid); } - return ToBuffer(client, device_ordinal, argument); + return literal_to_buffer(argument); }(); TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client()); } -LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) - : shaped_buffer_(std::move(shaped_buffer)) {} +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, + xla::LocalClient* client) + : shaped_buffer_(std::move(shaped_buffer)), client_(client) {} const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; @@ -199,8 +131,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->ShapedBufferToLiteral(*shaped_buffer()); + return client_->ShapedBufferToLiteral(*shaped_buffer()); } LocalShapedBufferTuple::LocalShapedBufferTuple( @@ -231,140 +162,94 @@ StatusOr LocalShapedBufferTuple::Release(int i) { int64 LocalShapedBufferTuple::size() const { return elements_.size(); } -XrtAllocation::XrtAllocation(int64 handle, Shape shape, - const string& session_target) - : handle_(handle), shape_(shape), session_target_(session_target) {} - -XrtAllocation::~XrtAllocation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } +StatusOr LocalShapedBuffer::DestructureTuple() { + const Shape tuple_shape = shape(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); } -} - -/* static */ -StatusOr XrtAllocation::FromLiteral( - const Literal& argument, const string& session_target) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = argument.ToProto(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto literal_string = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); - TF_RETURN_IF_ERROR(root.status()); - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({literal_string, alloc.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = Release(); - int64 handle = outputs[0].scalar()(); - return new XrtAllocation(handle, argument.shape(), session_target); -} - -const int64 XrtAllocation::handle() const { return handle_; } - -const Shape& XrtAllocation::shape() const { return shape_; } - -StatusOr XrtAllocation::ToLiteral() const { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); - TF_RETURN_IF_ERROR(root.status()); + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + ShapeTree& shape_tree = tuple_buffer.buffers(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - xla::LiteralProto response; - TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); - return Literal::CreateFromProto(response); -} + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); -XrtAllocationTuple::XrtAllocationTuple(std::vector elements) - : elements_(std::move(elements)) { - for (auto* element : elements_) { - CHECK(element != nullptr); + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_)); } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); } -XrtAllocationTuple::~XrtAllocationTuple() { - for (XrtAllocation* element : elements_) { - if (element != nullptr) { - delete element; - } - } -} +LocalExecutable::LocalExecutable( + std::unique_ptr executable, + xla::DeviceAssignment device_assignment, xla::LocalClient* client) + : executable_(std::move(executable)), + device_assignment_(std::move(device_assignment)), + client_(client) {} -StatusOr XrtAllocationTuple::Release(int i) { - XrtAllocation* element = elements_[i]; - if (element == nullptr) { - return InvalidArgument("Attempted to release already-released element %d.", - i); +std::vector LocalExecutable::DeviceOrdinals() const { + int num_replicas = device_assignment_.replica_count(); + std::vector device_ordinals; + device_ordinals.reserve(num_replicas); + for (int i = 0; i < num_replicas; ++i) { + device_ordinals.push_back(device_assignment_(i, 0)); } - elements_[i] = nullptr; - return element; + return device_ordinals; } -int64 XrtAllocationTuple::size() const { return elements_.size(); } - -CompiledLocalComputation::CompiledLocalComputation( - std::unique_ptr executable) - : executable_(std::move(executable)) {} - -StatusOr CompiledLocalComputation::Execute( +StatusOr LocalExecutable::Execute( absl::Span argument_handles) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + if (num_replicas() != 1) { + return InvalidArgument( + "Attempted to execute computation with %d replicas using Execute()", + num_replicas()); + } StatusOr result_buffer_status; - if (!device_ordinal_status.ok()) { - result_buffer_status = device_ordinal_status.status(); - } else { - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica 0 mapped to device ordinal for execution: " - << device_ordinal; + const int device_ordinal = device_assignment_(0, 0); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(1, /*computation_count=*/1) - .ConsumeValueOrDie(); + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client_->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); - result_buffer_status = executable_->Run(argument_buffers, options); - } + result_buffer_status = executable_->Run(argument_buffers, options); if (!result_buffer_status.ok()) { return InternalError( @@ -372,34 +257,30 @@ StatusOr CompiledLocalComputation::Execute( "%s.", result_buffer_status.status().ToString()); } - return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(), + client_); } -StatusOr CompiledLocalComputation::ExecutePerReplica( +StatusOr LocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - const int num_replicas = GetReplicaCount(); + const int num_devices = client_->device_count(); - if (argument_handles.size() != num_replicas) { + if (argument_handles.size() != num_replicas()) { return InvalidArgument( "Attempted to execute with %d replicas when replica count is %d", - argument_handles.size(), num_replicas); + argument_handles.size(), num_devices); + } + if (argument_handles.size() > num_devices) { + return InvalidArgument( + "Attempted to execute with %d replicas when device count is %d", + argument_handles.size(), num_devices); } - VLOG(1) << "Executing with " << num_replicas << " replicas."; - - // Each replica populates a StatusOr result, but only the output value of - // replica zero is returned. - std::vector> results(num_replicas); - auto execute = [this, client, num_replicas, &argument_handles, - &results](int replica) { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(1) << "Executing with " << num_replicas() << " replicas."; + + std::vector> results(num_replicas()); + auto execute = [this, &argument_handles, &results](int replica) { + const int device_ordinal = device_assignment_(replica, 0); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -409,41 +290,35 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( argument_buffers.push_back(handle->shaped_buffer()); } - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(num_replicas, /*computation_count=*/1) - .ConsumeValueOrDie(); - ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); StatusOr result_buffer_status = executable_->Run(argument_buffers, options); results[replica] = std::move(result_buffer_status); }; - if (num_replicas == 1) { + if (num_replicas() == 1) { // Fast-path if there is only one replica — run the computation on the // current thread. execute(0); } else { // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - num_replicas - 1); + num_replicas() - 1); - for (int replica = 0; replica < num_replicas - 1; ++replica) { + for (int replica = 0; replica < num_replicas() - 1; ++replica) { pool.Schedule([&execute, replica] { execute(replica); }); } - execute(num_replicas - 1); + execute(num_replicas() - 1); } - std::vector wrapped_results(num_replicas); - for (int replica = 0; replica < num_replicas; ++replica) { + std::vector wrapped_results(num_replicas()); + for (int replica = 0; replica < num_replicas(); ++replica) { auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( @@ -452,151 +327,43 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( replica, statusor.status().ToString()); } wrapped_results[replica] = - new LocalShapedBuffer(std::move(statusor).ValueOrDie()); + new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_); } return new LocalShapedBufferTuple(std::move(wrapped_results)); } -static StatusOr GetReturnValueShape(const XlaComputation& computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation.GetProgramShape()); - return std::move(*program_shape.mutable_result()); -} - -CompiledXrtComputation::CompiledXrtComputation( - const ProgramShape& program_shape, int64 handle, - const string& session_target) - : program_shape_(program_shape), - handle_(handle), - session_target_(session_target) {} - -CompiledXrtComputation::~CompiledXrtComputation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({computation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; - } -} - -StatusOr CompiledXrtComputation::Execute( - absl::Span argument_handles) { - const int num_expected_arguments = program_shape().parameters().size(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - std::vector arguments; - arguments.reserve(num_expected_arguments); - for (int i = 0; i < num_expected_arguments; ++i) { - arguments.push_back( - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); - } - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto execution_config = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto execute = tensorflow::ops::XRTExecute(root, computation_handle, - execution_config, arguments); - TF_RETURN_IF_ERROR(root.status()); - - TF_RET_CHECK(argument_handles.size() == arguments.size()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - for (int i = 0; i < arguments.size(); ++i) { - inputs.insert({arguments[i], argument_handles[i]->handle()}); - } - inputs.insert({computation_handle, handle()}); - inputs.insert({execution_config, e.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); - - int64 output = outputs[0].scalar()(); - return new XrtAllocation(output, program_shape().result(), session_target_); -} - -const ProgramShape& CompiledXrtComputation::program_shape() const { - return program_shape_; -} - -int64 CompiledXrtComputation::handle() const { return handle_; } - -LocalComputation::LocalComputation(XlaComputation computation) +Computation::Computation(XlaComputation computation) : computation_(std::move(computation)) {} -StatusOr LocalComputation::Compile( +StatusOr Computation::Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options) { + const ExecutableBuildOptions* build_options, const LocalClient& client) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { argument_shape_pointers.push_back(&argument_shape); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; } TF_ASSIGN_OR_RETURN( auto local_executable, - client->Compile(computation_, argument_shape_pointers, options)); - return new CompiledLocalComputation(std::move(local_executable)); -} - -StatusOr LocalComputation::CompileForXrt( - const std::vector& argument_shapes, const string& session_target) { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto compile = tensorflow::ops::XRTCompile(root, program); - TF_RETURN_IF_ERROR(root.status()); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - ProgramShape shapes; - for (auto& shape : argument_shapes) { - *shapes.add_parameters() = shape; - } - TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(&shapes); - *config->mutable_program_shape() = shapes.ToProto(); - auto snapshot = computation().Snapshot().ValueOrDie(); - *c.mutable_hlo_snapshot() = *snapshot; - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({program, c.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + client.client()->Compile(computation_, argument_shape_pointers, options)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client.client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation().GetProgramShape()); - int64 handle = outputs[0].scalar()(); - return new CompiledXrtComputation(program_shape, handle, session_target); + return new LocalExecutable(std::move(local_executable), + std::move(device_assignment), client.client()); } -const XlaComputation& LocalComputation::computation() const { - return computation_; -} +const XlaComputation& Computation::computation() const { return computation_; } -string LocalComputation::GetSerializedProto() const { +string Computation::GetSerializedProto() const { string result; if (!computation_.proto().SerializeToString(&result)) { LOG(ERROR) << "Failed to serialize the HloModuleProto."; @@ -605,132 +372,173 @@ string LocalComputation::GetSerializedProto() const { return result; } -StatusOr LocalComputation::GetReturnValueShape() const { - return swig::GetReturnValueShape(computation_); +StatusOr Computation::GetHloText() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(false); + return hlo_module->ToString(options); +} + +StatusOr Computation::GetHloDotGraph() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +StatusOr Computation::GetProgramShape() const { + return computation_.GetProgramShape(); +} + +StatusOr Computation::GetReturnValueShape() const { + TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape()); + return std::move(*shape.mutable_result()); } LocalOp::LocalOp(const XlaOp& op) : op_(op) {} const XlaOp& LocalOp::op() const { return op_; } -LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) +ComputationBuilder::ComputationBuilder(const string& computation_name) : builder_(computation_name) {} -void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { +void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); } -void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } +void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } -StatusOr LocalComputationBuilder::Build() { +StatusOr ComputationBuilder::Build() { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp ComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, const string& name) { return xla::Parameter(&builder_, parameter_number, shape, name); } -StatusOr LocalComputationBuilder::BuildWithRoot( - const LocalOp& root) { +StatusOr ComputationBuilder::BuildWithRoot(const LocalOp& root) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { +StatusOr ComputationBuilder::GetShape(const LocalOp& operand) { return builder_.GetShape(operand.op()); } -StatusOr LocalComputationBuilder::GetReturnValueShape() { +StatusOr ComputationBuilder::GetReturnValueShape() { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); return program_shape.result(); } -LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp ComputationBuilder::ReplicaId() { return xla::ReplicaId(&builder_); } + +LocalOp ComputationBuilder::Infeed(const Shape& shape) { return xla::Infeed(&builder_, shape); } -void LocalComputationBuilder::Outfeed(const LocalOp& operand, - const Shape& shape, - const string& outfeed_config) { +void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, + const string& outfeed_config) { xla::Outfeed(operand.op(), shape, outfeed_config); } -LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { +LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } -LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { +LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) { return xla::Iota(&builder_, element_type, size); } -LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, - int64 dimension) { +LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { return xla::Iota(&builder_, shape, dimension); } -LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, absl::Span broadcast_sizes) { +LocalOp ComputationBuilder::Broadcast(const LocalOp& operand, + absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } -LocalOp LocalComputationBuilder::BroadcastInDim( +LocalOp ComputationBuilder::BroadcastInDim( const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } -LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, - const LocalOp& padding_value, - const PaddingConfig& padding_config) { +LocalOp ComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, - absl::Span dimensions, - absl::Span new_sizes) { +LocalOp ComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return xla::CrossReplicaSum(operand.op()); +LocalOp ComputationBuilder::AllToAll( + const LocalOp& operand, int64 split_dimension, int64 concat_dimension, + int64 split_count, absl::Span replica_groups) { + std::vector rg; + rg.reserve(replica_groups.size()); + for (int i = 0; i < replica_groups.size(); ++i) { + rg.push_back(replica_groups[i]); + } + return xla::AllToAll(operand.op(), split_dimension, concat_dimension, + split_count, rg); +} + +LocalOp ComputationBuilder::CrossReplicaSum( + const LocalOp& operand, absl::Span replica_groups) { + return xla::CrossReplicaSum(operand.op(), replica_groups); } -LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides) { +LocalOp ComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } -LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, - int64 start_index, - int64 limit_index, int64 stride, - int64 dimno) { +LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } -LocalOp LocalComputationBuilder::DynamicSlice( - const LocalOp& operand, const LocalOp& start_indices, - absl::Span slice_sizes) { +LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand, + const LocalOp& start_indices, + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -LocalOp LocalComputationBuilder::DynamicUpdateSlice( - const LocalOp& operand, const LocalOp& update, - const LocalOp& start_indices) { +LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand, + const LocalOp& update, + const LocalOp& start_indices) { return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, - int64 dimension) { +LocalOp ComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -739,18 +547,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, return xla::ConcatInDim(&builder_, xla_ops, dimension); } -LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, +LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter) { + const LocalOp& init_value, const Computation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { +LocalOp ComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -760,22 +568,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { return xla::Tuple(&builder_, xla_ops); } -LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, - int64 index) { +LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { return xla::GetTupleElement(tuple_data.op(), index); } -LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { +LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { return xla::Dot(lhs.op(), rhs.op()); } -LocalOp LocalComputationBuilder::DotGeneral( +LocalOp ComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -LocalOp LocalComputationBuilder::ConvGeneralDilated( +LocalOp ComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -787,18 +595,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( feature_group_count); } -LocalOp LocalComputationBuilder::ConvertElementType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::ConvertElementType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::BitcastConvertType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, - absl::Span operands) { +LocalOp ComputationBuilder::Call(const Computation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -807,7 +615,7 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } -LocalOp LocalComputationBuilder::CustomCall( +LocalOp ComputationBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const std::vector& operand_shapes_with_layout, @@ -822,19 +630,19 @@ LocalOp LocalComputationBuilder::CustomCall( operand_shapes_with_layout, opaque); } -LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, absl::Span permutation) { +LocalOp ComputationBuilder::Transpose(const LocalOp& operand, + absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map(absl::Span operands, - const LocalComputation& local_computation, - absl::Span dimensions) { +LocalOp ComputationBuilder::Map(absl::Span operands, + const Computation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -845,17 +653,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span operands, dimensions); } -LocalOp LocalComputationBuilder::Reduce( +LocalOp ComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } -LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( +LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -867,51 +675,50 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( padding); } -LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, - const LocalOp& sigma, - const Shape& shape) { +LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape) { return xla::RngNormal(mu.op(), sigma.op(), shape); } -LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, - const Shape& shape) { +LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { return xla::RngUniform(a.op(), b.op(), shape); } -LocalOp LocalComputationBuilder::While(const LocalComputation& condition, - const LocalComputation& body, - const LocalOp& init) { +LocalOp ComputationBuilder::While(const Computation& condition, + const Computation& body, + const LocalOp& init) { return xla::While(condition.computation(), body.computation(), init.op()); } -LocalOp LocalComputationBuilder::Conditional( - const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation) { +LocalOp ComputationBuilder::Conditional(const LocalOp& predicate, + const LocalOp& true_operand, + const Computation& true_computation, + const LocalOp& false_operand, + const Computation& false_computation) { return xla::Conditional(predicate.op(), true_operand.op(), true_computation.computation(), false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { +StatusOr ComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } -LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { +LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { return xla::Sort(operand.op(), {}, dimension); } -LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, - const LocalOp& values, - int64 dimension) { +LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, int64 dimension) { return xla::Sort(keys.op(), {values.op()}, dimension); } -LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { - return xla::Cholesky(a.op()); +LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) { + return xla::Cholesky(a.op(), lower); } -LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { +LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { XlaBuilder* builder = a.op().builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); @@ -919,16 +726,32 @@ LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { }); } -LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, - const LocalOp& b, - bool left_side, bool lower, - bool transpose_a, - bool conjugate_a) { - return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a, - conjugate_a); +LocalOp ComputationBuilder::Eigh(const LocalOp& a, bool lower) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto eigh = xla::SelfAdjointEig(a.op(), lower); + return xla::Tuple(builder, {eigh.v, eigh.w}); + }); } -LocalOp LocalComputationBuilder::Gather( +LocalOp ComputationBuilder::SVD(const LocalOp& a) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto svd = xla::SVD(a.op()); + return xla::Tuple(builder, {svd.u, svd.d, svd.v}); + }); +} + +LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b, + bool left_side, bool lower, + bool unit_diagonal, + int transpose_a) { + return xla::TriangularSolve( + a.op(), b.op(), left_side, lower, unit_diagonal, + xla::TriangularSolveOptions::Transpose(transpose_a)); +} + +LocalOp ComputationBuilder::Gather( const LocalOp& input, const LocalOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes) { @@ -936,24 +759,24 @@ LocalOp LocalComputationBuilder::Gather( slice_sizes); } -LocalOp LocalComputationBuilder::Scatter( +LocalOp ComputationBuilder::Scatter( const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers) { return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), update_computation.computation(), dimension_numbers); } -StatusOr LocalComputationBuilder::BuildConstantSubGraph( +StatusOr ComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.BuildConstantSubGraph(operand.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -#define _FORWARD(method_name, return_sig, args_sig, args) \ - return_sig LocalComputationBuilder::method_name args_sig { \ - return xla::method_name args; \ +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig ComputationBuilder::method_name args_sig { \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -995,6 +818,7 @@ _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) +_FORWARD_UNOP(Clz) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Expm1) @@ -1040,108 +864,9 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; } -void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } - -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { - delete computation; -} - -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { - delete computation; -} - -void DeleteLocalComputation(LocalComputation* computation) { - delete computation; -} - -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer) { - const Shape tuple_shape = local_shaped_buffer->shape(); +void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; } - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - DeviceMemoryAllocator* allocator = - local_shaped_buffer->shaped_buffer()->memory_allocator(); - ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); - - // Extract some metadata we use to construct scoped buffers. - const se::Platform* platform = tuple_buffer.platform(); - int device_ordinal = tuple_buffer.device_ordinal(); - - ShapeTree& shape_tree = tuple_buffer.buffers(); - std::vector results; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - // Create a shaped buffer for this destructured tuple element. - const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); - VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; - ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - - ShapeUtil::ForEachSubshape( - subshape, [&](const Shape& s, const ShapeIndex& index) { - ShapeIndex original(index); - original.push_front(i); - se::DeviceMemoryBase* device_memory = - shape_tree.mutable_element(original); - shaped_buffer.set_buffer(*device_memory, index); - *device_memory = se::DeviceMemoryBase(); - }); - - VLOG(3) << "Completed tuple element: " << i; - results.push_back(new LocalShapedBuffer( - ScopedShapedBuffer(std::move(shaped_buffer), allocator))); - } - // Deallocate the root buffer. - se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); - TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); - return new LocalShapedBufferTuple(std::move(results)); -} - -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target) { - const Shape& tuple_shape = allocation->shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); - auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - std::vector results; - for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - inputs.clear(); - inputs.insert({base_handle, allocation->handle()}); - inputs.insert({shape_index, {i}}); - std::vector outputs; - auto status = session.Run(inputs, {subtuple}, &outputs); - if (!status.ok()) { - // Clean up before returning non-ok status. - for (int j = 0; j < results.size(); ++j) { - delete results[j]; - } - return status; - } - const int64 subtuple_handle = outputs[0].scalar()(); - const Shape& subtuple_shape = - ShapeUtil::GetTupleElementShape(tuple_shape, i); - results.push_back( - new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); - } - return new XrtAllocationTuple(std::move(results)); -} +void DeleteComputation(Computation* computation) { delete computation; } } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 98759cf984751d2cef8df4449d392ace786a8ebc..5046c1ec011032cb7166c281a297388a8e02c4e8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -22,9 +22,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -35,42 +32,42 @@ limitations under the License. namespace xla { namespace swig { -// Initializes the number of replicas that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializeReplicaCount(int replica_count); - -// Initializes the platform name that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializePlatformName(const string& platform_name); - -// Returns the replica count that is currently set, regardless of whether the -// local XLA service has been instantiated yet or not. -int GetReplicaCount(); - // Registers a 'fn_capsule' as a CPU custom call target. // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name // "xla._CPU_CUSTOM_CALL_TARGET". Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); -// Wraps the local client's infeed-transfer function. -// -// The default device ordinal (0) is used. -Status TransferToInfeedLocal(const Literal& literal); +// Wrapper around an xla::LocalClient. +class LocalClient { + public: + // Initializes a local XLA client for `platform_name`. Returns an error if no + /// such platform exists, or if the platform has no visible devices. + static StatusOr Get(const string& platform_name); + + // Copyable and moveable; the class is just a wrapper around a + // xla::LocalClient pointer for convenient SWIG wrapping. + + // Returns the number of devices known to the XLA client. + int DeviceCount() const; + + // Wraps the local client's infeed-transfer function. + // + // The default device ordinal (0) is used. + Status TransferToInfeed(const Literal& literal, int device_ordinal); -// Transfers the given literal to the infeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); + // Transfers a literal of the given shape from the outfeed of the given + // replica. + StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); -// Transfers a literal of the given shape from the outfeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number); + xla::LocalClient* client() const { return client_; } + + private: + LocalClient(xla::LocalClient* client); + + xla::LocalClient* client_; +}; + +class LocalShapedBufferTuple; // Represents a reference to literals that live in a device-allocated buffer via // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a @@ -79,9 +76,9 @@ class LocalShapedBuffer { public: static StatusOr FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number); + const LocalClient& client, int device_ordinal); - LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); StatusOr ToLiteral() const; const Shape& shape() const; const ScopedShapedBuffer* shaped_buffer() const; @@ -90,8 +87,13 @@ class LocalShapedBuffer { // analogous to std::unique_ptr::release(). ShapedBuffer Release(); + // Destructures a tuple-valued LocalShapedBuffer into its constituent + // elements in LocalShapedBufferTuple form. + StatusOr DestructureTuple(); + private: ScopedShapedBuffer shaped_buffer_; + xla::LocalClient* client_; }; // Result of a tuple destructuring operation on a LocalShapedBuffer -- this @@ -117,68 +119,20 @@ class LocalShapedBufferTuple { std::vector elements_; }; -// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements -// in LocalShapedBufferTuple form. -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer); - -// Represents a reference to literals that live in a device-allocated buffer via -// XRT. Specifically, wraps an int64 handle produced by running the allocation -// graph, and an XLA shape to track the referent's shape. -class XrtAllocation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which allocation and deallocation - // graphs are run. - static StatusOr FromLiteral(const Literal& argument, - const string& session_target); - - XrtAllocation(int64 handle, Shape shape, const string& session_target); - ~XrtAllocation(); - StatusOr ToLiteral() const; - const Shape& shape() const; - const int64 handle() const; - - private: - const int64 handle_; - const Shape shape_; - const string session_target_; -}; - -// Result of a tuple destructuring operation on an XrtAllocation. -class XrtAllocationTuple { - public: - // Note: any XrtAllocation elements that are not Release()'d will be - // deallocated in the destructor. - explicit XrtAllocationTuple(std::vector elements); - - ~XrtAllocationTuple(); - - // Releases the ith element to the caller. Further attempts to release the ith - // element will return an invalid argument error. - StatusOr Release(int i); - - // Returns the number of elements in the destructured tuple. - int64 size() const; - - private: - std::vector elements_; -}; - -// Destructures a tuple-valued XrtAllocation into its constitutent elements -// in XrtAllocationTuple form. -// -// Accepts a `session_target` argument, used in constructing the -// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, -// and passed along in constructing each constituent XrtAllocation. -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target); - // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XLA LocalExecutable. -class CompiledLocalComputation { +class LocalExecutable { public: - CompiledLocalComputation(std::unique_ptr executable); + LocalExecutable(std::unique_ptr executable, + xla::DeviceAssignment device_assignment, + xla::LocalClient* client); + + int num_replicas() const { + return executable_->build_options().num_replicas(); + } + + // Returns the device ordinals to which each replica is assigned. + std::vector DeviceOrdinals() const; StatusOr Execute( absl::Span argument_handles); @@ -190,47 +144,22 @@ class CompiledLocalComputation { absl::Span > argument_handles); private: - std::unique_ptr executable_; + const std::unique_ptr executable_; + const xla::DeviceAssignment device_assignment_; + xla::LocalClient* const client_; }; -// Represents a compiled computation that can be executed given handles to -// device-allocated literals. Specifically, wraps an XRT computation handle. -class CompiledXrtComputation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the execution graph is run. - CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, - const string& session_target); - ~CompiledXrtComputation(); - - StatusOr Execute( - absl::Span argument_handles); - - const ProgramShape& program_shape() const; - int64 handle() const; - - private: - const ProgramShape program_shape_; - const int64 handle_; - const string session_target_; -}; - -// Wraps a XlaComputation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a ComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. -class LocalComputation { +class Computation { public: - LocalComputation(XlaComputation computation); + Computation(XlaComputation computation); - StatusOr Compile( + StatusOr Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options); - - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the compilation graph is run. - StatusOr CompileForXrt( - const std::vector& argument_shapes, const string& session_target); + const ExecutableBuildOptions* build_options, const LocalClient& client); const XlaComputation& computation() const; @@ -239,6 +168,15 @@ class LocalComputation { // string on failure. string GetSerializedProto() const; + // Returns the computation in human-readable HLO text format. + StatusOr GetHloText() const; + + // Returns the computation in graphviz dot format. + StatusOr GetHloDotGraph() const; + + // Returns the program shape for this computation. + StatusOr GetProgramShape() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -246,7 +184,7 @@ class LocalComputation { XlaComputation computation_; }; -// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// Wraps a XlaOp produced by a ComputationBuilder. This class is intended // to be made available to Python via SWIG. class LocalOp { public: @@ -263,20 +201,20 @@ class LocalOp { // Python. // - Set up the underlying builder to use the client library's // LocalClient. -// - Wrap Computations in LocalComputations for Python access. -// - Correspondingly unwrap incoming LocalComputations. -class LocalComputationBuilder { +// - Wrap Computations in Computations for Python access. +// - Correspondingly unwrap incoming Computations. +class ComputationBuilder { public: - LocalComputationBuilder(const string& computation_name); + ComputationBuilder(const string& computation_name); void SetOpMetadata(const OpMetadata& metadata); void ClearOpMetadata(); - // Returns an owned LocalComputation to the caller on success. - StatusOr Build(); + // Returns an owned Computation to the caller on success. + StatusOr Build(); - // Returns an owned LocalComputation to the caller on success with given root. - StatusOr BuildWithRoot(const LocalOp& root); + // Returns an owned Computation to the caller on success with given root. + StatusOr BuildWithRoot(const LocalOp& root); LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); @@ -286,6 +224,8 @@ class LocalComputationBuilder { // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); + LocalOp ReplicaId(); + LocalOp Infeed(const Shape& shape); void Outfeed(const LocalOp& operand, const Shape& shape, @@ -312,7 +252,12 @@ class LocalComputationBuilder { LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); - LocalOp CrossReplicaSum(const LocalOp& operand); + LocalOp AllToAll(const LocalOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + absl::Span replica_groups); + + LocalOp CrossReplicaSum(const LocalOp& operand, + absl::Span replica_groups); LocalOp Slice(const LocalOp& operand, absl::Span start_indices, absl::Span limit_indices, @@ -330,11 +275,11 @@ class LocalComputationBuilder { LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span > padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter); + const LocalOp& init_value, const Computation& scatter); LocalOp Tuple(absl::Span elements); @@ -360,7 +305,7 @@ class LocalComputationBuilder { LocalOp BitcastConvertType(const LocalOp& operand, PrimitiveType new_element_type); - LocalOp Call(const LocalComputation& local_computation, + LocalOp Call(const Computation& local_computation, absl::Span operands); LocalOp CustomCall(const string& call_target_name, @@ -375,16 +320,16 @@ class LocalComputationBuilder { LocalOp Rev(const LocalOp& operand, absl::Span dimensions); LocalOp Map(absl::Span operands, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -396,13 +341,13 @@ class LocalComputationBuilder { LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - LocalOp While(const LocalComputation& condition, const LocalComputation& body, + LocalOp While(const Computation& condition, const Computation& body, const LocalOp& init); LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, + const Computation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation); + const Computation& false_computation); StatusOr IsConstant(const LocalOp& operand); @@ -413,21 +358,27 @@ class LocalComputationBuilder { LocalOp QR(const LocalOp& a, bool full_matrices); - LocalOp Cholesky(const LocalOp& a); + LocalOp Cholesky(const LocalOp& a, bool lower); + + LocalOp Eigh(const LocalOp& a, bool lower); + + LocalOp SVD(const LocalOp& a); + // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose + // enum. We use an integer here so we don't have to teach SWIG about the + // enum. LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a); + bool lower, bool unit_diagonal, int transpose_a); LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, - const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - StatusOr BuildConstantSubGraph(const LocalOp& operand); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -469,6 +420,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(Pow) _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) + _FORWARD_UNOP(Clz) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Expm1) @@ -516,10 +468,8 @@ class LocalComputationBuilder { // Functions for freeing resources from the Python side. void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); -void DeleteXrtAllocation(XrtAllocation* allocation); -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); -void DeleteLocalComputation(LocalComputation* computation); +void DeleteLocalExecutable(LocalExecutable* computation); +void DeleteComputation(Computation* computation); } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 39bff3f5719fd6b2ee1856848ade7152a8e7ee3e..a1a4f007f2fc9ce730f9c7fd11dadae85250edb5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -16,89 +16,6 @@ limitations under the License. // SWIG typemaps and declarations for building, compiling, and // executing XLA computations, wrapping most of what is declared in // local_computation_builder.h. -// -// The typemaps below implement/assert the following correspondences -// (with elaborations below): -// -// C++ Python -// -------------------------------------+--------------------------------------- -// Span <- sequence of int -// Span <- sequence of LocalOp -// Literal <-> (nested tuple of) numpy ndarray -// std::vector <- sequence of (nested tuple of) ndarray -// Shape -> pair holding (dtype, dimensions) -// <- object duck-typed as xla_client.Shape -// std::vector <- sequence of xla_client.Shape objects -// PrimitiveType <- int -// Span> <- sequence of int pairs -// PaddingConfig proto <- corresponding Python proto -// ConvolutionDimensionNumbers proto <- corresponding Python proto -// DotDimensionNumbers proto <- corresponding Python proto -// GatherDimensionNumbers proto <- corresponding Python proto -// ScatterDimensionNumbers proto <- corresponding Python proto -// -// Arrows indicate whether a conversion only ever occurs in one -// direction, or whether it is maintained bidirectionally. -// -// The Python objects corresponding to C++ Literals have the type: -// -// T = ndarray | (T, ...) -// -// where a terminal numpy ndarray translates to a Literal with a -// non-tuple Shape, an XLA primitive element type corresponding to the -// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates -// to a tuple-shaped Literal whose tuple components are translated -// recursively. For example, if x is a numpy ndarray in Python, with -// shape (2, 3) and dtype of dtype('float32'), then x translates to a -// Literal with rank 2, dimension 2 and 3, and XLA primitive type -// F32. Meanwhile, -// -// (x, (x, x), (x,)), -// -// translates to a tuple-shaped XLA Literal, whose component subshapes -// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. -// -// Shapes output by C++ become Python objects with the type: -// -// T = (dtype, S) -// S = DIMENSIONS | TUPLE_SHAPES -// DIMENSIONS = (int, ...) -// TUPLE_SHAPES = (T, ...) -// -// In the pair described by the T rule, the terminal dtype determines -// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is -// dtype('O'), numpy's object dtype, the structure represents a tuple -// shape and the expansion of the non-terminal S is -// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type -// and S expands into DIMENSIONS giving dimension sizes. For example: -// -// (dtype('float32'), (3, 5, 7)) -// -// describes a 3x5x7 array of F32s, and -// -// (dtype('O'), ((dtype('float32'), (2, 3)), -// (dtype('float64'), (4, 5)))) -// -// describes a tuple shape with two subshapes: the first a 2x3 F32, -// and the other a 4x5 F64. -// -// The Python int corresponding to a PrimitiveType enum must be valid -// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). -// -// The SWIG object wrappers generated by this file are not intended -// for end use, but rather for internal use in the Python XLA client, -// xla_client.py. -// -// One central reason for the Python-side indirection is that the -// Python-side objects produced by the typemaps in this file are -// further packaged up by xla_client before being passed on. For -// instance, the Python pair produced for a C++ Shape is further -// wrapped in a Python class (xla_client.Shape) so as not to expose -// the raw pair externally. -// -// Other SWIG object wrappers (e.g. of LocalComputation) are further -// wrapped by xla_client in order to set up a custom destructor that -// triggers memory deallocation on the C++ side. %module(threads="1") local_computation_builder @@ -106,6 +23,7 @@ limitations under the License. %nothread; %include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" %{ // Must be included first @@ -123,87 +41,6 @@ limitations under the License. using namespace xla; using namespace xla::swig; -namespace xla { - -namespace swig { - -bool GetIntAttr(PyObject* o, const char* field, int64* result) { - PyObject* fo = PyObject_GetAttrString(o, field); - if (!fo) { - return false; - } - const int64 value = numpy::PyIntOrPyLongToLong(fo); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(fo); - return false; - } - Py_DECREF(fo); - *result = value; - return true; -} - -// Returns "ok"; true if there is no error, false if there was an error. -bool HandleStringAttribute(PyObject* o, - const char* attr_name, - std::function f) { - if (!PyObject_HasAttrString(o, attr_name)) { - return true; // It's ok for the object to not have the attribute. - } - PyObject* attr = PyObject_GetAttrString(o, attr_name); - if (attr == nullptr) { - return false; // An error occurred getting the attribute. - } - if (attr == Py_None) { - Py_DECREF(attr); - return true; // The attribute is None, which we consider ok. - } - if (!PyString_Check(attr)) { - string message = absl::StrFormat("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr)); - PyErr_SetString(PyExc_TypeError, message.c_str()); - Py_DECREF(attr); - return false; // Type error, not ok. - } - f(PyString_AsString(attr)); - Py_DECREF(attr); - return true; // Handled string attribute, ok! -} - -bool HandleRepeatedInt64Attribute( - PyObject* o, const char* attr_name, - tensorflow::protobuf::RepeatedField* field) { - PyObject* seq = PyObject_GetAttrString(o, attr_name); - if (!seq) { - return false; - } - - int length = PySequence_Size(seq); - if (length == -1) { - Py_DECREF(seq); - return false; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(seq, i); - if (!item) { - Py_DECREF(seq); - return false; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(seq); - return false; - } - *field->Add() = dimension; - Py_DECREF(item); - } - Py_DECREF(seq); - return true; -} - -} // namespace swig -} // namespace xla %} // Required to use PyArray_* functions. @@ -211,57 +48,6 @@ bool HandleRepeatedInt64Attribute( tensorflow::ImportNumpy(); %} -// Basic types - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = PyBool_FromLong($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) Status { - if (!$1.ok()) { - PyErr_SetString( - PyExc_RuntimeError, $1.ToString().c_str()); - SWIG_fail; - } - Py_INCREF(Py_None); - $result = Py_None; -} - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.resize(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - temps[i] = numpy::PyIntOrPyLongToLong(py_int); - if (temps[i] == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - SWIG_fail; - } - Py_DECREF(py_int); - Py_DECREF(o); - } - $1 = temps; -} - // Computation builder types %typemap(in) absl::Span( @@ -286,12 +72,12 @@ tensorflow::ImportNumpy(); // Computation and buffer/allocation types -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - auto* value = $1.ValueOrDie(); + xla::swig::LocalClient value = $1.ValueOrDie(); { - auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + auto $1 = value; + $typemap(out, xla::swig::LocalClient) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -299,12 +85,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledXrtComputation*) + $typemap(out, xla::swig::LocalExecutable*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -338,38 +124,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::XrtAllocation*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocationTuple*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::LocalComputation*) + $typemap(out, xla::swig::Computation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -429,448 +189,6 @@ tensorflow::ImportNumpy(); $1 = temps; } -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - XrtAllocation* xrta; - if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(xrta); - Py_DECREF(o); - } - $1 = temps; -} - -// Literal - -%typemap(out) StatusOr { - if ($1.ok()) { - Literal value = $1.ConsumeValueOrDie(); - $result = numpy::PyObjectFromXlaLiteral(*value); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) const Literal& (StatusOr literal_status) { - literal_status = numpy::XlaLiteralFromPyObject($input); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - SWIG_fail; - } - $1 = &literal_status.ValueOrDie(); -} - -%typemap(out) Literal { - $result = numpy::PyObjectFromXlaLiteral(*$1); -} - -%typemap(out) StatusOr { - if (!$1.ok()) { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } - $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - Py_DECREF(o); - SWIG_fail; - } - temps.push_back(literal_status.ConsumeValueOrDie()); - Py_DECREF(o); - } - $1 = &temps; -} - -// OpMetadata - -%typemap(in) const OpMetadata& (OpMetadata temp) { - StatusOr statusor = numpy::OpMetadataFromPyObject($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -// Shape - -%typemap(out) const Shape& { - $result = numpy::PyShapeInfoFromXlaShape(*$1); -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) const Shape& (Shape temp) { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -%typemap(in) const absl::optional& ( - absl::optional temp) { - if ($input == Py_None) { - temp = absl::nullopt; - $1 = &temp; - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; - } -} - -%typemap(out) std::unique_ptr { - $result = numpy::PyShapeInfoFromXlaShape(*$1); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - $1 = &temps; -} - -%typemap(in) const std::vector >& ( - std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (o == Py_None) { - temps.push_back(absl::nullopt); - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - } - $1 = &temps; -} - -// PrimitiveType - -%typemap(in) PrimitiveType { - PyObject* py_int = numpy::PyNumberToPyInt($input); - if (!py_int) { - PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - SWIG_fail; - } - const long value = numpy::PyIntOrPyLongToLong(py_int); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - SWIG_fail; - } - if (!PrimitiveType_IsValid(value)) { - PyErr_SetString( - PyExc_TypeError, "Argument not valid for PrimitiveType enum"); - Py_DECREF(py_int); - SWIG_fail; - } - $1 = static_cast(value); -} - -// Span> - -%typemap(in) absl::Span > - (std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (!o) { - SWIG_fail; - } - PyObject* first = PyTuple_GetItem(o, 0); - if (!first) { - Py_DECREF(o); - SWIG_fail; - } - PyObject* first_pyint = numpy::PyNumberToPyInt(first); - if (!first_pyint) { - PyErr_SetString( - PyExc_TypeError, - "First pair item cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - PyObject* second_pyint = numpy::PyNumberToPyInt(second); - if (!second_pyint) { - PyErr_SetString( - PyExc_TypeError, - "Second pair item cannot be converted to int"); - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); - if (first_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); - if (second_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - temps.push_back(std::make_pair(first_value, second_value)); - Py_DECREF(o); - } - $1 = temps; -} - -// DotDimensionNumbers - -%typemap(in) const DotDimensionNumbers& - (DotDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "lhs_contracting_dimensions", - dimension_numbers.mutable_lhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_contracting_dimensions", - dimension_numbers.mutable_rhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "lhs_batch_dimensions", - dimension_numbers.mutable_lhs_batch_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_batch_dimensions", - dimension_numbers.mutable_rhs_batch_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// PaddingConfig - -%typemap(in) const PaddingConfig& - (PaddingConfig padding_config) { - PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); - if (!dimensions) { - SWIG_fail; - } - - int length = PySequence_Size(dimensions); - if (length == -1) { - Py_DECREF(dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(dimensions, i); - if (!item) { - Py_DECREF(dimensions); - SWIG_fail; - } - int64 edge_padding_low, edge_padding_high, interior_padding; - if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) - || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) - || !GetIntAttr(item, "interior_padding", &interior_padding)) { - Py_DECREF(item); - Py_DECREF(dimensions); - SWIG_fail; - } - Py_DECREF(item); - - PaddingConfig::PaddingConfigDimension* dimension = - padding_config.add_dimensions(); - dimension->set_edge_padding_low(edge_padding_low); - dimension->set_edge_padding_high(edge_padding_high); - dimension->set_interior_padding(interior_padding); - } - Py_DECREF(dimensions); - - $1 = &padding_config; -} - -// ConvolutionDimensionNumbers - -%typemap(in) const ConvolutionDimensionNumbers& - (ConvolutionDimensionNumbers dimension_numbers) { - int64 value; - - if (!GetIntAttr($input, "input_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_batch_dimension(value); - - if (!GetIntAttr($input, "input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_feature_dimension(value); - - if (!GetIntAttr($input, "output_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_batch_dimension(value); - - if (!GetIntAttr($input, "output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_input_feature_dimension(value); - - if (!HandleRepeatedInt64Attribute( - $input, "input_spatial_dimensions", - dimension_numbers.mutable_input_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "kernel_spatial_dimensions", - dimension_numbers.mutable_kernel_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "output_spatial_dimensions", - dimension_numbers.mutable_output_spatial_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// GatherDimensionNumbers - -%typemap(in) const GatherDimensionNumbers& - (GatherDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "offset_dims", - dimension_numbers.mutable_offset_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "collapsed_slice_dims", - dimension_numbers.mutable_collapsed_slice_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "start_index_map", - dimension_numbers.mutable_start_index_map())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// ScatterDimensionNumbers - -%typemap(in) const ScatterDimensionNumbers& - (ScatterDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "update_window_dims", - dimension_numbers.mutable_update_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "inserted_window_dims", - dimension_numbers.mutable_inserted_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "scatter_dims_to_operand_dims", - dimension_numbers.mutable_scatter_dims_to_operand_dims())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - // ExecutableBuildOptions %typemap(in) const ExecutableBuildOptions* @@ -878,41 +196,38 @@ tensorflow::ImportNumpy(); if ($input == Py_None) { $1 = NULL; } else { - if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { - build_options.mutable_debug_options()->set_xla_generate_hlo_graph(std::move(s)); + if (!HandleStringAttribute($input, "dump_to", [&](string s) { + build_options.mutable_debug_options()->set_xla_dump_to(std::move(s)); })) { return nullptr; } - if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { - build_options.mutable_debug_options()->set_xla_dump_optimized_hlo_proto_to(std::move(s)); + if (!HandleStringAttribute($input, "dump_hlo_pass_re", [&](string s) { + build_options.mutable_debug_options()->set_xla_dump_hlo_pass_re(std::move(s)); })) { return nullptr; } - if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { - build_options.mutable_debug_options()->set_xla_dump_unoptimized_hlo_proto_to(std::move(s)); + if (!HandleStringAttribute($input, "dump_hlo_module_re", [&](string s) { + build_options.mutable_debug_options()->set_xla_dump_hlo_module_re(std::move(s)); })) { return nullptr; } - if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { - build_options.mutable_debug_options()->set_xla_dump_per_pass_hlo_proto_to(std::move(s)); + if (!HandleBoolAttribute($input, "dump_hlo_as_text", [&](bool b) { + build_options.mutable_debug_options()->set_xla_dump_hlo_as_text(b); })) { return nullptr; } - - PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); - if (o == NULL) { - SWIG_fail; + if (!HandleBoolAttribute($input, "dump_hlo_as_proto", [&](bool b) { + build_options.mutable_debug_options()->set_xla_dump_hlo_as_proto(b); + })) { + return nullptr; } - if (o != Py_None) { - if (!PyBool_Check(o)) { - PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); - SWIG_fail; - } - build_options.mutable_debug_options()->set_xla_hlo_profile(o == Py_True); + if (!HandleBoolAttribute($input, "hlo_profile", [&](bool b) { + build_options.mutable_debug_options()->set_xla_hlo_profile(b); + })) { + return nullptr; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "result_shape"); + PyObject* o = PyObject_GetAttrString($input, "result_shape"); if (o == nullptr) { return nullptr; } @@ -927,6 +242,12 @@ tensorflow::ImportNumpy(); } Py_DECREF(o); + int64 num_replicas; + if (!GetIntAttr($input, "num_replicas", &num_replicas)) { + SWIG_fail; + } + build_options.set_num_replicas(num_replicas); + $1 = &build_options; } } @@ -934,159 +255,154 @@ tensorflow::ImportNumpy(); %ignoreall %unignore xla; %unignore xla::swig; -%unignore xla::swig::InitializeReplicaCount; -%unignore xla::swig::InitializePlatformName; -%unignore xla::swig::GetReplicaCount; %unignore xla::swig::RegisterCpuCustomCallTarget; -%unignore xla::swig::TransferToInfeedLocal; -%unignore xla::swig::TransferToInfeedLocalReplica; -%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::LocalClient; +%unignore xla::swig::LocalClient::Get; +%unignore xla::swig::LocalClient::DeviceCount; +%unignore xla::swig::LocalClient::TransferToInfeed; +%unignore xla::swig::LocalClient::TransferFromOutfeed; %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; %unignore xla::swig::LocalShapedBuffer::shape; +%unignore xla::swig::LocalShapedBuffer::DestructureTuple; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; -%unignore xla::swig::XrtAllocation; -%unignore xla::swig::XrtAllocation::FromLiteral; -%unignore xla::swig::XrtAllocation::ToLiteral; -%unignore xla::swig::XrtAllocation::shape; -%unignore xla::swig::XrtAllocationTuple; -%unignore xla::swig::XrtAllocationTuple::Release; -%unignore xla::swig::XrtAllocationTuple::size; -%unignore xla::swig::CompiledLocalComputation; -%unignore xla::swig::CompiledLocalComputation::Execute; -%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; -%unignore xla::swig::CompiledXrtComputation; -%unignore xla::swig::CompiledXrtComputation::Execute; -%unignore xla::swig::LocalComputation; -%unignore xla::swig::LocalComputation::Compile; -%unignore xla::swig::LocalComputation::CompileForXrt; -%unignore xla::swig::LocalComputation::GetReturnValueShape; -%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalExecutable; +%unignore xla::swig::LocalExecutable::DeviceOrdinals; +%unignore xla::swig::LocalExecutable::Execute; +%unignore xla::swig::LocalExecutable::ExecutePerReplica; +%unignore xla::swig::Computation; +%unignore xla::swig::Computation::Compile; +%unignore xla::swig::Computation::GetProgramShape; +%unignore xla::swig::Computation::GetReturnValueShape; +%unignore xla::swig::Computation::GetSerializedProto; +%unignore xla::swig::Computation::GetHloText; +%unignore xla::swig::Computation::GetHloDotGraph; %unignore xla::swig::LocalOp; -%unignore xla::swig::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::Build; -%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; -%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; -%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; -%unignore xla::swig::LocalComputationBuilder::Parameter; -%unignore xla::swig::LocalComputationBuilder::GetShape; -%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; -%unignore xla::swig::LocalComputationBuilder::Infeed; -%unignore xla::swig::LocalComputationBuilder::Outfeed; -%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; -%unignore xla::swig::LocalComputationBuilder::ConstantR0; -%unignore xla::swig::LocalComputationBuilder::Iota; -%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; -%unignore xla::swig::LocalComputationBuilder::Broadcast; -%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; -%unignore xla::swig::LocalComputationBuilder::Pad; -%unignore xla::swig::LocalComputationBuilder::Reshape; -%unignore xla::swig::LocalComputationBuilder::Collapse; -%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; -%unignore xla::swig::LocalComputationBuilder::Slice; -%unignore xla::swig::LocalComputationBuilder::SliceInDim; -%unignore xla::swig::LocalComputationBuilder::DynamicSlice; -%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; -%unignore xla::swig::LocalComputationBuilder::ConcatInDim; -%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::Select; -%unignore xla::swig::LocalComputationBuilder::Tuple; -%unignore xla::swig::LocalComputationBuilder::GetTupleElement; -%unignore xla::swig::LocalComputationBuilder::ConvertElementType; -%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; -%unignore xla::swig::LocalComputationBuilder::Call; -%unignore xla::swig::LocalComputationBuilder::Transpose; -%unignore xla::swig::LocalComputationBuilder::Rev; -%unignore xla::swig::LocalComputationBuilder::Clamp; -%unignore xla::swig::LocalComputationBuilder::Map; -%unignore xla::swig::LocalComputationBuilder::Reduce; -%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::RngNormal; -%unignore xla::swig::LocalComputationBuilder::RngUniform; -%unignore xla::swig::LocalComputationBuilder::RngBernoulli; -%unignore xla::swig::LocalComputationBuilder::While; -%unignore xla::swig::LocalComputationBuilder::Conditional; -%unignore xla::swig::LocalComputationBuilder::IsConstant; -%unignore xla::swig::LocalComputationBuilder::Eq; -%unignore xla::swig::LocalComputationBuilder::Ne; -%unignore xla::swig::LocalComputationBuilder::Ge; -%unignore xla::swig::LocalComputationBuilder::Gt; -%unignore xla::swig::LocalComputationBuilder::Lt; -%unignore xla::swig::LocalComputationBuilder::Le; -%unignore xla::swig::LocalComputationBuilder::Dot; -%unignore xla::swig::LocalComputationBuilder::DotGeneral; -%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; -%unignore xla::swig::LocalComputationBuilder::Add; -%unignore xla::swig::LocalComputationBuilder::Sub; -%unignore xla::swig::LocalComputationBuilder::Mul; -%unignore xla::swig::LocalComputationBuilder::Div; -%unignore xla::swig::LocalComputationBuilder::Rem; -%unignore xla::swig::LocalComputationBuilder::Max; -%unignore xla::swig::LocalComputationBuilder::Min; -%unignore xla::swig::LocalComputationBuilder::And; -%unignore xla::swig::LocalComputationBuilder::Or; -%unignore xla::swig::LocalComputationBuilder::Xor; -%unignore xla::swig::LocalComputationBuilder::ShiftLeft; -%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; -%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; -%unignore xla::swig::LocalComputationBuilder::Not; -%unignore xla::swig::LocalComputationBuilder::Abs; -%unignore xla::swig::LocalComputationBuilder::Exp; -%unignore xla::swig::LocalComputationBuilder::Expm1; -%unignore xla::swig::LocalComputationBuilder::Floor; -%unignore xla::swig::LocalComputationBuilder::Ceil; -%unignore xla::swig::LocalComputationBuilder::Round; -%unignore xla::swig::LocalComputationBuilder::Log; -%unignore xla::swig::LocalComputationBuilder::Log1p; -%unignore xla::swig::LocalComputationBuilder::Sign; -%unignore xla::swig::LocalComputationBuilder::Cos; -%unignore xla::swig::LocalComputationBuilder::Sin; -%unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::Atan2; -%unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::Pow; -%unignore xla::swig::LocalComputationBuilder::Neg; -%unignore xla::swig::LocalComputationBuilder::Sort; -%unignore xla::swig::LocalComputationBuilder::SortKeyVal; -%unignore xla::swig::LocalComputationBuilder::Sqrt; -%unignore xla::swig::LocalComputationBuilder::Rsqrt; -%unignore xla::swig::LocalComputationBuilder::Square; -%unignore xla::swig::LocalComputationBuilder::Reciprocal; -%unignore xla::swig::LocalComputationBuilder::Erfc; -%unignore xla::swig::LocalComputationBuilder::Erf; -%unignore xla::swig::LocalComputationBuilder::ErfInv; -%unignore xla::swig::LocalComputationBuilder::Lgamma; -%unignore xla::swig::LocalComputationBuilder::Digamma; -%unignore xla::swig::LocalComputationBuilder::Acos; -%unignore xla::swig::LocalComputationBuilder::Asin; -%unignore xla::swig::LocalComputationBuilder::Atan; -%unignore xla::swig::LocalComputationBuilder::Tan; -%unignore xla::swig::LocalComputationBuilder::Acosh; -%unignore xla::swig::LocalComputationBuilder::Asinh; -%unignore xla::swig::LocalComputationBuilder::Atanh; -%unignore xla::swig::LocalComputationBuilder::Cosh; -%unignore xla::swig::LocalComputationBuilder::Sinh; -%unignore xla::swig::LocalComputationBuilder::Real; -%unignore xla::swig::LocalComputationBuilder::Imag; -%unignore xla::swig::LocalComputationBuilder::Conj; -%unignore xla::swig::LocalComputationBuilder::Complex; -%unignore xla::swig::LocalComputationBuilder::Cholesky; -%unignore xla::swig::LocalComputationBuilder::QR; -%unignore xla::swig::LocalComputationBuilder::TriangularSolve; -%unignore xla::swig::LocalComputationBuilder::CustomCall; -%unignore xla::swig::LocalComputationBuilder::Gather; -%unignore xla::swig::LocalComputationBuilder::Scatter; -%unignore xla::swig::DeleteLocalComputation; -%unignore xla::swig::DestructureLocalShapedBufferTuple; -%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::Build; +%unignore xla::swig::ComputationBuilder::BuildWithRoot; +%unignore xla::swig::ComputationBuilder::SetOpMetadata; +%unignore xla::swig::ComputationBuilder::ClearOpMetadata; +%unignore xla::swig::ComputationBuilder::Parameter; +%unignore xla::swig::ComputationBuilder::GetShape; +%unignore xla::swig::ComputationBuilder::GetReturnValueShape; +%unignore xla::swig::ComputationBuilder::ReplicaId; +%unignore xla::swig::ComputationBuilder::Infeed; +%unignore xla::swig::ComputationBuilder::Outfeed; +%unignore xla::swig::ComputationBuilder::ConstantLiteral; +%unignore xla::swig::ComputationBuilder::ConstantR0; +%unignore xla::swig::ComputationBuilder::Iota; +%unignore xla::swig::ComputationBuilder::BroadcastedIota; +%unignore xla::swig::ComputationBuilder::Broadcast; +%unignore xla::swig::ComputationBuilder::BroadcastInDim; +%unignore xla::swig::ComputationBuilder::Pad; +%unignore xla::swig::ComputationBuilder::Reshape; +%unignore xla::swig::ComputationBuilder::Collapse; +%unignore xla::swig::ComputationBuilder::AllToAll; +%unignore xla::swig::ComputationBuilder::CrossReplicaSum; +%unignore xla::swig::ComputationBuilder::Slice; +%unignore xla::swig::ComputationBuilder::SliceInDim; +%unignore xla::swig::ComputationBuilder::DynamicSlice; +%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::ComputationBuilder::ConcatInDim; +%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::Select; +%unignore xla::swig::ComputationBuilder::Tuple; +%unignore xla::swig::ComputationBuilder::GetTupleElement; +%unignore xla::swig::ComputationBuilder::ConvertElementType; +%unignore xla::swig::ComputationBuilder::BitcastConvertType; +%unignore xla::swig::ComputationBuilder::Call; +%unignore xla::swig::ComputationBuilder::Transpose; +%unignore xla::swig::ComputationBuilder::Rev; +%unignore xla::swig::ComputationBuilder::Clamp; +%unignore xla::swig::ComputationBuilder::Map; +%unignore xla::swig::ComputationBuilder::Reduce; +%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::RngNormal; +%unignore xla::swig::ComputationBuilder::RngUniform; +%unignore xla::swig::ComputationBuilder::RngBernoulli; +%unignore xla::swig::ComputationBuilder::While; +%unignore xla::swig::ComputationBuilder::Conditional; +%unignore xla::swig::ComputationBuilder::IsConstant; +%unignore xla::swig::ComputationBuilder::Eq; +%unignore xla::swig::ComputationBuilder::Ne; +%unignore xla::swig::ComputationBuilder::Ge; +%unignore xla::swig::ComputationBuilder::Gt; +%unignore xla::swig::ComputationBuilder::Lt; +%unignore xla::swig::ComputationBuilder::Le; +%unignore xla::swig::ComputationBuilder::Dot; +%unignore xla::swig::ComputationBuilder::DotGeneral; +%unignore xla::swig::ComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::ComputationBuilder::Add; +%unignore xla::swig::ComputationBuilder::Sub; +%unignore xla::swig::ComputationBuilder::Mul; +%unignore xla::swig::ComputationBuilder::Div; +%unignore xla::swig::ComputationBuilder::Rem; +%unignore xla::swig::ComputationBuilder::Max; +%unignore xla::swig::ComputationBuilder::Min; +%unignore xla::swig::ComputationBuilder::And; +%unignore xla::swig::ComputationBuilder::Or; +%unignore xla::swig::ComputationBuilder::Xor; +%unignore xla::swig::ComputationBuilder::ShiftLeft; +%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::ComputationBuilder::ShiftRightLogical; +%unignore xla::swig::ComputationBuilder::Not; +%unignore xla::swig::ComputationBuilder::Clz; +%unignore xla::swig::ComputationBuilder::Abs; +%unignore xla::swig::ComputationBuilder::Exp; +%unignore xla::swig::ComputationBuilder::Expm1; +%unignore xla::swig::ComputationBuilder::Floor; +%unignore xla::swig::ComputationBuilder::Ceil; +%unignore xla::swig::ComputationBuilder::Round; +%unignore xla::swig::ComputationBuilder::Log; +%unignore xla::swig::ComputationBuilder::Log1p; +%unignore xla::swig::ComputationBuilder::Sign; +%unignore xla::swig::ComputationBuilder::Cos; +%unignore xla::swig::ComputationBuilder::Sin; +%unignore xla::swig::ComputationBuilder::Tanh; +%unignore xla::swig::ComputationBuilder::Atan2; +%unignore xla::swig::ComputationBuilder::IsFinite; +%unignore xla::swig::ComputationBuilder::Pow; +%unignore xla::swig::ComputationBuilder::Neg; +%unignore xla::swig::ComputationBuilder::Sort; +%unignore xla::swig::ComputationBuilder::SortKeyVal; +%unignore xla::swig::ComputationBuilder::Sqrt; +%unignore xla::swig::ComputationBuilder::Rsqrt; +%unignore xla::swig::ComputationBuilder::Square; +%unignore xla::swig::ComputationBuilder::Reciprocal; +%unignore xla::swig::ComputationBuilder::Erfc; +%unignore xla::swig::ComputationBuilder::Erf; +%unignore xla::swig::ComputationBuilder::ErfInv; +%unignore xla::swig::ComputationBuilder::Lgamma; +%unignore xla::swig::ComputationBuilder::Digamma; +%unignore xla::swig::ComputationBuilder::Acos; +%unignore xla::swig::ComputationBuilder::Asin; +%unignore xla::swig::ComputationBuilder::Atan; +%unignore xla::swig::ComputationBuilder::Tan; +%unignore xla::swig::ComputationBuilder::Acosh; +%unignore xla::swig::ComputationBuilder::Asinh; +%unignore xla::swig::ComputationBuilder::Atanh; +%unignore xla::swig::ComputationBuilder::Cosh; +%unignore xla::swig::ComputationBuilder::Sinh; +%unignore xla::swig::ComputationBuilder::Real; +%unignore xla::swig::ComputationBuilder::Imag; +%unignore xla::swig::ComputationBuilder::Conj; +%unignore xla::swig::ComputationBuilder::Complex; +%unignore xla::swig::ComputationBuilder::Cholesky; +%unignore xla::swig::ComputationBuilder::QR; +%unignore xla::swig::ComputationBuilder::Eigh; +%unignore xla::swig::ComputationBuilder::SVD; +%unignore xla::swig::ComputationBuilder::TriangularSolve; +%unignore xla::swig::ComputationBuilder::CustomCall; +%unignore xla::swig::ComputationBuilder::Gather; +%unignore xla::swig::ComputationBuilder::Scatter; +%unignore xla::swig::DeleteComputation; %unignore xla::swig::DeleteLocalShapedBuffer; -%unignore xla::swig::DeleteXrtAllocation; -%unignore xla::swig::DeleteCompiledLocalComputation; -%unignore xla::swig::DeleteCompiledXrtComputation; +%unignore xla::swig::DeleteLocalExecutable; %thread; %include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 52c5c621f7294c5da341879d15b77559fe870551..de7b1e48a8dbb8f2cdb1709258a218f4a07e7688 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -26,6 +26,10 @@ namespace swig { namespace numpy { +Safe_PyObjectPtr make_safe(PyObject* object) { + return Safe_PyObjectPtr(object); +} + int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { switch (primitive_type) { case PRED: @@ -123,28 +127,42 @@ bool NumpyTypeIsValid(int np_type) { } } -PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) { int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); - PyObject* dimensions; + Safe_PyObjectPtr dimensions; if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); - dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape))); for (int i = 0; i < num_elements; ++i) { PyTuple_SET_ITEM( - dimensions, i, - PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + dimensions.get(), i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)) + .release()); } } else { int rank = shape.rank(); - dimensions = PyTuple_New(rank); + dimensions = make_safe(PyTuple_New(rank)); for (int i = 0; i < rank; ++i) { - PyTuple_SET_ITEM(dimensions, i, + PyTuple_SET_ITEM(dimensions.get(), i, LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); } } - return PyTuple_Pack(2, np_dtype, dimensions); + return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release())); +} + +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape) { + Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size())); + for (int i = 0; i < shape.parameters_size(); ++i) { + PyTuple_SET_ITEM(arg_shapes.get(), i, + PyShapeInfoFromXlaShape(shape.parameters(i)).release()); + } + + Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result()); + return make_safe( + PyTuple_Pack(2, arg_shapes.release(), result_shape.release())); } // Precondition: o->ob_type == &PyArrayDescr_Type @@ -349,13 +367,17 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); - PyObject* tuple = PyTuple_New(num_elements); + std::vector elems(num_elements); + for (int i = 0; i < num_elements; i++) { + TF_ASSIGN_OR_RETURN(elems[i], + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + } + Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements)); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM(tuple, i, - PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + PyTuple_SET_ITEM(tuple.get(), i, elems[i].release()); } return tuple; } else { @@ -365,10 +387,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); } int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); - PyObject* array = - PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); - CopyLiteralToNumpyArray(np_type, literal, - reinterpret_cast(array)); + Safe_PyObjectPtr array = make_safe( + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0)); + TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray( + np_type, literal, reinterpret_cast(array.get()))); return array; } } @@ -408,6 +430,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_BOOL: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_INT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_INT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -417,6 +445,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_UINT8: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_UINT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_UINT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -445,12 +476,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array) { +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_INT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_INT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -460,6 +497,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_UINT8: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_UINT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_UINT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -482,8 +522,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, CopyLiteralToNumpyArray(literal, py_array); break; default: - LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + return InvalidArgument( + "No XLA literal container for Numpy type number: %d", np_type); } + return Status::OK(); } PyObject* LongToPyIntOrPyLong(long x) { // NOLINT @@ -525,6 +567,118 @@ PyObject* PyNumberToPyInt(PyObject* o) { } // namespace numpy +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } +#if PY_MAJOR_VERSION < 3 + if (!PyString_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); +#else + if (!PyBytes_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyBytes_AsString(attr)); +#endif + + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleBoolAttribute(PyObject* o, const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } + if (!PyBool_Check(attr)) { + string message = absl::StrFormat("%s must be a boolean or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyObject_IsTrue(attr)); + Py_DECREF(attr); + return true; // Handled boolean attribute, ok! +} + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; +} + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 40ff2d9ad214cc4dcad42234fa296834cbc92882..d7a611d7fd5c708b313db04cce8e05f1a72c5e47 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -36,6 +36,16 @@ namespace swig { namespace numpy { +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +// Safe container for an owned PyObject. On destruction, the reference count of +// the contained object will be decremented. +using Safe_PyObjectPtr = std::unique_ptr; + +Safe_PyObjectPtr make_safe(PyObject* object); + // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and // vice versa. @@ -54,7 +64,13 @@ bool NumpyTypeIsValid(int np_type); // providing the array dimensions. // // The return value is a new reference. -PyObject* PyShapeInfoFromXlaShape(const Shape& shape); +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple +// of argument shapes and result_shape is the result shape. Each shape is as +// described in in PyShapeInfoFromXlaShape's comment. +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape); // Converts a Python object with a method interface mathing that of // xla_client.Shape into an XLA Shape object. @@ -74,7 +90,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,8 +106,8 @@ StatusOr XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array); +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array); template void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { @@ -120,6 +136,20 @@ PyObject* PyNumberToPyInt(PyObject* o); } // namespace numpy +// Miscellaneous swig helpers that don't have a better home. + +bool GetIntAttr(PyObject* o, const char* field, int64* result); + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f); +bool HandleBoolAttribute(PyObject* o, const char* attr_name, + std::function f); + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field); + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..ef77ed3d95850fdfc7145e6fe1df4833d20bb7df --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds @@ -0,0 +1,2 @@ +_PyInit__pywrap_xla +_init_pywrap_xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..d31cfce7be7b6accf05ef77f3485904099965afc --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds @@ -0,0 +1,6 @@ +xla { + global: + PyInit_*; + local: + *; +}; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 1684cb20e6d2de8bf4a2545092f82cc42394790a..cb7d19d39b1d50a5bec564b59779c6f93018ed81 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An in-process, local XLA client in Python, supporting AOT compilation.""" +"""An XLA client in Python, supporting AOT compilation.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import enum # pylint: disable=g-bad-import-order import inspect @@ -29,17 +30,39 @@ import numpy as np import six from six.moves import xrange -from tensorflow.compiler.xla import xla_data_pb2 +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + from tensorflow.compiler.xla.python import pywrap_xla as c_api -from tensorflow.compiler.xla.service import hlo_pb2 + +# Import the XRT backend, if available. +try: + # pylint: disable=g-import-not-at-top + from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api +except ImportError: + xrt_api = None # Most functions are snake_case for consistency with other modules, whereas -# method names of ComputationBuilder and LocalComputation are CamelCase for +# method names of ComputationBuilder and Computation are CamelCase for # consistency with XLA. # pylint: disable=invalid-name +# Version of the XLA Python client. +# +# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is +# packaged separately from the Python code that consumes it (jax). +# +# We occasionally need to make backwards-incompatible changes to jaxlib, in +# which case we need to be able to detect when incompatible versions are +# installed. +def version(): + return (0, 1, 8) + + _OP_METADATA_FIELDS = [ 'op_type', 'op_name', @@ -49,22 +72,163 @@ _OP_METADATA_FIELDS = [ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) +@six.add_metaclass(abc.ABCMeta) +class Backend(object): + """Abstract base class for XLA backends.""" + + @abc.abstractmethod + def device_count(self): + """Returns the number of devices known to the backend.""" + + @abc.abstractmethod + def buffer_from_pyval(self, pyval, device=0): + """Allocates a fresh buffer and populates it with `pyval`.""" + + @abc.abstractmethod + def delete_buffer(self, c_buffer): + """Deletes buffer `c_buffer`.""" + + @abc.abstractmethod + def destructure_tuple(self, c_buffer): + """Destructures a tuple buffer into a sequence of buffers.""" + + @abc.abstractmethod + def compile(self, computation, argument_shapes, result_shape, + compile_options): + """Compiles a computation. Returns an executable.""" + + @abc.abstractmethod + def delete_executable(self, executable): + """Deletes an executable.""" + + @abc.abstractmethod + def execute(self, executable, args): + """Runs an executable without replication.""" + + @abc.abstractmethod + def execute_replicated(self, executable, per_replica_args): + """Runs an executable in a replicated manner.""" + + +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + +class XlaLocalBackend(Backend): + """XLA backend implemented using the in-process xla::LocalClient API.""" + + def __init__(self, platform=None): + platform = platform or _get_default_platform_name() + self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) + self._delete_buffer = c_api.DeleteLocalShapedBuffer + self._delete_executable = c_api.DeleteLocalExecutable + + def device_count(self): + return self.client.DeviceCount() + + def buffer_from_pyval(self, pyval, device=0): + return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) + + def delete_buffer(self, c_buffer): + self._delete_buffer(c_buffer) + + def destructure_tuple(self, c_buffer): + result = c_buffer.DestructureTuple() + return [result.Release(i) for i in xrange(result.size())] + + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return c_computation.Compile(argument_shapes, compile_options, self.client) + + def delete_executable(self, executable): + self._delete_executable(executable) + + def execute(self, executable, args): + return executable.Execute(args) + + def execute_replicated(self, executable, per_replica_args): + output_buffer_tup = executable.ExecutePerReplica(per_replica_args) + size = output_buffer_tup.size() + return [output_buffer_tup.Release(i) for i in xrange(size)] + + +class XrtBackend(Backend): + """XLA backend implemented using XRT.""" + + def __init__(self, target): + self.target = target + self._delete_buffer = xrt_api.DeleteXrtAllocation + self._delete_executable = xrt_api.DeleteXrtExecutable + + def device_count(self): + return 1 # Multidevice execution not implemented. + + def buffer_from_pyval(self, pyval, device=0): + if device != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + return xrt_api.XrtAllocation.FromLiteral(pyval, + _maybe_encode_string(self.target)) + + def delete_buffer(self, c_buffer): + self._delete_buffer(c_buffer) + + def destructure_tuple(self, c_buffer): + result = xrt_api.DestructureXrtAllocationTuple( + c_buffer, _maybe_encode_string(self.target)) + return [result.Release(i) for i in xrange(result.size())] + + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return xrt_api.XrtExecutable.CompileForXrt( + c_computation.GetSerializedProto(), argument_shapes, result_shape, + _maybe_encode_string(self.target)) + + def delete_executable(self, executable): + self._delete_executable(executable) + + def execute(self, executable, args): + return executable.Execute(args) + + def execute_replicated(self, executable, per_replica_args): + if len(per_replica_args) != 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + return [executable.Execute(per_replica_args[0])] + + +_default_platform_name = 'Host' +_default_backend = None + + +def _get_default_platform_name(): + return _default_platform_name + + +def _get_default_local_backend(): + global _default_backend + global _default_platform_name + if _default_backend is None: + _default_backend = XlaLocalBackend(_default_platform_name) + return _default_backend + + class BackendType(enum.Enum): XLA_LOCAL = 1 XRT = 2 -BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target')) -XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local') - - -def OpMetadataToProto(pyobj): - proto = xla_data_pb2.OpMetadata() - for field in _OP_METADATA_FIELDS: - attr = getattr(pyobj, field) - if attr is not None: - setattr(proto, field, attr) - return proto +def BackendSpec(backend, target): + """Compatibility wrapper to support older clients. Do not use in new code.""" + if backend == BackendType.XLA_LOCAL: + return _get_default_local_backend() + elif backend == BackendType.XRT: + return XrtBackend(target) + else: + raise ValueError('Unknown backend {}'.format(backend)) def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): @@ -78,13 +242,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) -def _maybe_encode_string(s): - if six.PY3: - return s.encode('utf-8') - else: - return s - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -113,8 +270,7 @@ def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) for out_size, stride, filter_size, in_size in zip(out_shape, window_strides, rhs_dims, lhs_dims)] - return [(pad_size // 2, pad_size - pad_size // 2) - for pad_size in pad_sizes] + return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] else: msg = 'Unexpected PaddingType value: {}' raise ValueError(msg.format(padding_type)) @@ -122,6 +278,7 @@ def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, _UNARY_OPS = [ 'Not', + 'Clz', 'Abs', 'Exp', 'Expm1', @@ -185,29 +342,56 @@ _BINARY_OPS = [ ] +class PrimitiveType(enum.IntEnum): + """Python copy of the XLA PrimitiveType enum. + + Must match the corresponding protocol buffer. + """ + PRIMITIVE_TYPE_INVALID = 0 + PRED = 1 + S8 = 2 + S16 = 3 + S32 = 4 + S64 = 5 + U8 = 6 + U16 = 7 + U32 = 8 + U64 = 9 + BF16 = 16 + F16 = 10 + F32 = 11 + F64 = 12 + C64 = 15 + C128 = 18 + TUPLE = 13 + OPAQUE = 14 + TOKEN = 17 + + XLA_ELEMENT_TYPE_TO_DTYPE = { - xla_data_pb2.PRED: np.dtype('bool'), - xla_data_pb2.S8: np.dtype('int8'), - xla_data_pb2.S16: np.dtype('int16'), - xla_data_pb2.S32: np.dtype('int32'), - xla_data_pb2.S64: np.dtype('int64'), - xla_data_pb2.U8: np.dtype('uint8'), - xla_data_pb2.U16: np.dtype('uint16'), - xla_data_pb2.U32: np.dtype('uint32'), - xla_data_pb2.U64: np.dtype('uint64'), - xla_data_pb2.F16: np.dtype('float16'), - xla_data_pb2.F32: np.dtype('float32'), - xla_data_pb2.F64: np.dtype('float64'), - xla_data_pb2.C64: np.dtype('complex64'), - xla_data_pb2.C128: np.dtype('complex128'), - xla_data_pb2.TUPLE: np.dtype(np.object), + PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S8: np.dtype('int8'), + PrimitiveType.S16: np.dtype('int16'), + PrimitiveType.S32: np.dtype('int32'), + PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U8: np.dtype('uint8'), + PrimitiveType.U16: np.dtype('uint16'), + PrimitiveType.U32: np.dtype('uint32'), + PrimitiveType.U64: np.dtype('uint64'), + PrimitiveType.F16: np.dtype('float16'), + PrimitiveType.F32: np.dtype('float32'), + PrimitiveType.F64: np.dtype('float64'), + PrimitiveType.C64: np.dtype('complex64'), + PrimitiveType.C128: np.dtype('complex128'), + PrimitiveType.TUPLE: np.dtype(np.object), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et - for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()} +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} def dtype_to_etype(dtype): @@ -223,33 +407,18 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend, replica): + def __init__(self, c_buffer, backend, device): self.c_buffer = c_buffer self._backend = backend - self._replica = replica - if backend.backend_type == BackendType.XRT: - self._delete = c_api.DeleteXrtAllocation - else: - self._delete = c_api.DeleteLocalShapedBuffer + self._device = device @staticmethod - def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, device=0, backend=None): """Allocate and copy to XLA the given python value.""" + backend = backend or _get_default_local_backend() pyval = require_numpy_array_layout(pyval) - num_replicas = get_replica_count() - if not 0 <= replica < num_replicas: - raise ValueError( - 'Attempt to place buffer on replica {} when the replica count is {}' - .format(replica, num_replicas)) - if backend.backend_type == BackendType.XRT: - if replica != 0: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - cbuf = c_api.XrtAllocation.FromLiteral( - pyval, _maybe_encode_string(backend.target)) - else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) - return LocalBuffer(cbuf, backend, replica) + cbuf = backend.buffer_from_pyval(pyval, device) + return LocalBuffer(cbuf, backend, device) def to_py(self): return self.c_buffer.ToLiteral() @@ -257,29 +426,24 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) - def replica(self): - return self._replica + def device(self): + return self._device def delete(self): if self.c_buffer is not None: - self._delete(self.c_buffer) + # Python may have freed c_api first. + if c_api: + self._backend.delete_buffer(self.c_buffer) self.c_buffer = None def destructure(self): """Assuming a tuple buffer, unpack it into constituent tuple elements.""" assert self.c_buffer is not None - if self._backend.backend_type == BackendType.XRT: - result = c_api.DestructureXrtAllocationTuple( - self.c_buffer, _maybe_encode_string(self._backend.target)) - else: - result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) + result = self._backend.destructure_tuple(self.c_buffer) self.delete() - size = result.size() - destructured = tuple( - LocalBuffer( - result.Release(i), replica=self._replica, backend=self._backend) - for i in xrange(size)) - return destructured + return tuple( + LocalBuffer(sub_buffer, device=self._device, backend=self._backend) + for sub_buffer in result) def is_deleted(self): return self.c_buffer is None @@ -288,6 +452,13 @@ class LocalBuffer(object): self.delete() +class Format(enum.IntEnum): + """Python copy of the Format protocol buffer enum.""" + INVALID_FORMAT = 0 + DENSE = 1 + SPARSE = 2 + + class Shape(object): """Represents an XLA shape. @@ -317,8 +488,8 @@ class Shape(object): if (not isinstance(dimensions, tuple) or not all(isinstance(i, int) for i in dimensions)): dimensions = tuple(int(i) for i in dimensions) - return Shape(dimensions, np.dtype(element_type), - minor_to_major=minor_to_major) + return Shape( + dimensions, np.dtype(element_type), minor_to_major=minor_to_major) @staticmethod def from_pyval(pyval): @@ -397,8 +568,8 @@ class Shape(object): """Map f over each leaf-level array subshape. Args: - f: The function to apply. Whenever f returns None, the identity is - applied instead. + f: The function to apply. Whenever f returns None, the identity is applied + instead. Returns: A new Shape with the mapped leaves. @@ -423,22 +594,56 @@ class Shape(object): raise ValueError('not an array shape') if not isinstance(minor_to_major, tuple): raise TypeError('minor_to_major must be a tuple') - updated = Shape.array_shape( - self.element_type(), self.dimensions(), minor_to_major) + updated = Shape.array_shape(self.element_type(), self.dimensions(), + minor_to_major) updated._check_minor_to_major() # pylint: disable=protected-access return updated + def with_major_to_minor_layout_if_absent(self): + """Returns a copy of a shape with missing layouts set to major-to-minor.""" + + def f(a): + if a.minor_to_major(): + return None + return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1))) + + return self.map_leaves(f) + + def serialize(self, proto): + """Serializes 'shape' into proto.""" + if self.is_tuple(): + proto.element_type = PrimitiveType.TUPLE + for shape in self.tuple_shapes(): + shape.serialize(proto.tuple_shapes.add()) + else: + proto.element_type = dtype_to_etype(self.element_type()) + proto.dimensions.extend(self.dimensions()) + proto.is_dynamic_dimension.extend([False for _ in self.dimensions()]) + if self.minor_to_major(): + proto.layout.format = Format.DENSE + proto.layout.minor_to_major.extend(self.minor_to_major()) + + +ProgramShape = collections.namedtuple('ProgramShape', + ('parameter_shapes', 'result_shape')) + def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] - if element_type == xla_data_pb2.TUPLE: + if element_type == PrimitiveType.TUPLE: shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims) return Shape.tuple_shape(shapes) else: return Shape.array_shape(dtype, dims) +def _wrap_program_shape(shape_info): + arg_shapes, result_shape = shape_info + return ProgramShape([_wrap_shape(arg) for arg in arg_shapes], + _wrap_shape(result_shape)) + + def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -454,14 +659,16 @@ class CompileOptions(object): """ def __init__(self): - self.generate_hlo_graph = None - self.dump_optimized_hlo_proto_to = None - self.dump_unoptimized_hlo_proto_to = None - self.dump_per_pass_hlo_proto_to = None - self.hlo_profile = False + self.xla_dump_to = None + self.dump_hlo_pass_re = None + self.dump_hlo_module_re = None + self.dump_hlo_as_text = None + self.dump_hlo_as_proto = None + self.hlo_profile = None + self.num_replicas = get_replica_count() -def transfer_to_infeed(value, replica_number=None): +def transfer_to_infeed(value, device_ordinal=0): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -471,96 +678,92 @@ def transfer_to_infeed(value, replica_number=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - replica_number: the replica number to infeed the value to -- if not - provided, then the default replica (trivially replica 0) is used. + device_ordinal: the device to infeed the value to. Each device has a + distinct infeed queue. """ - if replica_number is None: - c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) - else: - c_api.TransferToInfeedLocalReplica( - require_numpy_array_layout(value), replica_number) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + backend.client.TransferToInfeed( + require_numpy_array_layout(value), device_ordinal) -def transfer_from_outfeed(shape, replica_number=None): - """Transfers a literal of the given shape from replica_number's outfeed. +def transfer_from_outfeed(shape, device_ordinal=0): + """Transfers a literal of the given shape from `device_ordinal`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - replica_number: The replica number ordinal to transfer the outfeed value - from. (Each replica has a distinct outfeed queue.) + device_ordinal: The device ordinal to transfer the outfeed value from. Each + device has a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + return backend.client.TransferFromOutfeed(shape, device_ordinal) -class LocalComputation(object): - """Python wrapper for a local XLA Computation. +class Computation(object): + """Python wrapper for an XLA Computation. - A LocalComputation can be executed if it is compiled. Otherwise, it - can still be used as a Computation where required by the - ComputationBuilder methods. + A Computation can be compiled to form an Executable, or used as a + subcomputation in ComputationBuilder methods. """ - def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): + def __init__(self, c_computation, backend=None): self._c_computation = c_computation + # The backend argument is deprecated. Pass a backend to Compile() instead. self._backend = backend - self._is_compiled = is_compiled - - # Ensure a reference to C-based destructor for use in __del__. - if is_compiled: - if backend.backend_type == BackendType.XRT: - assert isinstance(c_computation, c_api.CompiledXrtComputation) - self._delete = c_api.DeleteCompiledXrtComputation - else: - assert isinstance(c_computation, c_api.CompiledLocalComputation) - self._delete = c_api.DeleteCompiledLocalComputation - else: - assert isinstance(c_computation, c_api.LocalComputation) - self._delete = c_api.DeleteLocalComputation + self._delete_computation = c_api.DeleteComputation @property def computation(self): - if self._is_compiled: - raise ValueError( - 'Attempt to read the XLA computation of a compiled LocalComputation.') return self._c_computation - def GetProto(self): - """Get the HloModuleProto proto object in this local computation. + def GetSerializedProto(self): + """Gets the serialized HloModuleProto proto object in this computation. Returns: - An HloModuleProto proto object that has the whole-graph information. + A string containing a serialized HloModuleProto proto containing the + computation and its dependencies. """ - serialized = self.computation.GetSerializedProto() - proto = hlo_pb2.HloModuleProto.FromString(serialized) - return proto + return self.computation.GetSerializedProto() + + def GetHloText(self): + """Get the textual HLO representation of this computation. - def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): - """Compiles an un-compiled local computation. + Returns: + A string containing the textual HLO. + """ + return self.computation.GetHloText() + + def GetHloDotGraph(self): + """Get a Graphviz Dot representation of this computation. + + Returns: + A string containing the graphviz dot graph. + """ + return self.computation.GetHloDotGraph() - Local computations are the result of a "LocalComputationBuild'ing" process - -- they start in uncompiled form, and via a call to Compile() turn into a - compiled local computation. + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, + backend=None): + """Compiles a computation. - Raises: - ValueError: if this is already a compiled local computation. + Computations are the result of a "ComputationBuild'ing" process. Arguments: argument_shapes: parameter shapes -- they are first laid out by layout_fn if layout_fn is provided. Otherwise, the default layout for those shapes will be used. - compile_options: options to use for compilation, includes an optional - laid out result shape for the computation. + compile_options: options to use for compilation, includes an optional laid + out result shape for the computation. layout_fn: lambda that is used to lay out the argument/result shapes. + backend: a `Backend` for which an executable should be generated. Returns: - A newly *compiled* local computation instance. + A Executable instance. """ - if self._is_compiled: - raise ValueError('Attempt to compile a compiled local XLA computation.') - + backend = backend or self._backend or _get_default_local_backend() result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: @@ -573,32 +776,52 @@ class LocalComputation(object): compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - if self._backend.backend_type == BackendType.XRT: - c = self.computation.CompileForXrt( - argument_shapes, _maybe_encode_string(self._backend.target)) - else: - c = self.computation.Compile(argument_shapes, compile_options) - return LocalComputation(c, is_compiled=True, backend=self._backend) + c = backend.compile(self.computation, argument_shapes, result_shape, + compile_options) + return Executable(c, backend=backend) def CompileWithExampleArguments(self, arguments=(), compile_options=None, - layout_fn=None): + layout_fn=None, + backend=None): return self.Compile( argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, - layout_fn=layout_fn) + layout_fn=layout_fn, + backend=backend) + + def GetProgramShape(self): + return _wrap_program_shape(self._c_computation.GetProgramShape()) def GetReturnValueShape(self): return _wrap_shape(self._c_computation.GetReturnValueShape()) + def __del__(self): + if self._c_computation: + self._delete_computation(self._c_computation) + + +class Executable(object): + """Python wrapper for an XLA Executable.""" + + def __init__(self, c_executable, backend=None): + self._c_executable = c_executable + self._device_ordinals = c_executable.DeviceOrdinals() + self._backend = backend + + def DeviceOrdinals(self): + """Returns a list containing the device ordinals for each replica.""" + return self._device_ordinals + def Execute(self, arguments=(), check_for_deleted_args=True): """Execute on one replica with LocalBuffer arguments and return value.""" if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): raise ValueError('Executing with deleted local buffer argument') raw_args = [arg.c_buffer for arg in arguments] - output_buffer = self._c_computation.Execute(raw_args) - return LocalBuffer(output_buffer, backend=self._backend, replica=0) + output_buffer = self._backend.execute(self._c_executable, raw_args) + return LocalBuffer( + output_buffer, backend=self._backend, device=self._device_ordinals[0]) def ExecutePerReplica(self, arguments=None): """Execute on many replicas with LocalBuffer arguments and return value. @@ -608,14 +831,12 @@ class LocalComputation(object): sequence comprises the arguments for execution on the i'th replica. Returns: - A list of the computation's outputs on each replica, as a LocalBuffer. If + A list of the computation's outputs for each replica, as a LocalBuffer. If a shallow sequence of arguments was passed in for `arguments`, then the sole, zero'th replica's output is returned instead, as a LocalBuffer. """ - if not self._is_compiled: - raise ValueError('Cannot execute an uncompiled local XLA computation.') if arguments is None: - arguments = ((),) * get_replica_count() + arguments = ((),) * len(self._device_ordinals) else: arguments = [list(replica_args) for replica_args in arguments] @@ -624,37 +845,35 @@ class LocalComputation(object): for arg in replica_args: if arg.is_deleted(): raise ValueError('Executing with deleted local buffer argument') - if arg.replica() != replica: + if arg.device() != self._device_ordinals[replica]: raise ValueError( - 'Executing on replica {} with argument from replica {}'.format( - replica, arg.replica())) + 'Executing on device {} with argument from device {}'.format( + self._device_ordinals[replica], arg.device())) # Pull out argument buffer handles + # pylint: disable=g-complex-comprehension stripped_args = [ [arg.c_buffer for arg in replica_args] for replica_args in arguments ] # Execute - if self._backend.backend_type == BackendType.XRT: - if len(stripped_args) > 1: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - output_buffers = [self._c_computation.Execute(stripped_args[0])] - else: - output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) - size = output_buffer_tup.size() - output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + output_buffers = self._backend.execute_replicated(self._c_executable, + stripped_args) # Wrap output handles in LocalBuffer instances return tuple( - LocalBuffer(output_buffer, backend=self._backend, replica=replica) + LocalBuffer( + output_buffer, + backend=self._backend, + device=self._device_ordinals[replica]) for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): """Execute on one replica with Python values as arguments and output.""" def put(arg): - return LocalBuffer.from_pyval(arg, backend=self._backend) + return LocalBuffer.from_pyval( + arg, device=self._device_ordinals[0], backend=self._backend) arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() @@ -662,24 +881,27 @@ class LocalComputation(object): def ExecuteWithPythonValuesPerReplica(self, arguments): """Execute on many replicas with Python values as arguments and output.""" - def put(arg, replica): - return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + def put(arg, device): + return LocalBuffer.from_pyval(arg, device, backend=self._backend) - arguments = [[put(arg, replica) - for arg in replica_args] - for replica, replica_args in enumerate(arguments)] + # pylint: disable=g-complex-comprehension + arguments = [[ + put(arg, self._device_ordinals[replica]) for arg in replica_args + ] for replica, replica_args in enumerate(arguments)] return [out.to_py() for out in self.ExecutePerReplica(arguments)] def __del__(self): - self._delete(self._c_computation) + # Python may have freed c_api first. + if c_api and self._c_executable: + self._backend.delete_executable(self._c_executable) class ComputationBuilder(object): """XLA computation builder. Enqueues XLA ops in sequence and in order to build a - LocalComputation, which in turn can be compiled into a - CompiledLocalComputation, which in turn can be locally executed. + Computation, which in turn can be compiled into a + LocalExecutable, which in turn can be locally executed. """ # The methods of this class map 1-to-1 onto the XLA C++ @@ -690,16 +912,24 @@ class ComputationBuilder(object): # pylint: disable=g-doc-args def __init__(self, name): - self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._client = c_api.ComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self, root=None, backend=XLA_LOCAL_BACKEND): + def Build(self, root=None, backend=None): + """Builds a `Computation` from the contents of the builder. + + Args: + root: if not None, the operator containing the return value of the + computation. + backend: deprecated. Pass a `backend` to `Computation.Compile` instead. + + Returns: + A `Computation`. + """ if root is not None: - return LocalComputation( - self._client.BuildWithRoot(root), is_compiled=False, backend=backend) + return Computation(self._client.BuildWithRoot(root), backend=backend) else: - return LocalComputation( - self._client.Build(), is_compiled=False, backend=backend) + return Computation(self._client.Build(), backend=backend) def SetOpMetadata(self, op_metadata): """Set metadata for operations that are about to be enqueued.""" @@ -732,8 +962,8 @@ class ComputationBuilder(object): """Enqueues a constant op onto the computation. Args: - value: value for the constant, as a np.array with an explicit dtype set - to one of the supported types. + value: value for the constant, as a np.array with an explicit dtype set to + one of the supported types. Returns: A LocalOp. @@ -802,9 +1032,9 @@ class ComputationBuilder(object): Args: shape: the parameter's shape as a Shape object. name: optional string name for the parameter. - parameter_num: parameter number in the computation function. If None, - the next linear parameter number is used. The default value capability - can be used for auto-numbering. If you're using auto-numbering for some + parameter_num: parameter number in the computation function. If None, the + next linear parameter number is used. The default value capability can + be used for auto-numbering. If you're using auto-numbering for some parameters, use it for *all* parameters to avoid clashes. Returns: @@ -821,8 +1051,8 @@ class ComputationBuilder(object): """Enqueues a Parameter op onto the computation. Args: - value: a Numpy array, or a nested tuple thereof, from which the - shape is inferred. + value: a Numpy array, or a nested tuple thereof, from which the shape is + inferred. name: as in ParameterWithShape. parameter_num: as in ParameterWithShape. @@ -877,8 +1107,8 @@ class ComputationBuilder(object): Args: operand: the operand LocalOp to broadcast. shape: tuple of integers, the expected output shape. - broadcast_dimensions: tuple of integers identifying which dimensions - of the output are to be broadcast into. + broadcast_dimensions: tuple of integers identifying which dimensions of + the output are to be broadcast into. Returns: A LocalOp representing the added broadcast-in-dimensions op. @@ -930,20 +1160,28 @@ class ComputationBuilder(object): def GetComputationStats(self): raise NotImplementedError() + def ReplicaId(self): + """Enqueues a ReplicaId operation onto the computation. + + Returns: + A LocalOp representing the replica id. + """ + return self._client.ReplicaId() + def Pad(self, operand, padding_value, padding_config): """Enqueues a Pad operation onto the computation. Args: operand: LocalOp representing the array to pad. padding_value: LocalOp representing the scalar pad value. - padding_config: either an xla_data_pb2.PaddingConfig or a list of integer - triples (edge_padding_low, edge_padding_high, interior_padding) - representing the configuration of the padding operation. + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. Returns: A LocalOp representing the added Pad op. """ - if not isinstance(padding_config, xla_data_pb2.PaddingConfig): + if isinstance(padding_config, tuple) or isinstance(padding_config, list): padding_config = GetPaddingConfigFromTriples(padding_config) return self._client.Pad(operand, padding_value, padding_config) @@ -964,16 +1202,62 @@ class ComputationBuilder(object): dimensions = tuple(range(ndim)) return self._client.Reshape(operand, dimensions, new_sizes) - def CrossReplicaSum(self, operand): + def AllToAll(self, + operand, + split_dimension, + concat_dimension, + replica_groups=None): + """AllToAll op. + + Args: + operand: LocalOp representing the input array + split_dimension: the dimension along which the operand is split + concat_dimension: the dimension along which the split blocks are + concatenated + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the all-to-all is performed. If not supplied or None (the + default), all replicas belong to the same group. + + Returns: + A LocalOp that represents the all-to-all concatenation. + """ + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups + ] + if not replica_groups: + split_count = get_replica_count() + else: + split_count = len(replica_groups[0]) + if not all(split_count == len(g) for g in replica_groups): + raise ValueError('Replica groups must be equally sized') + return self._client.AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups_protos) + + def CrossReplicaSum(self, operand, replica_groups=None): """CrossReplicaSum op. Args: operand: the operand to sum across replica instances. + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the cross-replica sum is performed. If not supplied or None + (the default), all replicas belong to the same group. Returns: - A LocalOp that has the sum of the value among all replicas. + A LocalOp that represents on each replica the sum of its group's values. """ - return self._client.CrossReplicaSum(operand) + if replica_groups is None: + replica_groups = [] # special value for XLA API + else: + replica_groups = [ + _make_replica_group_proto(group) for group in replica_groups + ] + return self._client.CrossReplicaSum(operand, replica_groups) def Collapse(self, operand, dimensions): """Collapse op.""" @@ -1000,8 +1284,8 @@ class ComputationBuilder(object): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: LocalOp for array of dimension N and type T over - which the windows slide. + operand: LocalOp for array of dimension N and type T over which the + windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. @@ -1016,8 +1300,8 @@ class ComputationBuilder(object): A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( - padding, self.GetShape(operand).dimensions(), - window_dimensions, window_strides) + padding, self.GetShape(operand).dimensions(), window_dimensions, + window_strides) return self._client.SelectAndScatterWithGeneralPadding( operand, select.computation, window_dimensions, window_strides, pads, source, init_value, scatter.computation) @@ -1071,8 +1355,8 @@ class ComputationBuilder(object): Args: operand: LocalOp for the N dimensional array to be sliced. - start_indices: LocalOp for the 1D array of N integers - containing the starting indices of the slice. + start_indices: LocalOp for the 1D array of N integers containing the + starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. @@ -1089,6 +1373,7 @@ class ComputationBuilder(object): update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. + Returns: A LocalOp representing the added DynamicUpdateSlice op. """ @@ -1122,8 +1407,8 @@ class ComputationBuilder(object): Args: computation_to_apply: a Computation object. - operands: an iterable of LocalOp. The number and types of - operands must match the arity of computation_to_apply. + operands: an iterable of LocalOp. The number and types of operands must + match the arity of computation_to_apply. Returns: A LocalOp representing the added call op. @@ -1200,8 +1485,8 @@ class ComputationBuilder(object): A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( - padding, self.GetShape(operand).dimensions(), window_dimensions, - window_strides) + padding, + self.GetShape(operand).dimensions(), window_dimensions, window_strides) return self._client.ReduceWindowWithGeneralPadding( operand, init_value, computation_to_apply.computation, window_dimensions, window_strides, (), (), pads) @@ -1234,10 +1519,8 @@ class ComputationBuilder(object): Args: mu: A LocalOp to an F32 scalar specifying the mean. - sigma: A LocalOp to an F32 scalar specifying the standard - deviation. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) @@ -1247,16 +1530,15 @@ class ComputationBuilder(object): """Enqueues an RngUniform operation onto the computation. Args: - a: a LocalOp to an F32, S32, or U32 scalar (consistent with - the type of b) specifying the low end of the interval [a, b) over which - values are generated. - b: a LocalOp to an F32, S32, or U32 scalar (consistent with - the type of a) specifying the high end of the interval [a, b) over which - values are generated. + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) + specifying the low end of the interval [a, b) over which values are + generated. + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) + specifying the high end of the interval [a, b) over which values are + generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - - Returns: a LocalOp to the generated array of values with the - same numeric type (F32, S32, or U32) as the arguments a and b. + Returns: a LocalOp to the generated array of values with the same numeric + type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) return self._client.RngUniform(a, b, shape) @@ -1268,7 +1550,6 @@ class ComputationBuilder(object): cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T init: a LocalOp for the initial parameter, which has type T - Returns: a LocalOp representing the While operation. """ return self._client.While(cond.computation, body.computation, init) @@ -1283,19 +1564,17 @@ class ComputationBuilder(object): true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a LocalOp representing the Conditional operation. """ - return self._client.Conditional( - pred, true_operand, true_computation.computation, false_operand, - false_computation.computation) + return self._client.Conditional(pred, true_operand, + true_computation.computation, false_operand, + false_computation.computation) def IsConstant(self, operand): """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - Returns: bool indicating whether `operand` is a compile-time constant, meaning its value does not depend on any parametersor, or on stateful operators such as `RngNormal` or `Infeed`. @@ -1307,7 +1586,7 @@ class ComputationBuilder(object): Args: operand: a LocalOp to test. - Returns: a LocalComputation that is rooted on the given `operand` which is a + Returns: a Computation that is rooted on the given `operand` which is a compile-time constant. """ return self._client.BuildConstantSubGraph(operand) @@ -1318,7 +1597,6 @@ class ComputationBuilder(object): Args: lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a LocalOp representing the Dot operation. """ return self._client.Dot(lhs, rhs) @@ -1329,14 +1607,13 @@ class ComputationBuilder(object): Args: lhs: LocalOp for the left-hand-side array. rhs: LocalOp for the right-hand-side array. - dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested - tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of + dimension_numbers: either a DotDimensionNumbers or a nested tuple + ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a LocalOp representing the DotGeneral operation. """ - if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): + if isinstance(dimension_numbers, tuple): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) return self._client.DotGeneral(lhs, rhs, dimension_numbers) @@ -1349,15 +1626,15 @@ class ComputationBuilder(object): window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. feature_group_count: number of feature groups for grouped convolution. - Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( - padding, self.GetShape(lhs).dimensions()[2:], + padding, + self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) return self.ConvGeneralDilated( - lhs, rhs, window_strides, pads, (), (), - dimension_numbers=None, feature_group_count=feature_group_count) + lhs, rhs, window_strides, pads, (), (), dimension_numbers=None, + feature_group_count=feature_group_count) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count=1): @@ -1382,7 +1659,7 @@ class ComputationBuilder(object): def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" nd = num_spatial_dims - dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + dimension_numbers = ConvolutionDimensionNumbers() dimension_numbers.input_batch_dimension = 0 dimension_numbers.input_feature_dimension = 1 dimension_numbers.output_batch_dimension = 0 @@ -1406,35 +1683,33 @@ class ComputationBuilder(object): padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of integer dilation factors. rhs_dilation: length-N array-like of integer dilation factors. - dimension_numbers: optional, either an - xla_data_pb2.ConvolutionDimensionNumbers proto instance or a tuple - (lhs_spec, rhs_spec, out_spec) where each element is a string of length - N+2 identifying by position (1) batch dimensions in lhs, rhs, and the - output with the character 'N', (2) feature dimensions in lhs and the - output with the character 'C', (3) input and output feature dimensions - in rhs with the characters 'I' and 'O' respectively, and (4) spatial - dimension correspondences between lhs, rhs, and the output using any - distinct characters. For example, to indicate dimension numbers - consistent with the Conv operation with two spatial dimensions, one - could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate - dimension numbers consistent with the TensorFlow Conv2D operation, one - could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of - convolution dimension specification, window strides are associated with - spatial dimension character labels according to the order in which the - labels appear in the rhs_spec string, so that window_strides[0] is - matched with the dimension corresponding to the first character - appearing in rhs_spec that is not 'I' or 'O'. By default, use the same - dimension numbering as Conv and ConvWithGeneralPadding. + dimension_numbers: optional, either a ConvolutionDimensionNumbers object + or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of + length N+2 identifying by position: (1) batch dimensions in lhs, rhs, + and the output with the character 'N', (2) feature dimensions in lhs + and the output with the character 'C', (3) input and output feature + dimensions in rhs with the characters 'I' and 'O' respectively, and + (4) spatial dimension correspondences between lhs, rhs, and the output + using any distinct characters. For example, to indicate dimension + numbers consistent with the Conv operation with two spatial + dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another + example, to indicate dimension numbers consistent with the TensorFlow + Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using + the latter form of convolution dimension specification, window strides + are associated with spatial dimension character labels according to + the order in which the labels appear in the rhs_spec string, so that + window_strides[0] is matched with the dimension corresponding to the + first character appearing in rhs_spec that is not 'I' or 'O'. By + default, use the same dimension numbering as Conv and + ConvWithGeneralPadding. feature_group_count: number of feature groups for grouped convolution. - Returns: a LocalOp representing the ConvGenralDilated operation. """ if dimension_numbers is None: dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - elif not isinstance(dimension_numbers, - xla_data_pb2.ConvolutionDimensionNumbers): + elif isinstance(dimension_numbers, tuple): lhs_spec, rhs_spec, out_spec = dimension_numbers - dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + dimension_numbers = ConvolutionDimensionNumbers() dimension_numbers.input_batch_dimension = lhs_spec.index('N') dimension_numbers.input_feature_dimension = lhs_spec.index('C') @@ -1451,10 +1726,9 @@ class ComputationBuilder(object): dimension_numbers.output_spatial_dimensions.extend( sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), key=lambda i: rhs_spec.index(out_spec[i]))) - return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers, - feature_group_count) + return self._client.ConvGeneralDilated( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count) def Sort(self, operand, dimension=-1): """Enqueues a sort operation onto the computation.""" @@ -1464,31 +1738,50 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) - def Cholesky(self, a): + def Cholesky(self, a, lower=True): """Enqueues a Cholesky decomposition onto the computation.""" - return self._client.Cholesky(a) + return self._client.Cholesky(a, lower) def QR(self, a, full_matrices=True): """Enqueues a QR decomposition onto the computation.""" return self._client.QR(a, full_matrices) - def TriangularSolve(self, a, b, left_side=False, lower=False, - transpose_a=False, conjugate_a=False): + def TriangularSolve(self, + a, + b, + left_side=False, + lower=False, + transpose_a=False, + conjugate_a=False, + unit_diagonal=False): """Enqueues a triangular-solve operation onto the computation.""" - return self._client.TriangularSolve( - a, b, left_side, lower, transpose_a, conjugate_a) + if not transpose_a: + transpose = 1 + if conjugate_a: + a = self.Conj(a) + else: + transpose = 3 if conjugate_a else 2 + return self._client.TriangularSolve(a, b, left_side, lower, unit_diagonal, + transpose) + + def Eigh(self, a, full_matrices=True): + """Enqueues a symmetric/Hermitian eigendecomposition.""" + return self._client.Eigh(a, full_matrices) + + def SVD(self, a): + """Enqueues a singular value decomposition.""" + return self._client.SVD(a) def Gather(self, a, start_indices, dimension_numbers, slice_sizes): """Enqueues a Gather operation onto the computation.""" - return self._client.Gather(a, start_indices, dimension_numbers, - slice_sizes) + return self._client.Gather(a, start_indices, dimension_numbers, slice_sizes) def Scatter(self, a, scatter_indices, updates, update_computation, dimension_numbers): """Enqueues a Scatter operation onto the computation.""" return self._client.Scatter( a, scatter_indices, updates, update_computation.computation, - dimension_numbers,) + dimension_numbers) def _forward_methods_to_local_builder(): @@ -1496,7 +1789,7 @@ def _forward_methods_to_local_builder(): Set up methods, corresponding to unary and binary XLA operations, whose calls are forwarded in a boilerplate manner to the underlying - LocalComputationBuilder C-extension API. + ComputationBuilder C-extension API. """ def forward_to_local_builder_with_handles(target_method, is_binop=False): @@ -1516,22 +1809,27 @@ def _forward_methods_to_local_builder(): for method_name in _UNARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name)) + getattr(c_api.ComputationBuilder, method_name)) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) for method_name in _BINARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + getattr(c_api.ComputationBuilder, method_name), is_binop=True) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) _forward_methods_to_local_builder() +_default_replica_count = 1 + def initialize_replica_count(replica_count): - """Initializes the desired replica count to use on XLA service init. + """Initializes the default replica count to use. + + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. Args: replica_count: number of replicas that are desired for set up during XLA @@ -1540,29 +1838,30 @@ def initialize_replica_count(replica_count): Raises: A runtime exception if the XLA service has already been initialized. """ - c_api.InitializeReplicaCount(replica_count) - + global _default_replica_count + _default_replica_count = replica_count -def initialize_platform_name(platform_name): - """Initializes the desired platform name to use on XLA service init. - Args: - platform_name: string name of platform. +def get_replica_count(): + """Returns the default replica count. - Raises: - A runtime exception if the XLA service has already been initialized. + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. """ - platform_name = _maybe_encode_string(platform_name) - c_api.InitializePlatformName(platform_name) + return _default_replica_count -def get_replica_count(): - """Returns the current replica count used for the XLA service. +def initialize_platform_name(platform_name): + """Initializes the default platform name to use for XLA. - Note: this will return a value whether the XLA service has been initialized - yet or not. + Args: + platform_name: string name of platform. """ - return c_api.GetReplicaCount() + global _default_platform_name + _default_platform_name = platform_name + + # Make sure the platform is valid by trying to instantiate it. + _get_default_local_backend() def register_cpu_custom_call_target(name, fn): @@ -1575,22 +1874,111 @@ def register_cpu_custom_call_target(name, fn): c_api.RegisterCpuCustomCallTarget(name, fn) +class PaddingConfigDimension(object): + """Python representation of a xla.PaddingConfigDimension protobuf.""" + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') + + def __init__(self): + self.edge_padding_low = [] + self.edge_padding_high = [] + self.interior_padding = [] + + +class PaddingConfig(object): + """Python representation of a xla.PaddingConfig protobuf.""" + __slots__ = ('dimensions',) + + def __init__(self): + self.dimensions = [] + + def GetPaddingConfigFromTriples(triples): """Create PaddingConfig proto from list of triples of integers.""" - padding_config = xla_data_pb2.PaddingConfig() + padding_config = PaddingConfig() for lo, hi, interior in triples: - dimension = padding_config.dimensions.add() + dimension = PaddingConfigDimension() dimension.edge_padding_low = lo dimension.edge_padding_high = hi dimension.interior_padding = interior + padding_config.dimensions.append(dimension) return padding_config +class DotDimensionNumbers(object): + """Python representation of a xla.DotDimensionNumbers protobuf.""" + __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', 'rhs_batch_dimensions') + + def __init__(self): + self.lhs_contracting_dimensions = [] + self.rhs_contracting_dimensions = [] + self.lhs_batch_dimensions = [] + self.rhs_batch_dimensions = [] + + def GetDotDimensionsFromLists(dimension_numbers): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers - dot_dims_proto = xla_data_pb2.DotDimensionNumbers() + dot_dims_proto = DotDimensionNumbers() dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) return dot_dims_proto + + +class ConvolutionDimensionNumbers(object): + """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" + __slots__ = ('input_batch_dimension', 'input_feature_dimension', + 'input_spatial_dimensions', 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', + 'output_batch_dimension', 'output_feature_dimension', + 'output_spatial_dimensions') + + def __init__(self): + self.input_batch_dimension = 0 + self.input_feature_dimension = 0 + self.input_spatial_dimensions = [] + self.kernel_input_feature_dimension = 0 + self.kernel_output_feature_dimension = 0 + self.kernel_spatial_dimensions = [] + self.output_batch_dimension = 0 + self.output_feature_dimension = 0 + self.output_spatial_dimensions = [] + + +class GatherDimensionNumbers(object): + """Python representation of a xla.GatherDimensionNumbers protobuf.""" + __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', + 'index_vector_dim') + + def __init__(self): + self.offset_dims = [] + self.collapsed_slice_dims = [] + self.start_index_map = [] + self.index_vector_dim = 0 + + +class ScatterDimensionNumbers(object): + """Python representation of a xla.ScatterDimensionNumbers protobuf.""" + __slots__ = ('update_window_dims', 'inserted_window_dims', + 'scatter_dims_to_operand_dims', 'index_vector_dim') + + def __init__(self): + self.update_window_dims = [] + self.inserted_window_dims = [] + self.scatter_dims_to_operand_dims = [] + self.index_vector_dim = 0 + + +class ReplicaGroup(object): + """Python representation of a xla.ReplicaGroup protobuf.""" + __slots__ = ('replica_ids',) + + def __init__(self): + self.replica_ids = [] + + +def _make_replica_group_proto(replica_group): + replica_group_proto = ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 874e087eb6d4b785066edae21b1d11ebb024cd3e..1aedc43c02f1c127de5f1f971637203c3434cc91 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -24,12 +24,25 @@ import threading import numpy as np +from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client import unittest -class LocalComputationTest(unittest.TestCase): +class EnumTest(unittest.TestCase): + """Verifies Python enumerations match their protocol buffer equivalents.""" + + def testPrimitiveType(self): + for name, value in xla_client.PrimitiveType.__members__.items(): + self.assertEqual(value, getattr(xla_data_pb2, name)) + + def testFormat(self): + for name, value in xla_client.Format.__members__.items(): + self.assertEqual(value, getattr(xla_data_pb2, name)) + + +class ComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -85,9 +98,35 @@ def NumpyArrayBool(*args, **kwargs): return np.array(*args, dtype=np.bool, **kwargs) -class ComputationsWithConstantsTest(LocalComputationTest): +class ComputationPrinting(unittest.TestCase): + + def ExampleComputation(self): + builder = xla_client.ComputationBuilder("acomputation") + p0 = builder.ParameterFromNumpy(np.float32(0)) + p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) + builder.Mul(p0, p1) + return builder.Build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.GetHloText() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.GetHloDotGraph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" + def testConstantScalarSumS8(self): + c = self._NewComputation() + root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + self._ExecuteAndCompareExact(c, expected=np.int8(3)) + def testConstantScalarSumF32(self): c = self._NewComputation() root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) @@ -204,16 +243,6 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayS32([1]))) self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) - def testGetProto(self): - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), - c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) - built = c.Build() - proto = built.GetProto() # HloModuleProto - self.assertTrue(len(proto.computations) == 1) - self.assertTrue(len(proto.computations[0].instructions) == 3) - def testSum2DF64(self): c = self._NewComputation() c.Add( @@ -298,7 +327,7 @@ class ComputationsWithConstantsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=0.75) -class ParametersTest(LocalComputationTest): +class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing.""" def setUp(self): @@ -378,7 +407,7 @@ class ParametersTest(LocalComputationTest): expected=[-4.3, 1.3, -6.3, 3.3]) -class LocalBufferTest(LocalComputationTest): +class LocalBufferTest(ComputationTest): """Tests focusing on execution with LocalBuffers.""" def _Execute(self, c, arguments): @@ -476,7 +505,7 @@ class LocalBufferTest(LocalComputationTest): self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) -class SingleOpTest(LocalComputationTest): +class SingleOpTest(ComputationTest): """Tests for single ops. The goal here is smoke testing - to exercise the most basic functionality of @@ -502,11 +531,11 @@ class SingleOpTest(LocalComputationTest): def testConvertElementType(self): xla_types = { - np.bool: xla_client.xla_data_pb2.PRED, - np.int32: xla_client.xla_data_pb2.S32, - np.int64: xla_client.xla_data_pb2.S64, - np.float32: xla_client.xla_data_pb2.F32, - np.float64: xla_client.xla_data_pb2.F64, + np.bool: xla_client.PrimitiveType.PRED, + np.int32: xla_client.PrimitiveType.S32, + np.int64: xla_client.PrimitiveType.S64, + np.float32: xla_client.PrimitiveType.F32, + np.float64: xla_client.PrimitiveType.F64, } def _ConvertAndTest(template, src_dtype, dst_dtype): @@ -527,13 +556,13 @@ class SingleOpTest(LocalComputationTest): def testBitcastConvertType(self): xla_x32_types = { - np.int32: xla_client.xla_data_pb2.S32, - np.float32: xla_client.xla_data_pb2.F32, + np.int32: xla_client.PrimitiveType.S32, + np.float32: xla_client.PrimitiveType.F32, } xla_x64_types = { - np.int64: xla_client.xla_data_pb2.S64, - np.float64: xla_client.xla_data_pb2.F64, + np.int64: xla_client.PrimitiveType.S64, + np.float64: xla_client.PrimitiveType.F64, } def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): @@ -553,6 +582,18 @@ class SingleOpTest(LocalComputationTest): for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + # TODO(b/123523486) implement AllToAll on CPU + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + c.AllToAll(c.Constant(lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=lhs) + def testCrossReplicaSumOneReplica(self): samples = [ NumpyArrayF32(42.0), @@ -565,6 +606,23 @@ class SingleOpTest(LocalComputationTest): c.CrossReplicaSum(c.Constant(lhs)) self._ExecuteAndCompareExact(c, expected=lhs) + def testReplicaId(self): + c = self._NewComputation() + _ = c.ReplicaId() + self._ExecuteAndCompareExact(c, expected=0) + + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + c.CrossReplicaSum(c.Constant(lhs), [[0]]) + self._ExecuteAndCompareExact(c, expected=lhs) + def testDotMatrixVectorF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) @@ -608,7 +666,7 @@ class SingleOpTest(LocalComputationTest): lhs = NumpyArrayF32(rng.randn(10, 3, 4)) rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers() + dimension_numbers = xla_client.DotDimensionNumbers() dimension_numbers.lhs_contracting_dimensions.append(2) dimension_numbers.rhs_contracting_dimensions.append(1) dimension_numbers.lhs_batch_dimensions.append(0) @@ -727,6 +785,12 @@ class SingleOpTest(LocalComputationTest): c.Not(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=~arr) + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + c.Clz(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=[17, 3]) + def testExp(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -914,12 +978,13 @@ class SingleOpTest(LocalComputationTest): def testPadWithPaddingConfig(self): c = self._NewComputation() - padding_config = xla_client.xla_data_pb2.PaddingConfig() + padding_config = xla_client.PaddingConfig() for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: - dimension = padding_config.dimensions.add() + dimension = xla_client.PaddingConfigDimension() dimension.edge_padding_low = lo dimension.edge_padding_high = hi dimension.interior_padding = interior + padding_config.dimensions.append(dimension) c.Pad( c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), c.Constant(NumpyArrayF32(0.0)), @@ -962,14 +1027,13 @@ class SingleOpTest(LocalComputationTest): c.Constant(NumpyArrayF32(2))) self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) - # TODO(b/72689392): re-enable when bug S32 resolved - def DISABLED_testClampS32(self): + def testClampS32(self): c = self._NewComputation() c.Clamp( c.Constant(NumpyArrayS32(-1)), c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), c.Constant(NumpyArrayS32(2))) - self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2]) + self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) def testSelect(self): c = self._NewComputation() @@ -1102,6 +1166,26 @@ class SingleOpTest(LocalComputationTest): q, r = self._Execute(c, ()) np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + def testEigh(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + a = (a + a.T) / 2 + + c = self._NewComputation() + c.Eigh(c.Constant(a), full_matrices=True) + v, w = self._Execute(c, ()) + self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) + + def testSVD(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.SVD(c.Constant(a)) + u, d, v = self._Execute(c, ()) + self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) + def testTriangularSolve(self): a_vals = np.array( [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], @@ -1132,7 +1216,7 @@ class SingleOpTest(LocalComputationTest): def testGather(self): a = np.arange(9).astype(np.int32).reshape((3, 3)) indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) - dnums = xla_client.xla_data_pb2.GatherDimensionNumbers() + dnums = xla_client.GatherDimensionNumbers() dnums.offset_dims.append(1) dnums.offset_dims.append(2) dnums.start_index_map.append(0) @@ -1145,7 +1229,7 @@ class SingleOpTest(LocalComputationTest): np.testing.assert_allclose(g, expected, rtol=1e-4) -class EmbeddedComputationsTest(LocalComputationTest): +class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" def _CreateConstantS32Computation(self): @@ -1596,7 +1680,7 @@ class EmbeddedComputationsTest(LocalComputationTest): scatter_indices = np.array([0, 2], dtype=np.int32) updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) - dnums = xla_client.xla_data_pb2.ScatterDimensionNumbers() + dnums = xla_client.ScatterDimensionNumbers() dnums.update_window_dims.append(1) dnums.inserted_window_dims.append(0) dnums.scatter_dims_to_operand_dims.append(0) @@ -1609,7 +1693,7 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=expected) -class ErrorTest(LocalComputationTest): +class ErrorTest(ComputationTest): def setUp(self): self.f32_scalar_2 = NumpyArrayF32(2.0) @@ -1626,7 +1710,7 @@ class ErrorTest(LocalComputationTest): lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) -class ComputationRootTest(LocalComputationTest): +class ComputationRootTest(ComputationTest): """Tests related to setting the root of the computation.""" def testComputationRootDifferentFromLastOp(self): diff --git a/tensorflow/compiler/xla/python/xla_data.i b/tensorflow/compiler/xla/python/xla_data.i new file mode 100644 index 0000000000000000000000000000000000000000..b18583c64d400bdb7b3bc50b3548df23f4a8c469 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_data.i @@ -0,0 +1,654 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps for building, compiling, and executing XLA computations. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// Span <- sequence of int +// vector -> sequence of int +// Span <- sequence of LocalOp +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) +// std::vector <- sequence of xla_client.Shape objects +// PrimitiveType <- int +// Span> <- sequence of int pairs +// PaddingConfig proto <- ducktyped Python proto +// ConvolutionDimensionNumbers proto <- ducktyped Python proto +// DotDimensionNumbers proto <- ducktyped Python proto +// GatherDimensionNumbers proto <- ducktyped Python proto +// ScatterDimensionNumbers proto <- ducktyped Python proto +// Span <- sequence of ReplicaGroup Python proto +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// Shapes output by C++ become Python objects with the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. +// +// Other SWIG object wrappers (e.g. of Computation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. +// + + +%module(threads="1") xla_data + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Basic types + + +%typemap(out) std::vector { + PyObject* out = PyList_New($1.size()); + for (int i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); + } + $result = out; +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyString_FromString($1.ConsumeValueOrDie().c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + SWIG_fail; + } + Py_INCREF(Py_None); + $result = Py_None; +} + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + SWIG_fail; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (StatusOr literal_status) { + literal_status = numpy::XlaLiteralFromPyObject($input); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + SWIG_fail; + } + $1 = &literal_status.ValueOrDie(); +} + +%typemap(out) Literal (StatusOr obj_status) { + obj_status = numpy::PyObjectFromXlaLiteral(*$1); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(out) StatusOr (StatusOr obj_status) { + if (!$1.ok()) { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } + obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + Py_DECREF(o); + SWIG_fail; + } + temps.push_back(literal_status.ConsumeValueOrDie()); + Py_DECREF(o); + } + $1 = &temps; +} + +// OpMetadata + +%typemap(in) const OpMetadata& (OpMetadata temp) { + StatusOr statusor = numpy::OpMetadataFromPyObject($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +// Shape + +%typemap(out) const Shape& { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyProgramShapeInfoFromXlaProgramShape( + $1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) const Shape& (Shape temp) { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +%typemap(in) const absl::optional& ( + absl::optional temp) { + if ($input == Py_None) { + temp = absl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(absl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + SWIG_fail; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + SWIG_fail; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + SWIG_fail; + } + $1 = static_cast(value); +} + +// Span> + +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + SWIG_fail; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + SWIG_fail; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + SWIG_fail; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + SWIG_fail; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + SWIG_fail; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + SWIG_fail; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// Span + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + ReplicaGroup rgrp; + if (!HandleRepeatedInt64Attribute( + o, "replica_ids", + rgrp.mutable_replica_ids())) { + SWIG_fail; + } + temps.push_back(rgrp); + Py_DECREF(o); + } + $1 = temps; +} diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c55abc17f87c369e3d5b2140a84014e07921a9a --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -0,0 +1,297 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/xrt.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace swig { + +XrtAllocation::XrtAllocation(int64 handle, Shape shape, + const string& session_target) + : handle_(handle), shape_(shape), session_target_(session_target) {} + +XrtAllocation::~XrtAllocation() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +/* static */ +StatusOr XrtAllocation::FromLiteral( + const Literal& argument, const string& session_target) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = argument.ToProto(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto literal_string = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({literal_string, alloc.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtAllocation(handle, argument.shape(), session_target); +} + +const int64 XrtAllocation::handle() const { return handle_; } + +const Shape& XrtAllocation::shape() const { return shape_; } + +StatusOr XrtAllocation::ToLiteral() const { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + + xla::LiteralProto response; + TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); + return Literal::CreateFromProto(response); +} + +XrtAllocationTuple::XrtAllocationTuple(std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + CHECK(element != nullptr); + } +} + +XrtAllocationTuple::~XrtAllocationTuple() { + for (XrtAllocation* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr XrtAllocationTuple::Release(int i) { + XrtAllocation* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int64 XrtAllocationTuple::size() const { return elements_.size(); } + +StatusOr XrtExecutable::CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto compile = tensorflow::ops::XRTCompile(root, program); + TF_RETURN_IF_ERROR(root.status()); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + ProgramShape program_shape; + for (auto& shape : argument_shapes) { + *program_shape.add_parameters() = shape; + } + *program_shape.mutable_result() = result_shape; + + LayoutUtil::SetToDefaultLayout(&program_shape); + *config->mutable_program_shape() = program_shape.ToProto(); + c.mutable_hlo_snapshot() + ->mutable_hlo() + ->mutable_hlo_module() + ->ParsePartialFromString(hlo_module_proto); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({program, c.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtExecutable(program_shape, handle, session_target); +} + +XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target) + : program_shape_(program_shape), + handle_(handle), + session_target_(session_target) {} + +XrtExecutable::~XrtExecutable() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({computation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +StatusOr XrtExecutable::Execute( + absl::Span argument_handles) { + const int num_expected_arguments = program_shape().parameters().size(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + std::vector arguments; + arguments.reserve(num_expected_arguments); + for (int i = 0; i < num_expected_arguments; ++i) { + arguments.push_back( + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); + } + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto execution_config = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto execute = tensorflow::ops::XRTExecute(root, computation_handle, + execution_config, arguments); + TF_RETURN_IF_ERROR(root.status()); + + TF_RET_CHECK(argument_handles.size() == arguments.size()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + for (int i = 0; i < arguments.size(); ++i) { + inputs.insert({arguments[i], argument_handles[i]->handle()}); + } + inputs.insert({computation_handle, handle()}); + inputs.insert({execution_config, e.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); + + int64 output = outputs[0].scalar()(); + return new XrtAllocation(output, program_shape().result(), session_target_); +} + +const ProgramShape& XrtExecutable::program_shape() const { + return program_shape_; +} + +int64 XrtExecutable::handle() const { return handle_; } + +void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } + +void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; } + +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target) { + const Shape& tuple_shape = allocation->shape(); + + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); + } + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); + auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + std::vector results; + for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + inputs.clear(); + inputs.insert({base_handle, allocation->handle()}); + inputs.insert({shape_index, {i}}); + std::vector outputs; + auto status = session.Run(inputs, {subtuple}, &outputs); + if (!status.ok()) { + // Clean up before returning non-ok status. + for (int j = 0; j < results.size(); ++j) { + delete results[j]; + } + return status; + } + const int64 subtuple_handle = outputs[0].scalar()(); + const Shape& subtuple_shape = + ShapeUtil::GetTupleElementShape(tuple_shape, i); + results.push_back( + new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); + } + return new XrtAllocationTuple(std::move(results)); +} + +} // namespace swig +} // namespace xla diff --git a/tensorflow/compiler/xla/python/xrt.h b/tensorflow/compiler/xla/python/xrt.h new file mode 100644 index 0000000000000000000000000000000000000000..710c3af3fa6b407127643797dbabad201cf076d4 --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" + +namespace xla { +namespace swig { + +// Represents a reference to literals that live in a device-allocated buffer via +// XRT. Specifically, wraps an int64 handle produced by running the allocation +// graph, and an XLA shape to track the referent's shape. +class XrtAllocation { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which allocation and deallocation + // graphs are run. + static StatusOr FromLiteral(const Literal& argument, + const string& session_target); + + XrtAllocation(int64 handle, Shape shape, const string& session_target); + ~XrtAllocation(); + StatusOr ToLiteral() const; + const Shape& shape() const; + const int64 handle() const; + + private: + const int64 handle_; + const Shape shape_; + const string session_target_; +}; + +// Result of a tuple destructuring operation on an XrtAllocation. +class XrtAllocationTuple { + public: + // Note: any XrtAllocation elements that are not Release()'d will be + // deallocated in the destructor. + explicit XrtAllocationTuple(std::vector elements); + + ~XrtAllocationTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int64 size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued XrtAllocation into its constituent elements +// in XrtAllocationTuple form. +// +// Accepts a `session_target` argument, used in constructing the +// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, +// and passed along in constructing each constituent XrtAllocation. +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target); + +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Specifically, wraps an XRT computation handle. +class XrtExecutable { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the compilation graph is run. + static StatusOr CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target); + + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the execution graph is run. + XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target); + ~XrtExecutable(); + + std::vector DeviceOrdinals() const { return {0}; } + + StatusOr Execute( + absl::Span argument_handles); + + const ProgramShape& program_shape() const; + int64 handle() const; + + private: + const ProgramShape program_shape_; + const int64 handle_; + const string session_target_; +}; + +// Functions for freeing resources from the Python side. +void DeleteXrtAllocation(XrtAllocation* allocation); +void DeleteXrtExecutable(XrtExecutable* computation); + +} // namespace swig +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ diff --git a/tensorflow/compiler/xla/python/xrt.i b/tensorflow/compiler/xla/python/xrt.i new file mode 100644 index 0000000000000000000000000000000000000000..456dd7be86e479b46815fc16b51a10431fe2060d --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.i @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Wrappers for XRT ops. + +%module(threads="1") xrt + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" + +%{ +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/xrt.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Computation and buffer/allocation types + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtExecutable*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocationTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + XrtAllocation* xrta; + if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), + SWIG_POINTER_EXCEPTION)) == -1) { + SWIG_fail; + } + temps.push_back(xrta); + Py_DECREF(o); + } + $1 = temps; +} + + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::XrtAllocation; +%unignore xla::swig::XrtAllocation::FromLiteral; +%unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; +%unignore xla::swig::XrtAllocationTuple; +%unignore xla::swig::XrtAllocationTuple::Release; +%unignore xla::swig::XrtAllocationTuple::size; +%unignore xla::swig::XrtExecutable; +%unignore xla::swig::XrtExecutable::CompileForXrt; +%unignore xla::swig::XrtExecutable::DeviceOrdinals; +%unignore xla::swig::XrtExecutable::Execute; +%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::DeleteXrtAllocation; +%unignore xla::swig::DeleteXrtExecutable; + +%thread; +%include "tensorflow/compiler/xla/python/xrt.h" +%nothread; + +%unignoreall diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py index 757e41a78ad2b57d2ef6e1f3055160be22c7b3ed..19bd685ab2260485d2a86f0a682d0cdd36712fdb 100644 --- a/tensorflow/compiler/xla/python_api/xla_literal.py +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -69,7 +69,7 @@ def _ConvertNumpyArrayToLiteral(ndarray): if ndarray.ndim == 0: getattr(literal, type_record.literal_field_name).append( - _np.asscalar(ndarray.astype(type_record.literal_field_type))) + ndarray.astype(type_record.literal_field_type).item()) else: # Ndarrays with boolean dtypes need special type conversion with protobufs if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index a1b0f4045ff071454451f9fe3942ac974f4f47ac..4d4500d840500f6a521a20fe8c94919af65ce31b 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -94,7 +94,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { } TEST_F(ReferenceUtilTest, MapArray2D) { - auto identity = [](float value) { return log(exp(value)); }; + auto identity = [](float value) { return log(std::exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e39e17c11068047a9b0bb97d296504281bdf5c23..64d2da499db04b4eed32da6335f142e5cff3495a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -114,6 +114,7 @@ tf_cc_test( ":bfloat16_normalization", ":bfloat16_support", ":hlo", + ":hlo_creation_utils", ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -166,6 +167,23 @@ tf_cc_test( ], ) +cc_library( + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_proto_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -272,6 +290,7 @@ tf_cc_test( srcs = ["hlo_evaluator_test.cc"], deps = [ ":hlo", + ":hlo_element_type_converter", ":hlo_evaluator", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", @@ -284,7 +303,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -333,6 +351,7 @@ cc_library( ":hlo_proto", ":name_uniquer", "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -679,7 +698,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -701,7 +719,9 @@ cc_library( ":compilation_cache", ":compiler", ":computation_layout", + ":computation_placer", ":device_memory_allocator", + ":dump", ":dynamic_dimension_inference", ":executable", ":execution_tracker", @@ -781,6 +801,7 @@ cc_library( ":backend", ":compiler", ":computation_layout", + ":dump", ":platform_util", ":service", "//tensorflow/compiler/xla:debug_options_flags", @@ -810,8 +831,8 @@ cc_library( name = "gpu_plugin", deps = [ ":service", - "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", + "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ], @@ -880,6 +901,7 @@ cc_library( deps = [ ":computation_layout", ":device_memory_allocator", + ":dump", ":hlo", ":hlo_execution_profile", ":hlo_graph_dumper", @@ -912,6 +934,7 @@ cc_library( hdrs = ["compiler.h"], deps = [ ":buffer_value", + ":computation_placer", ":executable", ":hlo", ":hlo_module_config", @@ -1203,7 +1226,6 @@ cc_library( ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -1351,6 +1373,7 @@ cc_library( deps = [ ":heap_simulator", ":hlo", + ":hlo_casting_utils", ":hlo_ordering", ":hlo_pass", ":logical_buffer", @@ -1461,11 +1484,15 @@ cc_library( hdrs = ["hlo_creation_utils.h"], deps = [ ":hlo", + ":hlo_module_config", ":shape_inference", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1510,6 +1537,20 @@ cc_library( ], ) +cc_library( + name = "op_expander_pass", + srcs = ["op_expander_pass.cc"], + hdrs = ["op_expander_pass.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "gather_expander", srcs = ["gather_expander.cc"], @@ -1518,6 +1559,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", + ":op_expander_pass", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", @@ -1541,6 +1583,51 @@ cc_library( ], ) +cc_library( + name = "triangular_solve_expander", + srcs = ["triangular_solve_expander.cc"], + hdrs = ["triangular_solve_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "cholesky_expander", + srcs = ["cholesky_expander.cc"], + hdrs = ["cholesky_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1602,7 +1689,7 @@ tf_cc_test( ":algebraic_simplifier", ":hlo", ":hlo_casting_utils", - ":hlo_matchers", + ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", ":pattern_matcher", @@ -1748,6 +1835,8 @@ cc_library( deps = [ ":hlo", ":hlo_evaluator", + ":pattern_matcher", + "@com_google_absl//absl/base", "@com_google_absl//absl/types:optional", ], ) @@ -1808,55 +1897,57 @@ tf_cc_test( ) cc_library( - name = "defuser", - srcs = ["defuser.cc"], - hdrs = ["defuser.h"], + name = "while_loop_trip_count_annotator", + srcs = ["while_loop_trip_count_annotator.cc"], + hdrs = ["while_loop_trip_count_annotator.h"], deps = [ - ":call_graph", ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", + ":while_loop_analysis", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", ], ) tf_cc_test( - name = "defuser_test", - srcs = ["defuser_test.cc"], + name = "while_loop_trip_count_annotator_test", + srcs = ["while_loop_trip_count_annotator_test.cc"], deps = [ - ":defuser", - ":hlo_matchers", + ":pattern_matcher", + ":while_loop_simplifier", + ":while_loop_trip_count_annotator", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", ], ) cc_library( - name = "implicit_broadcast_remover", - srcs = ["implicit_broadcast_remover.cc"], - hdrs = ["implicit_broadcast_remover.h"], + name = "defuser", + srcs = ["defuser.cc"], + hdrs = ["defuser.h"], deps = [ + ":call_graph", ":hlo", - ":hlo_dce", ":hlo_pass", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) tf_cc_test( - name = "implicit_broadcast_remover_test", - srcs = ["implicit_broadcast_remover_test.cc"], + name = "defuser_test", + srcs = ["defuser_test.cc"], deps = [ + ":defuser", ":hlo_matchers", - ":implicit_broadcast_remover", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1930,9 +2021,11 @@ cc_library( hdrs = ["dynamic_dimension_inference.h"], deps = [ ":hlo", + ":while_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", @@ -1964,16 +2057,18 @@ tf_cc_test( srcs = ["dynamic_padder_test.cc"], deps = [ ":dynamic_padder", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + ":hlo_runner", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], @@ -1984,6 +2079,9 @@ tf_cc_test( srcs = ["dynamic_dimension_inference_test.cc"], deps = [ ":dynamic_dimension_inference", + ":hlo", + ":hlo_matchers", + ":hlo_runner", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1991,9 +2089,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], @@ -2105,6 +2200,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -2161,30 +2257,15 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], ) -tf_cc_binary( - name = "graphviz_example", - srcs = ["graphviz_example.cc"], - deps = [ - ":hlo", - ":hlo_graph_dumper", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - tf_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], @@ -2282,6 +2363,7 @@ cc_library( deps = [ ":call_graph", ":hlo", + ":hlo_casting_utils", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -2303,6 +2385,7 @@ tf_cc_test( srcs = ["hlo_dataflow_analysis_test.cc"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_dataflow_analysis", ":hlo_graph_dumper", ":hlo_matchers", @@ -2473,6 +2556,7 @@ tf_cc_test( srcs = ["tuple_points_to_analysis_test.cc"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", @@ -2545,6 +2629,7 @@ cc_library( hdrs = ["copy_insertion.h"], deps = [ ":buffer_liveness", + ":dump", ":hlo", ":hlo_alias_analysis", ":hlo_dce", @@ -2800,6 +2885,7 @@ cc_library( "hlo_pass_pipeline.h", ], deps = [ + ":dump", ":hlo", ":hlo_graph_dumper", ":hlo_pass", @@ -2809,6 +2895,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -2848,7 +2935,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -3026,8 +3112,6 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", ], ) @@ -3143,6 +3227,7 @@ cc_library( hdrs = ["hlo_module_config.h"], deps = [ ":computation_layout", + ":computation_placer", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -3189,48 +3274,19 @@ tf_cc_test( ], ) -cc_library( - name = "hlo_tfgraph_builder", - srcs = ["hlo_tfgraph_builder.cc"], - hdrs = ["hlo_tfgraph_builder.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "hlo_tfgraph_builder_test", - srcs = ["hlo_tfgraph_builder_test.cc"], - deps = [ - ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "hlo_graph_dumper", - srcs = [ - "hlo_graph_dumper.cc", - "hlo_graph_html_renderer.cc", - ], + srcs = ["hlo_graph_dumper.cc"], hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", ":hlo_casting_utils", ":hlo_execution_profile", - ":hlo_tfgraph_builder", ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", @@ -3253,6 +3309,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -3270,7 +3327,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -3422,10 +3478,13 @@ cc_library( srcs = ["hlo_runner.cc"], hdrs = ["hlo_runner.h"], deps = [ + ":backend", + ":compiler", ":computation_placer", ":executable", ":hlo", ":hlo_module_group", + ":hlo_parser", ":transfer_manager", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -3433,11 +3492,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", @@ -3486,6 +3543,37 @@ tf_cc_test( ], ) +cc_library( + name = "stable_sort_expander", + srcs = ["stable_sort_expander.cc"], + hdrs = ["stable_sort_expander.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "stable_sort_expander_test", + srcs = ["stable_sort_expander_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":stable_sort_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], @@ -3582,7 +3670,6 @@ cc_library( ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", ], @@ -3612,7 +3699,6 @@ cc_library( ":hlo_memory_scheduler", ":hlo_pass", ":hlo_pass_pipeline", - ":implicit_broadcast_remover", "//tensorflow/compiler/xla:statusor", ], ) @@ -3638,7 +3724,6 @@ cc_library( ":hlo_evaluator", ":hlo_pass", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3655,12 +3740,13 @@ tf_cc_test( extra_copts = ["-Wno-string-plus-int"], deps = [ ":hlo_matchers", + ":hlo_parser", ":indexed_array_analysis", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -3677,11 +3763,14 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) @@ -3745,6 +3834,47 @@ cc_library( ], ) +cc_library( + name = "optimize_input_output_buffer_alias", + srcs = ["optimize_input_output_buffer_alias.cc"], + hdrs = ["optimize_input_output_buffer_alias.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "optimize_input_output_buffer_alias_test", + srcs = ["optimize_input_output_buffer_alias_test.cc"], + deps = [ + ":optimize_input_output_buffer_alias", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "ar_crs_combiner", srcs = ["ar_crs_combiner.cc"], @@ -3758,7 +3888,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 8a02c48146ea264fdcba85325c84e27b70090170..b223fc8b1b50851f786a0a725ecc584b97b9838f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -52,6 +53,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -107,7 +109,7 @@ bool IsAllFpConstantPowerOf2(const HloInstruction* op) { int exp; double mantissa = std::frexp(*val, &exp); - // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value // of +/-0.5 therefore indicates that the floating point value is a power of // 2. return mantissa == 0.5 || mantissa == -0.5; @@ -219,6 +221,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandlePower(HloInstruction* power) override; + Status HandleRemainder(HloInstruction* remainder) override; + Status HandleReshape(HloInstruction* reshape) override; Status HandleReduce(HloInstruction* reduce) override; @@ -276,15 +280,51 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { hlo)); } - // Helper method to perform and add reduction in a single dimension. - HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + // Converts to primitive type if the input hlo is not that type, otherwise + // returns the original hlo. + HloInstruction* AsType(HloInstruction* hlo, + const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + } + + // Transposes a dot operand such that the batch dimensions are the msot major, + // and the contracting dimensions are most minor. + StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + HloInstruction* dot_operand, absl::Span batch_dimensions, + absl::Span contracting_dimensions) { + std::vector transpose_dimensions(batch_dimensions.begin(), + batch_dimensions.end()); + for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { + if (!(absl::c_linear_search(batch_dimensions, i) || + absl::c_linear_search(contracting_dimensions, i))) { + transpose_dimensions.push_back(i); + } + } + transpose_dimensions.insert(transpose_dimensions.end(), + contracting_dimensions.begin(), + contracting_dimensions.end()); + return MakeTransposeHlo(dot_operand, transpose_dimensions); + } + + // Helper method to perform and add reduction on a list of dimensions. + HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); - Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + Shape shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, + hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( - shape, hlo, zero, {dim}, AddReduce_computation)); + shape, hlo, zero, dims, AddReduce_computation)); + } + + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + return AddReduce(hlo, std::vector{dim}); } // Convenience method for replacing an instruction with a bitcast. If operand @@ -812,10 +852,82 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) { return T{1.0} / constant.literal().Get(indices); }); } + +template +std::unique_ptr TryDivideToShift(HloInstruction* divide, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(divide->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(divide->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, + ComparisonDirection::kLt)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + + auto* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend, + shift_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect, + dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + HloInstruction* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + return HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount); + } + } + + return nullptr; +} } // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - Shape* shape; HloInstruction *a, *b, *c, *d; CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A @@ -824,6 +936,61 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } + // A / B => A >> log2(B) if B is a power of 2. + switch (divide->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + default: + break; + } + + Shape* shape; // exp(A)/exp(B) => exp(A-B) if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) .WithShape(m::Shape(&shape)))) { @@ -864,6 +1031,24 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { divide->shape(), HloOpcode::kMultiply, a, new_power)); } + // A/sqrt(B) => A*rsqrt(X). + if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) { + auto* rsqrt = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary(rsqrt->shape(), + HloOpcode::kMultiply, a, rsqrt)); + } + + // A/rsqrt(B) => A*sqrt(B). + if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) { + auto* sqrt = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary(sqrt->shape(), + HloOpcode::kMultiply, a, sqrt)); + } + // Simplifying integral division would produce unexpected results. if (ShapeUtil::ElementIsIntegral(divide->shape())) { return Status::OK(); @@ -874,8 +1059,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - Literal new_literal(b->shape()); - switch (b->shape().element_type()) { + Shape result_shape = b->literal().shape(); + Literal new_literal(result_shape); + switch (result_shape.element_type()) { case F16: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; @@ -958,7 +1144,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( const int64 rhs_rank = rhs->shape().rank(); const int64 lhs_rank = lhs->shape().rank(); const auto& dnums = dot->dot_dimension_numbers(); - if (dnums.rhs_contracting_dimensions_size() > 1) { + if (dnums.rhs_contracting_dimensions_size() != 1) { return false; } if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { @@ -988,16 +1174,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( std::swap(rhs_collapsing_dim, rhs_kept_dim); } - auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { - if (hlo->shape().element_type() == element_type) { - return hlo; - } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); - }; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - hlo = as_type(hlo, dot->shape().element_type()); + hlo = AsType(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); @@ -1006,7 +1184,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { - return AddReduce(as_type(hlo, F32), dim); + return AddReduce(AsType(hlo, F32), dim); }; auto broadcast = [&](HloInstruction* hlo, const Shape& shape, @@ -1041,7 +1219,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return true; } - // Simplify outer product into multiply with implicit broadcasting. + // Simplify outer product into multiply with broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { @@ -1115,8 +1293,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return dims; }; - // If the contracting dimension is 1, remove the degnerate dimnesions from the - // lhs and rhs, broadcast each to the result shape and multiply. + // If the contracting dimension is 1, remove the degnerate dimnensions from + // the lhs and rhs, broadcast each to the result shape and multiply. if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && (rhs_kept_dim == rhs_rank - 1 || (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { @@ -1453,7 +1631,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - + if (options_.is_layout_sensitive()) { + return Status::OK(); + } // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || @@ -1470,6 +1650,117 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot->shape().element_type() != BF16) { return Status::OK(); } + + // If there are no contracting dimensions, a dot can be rewritten as + // mul(broadcast(transpose(x)),broadcast(transpose(y))) + if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + if (dot->shape().rank() != lhs->shape().rank()) { + std::vector lhs_broadcast_dims(lhs->shape().rank()); + absl::c_iota(lhs_broadcast_dims, 0); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_lhs, lhs_broadcast_dims)); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + if (dot->shape().rank() != rhs->shape().rank()) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) { + rhs_broadcast_dims.push_back(i); + } + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_rhs, rhs_broadcast_dims)); + } + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, + new_lhs, new_rhs)); + } + + // If the lhs or rhs have only batch and contracting dimensions, a dot can be + // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) + if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + + int64 lhs_outer_dims = + lhs->shape().rank() - + (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + int64 rhs_outer_dims = + rhs->shape().rank() - + (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); + CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); + if (rhs_outer_dims > 0) { + std::vector lhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(lhs_broadcast_dims, 0); + lhs_broadcast_dims.resize(lhs->shape().rank()); + std::iota(lhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().lhs_batch_dimensions_size(), + lhs_broadcast_dims.end(), + dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + rhs_outer_dims); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_rhs->shape(), new_lhs, lhs_broadcast_dims)); + } else if (lhs_outer_dims > 0) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().rhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + rhs_broadcast_dims.resize(rhs->shape().rank()); + std::iota(rhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size(), + rhs_broadcast_dims.end(), + dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + lhs_outer_dims); + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_lhs->shape(), new_rhs, rhs_broadcast_dims)); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); + std::vector reduce_dims( + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + new_dot = AsType(new_dot, F32); + const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); + absl::c_iota( + reduce_dims, + outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + new_dot = AddReduce(new_dot, reduce_dims); + new_dot = AsType(new_dot, dot->shape().element_type()); + return ReplaceInstruction(dot, new_dot); + } + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || dot->shape().rank() > 2) { if (options_.enable_dot_strength_reduction() && @@ -1508,7 +1799,11 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 && + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 && + dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 && + dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 && + lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); @@ -2159,14 +2454,151 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return changed; } +namespace { +template +std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(remainder->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(remainder->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, + ComparisonDirection::kLt)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + auto* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary( + remainder->shape(), HloOpcode::kSelect, dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + HloInstruction* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd, + a, mask_amount); + } + } + return nullptr; +} +} // namespace + +Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { + HloInstruction *a, *b; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + // A % B => A & (B - 1) if B is a power of 2. + switch (remainder->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + default: + break; + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto operand = reshape->mutable_operand(0); // Reshape directly to empty constant if the shape contains zero-element // dimension. if (ShapeUtil::IsZeroElementArray(reshape->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal result. + Shape reshaped_shape = reshape->shape(); + if (!LayoutUtil::HasLayout(reshaped_shape)) { + LayoutUtil::SetToDefaultLayout(&reshaped_shape); + } auto empty_constant = HloInstruction::CreateConstant( - Literal::CreateFromShape(reshape->shape())); + Literal::CreateFromShape(reshaped_shape)); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); } @@ -2261,11 +2693,11 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( int64 start = slice->slice_starts(i); int64 low = padding_config.dimensions(i).edge_padding_low(); int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (start < low || start >= low + data) { + return true; } } - return true; + return false; }(); if (in_padding) { @@ -2406,6 +2838,27 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { new_slice_starts, new_slice_limits, slice->slice_strides())); } + auto only_broadcast_dims_sliced = [&] { + if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) { + return false; + } + for (int64 dim : slice->operand(0)->dimensions()) { + if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 || + slice->slice_limits(dim) != + slice->operand(0)->shape().dimensions(dim)) { + return false; + } + } + return true; + }; + if (only_broadcast_dims_sliced()) { + return ReplaceWithNewInstruction( + slice, + HloInstruction::CreateBroadcast( + slice->shape(), slice->mutable_operand(0)->mutable_operand(0), + slice->mutable_operand(0)->dimensions())); + } + TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); if (replaced) { return Status::OK(); @@ -2452,28 +2905,72 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Most of those optimizations can be done for multi-output - // reduces. - if (reduce->shape().IsTuple()) { - return Status::OK(); - } +Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { + HloReduceInstruction* reduce = Cast(hlo); + bool multi_output_reduce = reduce->shape().IsTuple(); + + // For tuple reduce, we require all reduce shapes to be the same, up to the + // element types, so we can just the first operand and the first result as a + // representative. + auto arg = reduce->inputs()[0]; + auto init_value = reduce->init_values()[0]; + const Shape& reduce_result_shape = + multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); - auto arg = reduce->mutable_operand(0); - auto init_value = reduce->mutable_operand(1); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || - ShapeUtil::IsZeroElementArray(reduce->shape())) { - return ReplaceWithNewInstruction( - reduce, - HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); + ShapeUtil::IsZeroElementArray(reduce_result_shape)) { + if (multi_output_reduce) { + std::vector broadcast_inits; + int64 inputs = reduce->input_count(); + for (int64 i = 0; i < inputs; ++i) { + broadcast_inits.push_back(computation_->AddInstruction( + HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), + reduce->init_values()[i], {}))); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateTuple(broadcast_inits)); + } else { + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {})); + } + } + + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce_result_shape) == + ShapeUtil::ElementsIn(arg->shape())) { + if (multi_output_reduce) { + std::vector reshaped_args; + int64 inputs = reduce->input_count(); + for (int64 i = 0; i < inputs; ++i) { + reshaped_args.push_back( + computation_->AddInstruction(HloInstruction::CreateReshape( + reduce->shape().tuple_shapes(i), reduce->inputs()[i]))); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateTuple(reshaped_args)); + } else { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce_result_shape, arg)); + } + } + + // TODO(b/112040122): Most of those optimizations below can be done for + // multi-output reduces. + if (multi_output_reduce) { + return Status::OK(); } // A Transpose feeding a reduce can simply permute the reduction dimensions // field if the output of the reduce is a vector or scalar. Higher ranked // result may require a transpose of the output. - if (reduce->shape().rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) { + if (reduce_result_shape.rank() <= 1 && + arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; for (auto dim : dimensions) { @@ -2481,20 +2978,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( - reduce->shape(), arg->mutable_operand(0), init_value, + reduce_result_shape, arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); } - // If the reduction results in the same number of elements, then the only - // possible side effect would be a reshape. Since the init_value is an - // identity of the reduction function, we can therefore replace the reduce - // with a simple reshape, ignoring the reduction function completely. - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape())) { - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); - } - // If a reduce feeds a reduce with the same computation and initial value, // they can be combined into a single reduce. if (arg->opcode() == HloOpcode::kReduce && @@ -2520,9 +3007,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), reduce_dims.end(), std::back_inserter(new_dimensions)); return ReplaceWithNewInstruction( - reduce, - HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), - init_value, new_dimensions, function)); + reduce, HloInstruction::CreateReduce( + reduce_result_shape, arg->mutable_operand(0), init_value, + new_dimensions, function)); } // A reshape that collapses multiple dimensions into a dimension being @@ -2565,7 +3052,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( - reduce->shape(), arg->mutable_operand(0), init_value, + reduce_result_shape, arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); } } @@ -2580,11 +3067,11 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( - HloInstruction::CreateReduce(reduce->shape(), operand, init_value, + HloInstruction::CreateReduce(reduce_result_shape, operand, init_value, reduce->dimensions(), function)); if (old_reduce != nullptr) { new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( - reduce->shape(), {old_reduce, new_reduce}, function)); + reduce_result_shape, {old_reduce, new_reduce}, function)); } old_reduce = new_reduce; } @@ -2784,7 +3271,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( } if (is_effective_broadcast()) { - VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; + VLOG(10) << "Replacing pad/reduce-window with broadcast."; auto fadd = [this](std::unique_ptr x) { return computation_->AddInstruction(std::move(x)); }; @@ -2846,109 +3333,6 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } - - if (!options_.enable_permutation_sort_replacement()) { - return Status::OK(); - } - // Check if we are sorting a permutation. In that case, we know that the keys - // will be sorted to the identity permutation, and we can represent the - // changes to the 'values' parameter as a scatter. - if (sort->operand_count() == 2 && - operand->opcode() == HloOpcode::kGetTupleElement) { - const HloInstruction* other_sort = operand->operand(0); - // Check whether the 'values' parameter is the result of another sort with - // the same sort dimension. - if (other_sort->opcode() == HloOpcode::kSort && - other_sort->operand_count() >= 2 && - other_sort->dimensions(0) == dimension_to_sort && - other_sort->operand(operand->tuple_index())->opcode() == - HloOpcode::kIota) { - auto* iota = - Cast(other_sort->operand(operand->tuple_index())); - // The sort operand needs to be an integral iota, and the iota dimension - // needs to be the dimension that was sorted. - if (iota->iota_dimension() == dimension_to_sort && - ShapeUtil::ElementIsIntegral(iota->shape())) { - // We use the following construction method for a Scatter that applies - // the permutation from 'keys' to the 'values' parameter. - // - Take the "keys" parameter of the second sort and reshape it to have - // another "1" dimension at the end. - // - Concatenate it with iotas of the same extended shape with all - // different iota_dimensions except the dimension_to_sort in the order - // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and - // dimension_to_sort = 1, we would have concatenate of (iota with - // iota_dimension=0, keys, iota with iota_dimension = 2) - // - Use this as the indices parameter of scatter, and set updates - // of the scatter to be a reshaped 'values' parameter of sort (adding - // 'rank' many 1 dimensions at the end). - int64 rank = operand->shape().rank(); - Shape extended_shape = operand->shape(); - extended_shape.add_dimensions(1); - extended_shape.mutable_layout()->add_minor_to_major(rank); - auto reshaped_permutation = computation_->AddInstruction( - HloInstruction::CreateReshape(extended_shape, operand)); - std::vector concat_operands; - for (int64 i = 0; i < rank; ++i) { - if (i == dimension_to_sort) { - concat_operands.push_back(reshaped_permutation); - } else { - concat_operands.push_back(computation_->AddInstruction( - HloInstruction::CreateIota(extended_shape, i))); - } - } - Shape concat_shape = operand->shape(); - concat_shape.add_dimensions(rank); - concat_shape.mutable_layout()->add_minor_to_major(rank); - auto scatter_indices = - rank > 1 ? computation_->AddInstruction( - HloInstruction::CreateConcatenate( - concat_shape, concat_operands, rank)) - : reshaped_permutation; - - // We don't care about the operand, it will be completely overridden by - // the updates. - auto scatter_operand = computation_->AddInstruction( - HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); - - // Construct the updates operand of scatter. - Shape update_shape = sort->operand(1)->shape(); - for (int64 i = 0; i < rank; ++i) { - update_shape.add_dimensions(1); - update_shape.mutable_layout()->add_minor_to_major(rank + i); - } - auto scatter_updates = - computation_->AddInstruction(HloInstruction::CreateReshape( - update_shape, sort->mutable_operand(1))); - - // Construct the updates computation, which simply replaces the operand - // values with the update values. - HloComputation::Builder b("update_replace_computation"); - Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); - b.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); - auto update_replace_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); - - ScatterDimensionNumbers dim_numbers; - dim_numbers.set_index_vector_dim(rank); - for (int64 i = 0; i < rank; ++i) { - dim_numbers.add_update_window_dims(rank + i); - dim_numbers.add_scatter_dims_to_operand_dims(i); - } - auto scatter = - computation_->AddInstruction(HloInstruction::CreateScatter( - sort->operand(1)->shape(), scatter_operand, scatter_indices, - scatter_updates, update_replace_computation, dim_numbers)); - return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple( - {computation_->AddInstruction(HloInstruction::CreateIota( - operand->shape(), dimension_to_sort)), - scatter})); - } - } - } return Status::OK(); } @@ -3325,6 +3709,11 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast( map->shape(), computation_->AddInstruction(std::move(clone)), {})); } + // Inline the map if the map computation only contains an elementwise + // operation that can accept arbitrary shapes. + if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) { + return Status::OK(); + } std::vector new_operands; for (auto* root_operand : map_root->operands()) { if (root_operand->opcode() != HloOpcode::kParameter) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index ff3f638b22e290f6f6237a5a72a257aa23ecd78b..df5a8c2ec141458a95fafb76b1e99e4b04a61b28 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -71,17 +71,6 @@ class AlgebraicSimplifierOptions { return enable_conv_simplification_; } - // If enable_permutation_sort_replacement is true, a sort op that is known to - // sort a permutation will be replaced with a scatter op. - void set_enable_permutation_sort_replacement( - bool enable_permutation_sort_replacement) { - enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; - } - - bool enable_permutation_sort_replacement() const { - return enable_permutation_sort_replacement_; - } - // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -99,7 +88,6 @@ class AlgebraicSimplifierOptions { bool is_layout_sensitive_{false}; bool enable_dot_strength_reduction_{true}; bool enable_conv_simplification_{true}; - bool enable_permutation_sort_replacement_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0dea498456eda3a6dc070493312e9825e6787c32..7f399ce0f112822bd476054cc848b5c3fe26b389 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -194,6 +195,86 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { m::Broadcast(m::ConstantScalar(0.125))))); } +TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT d = u32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ShiftRightLogical( + m::Parameter(0), m::Broadcast(m::ConstantScalar(3))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT d = s32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_shift = + m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_shift), match_shift))); +} + +TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT r = u32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::AndAnyOrder(m::Parameter(0), + m::Broadcast(m::ConstantScalar(7))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT r = s32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_and = + m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_and), match_and))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -424,6 +505,30 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { m::Broadcast(m::Op().Is(zero))))); } +TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) { + const char* kModuleStr = R"( + HloModule m + fusion { + x = f32[] parameter(0) + c = f32[] constant(42) + m = f32[] multiply(x, x) + ROOT a = f32[] add(m, c) + } + + map { + x = f32[] parameter(0) + ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion + } + + ENTRY test { + p = f32[2,2] parameter(0) + ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -2147,9 +2252,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloPassFix simplifier(default_options_); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2614,6 +2718,33 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0}), "param")); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0})); + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast, + /*start_indices=*/{0, 3}, + /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Broadcast(m::Parameter(0))))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { HloComputation::Builder builder(TestName()); const int64 dim0 = 11; @@ -2668,93 +2799,23 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewVerifiedModule(); + TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder, + module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } -TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options; - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, - GmockMatch(m::Tuple( - m::Iota(), - m::Scatter(m::Iota(), m::Concatenate(m::Iota(), m::Reshape()), - m::Reshape())))); -} - -TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { - // Same as ReplacePermutationSortWithScatter except that the iota has F32 - // type. - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = f32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options; - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - -TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { - // Same as ReplacePermutationSortWithScatter except that the sort dimensions - // don't match. - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options; - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); @@ -2764,10 +2825,11 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { HloInstruction::CreateParameter(1, values_shape, "values0")); auto values1 = builder.AddInstruction( HloInstruction::CreateParameter(2, values_shape, "values1")); - builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, - keys, {values0, values1})); - auto module = CreateNewVerifiedModule(); + TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape( + {keys_shape, values_shape, values_shape}), + {keys, values0, values1}, 0, /*is_stable=*/false, + &builder, module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -3703,8 +3765,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); @@ -3949,7 +4011,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { param = f32[3,4] parameter(0) constant = f32[] constant(0.0) pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 - ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]} } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -3960,6 +4022,27 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { const char* hlo_string = R"( HloModule module @@ -3981,6 +4064,29 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { EXPECT_THAT(root, GmockMatch(m::Parameter())); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) { + const char* hlo_string = R"( + HloModule module + + ENTRY entry () -> f32[1]{0} { + constant.val = f32[] constant(4) + constant.pad = f32[] constant(-7) + reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val) + pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0 + slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]} + ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0)))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { const char* hlo_string = R"( HloModule module @@ -4211,10 +4317,24 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { int m, k, n; PrimitiveType element_type; std::tie(m, k, n, element_type) = GetParam(); - - Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); - Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + std::vector lhs_dims = {1, 3, 5}; + std::vector rhs_dims = lhs_dims; + std::vector output_dims = lhs_dims; + if (m > 0) { + lhs_dims.push_back(m); + output_dims.push_back(m); + } + if (k > 0) { + lhs_dims.push_back(k); + rhs_dims.push_back(k); + } + if (n > 0) { + rhs_dims.push_back(n); + output_dims.push_back(n); + } + Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction( @@ -4228,16 +4348,18 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { dot_dnums.add_rhs_batch_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); dot_dnums.add_rhs_batch_dimensions(2); - dot_dnums.add_lhs_contracting_dimensions(4); - dot_dnums.add_rhs_contracting_dimensions(3); + if (k > 0) { + dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3); + dot_dnums.add_rhs_contracting_dimensions(3); + } builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); - const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; - const bool computation_should_be_modified = dot_should_be_transformed; - EXPECT_EQ(changed, computation_should_be_modified); + const bool dot_should_be_transformed = + m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1; + EXPECT_EQ(changed, dot_should_be_transformed); bool has_no_dot = true; for (const auto& hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kDot) { @@ -4248,10 +4370,12 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_SUITE_P( - BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, - ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), - ::testing::Values(1, 2), ::testing::Values(F32, BF16))); +INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation, + BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(F32, BF16))); class DotStrengthReductionTest : public AlgebraicSimplifierTest, @@ -4691,5 +4815,156 @@ INSTANTIATE_TEST_SUITE_P( DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); +TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { + const char* hlo_string = R"( +HloModule module + +reducer { + parameter.1 = f32[] parameter(0) + parameter.3 = f32[] parameter(2) + add.2 = f32[] add(parameter.1, parameter.3) + parameter.0 = f32[] parameter(1) + parameter.2 = f32[] parameter(3) + add.3 = f32[] add(parameter.0, parameter.2) + ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3) +} + +ENTRY entry { + parameter.6 = (f32[], f32[]) parameter(0) + get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0 + get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1 + constant = f32[] constant(0) + ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Tuple( + m::Reshape(m::GetTupleElement(m::Parameter(), 0)), + m::Reshape(m::GetTupleElement(m::Parameter(), 1))))); +} + +TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) { + const char* hlo_string = R"( +HloModule module + +reducer { + parameter.1 = f32[] parameter(0) + parameter.3 = f32[] parameter(2) + mul.2 = f32[] add(parameter.1, parameter.3) + parameter.0 = f32[] parameter(1) + parameter.2 = f32[] parameter(3) + add.3 = f32[] add(parameter.0, parameter.2) + ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3) +} + +ENTRY entry { + parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0) + get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0 + get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1 + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)), + m::Broadcast(m::ConstantScalar(1))))); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1}), "param")); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {0, 1}), param, {1})); + + // Create a reshape with zero sized result and without layout. + Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0}); + reshaped_shape.clear_layout(); + builder.AddInstruction( + HloInstruction::CreateReshape(reshaped_shape, broadcast)); + + std::unique_ptr module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {}); + shape.clear_layout(); + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + + HloInstruction* const_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(20.0f))); + builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + param, const_value)); + + std::unique_ptr module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Multiply())); +} + +// Test that 1/sqrt(X) is simplified to rsqrt(X). +TEST_F(AlgebraicSimplifierTest, RecipSqrt) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + sqrt = f32[] sqrt(p0) + ROOT div = f32[] divide(p1, sqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder(m::Parameter(1), + m::Rsqrt(m::Parameter(0))))); +} + +// Test that 1/rsqrt(X) is simplified to sqrt(X). +TEST_F(AlgebraicSimplifierTest, RecipRsqrt) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + rsqrt = f32[] rsqrt(p0) + ROOT div = f32[] divide(p1, rsqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder(m::Parameter(1), + m::Sqrt(m::Parameter(0))))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 8e327ee84b4c1cdab12a837aa684f14dd0c172c4..52d6982c70f7962ea9f54db0a4b1f2089a122c1c 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -29,18 +29,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace { - namespace m = match; -// Returns true iff the argument instruction is an AllReduce, followed by a -// certain sequence of instructions and then a CRS. It must be possible to move -// the AR past each instruction in the sequence. -bool MatchesArCrsPattern(HloInstruction* instruction) { +// Checks if the argument instruction is an AllReduce, followed by a certain +// sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. Returns the CRS, which is the +// last instruction in the sequence. +absl::optional ArCrsCombiner::MatchesArCrsPattern( + HloInstruction* instruction) { auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { if (instruction->user_count() != 1) { return false; @@ -73,21 +72,26 @@ bool MatchesArCrsPattern(HloInstruction* instruction) { if (!instruction->IsCrossModuleAllReduce() || !computation_is_addition(instruction->called_computations()[0]) || instruction->user_count() != 1) { - return false; + return absl::nullopt; } auto next = instruction->users()[0]; + int64 distance = 1; while (!next->IsCrossReplicaAllReduce()) { if (can_ar_move_past_instruction(next)) { next = next->users()[0]; } else { - return false; + return absl::nullopt; } + ++distance; + } + if (!Cast(next)->IsNoop() && + computation_is_addition(next->called_computations()[0])) { + return absl::optional(ArCrsPair(instruction, next, distance)); + } else { + return absl::nullopt; } - return computation_is_addition(next->called_computations()[0]); } -} // namespace - absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); @@ -99,7 +103,7 @@ absl::optional ArCrsCombiner::WhileFromBodyParameter( return caller_instruction; } } - return absl::optional(); + return absl::nullopt; } std::vector ArCrsCombiner::GetAllTuples( @@ -229,10 +233,56 @@ bool ArCrsCombiner::InstructionsComputeSameValue( } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS), + // ... , (ARn, CRS). + // If as we traverse the HLO graph we start tracking the pair (AR2, CRS), + // and later find that AR1's distance from the CRS is longer, we discard + // AR2 and start tracking AR1. We put the discarded ids in this set, in order + // to skip processing of short paths when we encounter the other ARs that + // have the same id as AR2. + absl::flat_hash_set discarded_ar_ids; for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - if (MatchesArCrsPattern(instruction)) { - all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction); + auto maybe_pair = MatchesArCrsPattern(instruction); + if (maybe_pair) { + auto pair = *maybe_pair; + int64 ar_id = *(instruction->all_reduce_id()); + if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) { + continue; + } + auto it = crs_reserved_map_.find(pair.crs); + if (it != crs_reserved_map_.end()) { + auto prev_ar_id = it->second; + // Since there is another AR paired with CRS, + // all_reduce_map_[prev_ar_id] should exist, but + // all_reduce_map_[ar_id] shouldn't. + CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end()); + CHECK_NE(prev_ar_id, ar_id); + auto prev_pair = all_reduce_map_[prev_ar_id].back(); + int64 prev_distance = prev_pair.distance; + if (prev_distance < pair.distance) { + // The current AR's distance to CRS is longer than the previously + // tracked AR, so we discard the previous AR. + all_reduce_map_.erase(prev_ar_id); + discarded_ar_ids.insert(prev_ar_id); + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; + } else { + // Discard the current AR id because we are keeping the previously + // tracked AR. + discarded_ar_ids.insert(ar_id); + } + } else { + if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) { + int64 prev_distance = all_reduce_map_[ar_id].back().distance; + CHECK_EQ(prev_distance, pair.distance) + << "All ARs with the same AR ID must have the same distance " + "from the corresponding CRSs. Found: " + << prev_distance << " and " << pair.distance; + } + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; + } } } } @@ -241,11 +291,11 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { auto all_reduce_id = it.first; - auto instruction_vec = it.second; - CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - for (int i = 1; i < instruction_vec.size(); ++i) { - auto instr_i = instruction_vec[i]; + auto pairs_vec = it.second; + CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); + auto instr_0 = pairs_vec[0].ar; + for (int i = 1; i < pairs_vec.size(); ++i) { + auto instr_i = pairs_vec[i].ar; auto next_0 = instr_0->users()[0]; auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; @@ -269,8 +319,9 @@ StatusOr ArCrsCombiner::RewriteGraph() { return false; } for (auto it : all_reduce_map_) { - auto instruction_vec = it.second; - for (auto all_reduce : instruction_vec) { + auto pairs_vec = it.second; + for (auto pair : pairs_vec) { + auto all_reduce = pair.ar; auto parent_computation = all_reduce->parent(); auto all_reduce_id = all_reduce->all_reduce_id(); auto prev = all_reduce->mutable_operand(0); @@ -291,16 +342,23 @@ StatusOr ArCrsCombiner::RewriteGraph() { ? next->operands()[1] : next->operands()[0]; // To move the AR past the addition/subtraction, we need to divide - // other_operand by the number of spatial partitions. - auto shape = other_operand->shape(); - Literal lit(shape); - lit.PopulateWithValue(num_spatial_partitions_); - auto divisor = parent_computation->AddInstruction( - HloInstruction::CreateConstant(lit.Clone())); - auto division = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDivide, other_operand, divisor)); - TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + // other_operand by the number of spatial partitions, except if + // other_operand is a cross-module AR, which can be eliminated. + if (other_operand->IsCrossModuleAllReduce() && + other_operand->user_count() == 1) { + TF_CHECK_OK(other_operand->ReplaceAllUsesWith( + other_operand->mutable_operand(0))); + } else { + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = parent_computation->AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + } break; } default: diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index 6f54b97615b270bc6b180dd47d9aff6473752b47..f503e1d5f2b519687e40818a61f0c0be9dfd3ab0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -26,11 +26,47 @@ limitations under the License. namespace xla { // When the HLO graph contains a cross-module AllReduce, followed by some simple -// linear operations, followed by a cross-replica AllReduce, we can combine the -// CMAR and the CRAR, to use an efficient AllReduce implementation that fully -// utilizes the interconnect bandwidth. +// linear operations, followed by a cross-replica AllReduce (also known as +// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an +// efficient AllReduce implementation that fully utilizes the interconnect +// bandwidth. // Such sequences appear in spatially partitioned models. -// This pass must run right after spatial partitioning. +// This pass must run right after spatial partitioning, when the code is still +// in a single HLO module. +// +// The steps are: +// 1) Find CMARs followed by simple ops followed by CRARs. +// 2) Group CMARs by all_reduce_id. They must all be rewritten. +// 3) Prove that the CMAR patterns in each core produce the same result. +// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the +// other operand by the number of spatial partitions. +// 5) Turn the CRAR into an all-core AllReduce. +// +// The pass also handles the case where multiple CMARs lead to the same CRAR, +// and eliminates all CMARs. This graph: +// +// Y +// | +// X CMAR_2 Z +// | \ / +// CMAR_1 + +// \ / +// + +// | +// CRAR +// +// gets rewritten to: +// +// Z num_partitions +// \ / +// Y div +// \ / +// X + +// \ / +// + +// | +// all-core AR +// class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) @@ -43,6 +79,28 @@ class ArCrsCombiner : public HloModulePass { HloInstruction* i2); private: + // We used this struct because multiple ARs could be paired with the same CRS. + // In this case, we want to select the AR that is furthest from the CRS, + // because it makes it easier to eliminate all ARs during RewriteGraph. + struct ArCrsPair { + HloInstruction* ar; + HloInstruction* crs; + // The length of the path from AR to CRS in the HLO graph. + int64 distance; + + ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, + int64 dist) + : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} + + string ToString() { + return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(), + ", distance: ", distance, ")"); + } + }; + + absl::optional MatchesArCrsPattern( + HloInstruction* instruction); + // If the passed instruction is a while parameter, and the while body is only // called by a single while instruction, return the while instruction. absl::optional WhileFromBodyParameter( @@ -80,8 +138,13 @@ class ArCrsCombiner : public HloModulePass { int num_spatial_partitions_; - // Map from all-reduce ids to the all reduce instructions. - absl::flat_hash_map> all_reduce_map_; + // Map from all-reduce ids to the AR/CRS pairs. + absl::flat_hash_map> all_reduce_map_; + + // Map from a CRS instruction to the all-reduce ID of the AR paired with the + // CRS. Sometimes, several ARs in the code could be paired with the same CRS. + // We use this map to pick a single AR/CRS path to rewrite. + absl::flat_hash_map crs_reserved_map_; std::unique_ptr call_graph_; }; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 08eb77f0320ff47099873bcdeebff7844aaac125..b972b1289b92a8f29818ff74512f679b9f44a131 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -221,7 +221,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -258,7 +258,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -296,7 +296,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -773,5 +773,405 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.11 = f32[] + add(%constant.1, %all-reduce.ar.1), + sharding={maximal device=0} + %add.12 = f32[] + add(%constant.2, %add.11), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%add.12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.21 = f32[] + add(%constant.1, %all-reduce.ar.2), + sharding={maximal device=0} + %add.22 = f32[] + add(%constant.2, %add.21), + sharding={maximal device=0} + %all-reduce.2 = f32[] + all-reduce(%add.22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))), + op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %sub.1 = f32[] + subtract(%constant.f32, %all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%sub.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %sub.2 = f32[] + subtract(%constant.f32, %all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%sub.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())), + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar11, %const1), + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add12 = f32[] + add(%add11, %ar12), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar21, %const1), + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add22 = f32[] + add(%add21, %ar22), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Parameter())), + op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar12, %const1), + sharding={maximal device=0} + %add12 = f32[] + add(%ar11, %add11), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar22, %const1), + sharding={maximal device=1} + %add22 = f32[] + add(%ar21, %add21), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))), + op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))))); + + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 215e8ced4bb3f98a26ac4eb9912a7fd4d917852f..d016d3e03d5e994841b81cda6214b6ff7cb550be 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/byte_order.h" @@ -67,18 +66,38 @@ const absl::optional>& BackendOptions::allowed_devices() const { return allowed_devices_; } +namespace { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool) + : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + tensorflow::thread::ThreadPool* pool_ = nullptr; +}; + +} // namespace + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. -struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper(const int num_threads) +struct Backend::IntraOpThreadPool { + explicit IntraOpThreadPool(const int num_threads) : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), "XLAEigen", num_threads)), - wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + wrapper(new EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} std::unique_ptr pool; - std::unique_ptr wrapper; + std::unique_ptr wrapper; std::unique_ptr device; }; @@ -146,8 +165,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); - intra_op_thread_pool_wrapper_.reset( - new EigenThreadPoolWrapper(num_threads)); + intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads)); } } @@ -159,17 +177,17 @@ int Backend::default_device_ordinal() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->device.get(); + return intra_op_thread_pool_->device.get(); } tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->pool.get(); + return intra_op_thread_pool_->pool.get(); } StatusOr Backend::stream_executor( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index c35f033dc0180409ae3888c2050021da83f5c72a..e7f29a044b95015aa7e547373c24971646833280 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -156,7 +156,6 @@ class Backend { Status ResetDevices(); private: - struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, absl::Span stream_executors, TransferManager* transfer_manager, @@ -183,7 +182,8 @@ class Backend { std::unique_ptr memory_allocator_; // For the CPU backend, an Eigen threadpool device for use by Eigen code. - std::unique_ptr intra_op_thread_pool_wrapper_; + struct IntraOpThreadPool; + std::unique_ptr intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index eda026ac5685dc469a6230094eb28b3618e36400..dbabd82dd55465dd4c85a56aea849a3e3702d6bf 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( *rhs = batch_dot->mutable_operand(1); const Shape& lhs_shape = lhs->shape(); + // A dot with no contracting dims will be rewritten into a multiply by + // AlgebraicSimplifier. Dots with multiple contracting dims are currently + // unsupported. + if (dim_numbers.lhs_contracting_dimensions_size() != 1) { + return false; + } + std::vector degenerate_dims; for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { if (lhs_shape.dimensions(batch_dim) == 1) { diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 52ec1a794c5e9f4452a4bf2b648f453d8acfe976..a81f394a38f091b89b7f1e4d26653ff549f35b75 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -169,5 +169,47 @@ main { /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); } +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsNonContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,101] parameter(0) + b = f32[1,101] parameter(1) + ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0}, + lhs_contracting_dims={}, + rhs_batch_dims={0}, + rhs_contracting_dims={} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsMultipleContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + lhs = f32[1,5,17,10,13] parameter(0) + rhs = f32[1,9,10,13,6,5] parameter(1) + ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0}, + rhs_batch_dims={2,0}, + lhs_contracting_dims={1,4}, + rhs_contracting_dims={5,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index e5f5c3edb2ac0c217317fbf809463aa31af9af59..d14e803be6ad6d0b7a7e22442de7e6da77f93577 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -95,15 +95,8 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(-0.5f))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, - operand, exponent); + return HloInstruction::CreateUnary(operand->shape(), HloOpcode::kRsqrt, + operand); } std::unique_ptr Mean( @@ -524,7 +517,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = + auto grad_output_times_activation_minus_mean = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, activation_minus_mean); @@ -532,9 +525,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). - auto sum_grad_output_times_activiation_minus_mean = + auto sum_grad_output_times_activation_minus_mean = add(HloInstruction::CreateReduce( - feature_shape, grad_output_times_activiation_minus_mean, zero, + feature_shape, grad_output_times_activation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). @@ -544,7 +537,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, + sum_grad_output_times_activation_minus_mean, rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) @@ -553,7 +546,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( // I3 = Sum(Grad[Y] * (X - E[X])) auto i3 = add(HloInstruction::CreateBroadcast( - activation_shape, sum_grad_output_times_activiation_minus_mean, + activation_shape, sum_grad_output_times_activation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 8e8fbbd935b154e5a77d68e60d861601d740bf03..34b516184fa861bd71f99f70a32782d242f11914 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -60,7 +60,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, input_shape, "activiation")); + HloInstruction::CreateParameter(0, input_shape, "activation")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, scale_shape, "scale")); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index d1b14d604f0559b6b18f7d1fba127669c241c8a3..72459961485f77b690eed6b8bde2cd03ebe770f1 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -84,7 +84,12 @@ Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( auto convert = computation->AddInstruction( HloInstruction::CreateConvert(hlo->shape(), hlo)); for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == F32) { + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + } else { + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); + } } if (is_root) { computation->set_root_instruction(convert); @@ -205,6 +210,28 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( return Status::OK(); } + std::vector bf16_called_comps; + for (auto* comp : hlo->called_computations()) { + bool comp_has_bf16 = false; + if (comp->root_instruction()->shape().element_type() == F32) { + f32_count += 1; + } else if (comp->root_instruction()->shape().element_type() == BF16) { + bf16_count += 1; + comp_has_bf16 = true; + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == F32) { + f32_count += 1; + } else if (param->shape().element_type() == BF16) { + bf16_count += 1; + comp_has_bf16 = true; + } + } + if (comp_has_bf16) { + bf16_called_comps.push_back(comp); + } + } + std::vector materialized_users = hlo->users(); std::vector output_elements(hlo->operand_count()); auto original_shape = hlo->shape(); @@ -236,7 +263,7 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( computation_->set_root_instruction(tuple); } *tuple->mutable_shape() = original_shape; - return Status::OK(); + return ConvertCalledComputations(hlo, bf16_called_comps); } Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 551ac4be73a7630d213a53ca3606aa7f890cd794..7dd46ca4e048210843e227c79f639be1bd34fe30 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -282,8 +283,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction* value = builder.AddInstruction( HloInstruction::CreateParameter(1, s32_shape, "value")); - HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); @@ -308,8 +312,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { HloInstruction* value = builder.AddInstruction( HloInstruction::CreateParameter(1, bf16_shape, "value")); - HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); auto computation = module->AddEntryComputation(builder.Build()); @@ -319,6 +326,14 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); EXPECT_NE(computation->root_instruction(), sort); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple); + EXPECT_EQ(sort->to_apply()->parameter_instruction(1)->shape().element_type(), + F32); + // Make sure that no convert to BF16 was added to the 'to_apply' comparison + // computation. + auto users = sort->to_apply()->parameter_instruction(1)->users(); + for (auto user : users) { + EXPECT_NE(user->opcode(), HloOpcode::kConvert); + } } // Tests that the normalization should not cause unsupported mixed precision due diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index a9b5d9916e400b39039248098c22a715e44ccfd2..357d38a5548b2aaa120f06eed26fe54c9f3f46ac 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -109,8 +109,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = @@ -574,8 +574,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { HloInstruction::CreateParameter(0, shape, "cond_param")); auto cond_dot = builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); - auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction( @@ -583,9 +583,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { cond_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -631,8 +632,8 @@ TEST_F(BFloat16PropagationTest, auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction(HloInstruction::CreateSlice( @@ -642,7 +643,8 @@ TEST_F(BFloat16PropagationTest, ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, - {1, 1})))))); + {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -705,8 +707,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); auto cond_dot = builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); - builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction( @@ -714,9 +716,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { cond_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -800,8 +803,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); auto cond0_dot = builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); - builder_cond0.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond0.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond0.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond0.AddInstruction( @@ -809,9 +812,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { cond0_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond0.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -828,8 +832,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); auto cond1_dot = builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); - builder_cond1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond1.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond1.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond1.AddInstruction( @@ -837,9 +841,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { cond1_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond1.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index d07615b828990f80e2f905837c46f5f2e15d5a63..cb682f49a5c8097b2fa5ce15ea9fdbbcf46668b4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -191,6 +191,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. worklist.push_back(std::make_pair(subcomputation, @@ -752,7 +753,8 @@ namespace { bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment, const LogicalBuffer& a_buffer, const LogicalBuffer& b_buffer) { - auto call_graph = assignment->liveness().hlo_ordering().call_graph(); + const CallGraph& call_graph = + assignment->liveness().hlo_ordering().call_graph(); const HloInstruction* a_ancestor; const HloInstruction* b_ancestor; std::tie(a_ancestor, b_ancestor) = @@ -1011,10 +1013,14 @@ Status BufferAssigner::AssignBuffersForComputation( // callers. BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer->index()); allocation->set_entry_computation_parameter( - instruction->parameter_number(), buffer->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for entry computation parameter: " << *buffer; + instruction->parameter_number(), buffer->index(), + parameter_has_alias); + VLOG(3) << "Mark allocation #" << allocation->index() + << " as entry computation parameter: " << *buffer; continue; } @@ -1416,12 +1422,14 @@ BufferAssigner::MergeColocatedBufferSets( << colocated_buffer_sets.size(); // Returns true if the given buffer is for the entry parameter. - auto is_entry_parameter = [](const LogicalBuffer& buffer) { + auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) { auto* instruction = buffer.instruction(); auto* computation = instruction->parent(); auto* module = computation->parent(); return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation(); + computation == module->entry_computation() && + !module->input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer.index()); }; std::vector set_can_be_merged(colocated_buffer_sets.size(), true); @@ -1443,7 +1451,7 @@ BufferAssigner::MergeColocatedBufferSets( for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { for (auto& buffer : colocated_buffer_sets[i]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || + is_readonly_entry_parameter(*buffer) || buffer->instruction()->opcode() == HloOpcode::kConstant) { set_can_be_merged[i] = false; break; @@ -1612,62 +1620,46 @@ void BufferAssigner::BuildColocatedBufferSets( AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); } else if (opcode == HloOpcode::kConditional) { - const HloInstruction* conditional_hlo = instruction; + const HloInstruction* conditional = instruction; ShapeUtil::ForEachSubshape( - conditional_hlo->shape(), - [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + conditional->shape(), + [this, conditional, &points_to_analysis, colocated_buffer_sets]( const Shape& /*subshape*/, const ShapeIndex& index) { std::vector colocated_set; - // Add conditional.result. - AddBufferToColocatedSet(conditional_hlo, index, - points_to_analysis, &colocated_set); - // Add conditional.true_computation.root. - AddBufferToColocatedSet( - conditional_hlo->true_computation()->root_instruction(), - index, points_to_analysis, &colocated_set); - // Add conditional.false_computation.root. - AddBufferToColocatedSet( - conditional_hlo->false_computation()->root_instruction(), - index, points_to_analysis, &colocated_set); + // Add cond.result. + AddBufferToColocatedSet(conditional, index, points_to_analysis, + &colocated_set); + for (int j = 0; j < conditional->branch_count(); ++j) { + // Add each cond.branch_computation[j].root. + AddBufferToColocatedSet( + conditional->branch_computation(j)->root_instruction(), + index, points_to_analysis, &colocated_set); + } AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); - // Add true_operand and conditional.true_computation.parameter(0) as a - // colocated buffer set. Note that this has to be done for each subshape - // in the true_operand of the conditional. - ShapeUtil::ForEachSubshape( - conditional_hlo->operand(1)->shape(), - [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { - std::vector true_set; - // Add conditional.true_operand. - AddBufferToColocatedSet(conditional_hlo->operand(1), index, - points_to_analysis, &true_set); - // Add conditional.true_computation.parameter_instruction(0). - AddBufferToColocatedSet( - conditional_hlo->true_computation()->parameter_instruction(0), - index, points_to_analysis, &true_set); - AddSetToColocatedBufferSets(true_set, colocated_buffer_sets); - }); - - // Add false_operand and conditional.false_computation.parameter(0) as a - // colocated buffer set. Note that this has to be done for each subshape - // in the false_operand of the conditional. - ShapeUtil::ForEachSubshape( - conditional_hlo->operand(2)->shape(), - [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { - std::vector false_set; - // Add conditional.false_operand. - AddBufferToColocatedSet(conditional_hlo->operand(2), index, - points_to_analysis, &false_set); - // Add conditional.false_computation.parameter_instruction(0). - AddBufferToColocatedSet( - conditional_hlo->false_computation()->parameter_instruction( - 0), - index, points_to_analysis, &false_set); - AddSetToColocatedBufferSets(false_set, colocated_buffer_sets); - }); + for (int j = 0; j < conditional->branch_count(); ++j) { + // Add branch_operand[j] (which is operand[j+1]) and + // cond.branch_computation[j].parameter(0) as a colocated + // buffer set. Note that this has to be done for each subshape in the + // branch_operand of the case. + ShapeUtil::ForEachSubshape( + conditional->operand(j + 1)->shape(), + [this, j, conditional, &points_to_analysis, + colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { + std::vector branch_set; + // Add cond.operand[j+1]. + AddBufferToColocatedSet(conditional->operand(j + 1), index, + points_to_analysis, &branch_set); + // Add cond.branch_computation[j].parameter_instruction(0). + AddBufferToColocatedSet( + conditional->branch_computation(j)->parameter_instruction( + 0), + index, points_to_analysis, &branch_set); + AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets); + }); + } } } } @@ -1733,10 +1725,6 @@ void BufferAssigner::AssignColocatedBufferSets( // module-level scope, we can allow buffers to be shared across // computations (in some cases). allocation = assignment->NewAllocation(*buffer, buffer_size); - if (entry_parameter_number >= 0) { - allocation->set_entry_computation_parameter( - entry_parameter_number, *entry_parameter_shape_idx); - } if (is_constant) { allocation->set_constant(true); } @@ -1750,6 +1738,16 @@ void BufferAssigner::AssignColocatedBufferSets( } colocated_buffers->insert(buffer); } + + // If an allocation contains a parameter, set corresponding fields. + if (entry_parameter_number >= 0) { + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + entry_parameter_number, *entry_parameter_shape_idx); + allocation->set_entry_computation_parameter(entry_parameter_number, + *entry_parameter_shape_idx, + parameter_has_alias); + } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 4baab9b6ad71293d48d5ed70c2922fdf40ef119a..448dec3b1aa0c0f85e1060a70e965fcf3952c320 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -96,7 +96,11 @@ class BufferAllocation { // Whether this allocation is readonly i.e. backed by memory we cannot write // to. bool is_readonly() const { - return is_entry_computation_parameter() || is_constant(); + // Entry parameters are generally readonly, except when they are aliased + // with any output. + return (is_entry_computation_parameter() && + !is_parameter_aliased_with_output_) || + is_constant(); } bool is_tuple() const { return is_tuple_; } @@ -274,8 +278,10 @@ class BufferAllocation { void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); void set_entry_computation_parameter(int64 parameter_number, - ShapeIndex param_shape_index) { + ShapeIndex param_shape_index, + bool parameter_aliased_with_output) { is_entry_computation_parameter_ = true; + is_parameter_aliased_with_output_ = parameter_aliased_with_output; parameter_number_ = parameter_number; param_shape_index_ = std::move(param_shape_index); } @@ -305,6 +311,9 @@ class BufferAllocation { // outlast the computation. bool is_entry_computation_parameter_ = false; + // Whether this entry computation parameter is aliased with output. + bool is_parameter_aliased_with_output_ = false; + // If this allocation holds an entry computation parameter, this field // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 1b4e93a2f303e5aad3e4081f36e2417277f62c71..704585033f076972dbd359fcd832834374566fa8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -190,8 +190,9 @@ class BufferAssignmentTest : public HloTestBase { HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index, + const4, ComparisonDirection::kLt)); return builder.Build(); } @@ -465,6 +466,40 @@ TEST_F(BufferAssignmentTest, Basic) { GetAssignedOutputAllocation(*buffers, sub); } +TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { + // If an input buffer and output buffer aliases, the input buffer can be + // reused for other intermediate results. + // + // param0[100] ----- (neg1) -- (neg2) + // | | + // + -------- Aliased ---------+ + + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "p0")); + auto neg_1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + auto neg_2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( + {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + + auto buffers = RunBufferAssignment(module.get()); + + BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param); + BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {}); + BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {}); + + // Everything use one buffer. + EXPECT_EQ(param_buffer.index(), neg_1_buffer.index()); + EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index()); +} + TEST_F(BufferAssignmentTest, AddCannotReuse) { // Pass in a special rule to indicate that "add" cannot reuse any buffer. // @@ -1829,8 +1864,8 @@ class WhileBufferAssignmentTest : public HloTestBase { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ten = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt)); return builder.Build(); } @@ -2101,8 +2136,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param, + const4, ComparisonDirection::kLt)); return builder.Build(); }; @@ -2496,7 +2532,7 @@ while_condition { state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) get-tuple-element = s32[] get-tuple-element(state), index=0 get-tuple-element.1 = s32[] constant(3) - ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1) + ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT } ENTRY entry_computation { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 94af788c54f6c722997311bec50da3ed93aa3cee..98304757cae91d22466ed25f8c6e36ce90a848db 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -64,6 +64,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: return CallContext::kParallel; default: diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index f41367914c0cb5fe66b1dbbc5ec6f8b7a67d592c..57a636fd740995d6cce933fe19d5592a64bde5cf 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -30,7 +30,7 @@ namespace xla { // The context in which a computation is called by another computation. enum class CallContext { - // In a parallel contex the computation is applied to each element of the + // In a parallel context the computation is applied to each element of the // array argument(s). kMap and kReduce instructions call computations in // parallel context. kParallel, @@ -256,6 +256,10 @@ class CallGraph { private: CallGraph(const HloModule* module); + // Not copyable. + CallGraph(const CallGraph&) = delete; + CallGraph& operator=(const CallGraph&) = delete; + // Sets the call contexts for every node in the graph. void SetCallContexts(); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 5de724f8924b78008ba4c56603b61bf93fbc5e7c..458aef1499954c5852486b015cb1474a5e70457f 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -83,8 +83,9 @@ class CallGraphTest : public HloTestBase { HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + zero, ComparisonDirection::kGt)); return builder.Build(); } diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc similarity index 74% rename from tensorflow/compiler/xla/client/lib/cholesky.cc rename to tensorflow/compiler/xla/service/cholesky_expander.cc index 414bd1494cd32f32a5c37e84119de930678a776b..1c39cf9bc0a093ec54715d4180b49094ca6266a0 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/cholesky.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include #include @@ -23,13 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -135,10 +135,8 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { }); } -} // namespace - -XlaOp Cholesky(XlaOp a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -194,12 +192,12 @@ XlaOp Cholesky(XlaOp a, int64 block_size, // l[i+k:, i:i+k] = // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); - auto update = TriangularSolve(factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size); + auto update = + TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } @@ -207,4 +205,55 @@ XlaOp Cholesky(XlaOp a, int64 block_size, }); } +} // namespace + +bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCholesky; +} + +StatusOr CholeskyExpander::ExpandInstruction( + HloInstruction* instruction) { + const CholeskyOptions& options = instruction->cholesky_options(); + const string name = absl::StrFormat( + "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(), + options.lower() ? "lower" : "upper"); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // TODO(b/62327888): We do something unusual here: we build the computation + // using the XlaBuilder API, which is nominally an XLA client API. We do + // this because the external APIs for building complicated computations + // (XlaBuilder) are much more ergonomic than the internal ones. As it turns + // out, XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()), + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST); + MaybeTransposeInMinorDims(l, !options.lower()); + + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/cholesky_expander.h b/tensorflow/compiler/xla/service/cholesky_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..d2958db1b8ca676f3872016ac6a62b872a6b6649 --- /dev/null +++ b/tensorflow/compiler/xla/service/cholesky_expander.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class CholeskyExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "cholesky_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 1965925fa7f6d50b1d7af918bc3468d4b4d5d0a2..a4758c2b9dbba8a1c560c8f2dc7a182e456f5e69 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -70,26 +71,14 @@ CompileOnlyService::CompileAheadOfTime( TF_RET_CHECK(instance.computation.has_host_program_shape()); const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - HloSnapshot hlo_snapshot; - *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = - absl::StrCat("computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR( - Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); - } - ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = instance.result_layout->ToProto(); + if (options.has_static_device_assignment()) { + TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( + execution_options.mutable_device_assignment())); + } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( @@ -99,7 +88,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); + DumpHloModuleIfEnabled(*hlo_module, "before_optimizations"); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d4db95da8eb901af8a6675f2991def73ccfe8ee6..9b483bd97e91720ded089abca593541ae532dedd 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -82,12 +83,24 @@ class AotCompilationOptions { const DebugOptions& debug_options() const { return debug_options_; } DebugOptions* mutable_debug_options() { return &debug_options_; } + bool has_static_device_assignment() const { + return static_device_assignment_.has_value(); + } + const DeviceAssignment& static_device_assignment() const { + CHECK(static_device_assignment_.has_value()); + return *static_device_assignment_; + } + void set_static_device_assignment(const DeviceAssignment& device_assignment) { + static_device_assignment_ = device_assignment; + } + protected: AotCompilationOptions(); private: DeviceMemoryAllocator* device_allocator_ = nullptr; DebugOptions debug_options_; + absl::optional static_device_assignment_; }; // Abstract superclass describing metadata produced during ahead-of-time diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index 4ea3a13f2835c5fef99c274f14d7d683c9ff5fc8..f1d0ca44f08688ccda5b4385d65eabc0fc2fc5e6 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -33,8 +33,8 @@ limitations under the License. namespace xla { // Tries to replace a conditional with a call operation of the corresponding -// computation. If the given conditional has a constant predicate, tries to -// replace it with a call to its true/false computation as appropriate and then +// computation. If the given conditional has a constant branch_index, tries to +// replace it with a call to its corresponding branch computation and then // inline that computation. // // Returns true if it made a change to the graph. @@ -50,24 +50,30 @@ static StatusOr TryRemoveConditional(HloInstruction* conditional) { return false; } - if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { - VLOG(2) << "Not attempting to remove conditional as its predicate is not a " - "compile-time constant: " - << conditional->ToShortString(); - return false; - } + // We can always inline a 1-branch conditional due to default branch fallback. + int branch_index = 0; + if (conditional->branch_count() > 1) { + if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { + VLOG(2) << "Not attempting to remove conditional as its branch_index is " + "not a compile-time constant: " + << conditional->ToShortString(); + return false; + } + if (conditional->operand(0)->shape().element_type() == PRED) { + branch_index = conditional->operand(0)->literal().Get({}) ? 0 : 1; + } else { + branch_index = conditional->operand(0)->literal().Get({}); + if (branch_index < 0 || branch_index >= conditional->branch_count()) { + branch_index = conditional->branch_count() - 1; + } + } + } auto computation = conditional->parent(); HloInstruction* call_op; - if (conditional->operand(0)->literal().Get({})) { - call_op = computation->AddInstruction(HloInstruction::CreateCall( - conditional->shape(), {conditional->mutable_operand(1)}, - conditional->true_computation())); - } else { - call_op = computation->AddInstruction(HloInstruction::CreateCall( - conditional->shape(), {conditional->mutable_operand(2)}, - conditional->false_computation())); - } + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(branch_index + 1)}, + conditional->branch_computation(branch_index))); conditional->SetupDerivedInstruction(call_op); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index b3ed27d9a849eced006eb3b01977ad2fe7ed7367..434bbe9ffd5da58901a65d1c51b77c33b9afa81c 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -191,8 +191,9 @@ HloInstruction* GetExpandedFilterMask( // linspace to create a diagonal predicate. Shape predicate_shape = ShapeUtil::MakeShape( PRED, AsInt64Slice(expanded_filter_shape.dimensions())); - return add_instruction(HloInstruction::CreateBinary( - predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); + return add_instruction(HloInstruction::CreateCompare( + predicate_shape, broadcasted_mask1, broadcasted_mask2, + ComparisonDirection::kEq)); } // This function handles batch_group_counts which are relevant only for @@ -215,81 +216,15 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { }; int64 input_batch_dimension = dim_numbers.input_batch_dimension(); - int64 input_feature_dimension = dim_numbers.input_feature_dimension(); int64 output_batch_dimension = dim_numbers.output_batch_dimension(); int64 output_feature_dimension = dim_numbers.output_feature_dimension(); - int64 kernel_input_feature_dimension = - dim_numbers.kernel_input_feature_dimension(); int64 input_batch = activation->shape().dimensions(input_batch_dimension); // We are not yet supporting batch_group of sizes greater than 1. TF_RET_CHECK(input_batch == batch_group_count); - if (is_cost_viable_(convolution)) { - // Add a dimension to the activation, and reshape. - Shape reshaped_activation_shape = activation->shape(); - ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape); - - activation = add( - HloInstruction::CreateReshape(reshaped_activation_shape, activation)); - - // Add a dimension to the filter, and reshape. - Shape reshaped_filter_shape = filter->shape(); - ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); - - filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - - int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; - - Shape new_output_shape = convolution->shape(); - ShapeUtil::AppendMajorDimension(1, &new_output_shape); - - int64 input_feature = - activation->shape().dimensions(input_feature_dimension); - - // The code below edits convolution dimension numbers. Please refer to - // conv_op_helpers.cc to find how the dimensions were set up originally. - - // Effectively, the new input batch becomes 1, and so does the kernel - // input feature. The original input batch now becomes a spatial dimension. - // The output batch (remember that the output is the new kernel for in - // backprop) becomes a spatial dimension too. - - dim_numbers.set_input_batch_dimension(new_spatial_dim); - dim_numbers.set_input_feature_dimension(input_batch_dimension); - dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); - - dim_numbers.add_input_spatial_dimensions(input_feature_dimension); - dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dimension); - - dim_numbers.add_output_spatial_dimensions(output_batch_dimension); - dim_numbers.set_output_batch_dimension(new_spatial_dim); - - // Add window for the new spatial dimension. - Window new_window = convolution->window(); - auto* dim = new_window.add_dimensions(); - dim->set_window_dilation(1); - dim->set_base_dilation(1); - dim->set_stride(1); - dim->set_size(input_feature); - - auto new_convolution = add(HloInstruction::CreateConvolve( - new_output_shape, activation, filter, - /*feature_group_count=*/batch_group_count, /*batch_group_count=*/1, - new_window, dim_numbers, convolution->precision_config())); - - // Delete the extra spatial dimension, and reshape. - Shape reshaped_convolution_shape = ShapeUtil::DeleteDimension( - new_spatial_dim - 1, new_convolution->shape()); - auto reshaped_convolution = HloInstruction::CreateReshape( - reshaped_convolution_shape, new_convolution); - - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(reshaped_convolution))); - - changed_ = true; - } else { + if (!is_cost_viable_(convolution) || filter_expansion_) { // We first obtain the expanded the filter (which is the convolution // output). The batch dimension is the expanded one (which originally // represents kernel input feature dimension). We mask the filter to zero @@ -316,14 +251,27 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution, zero_filter)); - auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(F32)); + PrimitiveType reduce_type = new_filter->shape().element_type(); + auto reduce_window_shape = new_convolution->shape(); + reduce_window_shape.set_dimensions(output_batch_dimension, 1); + + // Ensure that data input to reduce window uses at least 32 bits. + if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) { + reduce_type = F32; + reduce_window_shape.set_element_type(F32); + Shape convert_shape = new_filter->shape(); + convert_shape.set_element_type(F32); + new_filter = + add(HloInstruction::CreateConvert(convert_shape, new_filter)); + } + + auto zero_literal = LiteralUtil::Zero(reduce_type); auto zero_scalar = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto reduce_function = [&]() -> HloComputation* { HloComputation::Builder b("add_computation"); - Shape shape = ShapeUtil::MakeShape(F32, {}); + Shape shape = ShapeUtil::MakeShape(reduce_type, {}); auto lhs = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); auto rhs = @@ -333,19 +281,6 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); }; - auto reduce_window_shape = new_convolution->shape(); - reduce_window_shape.set_dimensions(output_batch_dimension, 1); - - // Ensure that data input to reduce window is of type F32. - if (primitive_util::BitWidth(new_filter->shape().element_type()) < - primitive_util::BitWidth(F32)) { - reduce_window_shape.set_element_type(F32); - Shape convert_shape = new_filter->shape(); - convert_shape.set_element_type(F32); - new_filter = - add(HloInstruction::CreateConvert(convert_shape, new_filter)); - } - // Create the reduce window. Window window; for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { @@ -375,6 +310,7 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(reduce_window_converted))); + changed_ = true; } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc index 585b81a5db632901be863893bf723fcba19388ea..9cee3eda95252d6c7d725fbb03030bd58f52e71f 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -109,16 +109,16 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - auto cost_model = [](HloInstruction* conv) { return true; }; + auto cost_model = [](HloInstruction* conv) { return false; }; ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ true); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); - // Make sure the convolution is converted to one with batch_group_count = 1. - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConvolution); - EXPECT_EQ(root->operand(0)->batch_group_count(), 1); - // Verify that the convolution is replaced by a reshape. - EXPECT_EQ(root->opcode(), HloOpcode::kReshape); + + // Verify that the convolution is replaced by a convert. + EXPECT_EQ(root->opcode(), HloOpcode::kConvert); + // Make sure the convert is being fed by a reduce window. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow); } } // namespace diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 5e26a63cebfa9b2e50f4b13335c10c246999d4df..8cb64a335301cd8b340ead9ad7d6ec868d9b2065 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -193,8 +194,8 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, // Add kCopy instructions around the given kWhile instruction to eliminate any // possible live range interference of HLO values assuming a dependency-based // ordering (HloDependencyOrdering). Copies are added conservatively. There -// likely are copies which are not strictly necessary, but there are removed -// later in the pass via CopyRemover. +// likely are copies which are not strictly necessary, but they are removed +// later in the pass via RemoveUnnecessaryCopies. // // // Elements (each ShapeIndex) in the loop state are considered independently. A @@ -310,17 +311,16 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computation roots, -// in order to resolve interference. We later rely on the CopyRemover to drop -// the unnecessary ones. +// We add copies for all the indices of the true and false computation roots, in +// order to resolve interference. We later rely on RemoveUnnecessaryCopies to +// drop the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) { VLOG(2) << "Adding copies for kConditional instruction " << conditional->name(); TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); - for (HloComputation* computation : - {conditional->true_computation(), conditional->false_computation()}) { + for (HloComputation* computation : conditional->branch_computations()) { HloInstruction* root = computation->root_instruction(); std::vector users = root->users(); TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, @@ -335,7 +335,8 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, // Conservatively adds copies before root instruction of entry computation and // each aliased parameter to resolve interference of aliased input and output -// buffer. We later rely on the CopyRemover to drop the unnecessary ones. +// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary +// ones. Status AddCopiesForAliasedInputOutputs(HloModule* module) { HloComputation* entry = module->entry_computation(); HloInstruction* root = entry->root_instruction(); @@ -433,587 +434,528 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) { return Status::OK(); } -// Class for removing unnecessary copies from the module. +// Class which tracks the HLO values within each HLO buffer in the module +// during copy removal. // -// kCopy instructions are added conservatively to guarantee no live range -// interference between HLO values. This class uses a more fine-grained analysis -// to remove some of these added copies which are not strictly necessary. +// The values are held in a linked list where there is one list for each +// buffer. Removing a copy instruction merges together the values in the +// source buffer of the copy to the destination buffer of the copy. This class +// tracks these value lists as copies are removed from the graph (and value +// lists are merged). +// +// The CopyRemover object is initialized to match the state of +// HloAliasAnalysis. However, as copies are removed this state diverges. The +// values-to-buffer mapping is maintained outside of HloAliasAnalysis because +// a fully updatable alias analysis is very slow. class CopyRemover { public: - CopyRemover(const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering, HloModule* module) - : module_(module), - alias_analysis_(alias_analysis), - buffer_value_tracker_(*module, alias_analysis, ordering) {} - - // Try to elide the given copy. The copy is elided if the instruction is not - // necessary to prevent live-range interference of HLO values. Returns true if - // copy was elided. + // The values held in a single HLO buffer are represented using a linked + // list. An element type in this list is ValueNode. // - // The copy instruction is not actually removed here. Instead it is left for - // dead in the graph. Later calls to DCE will remove the instruction. - StatusOr TryElideCopy(HloInstruction* copy) { - if (buffer_value_tracker_.TryElideCopy(copy)) { - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); - TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); - return true; - } - return false; - } - - string ToString() const { - string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); - StrAppend(&out, " Buffer values, in dependency order:\n"); - for (const HloBuffer& buffer : alias_analysis_.buffers()) { - StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); - } - return out; - } - - private: - // Class which tracks the HLO values within each HLO buffer in the module - // during copy removal. - // - // The values are held in a linked list where there is one list for each - // buffer. Removing a copy instruction merges together the values in the - // source buffer of the copy to the destination buffer of the copy. This class - // tracks these value lists as copies are removed from the graph (and value - // lists are merged). - // - // The BufferValueTracker object is initialized to match the state of - // HloAliasAnalysis. However, as copies are removed this state diverges. The - // values-to-buffer mapping is maintained outside of HloAliasAnalysis because - // a fully updatable alias analysis is very slow. - class BufferValueTracker { - public: - // The values held in a single HLO buffer are represented using a linked - // list. An element type in this list is ValueNode. - // - // This linked list is hand-rolled to enable efficient splicing of lists - // using only references to list elements without knowing which lists are - // being spliced. std::list requires a reference to the list object to - // splice. - struct ValueNode { - explicit ValueNode(const HloValue* v) : value(v) {} - - const HloValue* value; - - // The uses are maintained outside of HloValue::uses() because - // HloValue::uses() is not updatable (a fully updatable dataflow analysis - // is slow). - std::vector uses; - - // next/prev elements in the linked list. The list is circularly linked so - // these values are never null for elements in the list. - ValueNode* prev = nullptr; - ValueNode* next = nullptr; - }; + // This linked list is hand-rolled to enable efficient splicing of lists + // using only references to list elements without knowing which lists are + // being spliced. std::list requires a reference to the list object to + // splice. + struct ValueNode { + explicit ValueNode(const HloValue* v) : value(v) {} + + const HloValue* value; + + // The uses are maintained outside of HloValue::uses() because + // HloValue::uses() is not updatable (a fully updatable dataflow analysis + // is slow). + std::vector uses; + + // next/prev elements in the linked list. The list is circularly linked so + // these values are never null for elements in the list. + ValueNode* prev = nullptr; + ValueNode* next = nullptr; + }; - BufferValueTracker(const HloModule& module, - const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering) - : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { - // Construct a list for each HLO buffer in the alias analysis. Maintain a - // map from HloValue to the respective list element representing that - // value. The map is used to construct the copy info map below. - absl::flat_hash_map value_to_node; - for (const HloBuffer& buffer : alias_analysis.buffers()) { - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - if (value_a->shape().IsToken()) { - // Token values have no representation and cannot interfere. - continue; - } - for (const HloValue* value_b : buffer.values()) { - if (value_a != value_b) { - DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, - dataflow_) || - ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, - dataflow_)) - << value_a->ToShortString() << " and " - << value_b->ToShortString() << " are not ordered"; - } + CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering) + : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { + // Construct a list for each HLO buffer in the alias analysis. Maintain a + // map from HloValue to the respective list element representing that + // value. The map is used to construct the copy info map below. + absl::flat_hash_map value_to_node; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + // Verify values contained in the buffer are strictly ordered. This + // should always be the case after adding copies to eliminate + // interference. Specifically, the addition of the control flow edges + // between copies added around aliased operations (kWhile) guarantees + // this strict order. + for (const HloValue* value_a : buffer.values()) { + if (value_a->shape().IsToken()) { + // Token values have no representation and cannot interfere. + continue; + } + for (const HloValue* value_b : buffer.values()) { + if (value_a != value_b) { + DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, + dataflow_) || + ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, + dataflow_)) + << value_a->ToShortString() << " and " + << value_b->ToShortString() << " are not ordered"; } } - - std::vector values = buffer.values(); - absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); - - // Create a list containing all of the values in the buffer. - AddValueList(values, &value_to_node); } - // Create copy_map_ which contains the source and destination values - // of all copies. - CreateCopyMap(module, value_to_node); + std::vector values = buffer.values(); + absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); - XLA_VLOG_LINES(3, ToString()); - TF_DCHECK_OK(Verify()); + // Create a list containing all of the values in the buffer. + AddValueList(values, &value_to_node); } - // Add a list containing the given values to BufferValueTracker. This - // represents the values contained in a single buffer. For each value in - // 'values' an entry is created in value_to_node which indicates the - // respective ValueNode representing that value. - void AddValueList( - absl::Span values, - absl::flat_hash_map* value_to_node) { - ValueNode* tail = nullptr; - ValueNode* head = nullptr; - for (const HloValue* value : values) { - auto new_node = new ValueNode(value); - (*value_to_node)[value] = new_node; - - // Copy the HLO values's uses into the ValueNode for the value. These - // uses in ValueNode are updated as copies are removed. - new_node->uses.reserve(value->uses().size()); - for (const HloUse& use : value->uses()) { - new_node->uses.push_back(&use); - } + // Create copy_map_ which contains the source and destination values + // of all copies. + CreateCopyMap(module, value_to_node); - // Connect the new node into the linked list. - if (tail == nullptr) { - head = new_node; - } else { - tail->next = new_node; - new_node->prev = tail; - } - tail = new_node; + XLA_VLOG_LINES(3, ToString()); + TF_DCHECK_OK(Verify()); + } + + // Add a list containing the given values to CopyRemover. This + // represents the values contained in a single buffer. For each value in + // 'values' an entry is created in value_to_node which indicates the + // respective ValueNode representing that value. + void AddValueList( + absl::Span values, + absl::flat_hash_map* value_to_node) { + ValueNode* tail = nullptr; + ValueNode* head = nullptr; + for (const HloValue* value : values) { + auto new_node = new ValueNode(value); + (*value_to_node)[value] = new_node; + + // Copy the HLO values's uses into the ValueNode for the value. These + // uses in ValueNode are updated as copies are removed. + new_node->uses.reserve(value->uses().size()); + for (const HloUse& use : value->uses()) { + new_node->uses.push_back(&use); } - // The linked list is circular so connect the head and tail. - tail->next = head; - head->prev = tail; - value_lists_.insert(head); + // Connect the new node into the linked list. + if (tail == nullptr) { + head = new_node; + } else { + tail->next = new_node; + new_node->prev = tail; + } + tail = new_node; } - // This method also fills in copy_map_ which indicates which nodes - // in the value lists corresponding to the source and destination values of - // kCopy instructions. value_to_node should map each HloValue to its - // respective ValueNode. - void CreateCopyMap( - const HloModule& module, - const absl::flat_hash_map& value_to_node) { - for (HloComputation* computation : module.computations()) { - for (HloInstruction* instruction : computation->instructions()) { - // Add copies with unambiguous source values to the map. Copies with - // ambiguous sources are not removable. - if (instruction->opcode() == HloOpcode::kCopy) { - const HloValueSet& src_value_set = - dataflow_.GetValueSet(instruction->operand(0)); - if (src_value_set.values().size() == 1) { - CopyNodes& copy_node = copy_map_[instruction]; - copy_node.dest = - value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); - copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); - } + // The linked list is circular so connect the head and tail. + tail->next = head; + head->prev = tail; + value_lists_.insert(head); + } + + // This method also fills in copy_map_ which indicates which nodes + // in the value lists corresponding to the source and destination values of + // kCopy instructions. value_to_node should map each HloValue to its + // respective ValueNode. + void CreateCopyMap( + const HloModule& module, + const absl::flat_hash_map& value_to_node) { + for (HloComputation* computation : module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + // Add copies with unambiguous source values to the map. Copies with + // ambiguous sources are not removable. + if (instruction->opcode() == HloOpcode::kCopy) { + const HloValueSet& src_value_set = + dataflow_.GetValueSet(instruction->operand(0)); + if (src_value_set.values().size() == 1) { + CopyNodes& copy_node = copy_map_[instruction]; + copy_node.dest = + value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); + copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); } } } } + } - ~BufferValueTracker() { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - const ValueNode* tmp = p->next; - delete p; - p = tmp; - } while (p != head); - } + ~CopyRemover() { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + const ValueNode* tmp = p->next; + delete p; + p = tmp; + } while (p != head); } + } - // Verify invariants within the linked lists. - Status Verify() const { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - // Verify links between elements are consistent. - TF_RET_CHECK(p->prev->next == p); - TF_RET_CHECK(p->next->prev == p); - - const HloInstruction* def = p->value->defining_instruction(); - if (def->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, def)) { - TF_RET_CHECK(copy_map_.at(def).dest == p); - } - for (const HloUse* use : p->uses) { - if (use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, use->instruction)) { - TF_RET_CHECK(copy_map_.at(use->instruction).src == p); - } + // Verify invariants within the linked lists. + Status Verify() const { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + // Verify links between elements are consistent. + TF_RET_CHECK(p->prev->next == p); + TF_RET_CHECK(p->next->prev == p); + + const HloInstruction* def = p->value->defining_instruction(); + if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) { + TF_RET_CHECK(copy_map_.at(def).dest == p); + } + for (const HloUse* use : p->uses) { + if (use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, use->instruction)) { + TF_RET_CHECK(copy_map_.at(use->instruction).src == p); } + } - p = p->next; - } while (p != head); - } - return Status::OK(); + p = p->next; + } while (p != head); } + return Status::OK(); + } - // Try to elide the given copy. Elision of a copy is possible only if no - // live range interference is introduced by the copy's elimination. If - // elision is possible, then the internal state (value lists) are updated, - // and true is returned. Returns false otherwise. - bool TryElideCopy(const HloInstruction* copy) { - VLOG(2) << "Trying to remove " << copy->name(); + // Try to elide the given copy. Elision of a copy is possible only if no + // live range interference is introduced by the copy's elimination. If + // elision is possible, then the internal state (value lists) are updated, + // and true is returned. Returns false otherwise. + bool TryElideCopy(const HloInstruction* copy) { + VLOG(2) << "Trying to remove " << copy->name(); - if (!ContainsKey(copy_map_, copy)) { - VLOG(2) << copy->name() << " is not removable"; - return false; - } - if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { - VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + if (!ContainsKey(copy_map_, copy)) { + VLOG(2) << copy->name() << " is not removable"; + return false; + } + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } + const CopyNodes& copy_node = copy_map_.at(copy); + ValueNode* src = copy_node.src; + ValueNode* dest = copy_node.dest; + DCHECK(src != nullptr); + DCHECK(dest != nullptr); + + auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; + if (LiveRangeBefore(a, b)) { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is before " << b.value->ToShortString(); + return true; + } else { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is not before " << b.value->ToShortString(); return false; } - const CopyNodes& copy_node = copy_map_.at(copy); - ValueNode* src = copy_node.src; - ValueNode* dest = copy_node.dest; - DCHECK(src != nullptr); - DCHECK(dest != nullptr); - - auto is_live_range_before = [this](const ValueNode& a, - const ValueNode& b) { - VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; - if (LiveRangeBefore(a, b)) { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is before " << b.value->ToShortString(); - return true; - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); - return false; - } - }; + }; - VLOG(3) << copy->name() << " copies value " - << src->value->ToShortString(); - VLOG(3) << "Source buffer values: " << ValueListToString(src); - VLOG(3) << "Dest buffer values: " << ValueListToString(dest); + VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); + VLOG(3) << "Source buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(dest); - // A kCopy instruction copies an HLO value from a source buffer and - // defines an HLO value in a destination buffer. Most generally, the - // source and destination buffers may each hold more than one value at - // different points in the computation so we define the following: - // - // Values in source buffer: {s_0, ..., s_n} - // Values in destination buffer: {d_0, ..., d_m} - // - // A kCopy instruction between these buffers copies a value s_x in the - // source buffer and defines a value d_y in the destination buffer. The - // elision of a copy merges the source and destination buffers together, - // so the list of values for the source and destination buffers are - // merged. - // - // We handle two different cases for copy elision: - // - // (1) the kCopy defines the first value in the destination buffer (d_0). + // A kCopy instruction copies an HLO value from a source buffer and + // defines an HLO value in a destination buffer. Most generally, the + // source and destination buffers may each hold more than one value at + // different points in the computation so we define the following: + // + // Values in source buffer: {s_0, ..., s_n} + // Values in destination buffer: {d_0, ..., d_m} + // + // A kCopy instruction between these buffers copies a value s_x in the + // source buffer and defines a value d_y in the destination buffer. The + // elision of a copy merges the source and destination buffers together, + // so the list of values for the source and destination buffers are + // merged. + // + // We handle two different cases for copy elision: + // + // (1) the kCopy defines the first value in the destination buffer (d_0). + // + // (2) the kCopy copies the last value in the source buffer (s_n). + // + // For the remaining case where the kCopy copies a not-last value from the + // source buffer to a not-first value of the destination buffer, the kCopy + // instruction cannot be removed. This case is generated, for example, if + // the kCopy copies a while body parameter of the loop state at one tuple + // index to a different tuple index in the while body root. Removal of the + // copy necessarily results in live range interference of values in the + // loop state at the two different tuple indices. + // + // We can only perform copy elision if the resulting merged values have + // totally ordered live ranges; otherwise the merged buffer would have + // live range interference. + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { + // The copy copies an arbitrary value in the source buffer (call it s_x) + // and defines d_0, the first value in the destination buffer. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: // - // (2) the kCopy copies the last value in the source buffer (s_n). + // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} // - // For the remaining case where the kCopy copies a not-last value from the - // source buffer to a not-first value of the destination buffer, the kCopy - // instruction cannot be removed. This case is generated, for example, if - // the kCopy copies a while body parameter of the loop state at one tuple - // index to a different tuple index in the while body root. Removal of the - // copy necessarily results in live range interference of values in the - // loop state at the two different tuple indices. + // Removing the copy eliminates d_0, and uses of d_0 become uses of + // s_x. In the above ordering, the live range of d_m must be ordered + // before the live range of s_{x+1} and the definition and all uses of + // s_x must be ordered before the definition of d_1. These conditions + // are checked below prior to elision. // - // We can only perform copy elision if the resulting merged values have - // totally ordered live ranges; otherwise the merged buffer would have - // live range interference. - if (src->next == dest) { - // In the process of eliding copies, its possible for a copy to have the - // same source and destination buffer. In this case, the copy can be - // safely removed. - VLOG(2) << copy->name() << " source and destination buffers are same."; - } else if (IsHead(*dest)) { - // The copy copies an arbitrary value in the source buffer (call it s_x) - // and defines d_0, the first value in the destination buffer. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} - // - // Removing the copy eliminates d_0, and uses of d_0 become uses of - // s_x. In the above ordering, the live range of d_m must be ordered - // before the live range of s_{x+1} and the definition and all uses of - // s_x must be ordered before the definition of d_1. These conditions - // are checked below prior to elision. - // - // ** Technically it might be possible to have a non-interfering - // non-trivial interleaving of the values of the source and - // destination buffers in the resulting order. However, this case is - // slow and complicated to check and likely not worth it. So instead - // we simply check for the case where *all* values of the destination - // buffer (d_1 through d_m) are spliced into the point where the copy - // used to be. - VLOG(2) << copy->name() << " defines the first value in its buffer"; - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); - if (!is_live_range_before(*src, *next_dest)) { - return false; - } - } - ValueNode* next_src = Next(*src); - - if (next_src != nullptr) { - // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. - ValueNode* last_dest = dest->prev; - DCHECK(IsTail(*last_dest)); - if (!is_live_range_before(*last_dest, *next_src)) { - return false; - } - } - - // Splice in destination buffer values list right after 'src'. - SpliceAfter(dest, src); - } else if (IsTail(*src)) { - // The copy copies the last value in the source buffer, s_n, and defines - // an arbitrary value in the destination buffer, d_y. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} - // - // Removing the copy eliminates d_y, and uses of d_y become uses of - // s_n. To enforce the above order, the live range of d_{y-1} must be - // before the live range of s_0, and the live range of s_n must be - // before the live range of d_{y+1}. - // - // ** See comment above in the code handling Case (1). - VLOG(2) << copy->name() << " copies the last value (" - << src->value->ToShortString() << ") in its buffer"; - - ValueNode* prev_dest = Prev(*dest); - // nullptr condition handled above in the first 'if' case. - DCHECK(prev_dest != nullptr); - ValueNode* first_src = src->next; - DCHECK(IsHead(*first_src)); - if (!is_live_range_before(*prev_dest, *first_src)) { - // Live range of value d_{y-1} is not before s_0. + // ** Technically it might be possible to have a non-interfering + // non-trivial interleaving of the values of the source and + // destination buffers in the resulting order. However, this case is + // slow and complicated to check and likely not worth it. So instead + // we simply check for the case where *all* values of the destination + // buffer (d_1 through d_m) are spliced into the point where the copy + // used to be. + VLOG(2) << copy->name() << " defines the first value in its buffer"; + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); + if (!is_live_range_before(*src, *next_dest)) { return false; } - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - if (!is_live_range_before(*src, *next_dest)) { - // Live range of value s_n is not before d_{y+1}. - return false; - } - } - - // Splice source buffer values list right after 'prev_dest'. - SpliceAfter(first_src, prev_dest); - } else { - VLOG(2) - << copy->name() - << " copies value in middle of source buffer to value in middle " - "of destination buffer"; - return false; } + ValueNode* next_src = Next(*src); - RemoveCopyValue(dest); - - XLA_VLOG_LINES(4, ToString()); - TF_DCHECK_OK(Verify()); - - return true; - } - - // Delete the given ValueNode associated with a elided kCopy - // instruction. This should be called after splicing the value lists of the - // source and destination buffers together. - void RemoveCopyValue(ValueNode* copy_value_node) { - CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), - HloOpcode::kCopy); - ValueNode* operand_node = copy_value_node->prev; - CHECK(operand_node != copy_value_node); - - VLOG(2) << "Removing copy " << operand_node->value->ToShortString() - << " => " << copy_value_node->value->ToShortString(); - - // Splice out the copy value node. - operand_node->next = copy_value_node->next; - copy_value_node->next->prev = operand_node; - - // Patch up uses. Remove use of copy from operand_node uses. - auto it = absl::c_find_if( - operand_node->uses, [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); - CHECK(it != operand_node->uses.end()); - operand_node->uses.erase(it); - - // If the elided copy has any uses which are themselves kCopy instructions - // then patch up the copy info to reflect the that this kCopy instruction - // has a different operand (the operand of the elided copy). - for (const HloUse* copy_use : copy_value_node->uses) { - operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, copy_use->instruction)) { - copy_map_.at(copy_use->instruction).src = operand_node; + if (next_src != nullptr) { + // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. + ValueNode* last_dest = dest->prev; + DCHECK(IsTail(*last_dest)); + if (!is_live_range_before(*last_dest, *next_src)) { + return false; } } - // Delete the copy info and the value node. - copy_map_.erase(copy_value_node->value->defining_instruction()); - delete copy_value_node; - } - - // Returns true if the live range of given value 'a' is before the live - // range of 'b'. - // - // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not - // updated as copies are removed. - bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - if (a.uses.empty()) { - VLOG(2) << "Empty uses for " << *a.value; - return ordering_.IsDefinedBefore(*a.value, *b.value); + // Splice in destination buffer values list right after 'src'. + SpliceAfter(dest, src); + } else if (IsTail(*src)) { + // The copy copies the last value in the source buffer, s_n, and defines + // an arbitrary value in the destination buffer, d_y. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} + // + // Removing the copy eliminates d_y, and uses of d_y become uses of + // s_n. To enforce the above order, the live range of d_{y-1} must be + // before the live range of s_0, and the live range of s_n must be + // before the live range of d_{y+1}. + // + // ** See comment above in the code handling Case (1). + VLOG(2) << copy->name() << " copies the last value (" + << src->value->ToShortString() << ") in its buffer"; + + ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + DCHECK(prev_dest != nullptr); + ValueNode* first_src = src->next; + DCHECK(IsHead(*first_src)); + if (!is_live_range_before(*prev_dest, *first_src)) { + // Live range of value d_{y-1} is not before s_0. + return false; } - for (const HloUse* use : a.uses) { - VLOG(2) << "Checking use " << *use << " against " << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Use " << *use << " is NOT before " << *b.value; + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + if (!is_live_range_before(*src, *next_dest)) { + // Live range of value s_n is not before d_{y+1}. return false; } - VLOG(2) << "Use " << *use << " is before " << *b.value; } - return true; - } - // Returns whether 'node' is the last node in its list. - bool IsTail(const ValueNode& node) const { - return ContainsKey(value_lists_, node.next); + // Splice source buffer values list right after 'prev_dest'. + SpliceAfter(first_src, prev_dest); + } else { + VLOG(2) << copy->name() + << " copies value in middle of source buffer to value in middle " + "of destination buffer"; + return false; } - // Returns whether 'node' is the first node in its list. - bool IsHead(const ValueNode& node) const { - return ContainsKey(value_lists_, &node); - } + RemoveCopyValue(dest); - // Returns the next node in the list after 'node'. If 'node' is the - // tail, then nullptr is returned. - ValueNode* Next(const ValueNode& node) const { - if (IsTail(node)) { - return nullptr; - } else { - return node.next; + XLA_VLOG_LINES(4, ToString()); + TF_DCHECK_OK(Verify()); + + return true; + } + + // Delete the given ValueNode associated with a elided kCopy + // instruction. This should be called after splicing the value lists of the + // source and destination buffers together. + void RemoveCopyValue(ValueNode* copy_value_node) { + CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), + HloOpcode::kCopy); + ValueNode* operand_node = copy_value_node->prev; + CHECK(operand_node != copy_value_node); + + VLOG(2) << "Removing copy " << operand_node->value->ToShortString() + << " => " << copy_value_node->value->ToShortString(); + + // Splice out the copy value node. + operand_node->next = copy_value_node->next; + copy_value_node->next->prev = operand_node; + + // Patch up uses. Remove use of copy from operand_node uses. + auto it = absl::c_find_if(operand_node->uses, [copy_value_node]( + const HloUse* use) { + return use->instruction == copy_value_node->value->defining_instruction(); + }); + CHECK(it != operand_node->uses.end()); + operand_node->uses.erase(it); + + // If the elided copy has any uses which are themselves kCopy instructions + // then patch up the copy info to reflect the that this kCopy instruction + // has a different operand (the operand of the elided copy). + for (const HloUse* copy_use : copy_value_node->uses) { + operand_node->uses.push_back(copy_use); + if (copy_use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, copy_use->instruction)) { + copy_map_.at(copy_use->instruction).src = operand_node; } } - // Returns the previous node in the list before 'node'. If 'node' - // is the head, then nullptr is returned. - ValueNode* Prev(const ValueNode& node) const { - if (IsHead(node)) { - return nullptr; - } else { - return node.prev; + // Delete the copy info and the value node. + copy_map_.erase(copy_value_node->value->defining_instruction()); + delete copy_value_node; + } + + // Returns true if the live range of given value 'a' is before the live + // range of 'b'. + // + // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not + // updated as copies are removed. + bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { + if (a.uses.empty()) { + VLOG(2) << "Empty uses for " << *a.value; + return ordering_.IsDefinedBefore(*a.value, *b.value); + } + for (const HloUse* use : a.uses) { + VLOG(2) << "Checking use " << *use << " against " << *b.value; + if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; + return false; } + VLOG(2) << "Use " << *use << " is before " << *b.value; } + return true; + } - // Splices the entire linked list with 'head' as its head right after the - // node 'insert_after' in another linked list. - void SpliceAfter(ValueNode* head, ValueNode* insert_after) { - DCHECK(IsHead(*head)); - value_lists_.erase(head); + // Returns whether 'node' is the last node in its list. + bool IsTail(const ValueNode& node) const { + return ContainsKey(value_lists_, node.next); + } - ValueNode* tail = head->prev; - tail->next = insert_after->next; - insert_after->next->prev = tail; + // Returns whether 'node' is the first node in its list. + bool IsHead(const ValueNode& node) const { + return ContainsKey(value_lists_, &node); + } - insert_after->next = head; - head->prev = insert_after; + // Returns the next node in the list after 'node'. If 'node' is the + // tail, then nullptr is returned. + ValueNode* Next(const ValueNode& node) const { + if (IsTail(node)) { + return nullptr; + } else { + return node.next; } + } - string ValueListToString(const ValueNode* element) { - const ValueNode* head = element; - while (!IsHead(*head)) { - head = Prev(*head); - } - std::vector values; - for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { - values.push_back(p->value); - } - return absl::StrCat("{", - absl::StrJoin(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); + // Returns the previous node in the list before 'node'. If 'node' + // is the head, then nullptr is returned. + ValueNode* Prev(const ValueNode& node) const { + if (IsHead(node)) { + return nullptr; + } else { + return node.prev; } + } - string ToString() const { - string out = absl::StrCat("BufferValueTracker:\n"); - StrAppend(&out, " Def-use chains in each buffer:\n"); - for (const ValueNode* head : value_lists_) { - StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), - ":\n"); - const ValueNode* p = head; - do { - StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - absl::StrJoin(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), - "\n"); - - p = p->next; - } while (p != head); - } - StrAppend(&out, " Potentially removable copies:\n"); - for (const auto& pair : copy_map_) { - const HloInstruction* copy = pair.first; - const CopyNodes& copy_info = pair.second; - - StrAppend(&out, " ", copy->name(), " : ", - copy_info.src->value->ToShortString(), " => ", - copy_info.dest->value->ToShortString(), "\n"); - } - return out; - } + // Splices the entire linked list with 'head' as its head right after the + // node 'insert_after' in another linked list. + void SpliceAfter(ValueNode* head, ValueNode* insert_after) { + DCHECK(IsHead(*head)); + value_lists_.erase(head); - private: - const HloDataflowAnalysis& dataflow_; - const HloOrdering& ordering_; - - // The heads of all the value lists. Each value list represents the HLO - // values contained in a particular HLO buffer. The values in the list are - // in dependency order. - absl::flat_hash_set value_lists_; - - // Copy removal requires fast access to the value list elements - // corresponding to the source and destination values of the kCopy - // instruction. This data structure holds pointers to these elements for - // each kCopy instruction in the graph. - struct CopyNodes { - // The source and destinations values of the kCopy instruction. - ValueNode* src = nullptr; - ValueNode* dest = nullptr; - }; - absl::flat_hash_map copy_map_; - }; + ValueNode* tail = head->prev; + tail->next = insert_after->next; + insert_after->next->prev = tail; - HloModule* module_; - const HloAliasAnalysis& alias_analysis_; + insert_after->next = head; + head->prev = insert_after; + } - // Object tracking the HLO values contained in each HLO buffer. - BufferValueTracker buffer_value_tracker_; -}; + string ValueListToString(const ValueNode* element) { + const ValueNode* head = element; + while (!IsHead(*head)) { + head = Prev(*head); + } + std::vector values; + for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { + values.push_back(p->value); + } + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); + } -void MaybeDumpModule(const string& message, const HloModule& module) { - if (VLOG_IS_ON(3)) { - VLOG(3) << message; - XLA_VLOG_LINES(3, module.ToString()); - hlo_graph_dumper::MaybeDumpHloModule(module, message); + string ToString() const { + string out = absl::StrCat("CopyRemover:\n"); + StrAppend(&out, " Def-use chains in each buffer:\n"); + for (const ValueNode* head : value_lists_) { + StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), + ":\n"); + const ValueNode* p = head; + do { + StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), + "\n"); + + p = p->next; + } while (p != head); + } + StrAppend(&out, " Potentially removable copies:\n"); + for (const auto& pair : copy_map_) { + const HloInstruction* copy = pair.first; + const CopyNodes& copy_info = pair.second; + + StrAppend(&out, " ", copy->name(), " : ", + copy_info.src->value->ToShortString(), " => ", + copy_info.dest->value->ToShortString(), "\n"); + } + return out; } -} + + private: + const HloDataflowAnalysis& dataflow_; + const HloOrdering& ordering_; + + // The heads of all the value lists. Each value list represents the HLO + // values contained in a particular HLO buffer. The values in the list are + // in dependency order. + absl::flat_hash_set value_lists_; + + // Copy removal requires fast access to the value list elements + // corresponding to the source and destination values of the kCopy + // instruction. This data structure holds pointers to these elements for + // each kCopy instruction in the graph. + struct CopyNodes { + // The source and destinations values of the kCopy instruction. + ValueNode* src = nullptr; + ValueNode* dest = nullptr; + }; + absl::flat_hash_map copy_map_; +}; } // namespace @@ -1156,23 +1098,29 @@ Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module) { - MaybeDumpModule("after adding copies to resolve interference", *module); - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); + + CopyRemover copy_remover(*module, *alias_analysis, ordering); + if (VLOG_IS_ON(3)) { + LOG(INFO) << "Removing unnecessary copies in " << module->name(); + LOG(INFO) << "Buffer values, in dependency order: "; + for (const HloBuffer& buffer : alias_analysis->buffers()) { + LOG(INFO) << " HloBuffer " << buffer.id(); + } + } std::unique_ptr call_graph = CallGraph::Build(module); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + if (instruction->opcode() == HloOpcode::kCopy && + copy_remover.TryElideCopy(instruction)) { + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); } } } - MaybeDumpModule("after removing unnecessary copies", *module); - return Status::OK(); } @@ -1201,8 +1149,6 @@ StatusOr CopyInsertion::Run(HloModule* module) { // interference. If all copies were added in step (1) then copy removal would // also have to reason about things like constants and parameters live out of // the computation. - MaybeDumpModule("before copy insertion", *module); - std::unique_ptr call_graph = CallGraph::Build(module); if (!call_graph->IsFlattened()) { return FailedPrecondition( @@ -1231,23 +1177,25 @@ StatusOr CopyInsertion::Run(HloModule* module) { HloDCE dce; TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); + DumpHloModuleDuringPassIfEnabled( + name(), "after adding copies to resolve interference", *module); DependencyHloOrdering dep_ordering(module); TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); + DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies", + *module); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); - - MaybeDumpModule("after adding special-case copies", *module); + DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies", + *module); TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); TF_DCHECK_OK( VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); - MaybeDumpModule("after copy insertion", *module); - if (VLOG_IS_ON(1)) { int64 num_total_copies = 0; for (HloComputation* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4d4074943e3bf9f6f2a37abc63f037c2dab06e0f..6fa3161e57814ffa2c0030f607a05bda2a4e121f 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -420,9 +420,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, - induction_variable, limit_const)); + builder.AddInstruction(HloInstruction::CreateCompare( + condition_result_shape_, induction_variable, limit_const, + ComparisonDirection::kLt)); return builder.Build(); } @@ -1842,7 +1842,7 @@ HloModule TokensShouldNotBeCopied %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %TokensShouldNotBeCopied () -> s32[] { @@ -1855,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - HloRunner::CreateModuleFromString( - module_string, GetDebugOptionsForTest())); + ParseAndReturnVerifiedModule(module_string)); InsertCopies(module.get()); // There should be no copies added because tokens should not be copied. @@ -2061,7 +2060,7 @@ if-condition.v4 { p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 constant.4 = s32[] constant(0) - ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) + ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ } _functionalize_body_1__.v28 { @@ -2071,7 +2070,7 @@ _functionalize_body_1__.v28 { add.4 = s32[] add(get-tuple-element.68, constant.7) get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 - less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) + less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE constant.8 = s32[] constant(0) select = s32[] select(less-than-or-equal-to, constant.8, constant.7) get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 @@ -2088,7 +2087,7 @@ cond_wrapper.v3.1 { inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 constant.11 = s32[] constant(7) - ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) + ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT } _functionalize_body_2__.v25 { @@ -2111,7 +2110,7 @@ cond_wrapper.v3.2 { inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 constant.13 = s32[] constant(5) - ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) + ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT } ENTRY TestComputation { @@ -2119,8 +2118,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2144,7 +2142,7 @@ if-condition.v4 { p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 constant.4 = s32[] constant(0) - ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) + ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ } if-body.v5.1 { @@ -2161,7 +2159,7 @@ if-condition.v4.1 { p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 constant.6 = s32[] constant(1) - ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) + ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ } _functionalize_body_1__.v28 { @@ -2171,7 +2169,7 @@ _functionalize_body_1__.v28 { add.4 = s32[] add(get-tuple-element.72, constant.7) get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 - less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) + less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE constant.8 = s32[] constant(0) select = s32[] select(less-than-or-equal-to, constant.8, constant.7) get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 @@ -2189,7 +2187,7 @@ cond_wrapper.v3.1 { inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 constant.11 = s32[] constant(7) - ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) + ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT } _functionalize_body_2__.v25 { @@ -2212,7 +2210,7 @@ cond_wrapper.v3.2 { inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 constant.13 = s32[] constant(5) - ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) + ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT } ENTRY TestComputation { @@ -2220,8 +2218,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2238,7 +2235,7 @@ cond.inner { body.inner { param.body.inner = pred[] parameter(0) - ROOT neg = pred[] negate(param.body.inner) + ROOT not = pred[] not(param.body.inner) } cond.outer { @@ -2255,9 +2252,8 @@ ENTRY TestComputation { ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); InsertCopies(module.get()); // There should only be a single copy inserted, and it's in the entry diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d4535b204d7f3ad8d4e24beea5d0dd79e7a15ab0..66ceb57227cf20139aa355ffd0caea48fa1ddc50 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -95,6 +95,7 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", @@ -111,6 +112,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:cholesky_expander", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", @@ -136,6 +138,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", @@ -336,15 +339,15 @@ cc_library( srcs = ["ir_function.cc"], hdrs = ["ir_function.h"], deps = [ + ":cpu_runtime", ":ir_emission_utils", ":shape_partition", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -375,6 +378,7 @@ cc_library( ":vector_support_library", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -746,6 +750,7 @@ cc_library( ":ir_emission_utils", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl index e78330b21689fdd818cd97128bbcaaa9e0118602..ffa1cd4ec8e26e7dbe92e7b99cf65e99db5400b9 100644 --- a/tensorflow/compiler/xla/service/cpu/build_defs.bzl +++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl @@ -1,12 +1,11 @@ """build_defs for service/cpu.""" - def runtime_copts(): - """Returns copts used for CPU runtime libraries.""" - return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ - "//tensorflow:android_arm": ["-mfpu=neon"], - "//conditions:default": [] - }) + select({ - "//tensorflow:android": ["-O2"], - "//conditions:default": [] - })) + """Returns copts used for CPU runtime libraries.""" + return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ + "//tensorflow:android_arm": ["-mfpu=neon"], + "//conditions:default": [], + }) + select({ + "//tensorflow:android": ["-O2"], + "//conditions:default": [], + })) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 796a7cf94d02b0ad42366387a9d3f8d589b8840a..39926182894590e469c3cbb8d8620f3553fce873 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -66,9 +66,14 @@ class FilteredPassManager : public llvm::legacy::PassManager { explicit FilteredPassManager(bool disable_expensive_passes) : disable_expensive_passes_(disable_expensive_passes) {} void add(llvm::Pass* p) override { + llvm::StringRef PassName = p->getPassName(); + if (PassName.contains("Warn about non-applied transformations")) { + delete p; + return; + } if (disable_expensive_passes_) { - llvm::StringRef PassName = p->getPassName(); if (PassName.contains("Unroll loops")) { + delete p; return; } } @@ -89,7 +94,7 @@ std::unique_ptr CompilerFunctor::operator()( XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); if (pre_optimization_hook_) { - TF_CHECK_OK(pre_optimization_hook_(module)); + pre_optimization_hook_(module); } // Add the appropriate TargetLibraryInfo and TargetTransformInfo. @@ -133,7 +138,7 @@ std::unique_ptr CompilerFunctor::operator()( XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); if (post_optimization_hook_) { - TF_CHECK_OK(post_optimization_hook_(module)); + post_optimization_hook_(module); } // Generate code. @@ -145,17 +150,11 @@ std::unique_ptr CompilerFunctor::operator()( std::unique_ptr memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); - if (VLOG_IS_ON(2)) { + if (post_codegen_hook_) { llvm::Expected> obj_file = llvm::object::ObjectFile::createObjectFile(*memory_buffer); if (obj_file) { - StatusOr disasm_result = - disassembler_->DisassembleObjectFile(*obj_file.get()); - if (disasm_result.ok()) { - XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); - } else { - LOG(WARNING) << "Could not disassemble object file!"; - } + post_codegen_hook_(*obj_file.get()); } else { LOG(WARNING) << "Could convert memory buffer to object file!"; } diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index c38b896c5019b48fd2a16a51abd59e12ebdb29eb..edcd47e9e89748a1df814619f10870afb531bc7e 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Target/TargetMachine.h" -#include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/core/platform/logging.h" @@ -32,19 +31,21 @@ namespace cpu { class CompilerFunctor { public: explicit CompilerFunctor( - llvm::TargetMachine* target_machine, const Disassembler* disassembler, - int opt_level, bool optimize_for_size, bool enable_fast_math, + llvm::TargetMachine* target_machine, int opt_level, + bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, - LLVMCompiler::ModuleHook post_optimization_hook = nullptr) + LLVMCompiler::ModuleHook post_optimization_hook = nullptr, + std::function post_codegen_hook = + nullptr) : target_machine_(target_machine), - disassembler_(CHECK_NOTNULL(disassembler)), opt_level_(opt_level), optimize_for_size_(optimize_for_size), enable_fast_math_(enable_fast_math), disable_expensive_passes_(disable_expensive_passes), - pre_optimization_hook_(pre_optimization_hook), - post_optimization_hook_(post_optimization_hook) {} + pre_optimization_hook_(std::move(pre_optimization_hook)), + post_optimization_hook_(std::move(post_optimization_hook)), + post_codegen_hook_(std::move(post_codegen_hook)) {} // Compile a Module to an ObjectFile. std::unique_ptr operator()( @@ -61,13 +62,13 @@ class CompilerFunctor { unsigned opt_level, unsigned size_level) const; llvm::TargetMachine* target_machine_; - const Disassembler* disassembler_; const unsigned opt_level_; const bool optimize_for_size_; const bool enable_fast_math_; const bool disable_expensive_passes_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; + std::function post_codegen_hook_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index eafda68510d93ee54f2aead60a84f3e97b3fe1f4..7de159cf647190c18d4c02a1acbff31419d9e759 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" @@ -69,6 +70,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -95,6 +97,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" @@ -105,6 +108,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/dynamic_annotations.h" namespace xla { namespace cpu { @@ -255,6 +259,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); @@ -312,6 +319,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( }, TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); + + pipeline.AddPass( + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, target_machine_features); + pipeline.AddPass(); pipeline.AddPass(); @@ -319,10 +331,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); - - pipeline.AddPass( - module->mutable_entry_computation_layout(), - LayoutAssignment::InstructionCanChangeLayout, target_machine_features); return pipeline.Run(module).status(); } @@ -342,13 +350,10 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - // TODO(b/117156505): When the bug is fixed, the CPU backend should not - // produce layout changing elementwise operations. We will then pass - // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to - // enable stricter verification. pass.AddInvariantChecker( /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); options.set_enable_dot_strength_reduction(false); @@ -404,10 +409,20 @@ auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; }; llvm::TargetOptions CompilerTargetOptions( const HloModuleConfig& module_config) { llvm::TargetOptions target_options; - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/module_config.debug_options() - .xla_cpu_enable_fast_math(), - &target_options); + // In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc. + if (module_config.debug_options().xla_cpu_enable_fast_math()) { + target_options.UnsafeFPMath = true; + target_options.NoInfsFPMath = + module_config.debug_options().xla_cpu_fast_math_honor_infs(); + target_options.NoNaNsFPMath = + module_config.debug_options().xla_cpu_fast_math_honor_nans(); + target_options.NoSignedZerosFPMath = true; + } else { + target_options.UnsafeFPMath = false; + target_options.NoInfsFPMath = false; + target_options.NoNaNsFPMath = false; + target_options.NoSignedZerosFPMath = false; + } return target_options; } @@ -426,53 +441,32 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { } } -Status InitializeModuleHooks( +std::pair GetIRModuleHooks( const HloModule& hlo_module, const LLVMCompiler::ModuleHook& user_pre_optimization_hook, - const LLVMCompiler::ModuleHook& user_post_optimization_hook, - LLVMCompiler::ModuleHook* pre_optimization_ir_hook, - LLVMCompiler::ModuleHook* post_optimization_ir_hook) { - const string& ir_dump_directory = - hlo_module.config().debug_options().xla_dump_ir_to(); - if (ir_dump_directory.empty()) { - *pre_optimization_ir_hook = user_pre_optimization_hook; - *post_optimization_ir_hook = user_post_optimization_hook; - return Status::OK(); - } - - const string& hlo_module_name = hlo_module.name(); - + const LLVMCompiler::ModuleHook& user_post_optimization_hook) { // Create the IR hooks. If applicable, each IR hook does the following: // // * Calls the user supplied module hook. // * Writes out the IR to a file in the output directory designated by - // --xla_dump_ir_to - - *pre_optimization_ir_hook = - [user_pre_optimization_hook, ir_dump_directory, - hlo_module_name](const llvm::Module& llvm_module) { - if (user_pre_optimization_hook) { - TF_RETURN_IF_ERROR(user_pre_optimization_hook(llvm_module)); - } - return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/hlo_module_name, - llvm_module, - /*optimized=*/false); - }; - - *post_optimization_ir_hook = - [user_post_optimization_hook, ir_dump_directory, - hlo_module_name](const llvm::Module& llvm_module) { - if (user_post_optimization_hook) { - TF_RETURN_IF_ERROR(user_post_optimization_hook(llvm_module)); - } - return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/hlo_module_name, - llvm_module, - /*optimized=*/true); - }; - - return Status::OK(); + // --xla_dump_to + const HloModule* hlo_module_ptr = &hlo_module; + auto hook = [user_pre_optimization_hook, user_post_optimization_hook, + hlo_module_ptr](bool optimized, + const llvm::Module& llvm_module) { + const auto& user_hook = + !optimized ? user_pre_optimization_hook : user_post_optimization_hook; + if (user_hook) { + user_hook(llvm_module); + } + llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized); + }; + return {[hook](const llvm::Module& llvm_module) { + return hook(/*optimized=*/false, llvm_module); + }, + [hook](const llvm::Module& llvm_module) { + return hook(/*optimized=*/true, llvm_module); + }}; } Status VerifyLlvmModule(const llvm::Module& llvm_module) { @@ -486,7 +480,7 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { << "Invalid LLVM IR before optimizations:\n" << err_stream.str() << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_ir_to to get the IR. "; + "Rerun with --xla_dump_to to get the IR. "; return Status::OK(); } @@ -530,9 +524,6 @@ Status CreateHloProfilingArtifacts( StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, DeviceMemoryAllocator* /*device_allocator*/) { - VLOG(2) << "Before optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - std::unique_ptr jit_target_machine = SimpleOrcJIT::InferTargetMachineForJIT( CompilerTargetOptions(module->config()), @@ -540,29 +531,72 @@ StatusOr> CpuCompiler::RunHloPasses( TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, jit_target_machine.get())); - - VLOG(2) << "After optimization:"; - XLA_VLOG_LINES(2, module->ToString()); return std::move(module); } +namespace { + +// Post-compilation callback functor for use by SimpleOrcJIT. +// +// Dumps disassembled machine code if dumping is enabled for the module. +struct OrcJITPostCompilationHook { + // Gets an std::function that implements this hook. + static std::function Create( + const HloModule* module) { + // This struct is not copyable, but std::functions must be. So to create an + // std::function out of this struct, we have to wrap it in a shared_ptr. + auto wrapped = std::make_shared(module); + return [wrapped](const llvm::object::ObjectFile& obj_file) { + (*wrapped)(obj_file); + }; + } + + // Constructor can't be private because we want to call it from + // std::make_shared, but users should call Create() instead. + explicit OrcJITPostCompilationHook(const HloModule* module) + : module(module), + target_machine(SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config()))), + disassembler(*target_machine) {} + + private: + void operator()(const llvm::object::ObjectFile& obj_file) { + if (!DumpingEnabledForHloModule(*module)) { + return; + } + StatusOr disasm_or = + disassembler.DisassembleObjectFile(obj_file); + string text = disasm_or.ok() ? std::move(disasm_or).ValueOrDie().text + : absl::StrCat("Error disassembling: ", + disasm_or.status().ToString()); + DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text); + } + + const HloModule* module; + // disassembler keeps references to data inside of target_machine. + std::unique_ptr target_machine; + Disassembler disassembler; +}; + +} // namespace + StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { - const string timer_message = - "Compiling [" + module->name() + "] for CPU using JIT"; - XLA_SCOPED_LOGGING_TIMER(timer_message); - VLOG(1) << "Compiling: " << module->name(); + XLA_SCOPED_LOGGING_TIMER( + absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); + TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, &llvm_ir::InitializeLLVMCommandLineOptions, module->config()); ModuleHook pre_optimization_ir_hook; ModuleHook post_optimization_ir_hook; - TF_RETURN_IF_ERROR(InitializeModuleHooks( - *module, user_pre_optimization_hook_, user_post_optimization_hook_, - &pre_optimization_ir_hook, &post_optimization_ir_hook)); + std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = + GetIRModuleHooks(*module, user_pre_optimization_hook_, + user_post_optimization_hook_); // Compile must be thread-safe so create a new LLVM context for the module. auto llvm_context = absl::make_unique(); @@ -575,7 +609,8 @@ StatusOr> CpuCompiler::RunBackend( options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - pre_optimization_ir_hook, post_optimization_ir_hook); + pre_optimization_ir_hook, post_optimization_ir_hook, + OrcJITPostCompilationHook::Create(module.get())); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); @@ -596,8 +631,6 @@ StatusOr> CpuCompiler::RunBackend( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string xla_dump_optimized_hlo_proto_to = - module->config().debug_options().xla_dump_optimized_hlo_proto_to(); // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis @@ -616,13 +649,11 @@ StatusOr> CpuCompiler::RunBackend( /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "buffer_assignment", + assignment->ToString()); } + DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations"); // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from @@ -633,7 +664,13 @@ StatusOr> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features); + &target_machine_features, +#ifdef MEMORY_SANITIZER + /*emit_code_for_msan=*/true +#else + /*emit_code_for_msan=*/false +#endif + ); TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); @@ -670,9 +707,8 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); - XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -719,8 +755,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, } const CpuAotCompilationOptions& options = static_cast(aot_options); - llvm::StringRef target_triple = llvm_ir::AsStringRef(options.triple()); - llvm::Triple triple(llvm::Triple::normalize(target_triple)); + llvm::Triple triple(llvm::Triple::normalize(options.triple())); std::string error; const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); @@ -758,13 +793,12 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, pie_level = llvm::PIELevel::Large; break; } - llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); - llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr target_machine = absl::WrapUnique( - target->createTargetMachine(triple.getTriple(), cpu_name, features, - CompilerTargetOptions(modules[0]->config()), - reloc_model, llvm::None, opt_level)); + std::unique_ptr target_machine = + absl::WrapUnique(target->createTargetMachine( + triple.getTriple(), options.cpu_name(), options.features(), + CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, + opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. llvm::LLVMContext llvm_context; @@ -783,15 +817,9 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - VLOG(2) << "Before optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR( RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); - VLOG(2) << "After optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN(HloSchedule schedule, ScheduleModule(module, BufferSizeBytesFunction())); @@ -806,15 +834,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - const string xla_dump_optimized_hlo_proto_to = - module->config().debug_options().xla_dump_optimized_hlo_proto_to(); - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "buffer_assignment", + assignment->ToString()); } + DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations"); std::unordered_map instruction_to_profile_idx; std::unordered_map computation_to_profile_idx; @@ -831,7 +855,9 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features); + &target_machine_features, + // TODO(b/66051036): Run full msan for AOT. + /*emit_code_for_msan=*/false); TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); @@ -856,33 +882,43 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, /*is_top_level_computation=*/true, schedule.sequence(computation).instructions())); - CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); + CHECK(entry_function->getName() == entry_point_name); - ModuleHook pre_optimization_ir_dump_hook; - ModuleHook post_optimization_ir_dump_hook; - TF_RETURN_IF_ERROR(InitializeModuleHooks( - *module, user_pre_optimization_hook_, user_post_optimization_hook_, - &pre_optimization_ir_dump_hook, &post_optimization_ir_dump_hook)); + ModuleHook pre_optimization_ir_hook; + ModuleHook post_optimization_ir_hook; + std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = + GetIRModuleHooks(*module, user_pre_optimization_hook_, + user_post_optimization_hook_); // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the // pre-optimization IR dump hook before returning. { Status verify_status = VerifyLlvmModule(llvm_module); - if (!verify_status.ok() && pre_optimization_ir_dump_hook) { - pre_optimization_ir_dump_hook(llvm_module).IgnoreError(); + if (!verify_status.ok() && pre_optimization_ir_hook) { + pre_optimization_ir_hook(llvm_module); } TF_RETURN_IF_ERROR(verify_status); } - XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(llvm_module)); + auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) { + if (!DumpingEnabledForHloModule(*module)) { + return; + } + StatusOr disasm_or = + Disassembler(*target_machine).DisassembleObjectFile(obj_file); + string text = disasm_or.ok() + ? std::move(disasm_or).ValueOrDie().text + : absl::StrCat("Error disassembling: ", + disasm_or.status().ToString()); + DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text); + }; - Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( - target_machine.get(), &disassembler, opt_level, + target_machine.get(), opt_level, options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); + pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook); std::unique_ptr object_file = compiler_functor(llvm_module); ObjectFileData object_file_data(object_file->getBufferStart(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7fbe0fa157c57eb0c274662a1de95cf5328ccfa8..4ac61f44d9f38425da2d1fc6b9495cb4deba5047 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 6f79ad7c1468f27c74d84770ec6358fbcd1c1f09..5793f0080b64cf8d576760e274586ac88cdab5d0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" namespace xla { namespace cpu { @@ -42,9 +43,10 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } -bool IsMatrixVectorDot(const HloInstruction* hlo) { +bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) { const Shape& hlo_shape = hlo->shape(); - return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && + return !ShapeUtil::ElementIsComplex(hlo_shape) && + hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); } @@ -55,7 +57,8 @@ bool HasExactlyOneUse(const HloInstruction& hlo_instr) { bool CanBeOutputFused(const HloInstruction* producer, const HloInstruction* consumer) { - return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && + return consumer->opcode() == HloOpcode::kAdd && + IsNonComplexMatrixVectorDot(producer) && HasExactlyOneUse(*producer) == 1; } @@ -96,12 +99,16 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // TODO(b/28644064): see if the "producer->operand_count() == 0" check is - // necessary. - if (producer->operand_count() == 0 || - !InstructionFusion::ShouldFuse(consumer, operand_index)) { - VLOG(2) - << "Not fusing: producer has no operands, or !ShouldFuse(consumer)."; + if (!InstructionFusion::ShouldFuse(consumer, operand_index)) { + VLOG(2) << "Not fusing: !ShouldFuse(consumer)."; + return false; + } + + // Fuse constants in general but avoid creating 2-instruction fusions with + // just a constant and another node. + if (producer->opcode() == HloOpcode::kConstant && + consumer->opcode() != HloOpcode::kFusion) { + VLOG(2) << "Not fusing: insufficient non-constant nodes."; return false; } @@ -111,6 +118,14 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Don't fuse if fusing would cause too much code duplication because of + // inefficiencies in the fusion emitter. + // TODO(b/119692968): Remove this once the fusion emitter can handle + // arbitrary fusion nodes. + if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) { + return false; + } + if (consumer->opcode() == HloOpcode::kDot) { // In the general case we call out to optimized "black box" GEMM routines // for Dot, which precludes fusion. However, in very specific cases, we try diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index c4bde837e57e82584c2a007858ed8d55608acd3c..cdb52d1dd987a47ffcb6b5cad40821d0aca011ee 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -623,65 +623,10 @@ TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape, - HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); } -// Tests that we do not fuse instructions in cases where instructions in the -// fusion would reuse elements from its operand due to an implicit broadcast. -TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { - Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); - Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); - - HloComputation::Builder builder(TestName()); - - HloInstruction* small_param = - builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, small_shape, "param")); - HloInstruction* small_exp = builder.AddInstruction( - HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); - builder.AddInstruction( - HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); - - std::unique_ptr module = CreateNewUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - auto did_fusion = CpuInstructionFusion().Run(module.get()); - ASSERT_TRUE(did_fusion.ok()); - EXPECT_FALSE(did_fusion.ValueOrDie()); - ASSERT_THAT(module->entry_computation()->root_instruction(), - Not(op::Fusion())); -} - -// Like ReuseViaImplicitBroadcastUnary but with a binary operation. -TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { - Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); - Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); - - HloComputation::Builder builder(TestName()); - - HloInstruction* small_param = - builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, small_shape, "param")); - HloInstruction* large_param = - builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, large_shape, "param")); - HloInstruction* small_exp = builder.AddInstruction( - HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); - - builder.AddInstruction(HloInstruction::CreateBinary( - large_shape, HloOpcode::kAdd, small_exp, large_param)); - - std::unique_ptr module = CreateNewUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - auto did_fusion = CpuInstructionFusion().Run(module.get()); - ASSERT_TRUE(did_fusion.ok()); - EXPECT_FALSE(did_fusion.ValueOrDie()); - ASSERT_THAT(module->entry_computation()->root_instruction(), - Not(op::Fusion())); -} - void CreateComputationForDotAddOutputFusionTest(const string& test_name, HloModule* module, int m, int k, int n, @@ -811,7 +756,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) { RunFusionAndCheckOpcodesWereFused( module.get(), {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast, - HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); + HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter}); } std::vector GetGatherLoopFusionTestSpecs() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index a9febe891b5e9d1eb9e6b297952b50d1d26a3396..d8878e622c0500fc5328aa6c295a9e24a3a037f7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; -extern const char* const kKeyValueSortPREDSymbolName = - "__xla_cpu_runtime_KeyValueSortPRED"; -extern const char* const kKeyValueSortS8SymbolName = - "__xla_cpu_runtime_KeyValueSortS8"; -extern const char* const kKeyValueSortU8SymbolName = - "__xla_cpu_runtime_KeyValueSortU8"; -extern const char* const kKeyValueSortS16SymbolName = - "__xla_cpu_runtime_KeyValueSortS16"; -extern const char* const kKeyValueSortU16SymbolName = - "__xla_cpu_runtime_KeyValueSortU16"; -extern const char* const kKeyValueSortF16SymbolName = - "__xla_cpu_runtime_KeyValueSortF16"; -extern const char* const kKeyValueSortS32SymbolName = - "__xla_cpu_runtime_KeyValueSortS32"; -extern const char* const kKeyValueSortU32SymbolName = - "__xla_cpu_runtime_KeyValueSortU32"; -extern const char* const kKeyValueSortF32SymbolName = - "__xla_cpu_runtime_KeyValueSortF32"; -extern const char* const kKeyValueSortS64SymbolName = - "__xla_cpu_runtime_KeyValueSortS64"; -extern const char* const kKeyValueSortU64SymbolName = - "__xla_cpu_runtime_KeyValueSortU64"; -extern const char* const kKeyValueSortF64SymbolName = - "__xla_cpu_runtime_KeyValueSortF64"; - +extern const char* const kKeyValueSortSymbolName = + "__xla_cpu_runtime_KeyValueSort"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index b2e760a224ad8eaa61dae57b0f9cece04a7e54ae..3a2b44d8c1a80128d3577c374e751e73a89e9d59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; -extern const char* const kKeyValueSortPREDSymbolName; -extern const char* const kKeyValueSortS8SymbolName; -extern const char* const kKeyValueSortU8SymbolName; -extern const char* const kKeyValueSortS16SymbolName; -extern const char* const kKeyValueSortU16SymbolName; -extern const char* const kKeyValueSortF16SymbolName; -extern const char* const kKeyValueSortS32SymbolName; -extern const char* const kKeyValueSortU32SymbolName; -extern const char* const kKeyValueSortF32SymbolName; -extern const char* const kKeyValueSortS64SymbolName; -extern const char* const kKeyValueSortU64SymbolName; -extern const char* const kKeyValueSortF64SymbolName; +extern const char* const kKeyValueSortSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 3361a5973f5e8c91802b26d68477347b196d3cac..fae9670051a654f38f09856368ffb700b0c7a085 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 48510181bd01c87c9db764396b556fdf34e6c8c4..6107d40ab51bf7fb510d65194d52b8072aff9c85 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -250,11 +250,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() { std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - const bool enable_fast_math = - hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); - const bool optimize_for_size = - options::OptimizeForSizeRequested(hlo_module_config_); - EmitSmallGemm( /*scalar_type=*/primitive_type, /*m=*/m, /*k=*/k, /*n=*/n, @@ -262,9 +257,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() { /*max_vector_count=*/tile_size_n_in_vector_width, /*min_vectorization_width=*/std::min(4, max_target_vector_width), /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs, - /*rhs=*/rhs, /*result=*/target, b_, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size); + /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_); } void DotOpEmitter::EmitTiledLlvmIrGemv() { @@ -323,11 +316,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { llvm::Value* rhs_op = swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); - const bool enable_fast_math = - hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); - const bool optimize_for_size = - options::OptimizeForSizeRequested(hlo_module_config_); - const int target_vector_register_element_size = target_machine_features_.vector_register_num_elements( *b_->GetInsertBlock()->getParent(), primitive_type); @@ -349,9 +337,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, - /*result=*/result_op, b_, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size); + /*result=*/result_op, b_, hlo_module_config_); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; @@ -361,9 +347,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { /*tile_cols=*/vector_register_element_size, /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, - /*result=*/result_op, b_, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size); + /*result=*/result_op, b_, hlo_module_config_); } } @@ -445,10 +429,12 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_); - llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( - lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( - rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + std::vector lhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + std::vector rhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // Create the loop which does the sum of products reduction. // @@ -468,8 +454,12 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. - lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); - rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape, + b_->getInt64Ty()); + rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape, + b_->getInt64Ty()); // For computing the sum of products we alloca a single location to store the // dot product result as we accumulate it within the reduction loop. After the @@ -532,18 +522,20 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { // the rhs and lhs indexes with the reduction dimensions removed. The terms // from the rhs index are the lower dimensions in the index so we add them // first. - llvm_ir::IrArray::Index target_index(lhs_index.GetType()); + std::vector target_multi_index; for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { - target_index.push_back(lhs_index[dimension]); + target_multi_index.push_back(lhs_index[dimension]); } } for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { - target_index.push_back(rhs_index[dimension]); + target_multi_index.push_back(rhs_index[dimension]); } } + llvm_ir::IrArray::Index target_index( + target_multi_index, target_array_.GetShape(), lhs_index.GetType()); target_array_.EmitWriteArrayElement(target_index, result, b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -643,11 +635,13 @@ Status DotOpEmitter::EmitCallToRuntime() { llvm::Function* function = b_->GetInsertBlock()->getParent(); llvm::Module* module = function->getParent(); - llvm::Function* matmul_func = llvm::cast( - module->getOrInsertFunction(fn_name, matmul_type)); - matmul_func->setCallingConv(llvm::CallingConv::C); - matmul_func->setDoesNotThrow(); - matmul_func->setOnlyAccessesArgMemory(); + llvm::FunctionCallee matmul_func = + module->getOrInsertFunction(fn_name, matmul_type); + if (auto* fn = llvm::dyn_cast(matmul_func.getCallee())) { + fn->setCallingConv(llvm::CallingConv::C); + fn->setDoesNotThrow(); + fn->setOnlyAccessesArgMemory(); + } // The Eigen runtime function expects column-major layout. If the matrices are // row major, then use the following identity to compute the product: @@ -919,11 +913,11 @@ llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); Shape inner_shape = DropFirstDim(outer_array.GetShape()); - llvm_ir::IrArray::Index slice_index(b->getInt64Ty()); - slice_index.push_back(batch_index); - slice_index.InsertAt( - /*index=*/1, outer_array.GetShape().dimensions_size() - 1, - b->getInt64(0)); + std::vector multidim_index(inner_shape.rank() + 1, + b->getInt64(0)); + multidim_index[0] = batch_index; + llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(), + batch_index->getType()); llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b); llvm::Type* slice_ptr_type = llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo(); @@ -961,8 +955,8 @@ Status EmitBatchDotOperation( KernelSupportLibrary ksl(b); return ksl.ForWithStatus( - "bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1, - [&](llvm::Value* indvar) { + llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count, + /*step=*/1, [&](llvm::Value* indvar) { DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); adjusted_dim_numbers.clear_lhs_batch_dimensions(); adjusted_dim_numbers.clear_rhs_batch_dimensions(); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index c8312d80bd5012e5bcb42a410db18a7fa77a2eb6..fb021f277b0c796c4126e36da5dea405c603ccd1 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +using xla::llvm_ir::IrArray; + namespace xla { namespace cpu { @@ -51,10 +53,11 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, return Unimplemented("atan2"); } // Create a function declaration. - llvm::Function* function = - llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), - rhs->getType())); + llvm::Function* function = llvm::dyn_cast( + module_ + ->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(), + rhs->getType()) + .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -85,9 +88,11 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return Unimplemented("tanh"); } // Create a function declaration. - llvm::Function* function = llvm::cast( - module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name), - value->getType(), value->getType())); + llvm::Function* function = llvm::dyn_cast( + module_ + ->getOrInsertFunction(function_name, value->getType(), + value->getType()) + .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -102,21 +107,43 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) { - if (hlo->opcode() == HloOpcode::kMap) { - return [this, hlo, &operand_to_generator]( - const llvm_ir::IrArray::Index& index) -> StatusOr { - std::vector operands; - for (int i = 0; i < hlo->operand_count(); i++) { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))( - ElementwiseSourceIndex(index, *hlo, i))); - operands.push_back(operand_value); - } - return ir_emitter_->EmitElementalMap(*Cast(hlo), - operands, llvm_ir::IrName(hlo)); - }; + switch (hlo->opcode()) { + case HloOpcode::kMap: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + std::vector operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))(index)); + operands.push_back(operand_value); + } + return ir_emitter_->EmitElementalMap(*Cast(hlo), + operands, llvm_ir::IrName(hlo)); + }; + case HloOpcode::kReduceWindow: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return ir_emitter_->EmitElementalReduceWindow( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), index); + }; + case HloOpcode::kConvolution: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return ir_emitter_->EmitElementalConvolution( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), + operand_to_generator.at(hlo->operand(1)), index); + }; + case HloOpcode::kReduce: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return ir_emitter_->EmitElementalReduce( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), + operand_to_generator.at(hlo->operand(1)), index); + }; + default: + return ElementalIrEmitter::MakeElementGenerator(hlo, + operand_to_generator); } - return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index a8b139aec9e96b6bb580baf74789df7c998cebf8..2cc618e430215e26cb41c0a24a9c01b1ae33cec1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -72,7 +72,8 @@ bool PotentiallyImplementedAsEigenConvolution( CHECK( ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape)) { + PrimitiveType primitive_type = input_shape.element_type(); + if (primitive_type != F16 && primitive_type != F32) { return false; } if (window_util::HasWindowReversal(convolution.window())) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index f8a997045a66545bca9a35b2e582bf015e659b48..19e7b13bb3ac3c08c0131510e97cb90ead1c9437 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -74,10 +74,8 @@ limitations under the License. namespace xla { namespace { -using llvm_ir::AsStringRef; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -namespace gtl = tensorflow::gtl; } // namespace namespace cpu { @@ -87,7 +85,8 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features) + const TargetMachineFeatures* target_machine_features, + bool emit_code_for_msan) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -97,10 +96,9 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_cpu_enable_fast_math())); + target_machine_features_(*target_machine_features), + emit_code_for_msan_(emit_code_for_msan) { + b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_)); Status s = GatherComputationsByAllocationType( &hlo_module, &thread_local_computations_, &global_computations_); absl::c_sort(thread_local_computations_); @@ -159,11 +157,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset(new IrFunction( - function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, - &b_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction(function_name, linkage, + hlo_module_config_, module_, &b_, + num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -172,8 +168,7 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = BitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast)); return Status::OK(); } @@ -188,6 +183,7 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { /*Initializer=*/initializer, /*Name=*/""); result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global); return llvm::ConstantExpr::getBitCast( result_global, IrShapeType(literal.shape())->getPointerTo()); } @@ -302,7 +298,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), - GetEmittedValueFor(operand), &b_, module_); + GetEmittedValueFor(operand), &b_); return Status::OK(); } @@ -322,7 +318,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &b_, module_); + GetEmittedValueFor(on_false), &b_); return Status::OK(); } @@ -345,8 +341,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { assignment_.GetUniqueSlice(infeed, {1})); llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); - llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, - module_); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_); if (data_shape.IsTuple()) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); @@ -377,7 +372,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { } llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), - tuple_element_addresses, &b_, module_); + tuple_element_addresses, &b_); } else { TF_RETURN_IF_ERROR( EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); @@ -412,11 +407,18 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Function* acquire_func; if (kind == XfeedKind::kInfeed) { - acquire_func = llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + acquire_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type) + .getCallee()); } else { - acquire_func = llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type)); + acquire_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kAcquireOutfeedBufferForPopulationSymbolName, + acquire_type) + .getCallee()); } acquire_func->setCallingConv(llvm::CallingConv::C); @@ -429,11 +431,19 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Function* release_func; if (kind == XfeedKind::kInfeed) { - release_func = llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + release_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kReleaseInfeedBufferAfterDequeueSymbolName, + release_type) + .getCallee()); } else { - release_func = llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type)); + release_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, + release_type) + .getCallee()); } release_func->setCallingConv(llvm::CallingConv::C); @@ -483,7 +493,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { ShapeUtil::GetTupleElementShape(operand_shape, i); llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), - value, &b_, module_); + value, &b_); TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, tuple_element_shape, tuple_element)); } @@ -495,6 +505,27 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { const HloSortInstruction* sort = Cast(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); + PrimitiveType keys_type = keys_shape.element_type(); + switch (keys_type) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case BF16: + case F16: + case S32: + case U32: + case F32: + case S64: + case U64: + case F64: + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = @@ -542,109 +573,50 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - PrimitiveType keys_type = keys_shape.element_type(); - const char* fn_name = nullptr; - llvm::Type* keys_native_type = nullptr; - switch (keys_type) { - case PRED: - fn_name = runtime::kKeyValueSortPREDSymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S8: - fn_name = runtime::kKeyValueSortS8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case U8: - fn_name = runtime::kKeyValueSortU8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S16: - fn_name = runtime::kKeyValueSortS16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case U16: - fn_name = runtime::kKeyValueSortU16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case F16: - fn_name = runtime::kKeyValueSortF16SymbolName; - keys_native_type = b_.getHalfTy()->getPointerTo(); - break; - case S32: - fn_name = runtime::kKeyValueSortS32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case U32: - fn_name = runtime::kKeyValueSortU32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case F32: - fn_name = runtime::kKeyValueSortF32SymbolName; - keys_native_type = b_.getFloatTy()->getPointerTo(); - break; - case S64: - fn_name = runtime::kKeyValueSortS64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case U64: - fn_name = runtime::kKeyValueSortU64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case F64: - fn_name = runtime::kKeyValueSortF64SymbolName; - keys_native_type = b_.getDoubleTy()->getPointerTo(); - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); - } - + auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply()); + CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply())); llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), - {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo()}, + b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), + b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); - auto* key_value_sort_func = llvm::cast( - module_->getOrInsertFunction(fn_name, key_value_sort_type)); + auto* key_value_sort_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction(runtime::kKeyValueSortSymbolName, + key_value_sort_type) + .getCallee()); key_value_sort_func->setCallingConv(llvm::CallingConv::C); key_value_sort_func->setDoesNotThrow(); - llvm::Value* values; - llvm::Value* sizes; - if (sort->values_count() == 0) { - values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()); - sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo()); - } else { - values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt8PtrTy(), b_.getInt32(sort->values_count()), - "cc_values_alloca", &b_); - sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca", - &b_); - for (int64 i = 0; i < sort->values_count(); ++i) { - llvm::Value* value_as_i8ptr = - PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy()); - llvm::Value* slot_in_values_alloca = - ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); - Store(value_as_i8ptr, slot_in_values_alloca); - llvm::Value* slot_in_sizes_alloca = - ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); - llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i + 1)->shape().element_type())); - Store(size, slot_in_sizes_alloca); - } + llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", + &b_); + llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca", + &b_); + for (int64 i = 0; i < sort->operand_count(); ++i) { + llvm::Value* value_as_i8ptr = + PointerCast(destination_addresses[i], b_.getInt8PtrTy()); + llvm::Value* slot_in_values_alloca = + ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); + Store(value_as_i8ptr, slot_in_values_alloca); + llvm::Value* slot_in_sizes_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); + llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type())); + Store(size, slot_in_sizes_alloca); } Call(key_value_sort_func, - {PointerCast(destination_addresses[0], keys_native_type), - b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->values_count()), sizes}); + b_.getInt32(sort->operand_count()), sizes, + b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), + GetProfileCountersArgument(), less_than_function}); if (sort->values_count() > 0) { - llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, - module_); + llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_); } return Status::OK(); } @@ -655,7 +627,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (auto operand : tuple->operands()) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_); return Status::OK(); } @@ -665,8 +637,9 @@ llvm::Value* IrEmitter::EmitElementalMap( return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } -StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( - HloReduceWindowInstruction* reduce_window, +StatusOr IrEmitter::EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, const llvm_ir::IrArray::Index& index) { const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); @@ -692,21 +665,22 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), index.size()); + std::vector input_multi_index(index.size()); llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub( + input_multi_index[i] = NSWSub( NSWAdd(strided_index, NSWMul(window_index[i], b_.getInt64(window.dimensions(i).window_dilation()))), b_.getInt64(window.dimensions(i).padding_low())); // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = ICmpEQ( - SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), - b_.getInt64(0)); + llvm::Value* dilation_condition = + ICmpEQ(SRem(input_multi_index[i], + b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); if (in_bounds_condition == nullptr) { in_bounds_condition = dilation_condition; } else { @@ -714,15 +688,16 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( } // Apply base dilation to the index. - input_index[i] = - SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); - - // We need to check if 0 <= input_index[i] < bound, as otherwise we are in - // the padding so that we can skip the computation. That is equivalent to - // input_index[i] < bound as an *unsigned* comparison, since a negative - // value will wrap to a large positive value. + input_multi_index[i] = + SDiv(input_multi_index[i], + b_.getInt64(window.dimensions(i).base_dilation())); + + // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we + // are in the padding so that we can skip the computation. That is + // equivalent to input_multi_index[i] < bound as an *unsigned* comparison, + // since a negative value will wrap to a large positive value. llvm::Value* index_condition = - ICmpULT(input_index[i], + ICmpULT(input_multi_index[i], b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; @@ -737,8 +712,10 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( SetToFirstInsertPoint(if_data.true_block, &b_); // We are not in the padding, so carry out the computation. - llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); + llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(), + b_.getInt64Ty()); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, + input_generator(input_index)); llvm::Value* result = EmitThreadLocalCall( *reduce_window->to_apply(), {Load(accumulator_address), input_value}, "reducer_function"); @@ -749,11 +726,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*reduce_window, - /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32, F16})); - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -767,11 +739,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // // This is completely un-optimized and just here to have something // that works. - return EmitTargetElementLoop( - reduce_window, [&](const llvm_ir::IrArray::Index& index) { - return EmitTargetElementLoopBodyForReduceWindow( - Cast(reduce_window), index); - }); + return DefaultAction(reduce_window); } Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { @@ -851,15 +819,16 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); + std::vector operand_multi_index(source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + operand_multi_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); llvm::Value* index_condition = - ICmpULT(operand_index[i], + ICmpULT(operand_multi_index[i], b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = And(in_bounds_condition, index_condition); } @@ -885,6 +854,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); + llvm_ir::IrArray::Index operand_index( + operand_multi_index, operand_array.GetShape(), b_.getInt64Ty()); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); Store(operand_data, selected_value_address); @@ -918,16 +889,18 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // location is computed by calling the `scatter` function with the source // value and the current output value. SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); - llvm_ir::IrArray::Index selected_index(source_index.GetType()); + std::vector selected_multi_index; for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = InBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(Load(selected_index_address_slot)); + selected_multi_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); + llvm_ir::IrArray::Index selected_index( + selected_multi_index, output_array.GetShape(), source_index.GetType()); llvm::Value* output_value = output_array.EmitReadArrayElement(selected_index, &b_); llvm::Value* scatter_value = @@ -974,8 +947,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { hlo_module_config_, target_machine_features_); } -StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( - HloConvolutionInstruction* convolution, +StatusOr IrEmitter::EmitElementalConvolution( + const HloConvolutionInstruction* convolution, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& kernel_generator, const llvm_ir::IrArray::Index& index) { const HloInstruction* lhs = convolution->operand(0); const HloInstruction* rhs = convolution->operand(1); @@ -1080,30 +1055,34 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( // We are not in the padding, so carry out the computation. int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), num_dims); + std::vector input_multi_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; + input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } - input_index[dnums.input_feature_dimension()] = input_feature; - input_index[dnums.input_batch_dimension()] = batch; + input_multi_index[dnums.input_feature_dimension()] = input_feature; + input_multi_index[dnums.input_batch_dimension()] = batch; - llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); - llvm_ir::IrArray::Index kernel_index(b_.getInt64Ty(), num_dims); + std::vector kernel_multi_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = + kernel_multi_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), kernel_spatial[i]) : kernel_spatial[i]; } - kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; - kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; + kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature; + kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature; - llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); - llvm::Value* product = - FMul(input_array.EmitReadArrayElement(input_index, &b_), - kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(), + b_.getInt64Ty()); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, + input_generator(input_index)); + llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(), + b_.getInt64Ty()); + TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, + kernel_generator(kernel_index)); + llvm::Value* product = FMul(input_value, kernel_value); llvm::Value* sum = FAdd(Load(sum_address), product); Store(sum, sum_address); @@ -1116,7 +1095,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, C64, C128})); + /*supported_types=*/{F16, F32, F64, C64, C128})); // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. @@ -1229,8 +1208,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " "conv2d function."; } - llvm::Function* conv_func = llvm::cast( - module_->getOrInsertFunction(fn_name, conv_type)); + llvm::Function* conv_func = llvm::dyn_cast( + module_->getOrInsertFunction(fn_name, conv_type).getCallee()); conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); @@ -1271,11 +1250,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // // See the description of convolution in the XLA documentation for the pseudo // code for convolution. - return EmitTargetElementLoop( - convolution, [&](const llvm_ir::IrArray::Index& index) { - return EmitTargetElementLoopBodyForConvolution( - Cast(convolution), index); - }); + return DefaultAction(convolution); } Status IrEmitter::HandleFft(HloInstruction* fft) { @@ -1313,8 +1288,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { ? runtime::kEigenFftSymbolName : runtime::kEigenSingleThreadedFftSymbolName; - llvm::Function* fft_func = llvm::cast( - module_->getOrInsertFunction(fn_name, fft_type)); + llvm::Function* fft_func = llvm::dyn_cast( + module_->getOrInsertFunction(fn_name, fft_type).getCallee()); fft_func->setCallingConv(llvm::CallingConv::C); fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); @@ -1368,7 +1343,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } - llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_); return Status::OK(); } @@ -1617,22 +1592,23 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), &b_); - llvm_ir::IrArray::Index reduced_dims_index = + std::vector input_multi_index = reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_); llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); - llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; } } CHECK(output_index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + b_.getInt64Ty()); llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); @@ -1744,8 +1720,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_); - llvm_ir::IrArray::Index array_index(b_.getInt64Ty(), - reduce->shape().dimensions_size()); + std::vector array_multi_index( + reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -1753,7 +1729,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 end_index = reduce->shape().dimensions(dimension); std::unique_ptr loop = loop_nest.AddLoop( start_index, end_index, absl::StrFormat("dim.%d", dimension)); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); @@ -1774,12 +1750,14 @@ StatusOr IrEmitter::EmitVectorizedReduce( std::unique_ptr loop = loop_nest.AddLoop(start_index, end_index, vectorization_factor, absl::StrFormat("dim.%d", innermost_dimension)); - array_index[innermost_dimension] = loop->GetIndVarValue(); + array_multi_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), vectorization_factor); + llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(std::vector accumulator, EmitInnerLoopForVectorizedReduction( reduction_generator, array_index, vector_type, @@ -1805,13 +1783,15 @@ StatusOr IrEmitter::EmitVectorizedReduce( // in the following case: if (innermost_dimension_size % vectorization_factor) { // TODO(b/63775531): Consider using a scalar loop here to save on code size. - array_index[innermost_dimension] = + array_multi_index[innermost_dimension] = b_.getInt64(innermost_dimension_size - (innermost_dimension_size % vectorization_factor)); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), innermost_dimension_size % vectorization_factor); + llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(std::vector accumulator, EmitInnerLoopForVectorizedReduction( reduction_generator, array_index, vector_type, @@ -1831,10 +1811,12 @@ StatusOr IrEmitter::EmitVectorizedReduce( return true; } -StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( - HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { - const HloInstruction* arg = reduce->mutable_operand(0); - const HloInstruction* init_value = reduce->mutable_operand(1); +StatusOr IrEmitter::EmitElementalReduce( + const HloReduceInstruction* reduce, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index) { + const HloInstruction* arg = reduce->operand(0); absl::Span dimensions(reduce->dimensions()); // Initialize an accumulator with init_value. @@ -1842,9 +1824,10 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); - llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = Load(init_value_addr); - Store(load_init_value, accumulator_addr); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); + Store(init_value, accumulator_addr); // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1852,7 +1835,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( // AddLoopsForShapeOnDimensions will return an Index where induction Value*s // are placed for each dimension in dimensions, and all the rest are nullptrs. llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const llvm_ir::IrArray::Index reduced_dims_index = + std::vector input_multi_index = loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); @@ -1863,19 +1846,20 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( // fill in the rest of the dimensions with induction Value*s taken from // 'index' which iterates over the target array. See the high-level // description in the XLA documentation for details. - llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); - llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; } } CHECK(index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + b_.getInt64Ty()); // Apply the reduction function to the loaded value. - llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generator(input_index)); llvm::Value* result = EmitThreadLocalCall( *reduce->to_apply(), {Load(accumulator_addr), input_element}, "reduce_function"); @@ -1886,7 +1870,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support variadic reduce. + // TODO(b/118333695): Support variadic reduce. if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on CPU"); } @@ -1910,11 +1894,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } } - return EmitTargetElementLoop(reduce, - [&](const llvm_ir::IrArray::Index& index) { - return EmitTargetElementLoopBodyForReduce( - Cast(reduce), index); - }); + return DefaultAction(reduce); +} + +Status IrEmitter::HandleAllToAll(HloInstruction*) { + return Unimplemented("AllToAll is not implemented on CPU."); } Status IrEmitter::HandleSend(HloInstruction* send) { @@ -2021,15 +2005,17 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { const int64 num_outer_loops = outer_dims.size(); llvm_ir::ForLoopNest loops(IrName(slice), &b_); - llvm_ir::IrArray::Index target_index = + std::vector target_multi_index = loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); // Only the indices for the outer dimensions have been initialized in // target_index. The rest of the indices should get initialized to 0, since // for the rest of the dimensions the copy writes to the full dimension. - std::replace(target_index.begin(), target_index.end(), + std::replace(target_multi_index.begin(), target_multi_index.end(), static_cast(nullptr), static_cast(b_.getInt64(0))); + llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(), + b_.getInt64Ty()); if (num_outer_loops > 0) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); @@ -2037,7 +2023,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( - /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), + /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(), /*strides=*/slice->slice_strides(), /*builder=*/&b_); llvm::Value* memcpy_dest = @@ -2140,18 +2126,20 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); - llvm_ir::IrArray::Index output_index(operand_index.GetType()); + std::vector output_multi_index; for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = Mul(operand_index[i], b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); - output_index.push_back(index); + output_multi_index.push_back(index); } // Store the operand element to the computed output location. llvm_ir::IrArray output_array(GetIrArrayFor(pad)); + llvm_ir::IrArray::Index output_index( + output_multi_index, output_array.GetShape(), operand_index.GetType()); output_array.EmitWriteArrayElement(output_index, operand_data, &b_); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -2240,7 +2228,6 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { absl::Span operands(custom_call->operands()); - absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2253,13 +2240,34 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { InBoundsGEP(operands_alloca, {b_.getInt64(i)}); Store(operand_as_i8ptr, slot_in_operands_alloca); } - auto* custom_call_ir_function = - llvm::cast(module_->getOrInsertFunction( - AsStringRef(custom_call_target), - llvm::FunctionType::get( - /*Result=*/b_.getVoidTy(), - /*Params=*/{i8_ptr_type, operands_alloca->getType()}, - /*isVarArg=*/false))); + if (emit_code_for_msan_) { + // Mark the alloca as initialized for msan. The buffer gets read by the + // custom callee, which might be msan-instrumented. + // TODO(b/66051036): Run the msan instrumentation pass instead. + const llvm::DataLayout& dl = module_->getDataLayout(); + llvm::Type* intptr_type = b_.getIntPtrTy(dl); + auto* msan_unpoison_ir_function = llvm::cast( + module_ + ->getOrInsertFunction( + "__msan_unpoison", + llvm::FunctionType::get( + /*Result=*/b_.getVoidTy(), + /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false)) + .getCallee()); + Call(msan_unpoison_ir_function, + {PointerCast(operands_alloca, i8_ptr_type), + llvm::ConstantInt::get( + intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}); + } + auto* custom_call_ir_function = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + custom_call->custom_call_target(), + llvm::FunctionType::get( + /*Result=*/b_.getVoidTy(), + /*Params=*/{i8_ptr_type, operands_alloca->getType()}, + /*isVarArg=*/false)) + .getCallee()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); // Write the tuple table if the output is a tuple. @@ -2275,7 +2283,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { llvm::Value* addr = EmitBufferPointer(slice, elem_shape); base_ptrs.push_back(addr); } - llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_); } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); @@ -2337,7 +2345,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( - module_->getContext(), AsStringRef(IrName(xla_while, "header")), + module_->getContext(), IrName(xla_while, "header"), compute_function_->function()); Br(header_bb); b_.SetInsertPoint(header_bb); @@ -2350,11 +2358,11 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. - llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( - module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_->function()); + llvm::BasicBlock* body_bb = + llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"), + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( - module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); + module_->getContext(), IrName(xla_while, "exit")); CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. @@ -2409,11 +2417,13 @@ StatusOr IrEmitter::EmitFastConcatenate( llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); llvm_ir::ForLoopNest loops(IrName(concatenate), &b_); - llvm_ir::IrArray::Index outer_dims_index = + std::vector target_multi_index = loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); - std::replace(outer_dims_index.begin(), outer_dims_index.end(), + std::replace(target_multi_index.begin(), target_multi_index.end(), static_cast(nullptr), static_cast(b_.getInt64(0))); + llvm_ir::IrArray::Index target_index(target_multi_index, output_shape, + b_.getInt64Ty()); if (!outer_dims.empty()) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); @@ -2425,10 +2435,9 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. - llvm::Value* target_region_begin = - BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, - "target_region"), - i8_ptr_type); + llvm::Value* target_region_begin = BitCast( + target_array.EmitArrayElementAddress(target_index, &b_, "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2443,7 +2452,7 @@ StatusOr IrEmitter::EmitFastConcatenate( const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); llvm::Value* copy_source_address = BitCast( - source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), + source_array.EmitArrayElementAddress(target_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = @@ -2520,53 +2529,109 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { } Status IrEmitter::HandleConditional(HloInstruction* conditional) { - auto pred = conditional->operand(0); - TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && - pred->shape().element_type() == PRED) - << "Predicate on a Conditional must be bool; got: " - << ShapeUtil::HumanString(pred->shape()); - - HloComputation* true_computation = conditional->true_computation(); - HloComputation* false_computation = conditional->false_computation(); - TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), - true_computation->root_instruction()->shape())) - << "Shape of conditional should be same as the shape of the true " - << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) - << " and " - << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); - - TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), - false_computation->root_instruction()->shape())) - << "Shape of conditional should be same as the shape of the false " - << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) - << " and " - << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + auto branch_index = conditional->operand(0); + int num_branches = conditional->branch_count(); + TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) && + (branch_index->shape().element_type() == PRED || + branch_index->shape().element_type() == S32)) + << "Branch index on a conditional must be scalar bool or int32; got: " + << ShapeUtil::HumanString(branch_index->shape()); + + for (int b = 0; b < num_branches; ++b) { + HloComputation* br_computation = conditional->branch_computation(b); + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + br_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the " << b + << "th branch computation; got: " + << ShapeUtil::HumanString(conditional->shape()) << " and " + << ShapeUtil::HumanString(br_computation->root_instruction()->shape()); + } TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); - // Generating: - // if (pred) - // cond_result = true_computation(true_operand) - // else - // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = - Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ICmpNE( - pred_value, - llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), - "boolean_predicate"); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); - - SetToFirstInsertPoint(if_data.true_block, &b_); - EmitGlobalCall(*conditional->true_computation(), - IrName(conditional, "_true")); - - SetToFirstInsertPoint(if_data.false_block, &b_); - EmitGlobalCall(*conditional->false_computation(), - IrName(conditional, "_false")); - - SetToFirstInsertPoint(if_data.after_block, &b_); + if (branch_index->shape().element_type() == PRED) { + // Emit an if-else to LLVM: + // if (pred) + // cond_result = true_computation(true_operand) + // else + // cond_result = false_computation(false_operand) + llvm::LoadInst* pred_value = Load( + GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = + ICmpNE(pred_value, + llvm::ConstantInt::get( + llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); + + SetToFirstInsertPoint(if_data.true_block, &b_); + EmitGlobalCall(*conditional->branch_computation(0), + IrName(conditional, "_true")); + + SetToFirstInsertPoint(if_data.false_block, &b_); + EmitGlobalCall(*conditional->branch_computation(1), + IrName(conditional, "_false")); + + SetToFirstInsertPoint(if_data.after_block, &b_); + return Status::OK(); + } + // We emit a switch statement to LLVM: + // switch (branch_index) { + // default: + // result = branch_computations[num_branches-1](operands[num_branches-1]); + // break; + // case 0: + // result = branch_computations[0](operands[0]); break; + // case 1: + // result = branch_computations[1](operands[1]); break; + // ... + // case [[num_branches-2]]: + // result = branch_computations[num_branches-2](operands[num_branches-2]); + // break; + // } + llvm::LoadInst* branch_index_value = Load( + GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value"); + + auto case_block = b_.GetInsertBlock(); + llvm::BasicBlock* after_block; + // Add a terminator to the case block, if necessary. + if (case_block->getTerminator() == nullptr) { + after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_); + b_.SetInsertPoint(case_block); + b_.CreateBr(after_block); + } else { + after_block = + case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after"); + } + // Our basic block should now end with an unconditional branch. Remove it; + // we're going to replace it with a switch based branch. + case_block->getTerminator()->eraseFromParent(); + + // Lower the default branch computation. + auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_); + b_.SetInsertPoint(default_block); + EmitGlobalCall(*conditional->branch_computation(num_branches - 1), + IrName(conditional, "_default")); + b_.CreateBr(after_block); + + // Prepare the switch (branch_index) { ... } instruction. + b_.SetInsertPoint(case_block); + llvm::SwitchInst* case_inst = + b_.CreateSwitch(branch_index_value, default_block, num_branches - 1); + // Lower each branch's computation. + for (int b = 0; b < num_branches - 1; ++b) { // last branch is default + // Lower the case b: { ... ; break; } computation. + auto branch_block = + llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_); + b_.SetInsertPoint(branch_block); + EmitGlobalCall(*conditional->branch_computation(b), + IrName(conditional, absl::StrCat("_branch", b))); + b_.CreateBr(after_block); + case_inst->addCase(b_.getInt32(b), branch_block); + } + + SetToFirstInsertPoint(after_block, &b_); return Status::OK(); } @@ -2641,7 +2706,7 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), - AsStringRef(counter_name)); + counter_name); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2685,7 +2750,7 @@ llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) { void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo) { auto* cycle_start = ReadCycleCounter(b); - cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start"))); + cycle_start->setName(IrName(hlo, "cycle_start")); cycle_starts_[hlo] = cycle_start; if (first_read_cycle_start_ == nullptr) { first_read_cycle_start_ = cycle_start; @@ -2696,7 +2761,7 @@ void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter) { auto* cycle_end = ReadCycleCounter(b); - cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end"))); + cycle_end->setName(IrName(hlo, "cycle_end")); auto* cycle_start = cycle_starts_[hlo]; UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start); last_read_cycle_end_ = cycle_end; @@ -2869,7 +2934,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); llvm::Value* addr = EmitBufferPointer(slice, target_shape); - addr->setName(AsStringRef(IrName(op))); + addr->setName(IrName(op)); emitted_value_[op] = addr; return Status::OK(); } @@ -2909,7 +2974,7 @@ Status IrEmitter::EmitTargetElementLoop( for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_); } else { if (ShouldEmitParallelLoopFor(*target_op)) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index a6fb11dcbf9bb201ba8837866e2f509c48bfd061..e183ae01070e7d42701a3a32d5ddb8667e163663 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -72,13 +72,15 @@ class IrEmitter : public DfsHloVisitorWithDefault, // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. + // emit_code_for_msan: whether emitted code should be compatible with msan. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine); + const TargetMachineFeatures* target_machine, + bool emit_code_for_msan); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -116,6 +118,23 @@ class IrEmitter : public DfsHloVisitorWithDefault, const HloMapInstruction& map_instr, absl::Span elemental_operands, absl::string_view name); + // Emit code to emit the element at `index` for a reduce window instruction. + StatusOr EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::IrArray::Index& index); + // Emit code to emit the element at `index` for a convolution instruction. + StatusOr EmitElementalConvolution( + const HloConvolutionInstruction* convolution, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& kernel_generator, + const llvm_ir::IrArray::Index& index); + // Emit code to emit the element at `index` for a reduce instruction. + StatusOr EmitElementalReduce( + const HloReduceInstruction* reduce, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index); protected: // @@ -125,6 +144,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // special in some way are handled explicitly in HandleFoo methods. Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllToAll(HloInstruction* instruction) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant) override; Status HandleCopy(HloInstruction* copy) override; @@ -250,14 +270,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); - // Emits a function into the current module. This can be used for - // computations embedded inside other computations, such as the - // function that a map operation applies. - StatusOr EmitFunction( - HloComputation* function, // The function to emit. - absl::string_view - function_name_suffix); // Used for LLVM IR register names. - // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write // to and read from thread local allocations. @@ -532,17 +544,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; - StatusOr EmitTargetElementLoopBodyForMap( - HloMapInstruction* map, const llvm_ir::IrArray::Index& index); - StatusOr EmitTargetElementLoopBodyForReduceWindow( - HloReduceWindowInstruction* reduce_window, - const llvm_ir::IrArray::Index& index); - StatusOr EmitTargetElementLoopBodyForConvolution( - HloConvolutionInstruction* convolution, - const llvm_ir::IrArray::Index& index); - StatusOr EmitTargetElementLoopBodyForReduce( - HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index); - enum class XfeedKind { kInfeed, kOutfeed, @@ -582,6 +583,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector thread_local_computations_; std::vector global_computations_; + bool emit_code_for_msan_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index adfb8392bf6fa356f0a5cdab3ff74036eca8918e..42acd72f9661df8e6687cf544957dce112954dc5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -24,11 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" namespace xla { - -namespace { -using llvm_ir::AsStringRef; -} // namespace - namespace cpu { static std::vector GetComputeFunctionParams( @@ -48,15 +43,14 @@ static std::vector GetComputeFunctionParams( IrFunction::IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, - const bool optimize_for_size_requested, - const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds) + const HloModuleConfig& module_config, + llvm::Module* llvm_module, llvm::IRBuilder<>* b, + int64 num_dynamic_loop_bounds) : b_(b), llvm_module_(llvm_module), caller_insert_point_guard_(*b), num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { - Initialize(function_name, linkage, optimize_for_size_requested, - enable_fast_math); + Initialize(function_name, linkage, module_config); } IrFunction::~IrFunction() { @@ -75,8 +69,7 @@ DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { void IrFunction::Initialize(const string& function_name, llvm::Function::LinkageTypes linkage, - const bool optimize_for_size_requested, - const bool enable_fast_math) { + const HloModuleConfig& module_config) { // The function signature is: // void function(i8* retval, i8* run_options, i8** params, i8** // buffer_table, @@ -147,11 +140,8 @@ void IrFunction::Initialize(const string& function_name, // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. - function_ = - llvm_ir::CreateFunction(function_type, linkage, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size_requested, - function_name, llvm_module_); + function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config, + function_name, llvm_module_); // Set meaningful names for the function's arguments: useful for debugging. llvm::Function::arg_iterator arg_iter = function_->arg_begin(); @@ -193,7 +183,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), - b_->getInt64(offset), AsStringRef(name))); + b_->getInt64(offset), name)); } // Emits code to allocate an array of parameter address pointers, and store @@ -216,10 +206,9 @@ std::vector GetArrayFunctionCallArguments( absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = - b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(absl::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); + llvm::Value* parameter_as_i8ptr = b->CreateBitCast( + parameter_addresses[i], b->getInt8PtrTy(), + absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr")); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -266,9 +255,11 @@ Status EmitCallToParallelForkJoin( /*Params=*/compute_function_params, /*isVarArg=*/false); - llvm::Function* fork_join_func = - llvm::cast(module->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); + llvm::Function* fork_join_func = llvm::dyn_cast( + module + ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName, + fork_join_type) + .getCallee()); fork_join_func->setCallingConv(llvm::CallingConv::C); fork_join_func->setDoesNotThrow(); @@ -322,7 +313,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); + absl::StrCat(name, "_parallel_dimension_partitions")); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 623a5f185fa1fd0526bc8664e2ba11c9dde79b1d..02bcec9dfc783d4eea653d6d74e903909b666b98 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,8 +53,7 @@ namespace cpu { class IrFunction { public: IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, - const bool optimize_for_size_requested, - const bool enable_fast_math, llvm::Module* llvm_module, + const HloModuleConfig& module_config, llvm::Module* llvm_module, llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds); ~IrFunction(); @@ -92,7 +92,7 @@ class IrFunction { // Initialize an llvm::Function with standard signature based on arguments. void Initialize(const string& function_name, llvm::Function::LinkageTypes linkage, - bool optimize_for_size_requested, bool enable_fast_math); + const HloModuleConfig& module_config); // Emit ir to read and return the ir value for the dynamic loop bound at // 'offset' from the "dynamic_loop_bounds" argument of this function. diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index f9722ffadac801521ddcbb568dd4435fd02e951b..93ef51754d21ad3ff4e24298c89649ef4c2742fb 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -36,57 +36,88 @@ const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; namespace { -llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_tanh_function = module->getFunction(function_name); - if (vector_tanh_function == nullptr) { + +// Replaces calls to the function `fn_name` with the code generated by +// fn_body_generator. +// +// We assume that fn_name accepts either a scalar f32 or a vector of +// vector_width f32s, and that fn_body_generator generates a function body with +// the same inputs/outputs as fn_name. +void RewriteCalls( + llvm::Module* module, const char* fn_name, + std::function* b, llvm::Value* input, + int32 vector_width)> + fn_body_generator, + int32 vector_width, bool enable_fast_math) { + llvm::Function* fn = module->getFunction(fn_name); + if (fn == nullptr) { // If the function declaration is not present in the module, there can't be // any calls to resolve. Don't emit the function in this case. - return nullptr; + return; } - llvm::LLVMContext* context = &module->getContext(); + // Our task is to generate a function body for `fn`, but we can't generate a + // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it + // with a new function. + if (fn->isIntrinsic()) { + llvm::Function* new_fn = llvm::Function::Create( + fn->getFunctionType(), llvm::GlobalValue::InternalLinkage, + llvm::Twine("xla_impl.") + fn_name, module); + fn->replaceAllUsesWith(new_fn); + fn->eraseFromParent(); + fn = new_fn; + } - llvm::BasicBlock* vector_tanh_body = - llvm::BasicBlock::Create(*context, "body", vector_tanh_function); + llvm::LLVMContext* context = &module->getContext(); - llvm::IRBuilder<> b(vector_tanh_body); + llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn); + llvm::IRBuilder<> b(fn_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(enable_fast_math); b.setFastMathFlags(fast_math_flags); - llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); - b.CreateRet(llvm_ir::EmitFastTanh(&b, input)); - - DCHECK(!llvm::verifyFunction(*vector_tanh_function)); - return vector_tanh_function; -} + llvm::Value* input = &*fn->arg_begin(); -llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_exp_function = module->getFunction(function_name); - if (vector_exp_function == nullptr) { - // If the function declaration is not present in the module, there can't be - // any calls to resolve. Don't emit the function in this case. - return nullptr; + // Upcast to vector type if input is a scalar. + if (vector_width == 1) { + llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); + input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, + uint64_t{0}); } - llvm::LLVMContext* context = &module->getContext(); + // Generate the vectorized code. + CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + llvm::Value* result = fn_body_generator(&b, input, vector_width); + + // Downcast result to scalar type if necessary. + if (vector_width == 1) { + result = b.CreateExtractElement(result, uint64_t{0}); + } + b.CreateRet(result); + DCHECK(!llvm::verifyFunction(*fn)); - llvm::BasicBlock* vector_exp_body = - llvm::BasicBlock::Create(*context, "body", vector_exp_function); + // Force-inline `fn` into all of its callers and then delete `fn`. + // + // TODO(b/73081976): Should we avoid inlining these in some cases? + std::vector calls_to_inline; + for (auto* user : fn->users()) { + calls_to_inline.push_back(llvm::cast(user)); + } + for (auto* call_to_inline : calls_to_inline) { + llvm::InlineFunctionInfo inline_function_info; + CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); + } + fn->eraseFromParent(); +} - llvm::IRBuilder<> b(vector_exp_body); - llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - b.setFastMathFlags(fast_math_flags); +llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input, + int32 /*vector_width*/) { + return llvm_ir::EmitFastTanh(b, input); +} - VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32"); +llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, + int32 vector_width) { + VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32"); // This implements the same polynomial approximation as implemented in Eigen3. @@ -107,7 +138,6 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); - llvm::Value* input = &*vector_exp_function->arg_begin(); llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); @@ -128,49 +158,24 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); + b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); llvm::Value* vector_constant_23 = - b.CreateVectorSplat(vector_width, b.getInt32(23)); + b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b.getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width); // fx is clamped so we don't have to worry about it being out of range for // i32. - llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type); - emm0 = b.CreateAdd(emm0, vector_constant_0x7f); - emm0 = b.CreateShl(emm0, vector_constant_23); - llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type()); - - llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); + llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type); + emm0 = b->CreateAdd(emm0, vector_constant_0x7f); + emm0 = b->CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type()); - b.CreateRet(result); - - DCHECK(!llvm::verifyFunction(*vector_exp_function)); - return vector_exp_function; + return vsl.Max(vsl.Mul(y, emm0_f32), input); } -llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_log_function = module->getFunction(function_name); - if (vector_log_function == nullptr) { - // If the function declaration is not present in the module, there can't be - // any calls to resolve. Don't emit the function in this case. - return nullptr; - } - - llvm::LLVMContext* context = &module->getContext(); - - llvm::BasicBlock* vector_log_body = - llvm::BasicBlock::Create(*context, "body", vector_log_function); - - llvm::IRBuilder<> b(vector_log_body); - llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - b.setFastMathFlags(fast_math_flags); - - llvm::Value* input = &*vector_log_function->arg_begin(); - VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32"); +llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, + int32 vector_width) { + VectorSupportLibrary vsl(F32, vector_width, b, "log_f32"); const llvm::APFloat half = GetIeeeF32(0.5); const llvm::APFloat one = GetIeeeF32(1.0); @@ -193,129 +198,107 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, // The smallest non denormalized float number. const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); + const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000); const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); // invalid_mask is set if x is negative or NaN (and therefore output // must be NaN). llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); - llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf); // Cut off denormalized stuff. - input = vsl.Max(min_norm_pos, input); + llvm::Value* tmp0 = vsl.Max(min_norm_pos, input); // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); + b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); llvm::Value* vector_constant_23 = - b.CreateVectorSplat(vector_width, b.getInt32(23)); + b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b.getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width); - llvm::Value* emm0 = - b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23); + llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), + vector_constant_23); // Keep only the fractional part. - input = vsl.FloatAnd(input, inv_mant_mask); - input = vsl.FloatOr(input, half); + tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask); + tmp0 = vsl.FloatOr(tmp0, half); - emm0 = b.CreateSub(emm0, vector_constant_0x7f); - llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type())); + emm0 = b->CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type())); // part2: // if( x < SQRTHF ) { // e -= 1; // x = x + x - 1.0; // } else { x = x - 1.0; } - llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); - llvm::Value* tmp = vsl.FloatAnd(input, mask); - input = vsl.Sub(input, one); + llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF); + llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask); + tmp0 = vsl.Sub(tmp0, one); e = vsl.Sub(e, vsl.FloatAnd(mask, one)); - input = vsl.Add(input, tmp); + tmp0 = vsl.Add(tmp0, tmp1); - llvm::Value* x2 = vsl.Mul(input, input); - llvm::Value* x3 = vsl.Mul(x2, input); + llvm::Value* x2 = vsl.Mul(tmp0, tmp0); + llvm::Value* x3 = vsl.Mul(x2, tmp0); llvm::Value *y, *y1, *y2; - y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); - y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); - y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); - y = vsl.MulAdd(y, input, cephes_log_p2); - y1 = vsl.MulAdd(y1, input, cephes_log_p5); - y2 = vsl.MulAdd(y2, input, cephes_log_p8); + y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1); + y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4); + y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7); + y = vsl.MulAdd(y, tmp0, cephes_log_p2); + y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5); + y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8); y = vsl.MulAdd(y, x3, y1); y = vsl.MulAdd(y, x3, y2); y = vsl.Mul(y, x3); y1 = vsl.Mul(cephes_log_q1, e); - tmp = vsl.Mul(half, x2); + llvm::Value* tmp2 = vsl.Mul(half, x2); y = vsl.Add(y, y1); - input = vsl.Sub(input, tmp); + tmp0 = vsl.Sub(tmp0, tmp2); y2 = vsl.Mul(cephes_log_q2, e); - input = vsl.Add(input, y); - input = vsl.Add(input, y2); + tmp0 = vsl.Add(tmp0, y); + tmp0 = vsl.Add(tmp0, y2); - // Negative arg will be NAN, 0 will be -INF. - llvm::Value* or_lhs = - vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); - llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); - llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); + // Contains +/-inf where +/-inf is the correct answer, otherwise 0. + llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf), + vsl.FloatAnd(is_pos_inf_mask, pos_inf)); - b.CreateRet(result); + // Contains a finite result or nan. This is the correct answer only if both + // result_minus_inf and result_pos_inf are both 0. + // + // (This implementation works because 0xffffffff is a nan.) + llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask); - DCHECK(!llvm::verifyFunction(*vector_log_function)); - return vector_log_function; + // Combine the above into a final result. + return vsl.FloatOr(result_inf, + vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask), + result_finite_or_nan)); } } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { - auto* tanh_v4f32 = - EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* tanh_v8f32 = - EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - auto* exp_v4f32 = - EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* exp_v8f32 = - EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - auto* log_v4f32 = - EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* log_v8f32 = - EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - // Gather all the call sites, force inline them and then delete the vector - // function bodies. - // - // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? - - std::vector calls_to_inline; - for (auto* function : - {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { - if (function != nullptr) { - for (auto* user : function->users()) { - calls_to_inline.push_back(llvm::cast(user)); - } - } - } - - for (auto* call_to_inline : calls_to_inline) { - llvm::InlineFunctionInfo inline_function_info; - CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); - } - - for (auto* function : - {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { - if (function != nullptr) { - function->eraseFromParent(); - } - } + // Curry some params to RewriteCalls. + auto rewrite_calls = + std::bind(RewriteCalls, module, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, enable_fast_math); + + rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1); + rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1); + rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4); + rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8); + + rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4); + rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8); + + rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4); + rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8); } } // namespace runtime diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index a6f4273a5a70aab0bc88383283d2a55b1ecb1681..ffbd0d68ce9b5677d96761a5b10caed8335ef56a 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -39,7 +39,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm_ir::ForLoopNest loop_nest(loop_name, b_); const int64 num_dims = shape_.dimensions_size(); - llvm_ir::IrArray::Index array_index(index_type, num_dims); + std::vector array_multi_index(num_dims); // Add loops from outer-most to inner-most dimensions. for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { @@ -54,14 +54,14 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::unique_ptr loop = loop_nest.AddLoop( /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), /*suffix=*/absl::StrFormat("dim.%d", dimension)); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } } // Point IR builder at inner loop BB. @@ -71,6 +71,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); CHECK(exit_bb_ != nullptr); + llvm_ir::IrArray::Index array_index(array_multi_index, shape_, index_type); return {array_index}; } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index 2d9492eacfea34bec3b0f1115e171a5328b7cdc3..6f72ddadf94d4c5b9add2ee66e0f4ac9a8ae9099 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -69,8 +69,13 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); + CHECK_NE(function_ptr, nullptr); + CHECK_NE(partitions, nullptr); const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); + CHECK_NE(run_options, nullptr); + CHECK_NE(run_options->intra_op_thread_pool(), nullptr); + ComputeFunctionType function = reinterpret_cast(function_ptr); // Compute partition stride in 'partitions' array. diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 722aa3120ef4d8c957873ac58c361f19632dde1f..70a6d0af02c0c2db7208db561cf29e35a74707b2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include -#include #include -#include #include +#include #include -#include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/dynamic_annotations.h" @@ -28,80 +26,15 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace { -using tensorflow::int16; using tensorflow::int32; using tensorflow::int64; -using tensorflow::int8; -using tensorflow::uint16; -using tensorflow::uint32; -using tensorflow::uint64; -using tensorflow::uint8; - -template -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements); -} - -// We would like a total order of floating point numbers so that the -// sort has a predictable behavior in the presence of NaNs. Rather -// than using floating point comparison, we use the following trick: -// If f is a float, and -// x = bit_cast(f); -// y = x < 0 ? 0x7FFFFFFF - x : x; -// then y is ordered as an int32 such that finite values have the -// obvious order, -0 is ordered before 0, and -NaN and NaN appear at -// the beginning and end of the ordering. -template -CastType Convert(KeyType value) { - CastType casted_value; - memcpy(&casted_value, &value, sizeof(CastType)); - if (casted_value < 0) { - return static_cast(std::numeric_limits::max()) - - casted_value; - } - return casted_value; -} - -template -bool LessThan(KeyType lhs, KeyType rhs) { - return Convert(lhs) < - Convert(rhs); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, - int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), - Eigen::half_impl::half_to_float(rhs.first)); - }); -} +} // namespace -template -void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, - int32* values_primitive_type_size_in_bytes) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( + int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, int64* prof_counters, + void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); @@ -121,8 +54,9 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 num_iteration_elements = a * c; int64 sort_dimension_offset = c; - std::unique_ptr[]> row_to_sort( - new std::pair[sort_dimension_elements]); + std::unique_ptr indices(new int64[sort_dimension_elements]); + std::unique_ptr comparison_values(new char*[2 * values_count]); + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { @@ -135,24 +69,33 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - // TODO(b/26783907): We could define a custom iterator class that references - // all arrays. Then we could avoid the intermediate copy. However this - // would become more complicated, and it is not clear if the benefit is high - // enough. - for (int64 i = 0; i < sort_dimension_elements; ++i) { - row_to_sort[i] = - std::make_pair(keys[base_offset + i * sort_dimension_offset], i); - } - KeyValueSort(row_to_sort.get(), sort_dimension_elements); - for (int64 i = 0; i < sort_dimension_elements; ++i) { - keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + auto compare_function = [&](int64 a, int64 b) -> bool { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + for (int32 i = 0; i < values_count; ++i) { + comparison_values[i * 2] = values[i] + memory_index_lhs; + comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; + } + char result = 0; // Overwritten by less_than. + less_than(&result, run_options, comparison_values.get(), nullptr, + prof_counters); + return result != 0u; + }; + if (is_stable) { + std::stable_sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } else { + std::sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); } - // Reorder the values according to the order defined by the keys. + // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { for (int64 i = 0; i < sort_dimension_elements; ++i) { int64 memory_index = - (base_offset + row_to_sort[i].second * sort_dimension_offset) * + (base_offset + indices[i] * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; reordered_values[i] = @@ -168,88 +111,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, } } } -} // namespace - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( - int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( - uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( - int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( - uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( - int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( - uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( - float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( - int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( - uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( - double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 7821099386969e855ea1737cf53ef49c15c6e93b..50c2911c3bd392b6df12717c34d250ce86ad26e0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -21,76 +21,26 @@ limitations under the License. extern "C" { -// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' -// dimension of 'keys' is sorted into ascending order. If 'values_count' is <= -// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr. -// If 'values_count' > 0, they contain exactly 'values_count' many elements. -// Each element of 'values' also represents a 3-dimensional shape with -// dimensions [a, b, c], and the size of the primitive type of the i-th shape -// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in -// each 'values' shape are reordered in such a way that if the element at index -// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values' -// shape is also moved to index 'j' (which means that the same elements -// correspond to each other as before). -extern void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, +// Each entry in 'values' represents a 3-dimensional shape with dimensions +// [a, b, c]. The 'b' dimension of each shape is sorted into ascending order +// according to the results of comparisons using the provided 'less_than' +// function. 'values_count' must be > 0 and specifies the number of entries in +// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive +// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' +// bytes. 'is_stable' specifies whether the sorting should be stable. +// 'run_options' and 'prof_counters' are passed through to the less-than +// function, which expects the following arguments: +// - pointer to the return value buffer (char*) +// - xla::ExecutableRunOptions = 'run_options' (char*) +// - pointers to the parameter buffers (char**) +// - pointers to the buffer tables = nullptr for thread local functions (char**) +// - profile counters = 'prof_counters' (int64*) +extern void __xla_cpu_runtime_KeyValueSort( + tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS8( - tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU8( - tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS16( - tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU16( - tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS32( - tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU32( - tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF32( - float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS64( - tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU64( - tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF64( - double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); + tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, tensorflow::int64* prof_counters, + void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 12ab6360c560d31fb70c416b8519006ea8675d41..f5d3d840e23cfe26f269e7569ce5e6f7a10f7db6 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -91,14 +91,14 @@ SimpleOrcJIT::InferTargetMachineForJIT( return target_machine; } -SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, - llvm::CodeGenOpt::Level opt_level, - bool optimize_for_size, bool enable_fast_math, - bool disable_expensive_passes, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook) +SimpleOrcJIT::SimpleOrcJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool enable_fast_math, bool disable_expensive_passes, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook) : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), - disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( execution_session_, @@ -128,12 +128,13 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, [this](VModuleKeyT, const llvm::object::ObjectFile& object) { this->NotifyObjectFreed(object); }), - compile_layer_(object_layer_, - CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, optimize_for_size, - enable_fast_math, disable_expensive_passes, - std::move(pre_optimization_hook), - std::move(post_optimization_hook))), + compile_layer_( + object_layer_, + CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size, + enable_fast_math, disable_expensive_passes, + std::move(pre_optimization_hook), + std::move(post_optimization_hook), + std::move(post_codegen_hook))), gdb_jit_event_listener_( llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() @@ -240,18 +241,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); @@ -348,6 +338,11 @@ bool RegisterKnownJITSymbols() { reinterpret_cast(memset_pattern16)); #endif +#ifdef MEMORY_SANITIZER + registry->Register("__msan_unpoison", + reinterpret_cast(__msan_unpoison)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 3307c2f93d796bbdcd49af7f68e9f6c388e402ca..075a018987d70feccc56bc8cc376791ec66ea0c9 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" -#include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -51,29 +50,20 @@ class SimpleOrcJIT { using VModuleKeyT = llvm::orc::VModuleKey; // Create a new JIT, targeting the host architecture. - // The |target_options| parameter allows customization of certain code - // generation properties of the TargetMachine (whether or not float point math - // can be reassociated, etc.). - // The |opt_level| parameter controls the optimization level of the code - // generator. - // The |optimize_for_size| parameter specifies that the code generator should - // optimize to reduce code size, potentially at the cost of performance. - // The |disable_expensive_passes| parameter will disable certain optimization - // passes - // The |pre_optimization_hook| is invoked on the module before any IR - // level optimizations are applied. - // The |post_optimization_hook| is invoked on the module after all IR - // level optimizations are applied. - SimpleOrcJIT(const llvm::TargetOptions& target_options, - llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, - bool enable_fast_math, bool disable_expensive_passes, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook); - - // Data layout this JIT was created with. + // + // {pre,post}_optimization_hook is invoked on the module before/after all + // LLVM IR-level optimizations. post_codegen_hook is invoked after + // compiling to machine code. + SimpleOrcJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool enable_fast_math, bool disable_expensive_passes, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook); + const llvm::DataLayout& data_layout() const { return data_layout_; } - // Target triple (host) this JIT was created with. const llvm::Triple& target_triple() const { return target_machine_->getTargetTriple(); } @@ -107,7 +97,6 @@ class SimpleOrcJIT { std::vector module_keys_; std::unique_ptr target_machine_; - const Disassembler disassembler_; const llvm::DataLayout data_layout_; llvm::orc::ExecutionSession execution_session_; std::shared_ptr symbol_resolver_; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 0b4ac9dc29f88a19d967b7f04ffe42879711b54e..8b7f843582b697058fe328fe69990122d868ada4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -16,7 +16,6 @@ limitations under the License. // Tests that we call into Eigen for dot operations as needed. #include -#include #include #include "absl/strings/str_cat.h" diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index e30f95311fce229f9c559d3bb40142151e8bf3e3..f4da6856940f5f810d2d724c2f0607e780b06bf2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -56,8 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest { TEST_F(CpuExternalConstantsTest, Basic) { TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( -CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 -CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16 +CHECK-NOT: @constant_global_0 = external unnamed_addr constant [1024 x [1024 x float]], align 16 +CHECK: @0 = private unnamed_addr constant [4194304 x i8] {{.*}}, align 16 )"); } @@ -65,8 +65,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 -CHECK: @0 = private constant [64 x i8] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external unnamed_addr constant [16 x float], align 8 +CHECK: @0 = private unnamed_addr constant [64 x i8] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 04a81dfd35f459ff1fdb3181dc8fc65c62a37d4f..a72ebe2beea9be59f10e45dc8b296690d47aaa3b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -39,6 +39,13 @@ class CpuFusionTest : public HloTestBase { CpuFusionTest() {} ErrorSpec error_spec_{0.0001, 1e-5}; + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + return debug_options; + } }; TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { @@ -267,12 +274,9 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { EXPECT_EQ(4, fusion1->fused_instruction_count()); EXPECT_EQ(4, fusion2->fused_instruction_count()); - // Each fusion instruction should have one parameter and the parameter should - // be the constant. - EXPECT_EQ(1, fusion1->operand_count()); - EXPECT_EQ(constant, fusion1->operand(0)); - EXPECT_EQ(1, fusion2->operand_count()); - EXPECT_EQ(constant, fusion2->operand(0)); + // The fusion has no parameters, everything is fused including constants. + EXPECT_EQ(0, fusion1->operand_count()); + EXPECT_EQ(0, fusion2->operand_count()); } TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { @@ -324,10 +328,9 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { auto fusion_inst = tuple->operand(0); // There should be three fused instructions: negate2, exp2, and the fused - // parameter. + // constant. EXPECT_EQ(3, fusion_inst->fused_instruction_count()); - EXPECT_EQ(1, fusion_inst->operand_count()); - EXPECT_EQ(constant, fusion_inst->operand(0)); + EXPECT_EQ(0, fusion_inst->operand_count()); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 3fb0e3cd91fd2088884a0792f882fd7de72f0584..9078b8fd1ff6cb0ddac89d5fcd13a9ccfae07763 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -59,8 +59,9 @@ class CpuUnaryIntrinsicTest string features{spec.features.data(), spec.features.size()}; if (!features.empty()) { - std::replace_if(features.begin(), features.end(), - [](char c) { return c != '_' && !isalnum(c); }, '_'); + std::replace_if( + features.begin(), features.end(), + [](char c) { return c != '_' && !absl::ascii_isalnum(c); }, '_'); } else { features = ""; } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc index 3934c03a04c978009282b3cd0d39bacf9b12a356..e07ac9edc89de85f36dfdbbaa29886bc44b4c4a9 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -26,10 +26,16 @@ TEST_F(CpuKeyValueSortTest, SortR1) { const string hlo_text = R"( HloModule KeyValueSort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY main { a = f32[10] parameter(0) - ROOT result = f32[10] sort(f32[10] a), dimensions={0} + ROOT result = f32[10] sort(f32[10] a), dimensions={0}, to_apply=compare } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 0584c0484f810a03ccccd522163f54535440ef8b..fc670201125c1c1a9182ddd9667b8d2bb76b7a03 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -56,8 +56,8 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [48 x i8] -CHECK-NOT: private constant [48 x i8] +CHECK: private unnamed_addr constant [48 x i8] +CHECK-NOT: private unnamed_addr constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -102,10 +102,10 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK-DAG: private constant [4 x i8] -CHECK-DAG: private constant [8 x i8] -CHECK-NOT: private constant [4 x i8] -CHECK-NOT: private constant [8 x i8] +CHECK-DAG: private unnamed_addr constant [4 x i8] +CHECK-DAG: private unnamed_addr constant [8 x i8] +CHECK-NOT: private unnamed_addr constant [4 x i8] +CHECK-NOT: private unnamed_addr constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index a7702c2aeeaff8a46a2c4f2785ccb873ea2c08e5..030bd41c2fc73eac41fe43c1acdf862d5dc97f98 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -75,8 +75,9 @@ TEST_F(CpuNoAliasTest, Concat) { // the buffers in the HLO module. We'll inspect these loads to ensure that // they have the expected alias information. llvm::Module ir_module("test", context); - llvm::Function* func = llvm::cast( - ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); + llvm::Function* func = llvm::dyn_cast( + ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context)) + .getCallee()); llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); llvm::IRBuilder<> b(bb); auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index aab7f0b393881642437f1891256bd138823a3b87..76727c481bc394b24581e46afdb157ba6041a019 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -38,7 +38,7 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [48 x i8] +CHECK: private unnamed_addr constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc index eb6c44b70ab34d0a294880b5de4fe0b3ba5e19e5..7668f364bad050a60f74db4e9054b4f9c6df04d2 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -938,24 +938,76 @@ void TiledSmallGemmEmitter::EmitTiledGemm( }); } +llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) { + llvm::Type* type = + llvm::cast(pointer_type)->getElementType(); + while (auto* array_type = llvm::dyn_cast(type)) { + type = array_type->getElementType(); + } + + return type->getPointerTo(); +} + +struct GemvBuffersWithCanonicalType { + llvm::Value* lhs_canonicalized; + llvm::Value* rhs_canonicalized; + llvm::Value* addend_canonicalized; + llvm::Value* result_canonicalized; +}; + +GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType( + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) { + // We characterize a GEMV operation via M and K, since N is implicitly 1. + // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented + // by the same GEMV that multiplies [5,6] with [1,6]. However, the + // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial + // sense -- the in memory representations are the same) since they're computed + // from the `xla::Shape`s. Since we want to be able to call the same + // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV + // inputs here into the same type. + GemvBuffersWithCanonicalType buffers_with_canonical_type; + llvm::Type* lhs_type = lhs->getType(); + llvm::Type* rhs_type = rhs->getType(); + llvm::Type* addend_type = addend ? addend->getType() : nullptr; + llvm::Type* result_type = result->getType(); + + buffers_with_canonical_type.lhs_canonicalized = + b->CreateBitCast(lhs, GetPointerToElementType(lhs_type)); + buffers_with_canonical_type.rhs_canonicalized = + b->CreateBitCast(rhs, GetPointerToElementType(rhs_type)); + buffers_with_canonical_type.addend_canonicalized = + addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type)) + : nullptr; + buffers_with_canonical_type.result_canonicalized = + b->CreateBitCast(result, GetPointerToElementType(result_type)); + + return buffers_with_canonical_type; +} + } // namespace void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size) { + const HloModuleConfig& module_config) { RowMajorMatrixVectorProductEmitter::Config config( /*scalar_type=*/scalar_type, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + KernelSupportLibrary::EmitAndCallOutlinedKernel( - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, addend, result, - [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result) { + module_config, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, result, b); emitter.Emit(); @@ -966,18 +1018,23 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size) { + const HloModuleConfig& module_config) { ColumnMajorMatrixVectorProductEmitter::Config config( /*scalar_type=*/scalar_type, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + KernelSupportLibrary::EmitAndCallOutlinedKernel( - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, addend, result, - [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result) { + module_config, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, result, b); emitter.Emit(); @@ -989,7 +1046,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, int64 min_vectorization_width, int64 tile_size_m, int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size) { + const HloModuleConfig& module_config) { TiledSmallGemmEmitter::Config config( /*scalar_type=*/scalar_type, TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, @@ -999,9 +1056,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); KernelSupportLibrary::EmitAndCallOutlinedKernel( - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, result, + module_config, b, config.GetCacheKey(), lhs, rhs, result, [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, /*rhs=*/rhs, diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h index 0a82326cc3704bce8c122261383249c60eda1f3a..77581a53cfb9bc6330a38f0029486a708d837d4f 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ #include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -29,15 +30,15 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, tensorflow::int64 tile_cols, tensorflow::int64 m, tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, - llvm::IRBuilder<>* b, bool enable_fast_math, - bool optimize_for_size); + llvm::IRBuilder<>* b, + const HloModuleConfig& module_config); void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, tensorflow::int64 tile_cols, tensorflow::int64 m, tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size); + const HloModuleConfig& module_config); void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, tensorflow::int64 k, tensorflow::int64 n, @@ -46,8 +47,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, tensorflow::int64 min_vectorization_width, tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b, bool enable_fast_math, - bool optimize_for_size); + llvm::IRBuilder<>* b, const HloModuleConfig& module_config); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 5690d2be2fe3e21c96b51a5226e0b29148217fd1..c444fd7d4aa88fa21b1aa2b2f058bd689b234b15 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -114,6 +114,9 @@ class VectorSupportLibrary { // raison d'etre) less cluttered. llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs)); + } llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index ed37099a5428075928ec98b134632867d58bbfe7..490e057fcbcae66e90873fd0009fc80af431d901 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/defuser.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" -#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" namespace xla { @@ -49,7 +48,6 @@ Despecializer::Despecializer() : pipeline_("despecializer") { pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); - pipeline_.AddPass(); pipeline_.AddPass(); } diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index 46dcc3a438cbdf3ff1b3c99fa15b35ee7a4e280e..b6afaa17aa24608189bd29bd0371bc95709a5aaf 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -28,8 +28,8 @@ namespace xla { // optimized for one specific platform on a different platform (undoing platform // specific passes) with matching numerics for comparison. // -// Current despecialization passes are Defuser, ImplicitBroadcastRemover, -// and BFloat16MixedPrecisionRemoval. +// Current despecialization passes are HloDescheduler, ControlDepRemover, +// Defuser and BFloat16MixedPrecisionRemoval. class Despecializer : public HloModulePass { public: Despecializer(); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 2132468b9067ad4d5644d6cf3908a488a20ced05..246f2af09b5539612ef0e75929833f532dfa4083 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -103,11 +103,20 @@ class DfsHloVisitorBase { virtual Status HandlePower(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandleSqrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleRsqrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; + virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; + virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 680dd256bb15bd3a9eaff7241174c1d2833002c6..79ce3f82e8c1fe91d590ea7c47fa219ce8e8a80f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -91,6 +91,12 @@ class DfsHloVisitorWithDefaultBase Status HandleFft(HloInstructionPtr fft) override { return DefaultAction(fft); } + Status HandleTriangularSolve(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCholesky(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } @@ -100,6 +106,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleReplicaId(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc index 825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062..bd638917ccf4398d478d8b465a2029c1f6e3cc02 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -73,15 +73,14 @@ ENTRY TestComputation { abs = f32[] abs(arg) add = f32[] add(arg, gte) broadcast = f32[42] broadcast(add), dimensions={} - slice = f32[0] slice(broadcast), slice={[1:2]} + slice = f32[1] slice(broadcast), slice={[1:2]} copy = f32[] copy(arg) - eq = pred[] equal-to(arg, gte) + eq = pred[] compare(arg, gte), direction=EQ neg = f32[] negate(arg) ROOT convert = f64[] convert(f32[] arg) })"; std::unique_ptr module = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) - .ConsumeValueOrDie(); + ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie(); ElementwiseTestVisitor visitor; TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 855424067d26d4968270e5f24b11f5a053b70a55..559b9c1f2c9f341293ca89adc61e3312fd9f313c 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -297,7 +297,12 @@ StatusOr DotDecomposer::Run(HloModule* module) { const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); // A dot it not canonical if there are more than one contracting // dimension. - if (dnums.lhs_contracting_dimensions_size() > 1) { + if (dnums.lhs_contracting_dimensions_size() != 1) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty() && + dnums.lhs_contracting_dimensions().empty()) { non_canonical_dots.push_back(instruction); continue; } diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc new file mode 100644 index 0000000000000000000000000000000000000000..06d045650297a1efa52742e3a06066376235de5e --- /dev/null +++ b/tensorflow/compiler/xla/service/dump.cc @@ -0,0 +1,407 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dump.h" +#include "absl/strings/ascii.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { + +namespace { + +using absl::StrCat; +using absl::StrFormat; +using absl::string_view; + +struct CanonicalDebugOptions { + explicit CanonicalDebugOptions(const DebugOptions& opts) + : dump_to(opts.xla_dump_to()), + dump_as_text(opts.xla_dump_hlo_as_text()), + dump_as_proto(opts.xla_dump_hlo_as_proto()), + dump_as_dot(opts.xla_dump_hlo_as_dot()), + dump_as_html(opts.xla_dump_hlo_as_html()), + dump_as_url(opts.xla_dump_hlo_as_url()), + dump_snapshots(opts.xla_dump_hlo_snapshots()) { + // This constructor examines the values in `opts` and turns on other flags + // based on what we think is the user's intent. To reduce confusion about + // what was a user-specified value versus an extrapolated value, within this + // function we treat this struct's members as write-only, and read only from + // `opts`. + + // If dump_to is empty, default to dumping to stdout. + if (opts.xla_dump_to().empty()) { + dump_to = "-"; + } + + // Did the user specifiy an explicit format for dumping? + bool output_format_specified = + opts.xla_dump_hlo_as_text() || opts.xla_dump_hlo_as_proto() || + opts.xla_dump_hlo_as_dot() || opts.xla_dump_hlo_as_html() || + opts.xla_dump_hlo_as_url() || opts.xla_dump_hlo_snapshots(); + + // If we haven't specified an output format, default to dumping as text. + if (!output_format_specified) { + dump_as_text = true; + } + + // If we specified a regular expression restricting which modules to dump, + // respect that. + // + // If we didn't specify which modules to dump but we passed some other flag + // which implies dumping modules, dump all modules. + // + // Otherwise, don't dump any HLO modules. + if (!opts.xla_dump_hlo_module_re().empty()) { + // RE2 object is not copyable, and we can't capture "by move", so we + // resort to this hack. + string pattern = opts.xla_dump_hlo_module_re(); + should_dump_module = [pattern](string_view module_name) { + return RE2::PartialMatch(string(module_name), pattern); + }; + } else if (!opts.xla_dump_hlo_pass_re().empty() || + !opts.xla_dump_to().empty() || output_format_specified) { + should_dump_module = [](string_view) { return true; }; + } else { + should_dump_module = [](string_view) { return false; }; + } + + // Initialize should_dump_pass. This one is easy: We only dump per-pass + // data if the user asked for it explicitly. + if (!opts.xla_dump_hlo_pass_re().empty()) { + string pattern = opts.xla_dump_hlo_pass_re(); + should_dump_pass = [pattern](string_view pass_name) { + return RE2::PartialMatch(string(pass_name), pattern); + }; + } else { + should_dump_pass = [](string_view) { return false; }; + } + + // Output dirs "sponge" and "test_undeclared_outputs_dir" (case-insensitive) + // have a special meaning: Dump into the directory specified by the + // environment variable TEST_UNDECLARED_OUTPUTS_DIR. + string dump_to_lower = absl::AsciiStrToLower(opts.xla_dump_to()); + if (dump_to_lower == "sponge" || + dump_to_lower == "test_undeclared_outputs_dir") { + const char* dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + if (dir != nullptr) { + dump_to = dir; + } else { + LOG(ERROR) << "--xla_dump_to=" << opts.xla_dump_to() + << ", but environment variable TEST_UNDECLARED_OUTPUTS_DIR " + "is not set, so cannot dump anywhere."; + should_dump_module = [](string_view) { return false; }; + should_dump_pass = [](string_view) { return false; }; + } + } + } + + bool dumping_to_stdout() const { return dump_to == "-"; } + + string dump_to; + std::function should_dump_module; + std::function should_dump_pass; + + // dump_ir isn't present here because this file is mostly concerned with + // dumping HLO. + bool dump_as_text; + bool dump_as_proto; + bool dump_as_dot; + bool dump_as_html; + bool dump_as_url; + bool dump_snapshots; +}; + +string FilenameFor(const HloModule& module, string_view suffix) { + return StrFormat("module_%04d.%s", module.unique_id(), suffix); +} + +void DumpToFileInDirImpl(string_view filename, string_view contents, + const CanonicalDebugOptions& opts) { + if (opts.dumping_to_stdout()) { + LOG(ERROR) << "Refusing to write " << filename + << " to stdout. Pass --xla_dump_to= to write to a file."; + return; + } + + const string& dir = opts.dump_to; + VLOG(1) << "Dumping " << filename << " to " << dir; + + tensorflow::Env* env = tensorflow::Env::Default(); + // Two threads can race to observe the absence of the dump directory and + // simultaneously try to create it, causing the "losing" thread to get a + // "directory already exists" error. We can work around this by checking + // again whether the dir exists. + if (!env->IsDirectory(dir).ok()) { + auto status = env->RecursivelyCreateDir(dir); + if (!status.ok() && !env->IsDirectory(dir).ok()) { + LOG(ERROR) << "Could not create directory " << dir + << " for dumping XLA debug data: " << status; + return; + } + } + + string file_path = + tensorflow::io::JoinPath(dir, SanitizeFileName(string(filename))); + auto status = tensorflow::WriteStringToFile(env, file_path, contents); + if (!status.ok()) { + LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": " + << status; + } +} + +void DumpToFileInDirOrStdoutImpl(string_view filename, string_view contents, + const CanonicalDebugOptions& opts) { + // Dump to stdout if that's called for. + if (opts.dumping_to_stdout()) { + std::cout << "*** Begin " << filename << " ***\n" + << contents << "\n*** End " << filename << " ***" << std::endl; + return; + } + + // Otherwise, dump to a file. + DumpToFileInDirImpl(filename, contents, opts); +} + +void DumpHloModuleImpl(const HloModule& module, + const BufferAssignment* buffer_assn, + const HloExecutionProfile* profile, string_view suffix, + const CanonicalDebugOptions& opts) { + string filename = FilenameFor(module, suffix); + + if (opts.dump_as_text) { + DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), module.ToString(), + opts); + } + + if (opts.dump_as_proto) { + HloProto module_proto = + buffer_assn ? MakeHloProto(module, *buffer_assn) : MakeHloProto(module); + string pb; + if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) { + pb = "Failed to serialize HLO module proto."; + } + DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts); + } + + auto render_graph = [&](RenderedGraphFormat format) { + StatusOr rendered_graph = RenderGraph( + *module.entry_computation(), + /*label=*/filename, module.config().debug_options(), format, profile); + if (rendered_graph.ok()) { + return std::move(rendered_graph).ValueOrDie(); + } + return StrFormat("Error rendering graph: %s", + rendered_graph.status().ToString()); + }; + + if (opts.dump_as_dot) { + DumpToFileInDirImpl(StrFormat("%s.dot", filename), + render_graph(RenderedGraphFormat::kDot), opts); + } + + if (opts.dump_as_html) { + DumpToFileInDirImpl(StrFormat("%s.html", filename), + render_graph(RenderedGraphFormat::kHtml), opts); + } + + // Special case for rendering graphs as URLs. We'll dump them to a file + // because why not, but we always log them to stdout as well. + if (opts.dump_as_url) { + string url = render_graph(RenderedGraphFormat::kUrl); + std::cout << filename << " --> " << url << std::endl; + if (!opts.dumping_to_stdout()) { + DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts); + } + } +} + +static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + +// Maps a module's unique ID to a counter indicating how many times we've dumped +// this module during the compilation pipeline. This lets us keep the filenames +// ordered nicely. +// +// Entries added here leak forever; we have no way to GC them when a module +// dies. But we only add an entry if dumping is enabled for this module, and +// dumping a module leaks buffer space in stdout or bytes on disk *way* faster +// than this hashtable leaks memory. +static auto& module_id_to_step_number GUARDED_BY(mu) = + *new absl::flat_hash_map(); + +} // namespace + +void DumpToFileInDir(const HloModule& module, string_view suffix, + string_view contents) { + DumpToFileInDirImpl(FilenameFor(module, suffix), contents, + CanonicalDebugOptions(module.config().debug_options())); +} + +void DumpToFileInDirOrStdout(const HloModule& module, string_view suffix, + string_view contents) { + DumpToFileInDirOrStdoutImpl( + FilenameFor(module, suffix), contents, + CanonicalDebugOptions(module.config().debug_options())); +} + +void DumpHloModuleIfEnabled(const HloModule& module, string_view name) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (opts.should_dump_module(module.name())) { + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + name, opts); + } +} +void DumpHloModuleIfEnabled(const HloModule& module, + const BufferAssignment& buffer_assn, + string_view name) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (opts.should_dump_module(module.name())) { + DumpHloModuleImpl(module, &buffer_assn, /*profile=*/nullptr, name, opts); + } +} + +void DumpHloModuleIfEnabled(const HloModule& module, + const HloExecutionProfile& profile, + string_view name) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (opts.should_dump_module(module.name())) { + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, &profile, name, opts); + } +} + +bool DumpingEnabledForHloModule(string_view hlo_module_name, + const DebugOptions& opts) { + return CanonicalDebugOptions(opts).should_dump_module(hlo_module_name); +} + +bool DumpingToStdout(const DebugOptions& opts) { + return CanonicalDebugOptions(opts).dumping_to_stdout(); +} + +void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name, + string_view before_pass_name, + string_view after_pass_name, + const HloModule& module) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (!opts.should_dump_module(module.name())) { + return; + } + + if (!opts.should_dump_pass(before_pass_name) && + !opts.should_dump_pass(after_pass_name)) { + return; + } + + int64 step_number; + { + tensorflow::mutex_lock lock(mu); + step_number = module_id_to_step_number[module.unique_id()]++; + } + + string filename_suffix = + StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name, + after_pass_name, before_pass_name); + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + filename_suffix, opts); +} + +void DumpHloModuleDuringPassIfEnabled(string_view pass_name, + string_view step_name, + const HloModule& module) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (!opts.should_dump_module(module.name()) || + !opts.should_dump_pass(pass_name)) { + return; + } + + int64 step_number; + { + tensorflow::mutex_lock lock(mu); + step_number = module_id_to_step_number[module.unique_id()]++; + } + + string filename_suffix = + StrFormat("%04d.%s.%s", step_number, pass_name, step_name); + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + filename_suffix, opts); +} + +void DumpHloSnapshotIfEnabled(const HloModule& module, + const HloSnapshot& snapshot) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (!opts.should_dump_module(module.name()) || !opts.dump_snapshots) { + return; + } + int64 execution_count; + { + static auto& module_id_to_execution_count GUARDED_BY(mu) = + *new absl::flat_hash_map(); + tensorflow::mutex_lock lock(mu); + execution_count = module_id_to_execution_count[module.unique_id()]++; + } + string filename = + StrCat(FilenameFor(module, StrFormat("execution_%04d", execution_count)), + ".hlo_snapshot.pb"); + if (opts.dumping_to_stdout()) { + LOG(ERROR) << "Refusing to write HLO snapshot proto for " << filename + << " to stdout. Pass --xla_dump_to= to write to a file."; + return; + } + string pb; + if (!tensorflow::SerializeToStringDeterministic(snapshot, &pb)) { + LOG(ERROR) << "Failed to serialize HLO snapshot proto " << filename; + } + DumpToFileInDirImpl(filename, pb, opts); +} + +void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, + const DebugOptions& opts) { + CanonicalDebugOptions canonical_opts(opts); + string name = snapshot.hlo().hlo_module().name(); + if (!canonical_opts.should_dump_module(name) || + !canonical_opts.dump_snapshots) { + return; + } + + // We don't have a unique id for an HloSnapshot, so in this overload we just + // have to use its name. + int64 execution_count; + { + static auto& module_name_to_execution_count GUARDED_BY(mu) = + *new absl::flat_hash_map(); + tensorflow::mutex_lock lock(mu); + execution_count = module_name_to_execution_count[name]++; + } + string filename = StrFormat("module_%s.execution_%04d.hlo_snapshot.pb", name, + execution_count); + if (canonical_opts.dumping_to_stdout()) { + LOG(ERROR) << "Refusing to write HLO snapshot proto for " << filename + << " to stdout. Pass --xla_dump_to= to write to a file."; + return; + } + string pb; + if (!tensorflow::SerializeToStringDeterministic(snapshot, &pb)) { + LOG(ERROR) << "Failed to serialize HLO snapshot proto " << filename; + } + DumpToFileInDirImpl(filename, pb, canonical_opts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dump.h b/tensorflow/compiler/xla/service/dump.h new file mode 100644 index 0000000000000000000000000000000000000000..6edc9b28ddeb73a453bca4b96f1e34e3538c1e50 --- /dev/null +++ b/tensorflow/compiler/xla/service/dump.h @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DUMP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DUMP_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +// Consolidated utilities for logging information during compilation, usually +// based on the options specified in the DebugOptions proto. +// +// Most functions here take an HloModule and read the DebugOptions from the +// module's config. + +namespace xla { + +class BufferAssignment; +class HloExecutionProfile; +class HloSnapshot; + +// Writes the given string to a file in the xla_dump_to directory specified by +// module's DebugOptions. +// +// If module doesn't have an xla_dump_to directory, does nothing. +void DumpToFileInDir(const HloModule& module, absl::string_view file_suffix, + absl::string_view contents); + +// Like DumpToFileInDir, except if module doesn't have an xla_dump_to directory +// specified, or if that directory is equal to "-", writes to stdout instead. +void DumpToFileInDirOrStdout(const HloModule& module, + absl::string_view file_suffix, + absl::string_view contents); + +// Dumps the given HLO module if dumping is enabled for the module. Exactly +// where and in what formats it's dumped is determined by the module's config. +// +// If you pass an HloExecutionProfile, note that currently only DOT-based output +// formats (i.e. --xla_dump_as_{dot,html,url}) are able to incorporate it into +// their output. Other formats will just ignore the profile. +void DumpHloModuleIfEnabled(const HloModule& module, absl::string_view name); +void DumpHloModuleIfEnabled(const HloModule& module, + const BufferAssignment& buffer_assn, + absl::string_view name); +void DumpHloModuleIfEnabled(const HloModule& module, + const HloExecutionProfile& profile, + absl::string_view name); + +// Dumps the given HLO module after running one HLO pass and before running +// another, if that's enabled. +void DumpHloModuleBetweenPassesIfEnabled(absl::string_view pipeline_name, + absl::string_view before_pass_name, + absl::string_view after_pass_name, + const HloModule& module); + +// Dumps the given HLO module during the given HLO pass, if that's enabled. +// +// "step" is a human-readable description of where we are in the middle of this +// pass. For example, "before-assigning-layouts". +void DumpHloModuleDuringPassIfEnabled(absl::string_view pass_name, + absl::string_view step, + const HloModule& module); + +// Dumps the given HloSnapshot to the module's xla_dump_dir, if this is enabled. +// +// Prefer the first overload below, as this will give filenames that are +// consistent with the other methods here. The second overload (which doesn't +// take an HloModule) is useful in the cases when you're dumping an HloSnapshot +// and simply don't have an HloModule. +void DumpHloSnapshotIfEnabled(const HloModule& module, + const HloSnapshot& snapshot); +void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, + const DebugOptions& opts); + +// Returns true if we should dump data for an HloModule. This is useful if you +// want to check if DumpToFileInDir{,OrStdout} will do anything before +// generating an expensive string. +bool DumpingEnabledForHloModule(absl::string_view hlo_module_name, + const DebugOptions& opts); +inline bool DumpingEnabledForHloModule(const HloModule& module) { + return DumpingEnabledForHloModule(module.name(), + module.config().debug_options()); +} + +// Returns true if DumpToFileInDirOrStdout and DumpHloModuleIfEnabled will write +// to stdout, rather than to a file on disk. +// +// This is useful if you want to do something different when writing to stdout. +// For example, maybe you have (almost-)duplicate data that you wouldn't mind +// writing to two files, but you don't want to print twice. +bool DumpingToStdout(const DebugOptions& opts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DUMP_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 2b158d7a6ec510ce4cbc56bddc5cca71ac4f14f4..5f7d8a761252c03d43acb4869da6b1f59c823357 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -18,19 +18,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { -namespace { -bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { - return window_dimension.size() == 1 && window_dimension.stride() == 1 && - window_dimension.padding_low() == 0 && - window_dimension.padding_high() == 0 && - window_dimension.window_dilation() == 1 && - window_dimension.base_dilation() == 1; -} -} // namespace - class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { public: explicit DynamicDimensionInferenceVisitor( @@ -53,6 +45,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleDot(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleTranspose(HloInstruction* hlo) override; Status HandleReshape(HloInstruction* hlo) override; @@ -77,6 +71,10 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; + + Status HandleSlice(HloInstruction* hlo) override; + private: using OperandDynamicDimensionFn = std::functionSetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, int64 operand_index, HloInstruction* dynamic_size) { int64 broadcast_dim = hlo->dimensions(dimension); - parent_->SetDynamicSize(hlo, index, broadcast_dim, dynamic_size); + parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size); return Status::OK(); }); } @@ -349,7 +357,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow( const WindowDimension& window_dimension = reduce_window->window().dimensions(dimension); - if (!IsTrivialWindowDimension(window_dimension)) { + if (!window_util::IsTrivialWindowDimension(window_dimension)) { return Unimplemented( "Dynamic Spatial reduce window is not supported: %s", reduce_window->ToString()); @@ -370,7 +378,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( const WindowDimension& window_dimension = select_and_scatter->window().dimensions(dimension); - if (!IsTrivialWindowDimension(window_dimension)) { + if (!window_util::IsTrivialWindowDimension(window_dimension)) { return Unimplemented( "Dynamic Spatial select and scatter is not supported: %s", select_and_scatter->ToString()); @@ -383,6 +391,140 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( }); } +Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension, + int64 /*operand_index*/, HloInstruction* dynamic_size) { + if (hlo->slice_starts(dimension) != 0 || + hlo->slice_strides(dimension) != 1 || + hlo->slice_limits(dimension) != + operand->shape().dimensions(dimension)) { + return Unimplemented( + "Dynamic dimension propagation on Slice where it doesn't slice " + "out an entire dimension is not supported %s", + hlo->ToString()); + } + + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { + // While loop is handled by passing dynamic size hlos as parameters into the + // hlo while loop. This is done by replacing the original while with a new + // one. + // + // Before: + // + // op1 = ... + // op2 = ... + // op1_x = ... // dynamic dimension size of op1 + // while = while(op1, op2) + // + // + // After: + // + // op1 = ... + // op2 = ... + // op1_x = ... // dynamic dimension size of op1 + // while = while(op1, op2, op1_x) + // + // In the above graph, op_x is the bound of the dynamic dimension size of op1 + // and is wired into the while loop as new parameter. + // + // TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic + // bound can be propagated through native xla values instead of relying on + // additional parameter. + + // dynamic_size_to_operand_id_index_map keeps track of dynamic size operations + // to their operand ids in the new while loop. + absl::flat_hash_map + dynamic_size_to_operand_id_index_map; + + // operands_to_add collects dynamic sizes that need to be added to the while + // loop as parameters. Note that a dynamic size is ignored if it is already + // part of the parameter. i.e.: + // + // We don't do: + // + // op1 = ... + // op2 = ... + // op_x = ... // dynamic dimension size of both op1 and op2 + // while = while(op1, op2, op_x, op_x) // 4 parameters + // + // But we do: + // + // op1 = ... + // op2 = ... + // op_x = ... // dynamic dimension size of both op1 and op2 + // while = while(op1, op2, op_x) + // + // An alternative is to do this in a while loop CSE pass. + // + std::vector operands_to_add; + int64 operand_count = hlo->shape().tuple_shapes_size(); + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex, int64, int64, + HloInstruction* dynamic_size) { + const HloInstruction* tuple_operand = hlo->operand(0); + for (int64 i = 0; i < tuple_operand->operand_count(); ++i) { + if (dynamic_size == tuple_operand->operand(i)) { + dynamic_size_to_operand_id_index_map[dynamic_size] = i; + return Status::OK(); + } + } + auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size); + if (iter == dynamic_size_to_operand_id_index_map.end()) { + operands_to_add.push_back(dynamic_size); + dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++; + } + return Status::OK(); + })); + + if (!operands_to_add.empty()) { + // Only replace the while loop if there are new parameters to add. + HloInstruction* old_tuple_operand = hlo->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult result, + WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add)); + // WhileUtil creates a new while hlo and tuple. Update the dynamic size + // mapping for the newly created tuple. + HloInstruction* new_tuple_operand = + result.new_while_instr->mutable_operand(0); + parent_->CopyMapping(/*from=*/old_tuple_operand, /*to=*/new_tuple_operand); + hlo = result.new_while_instr; + } + + // We have replaced the while loop, now set the dynamic dimensions for the + // newly created while loop so that the hlos that consumes the while loop can + // see the dynamic dimensions. Also sets the dynamic parameter binding for + // running inference in the while loop. + DynamicParameterBinding binding_for_while; + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + DynamicParameterBinding::DynamicParameter dynamic_parameter{ + operand_index, + {dynamic_size_to_operand_id_index_map[dynamic_size]}}; + DynamicParameterBinding::DynamicDimension dynamic_dimension{ + operand_index, index, dimension}; + TF_RETURN_IF_ERROR( + binding_for_while.Bind(dynamic_parameter, dynamic_dimension)); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + })); + + // Run inference in while body and condition. + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_body(), binding_for_while, parent_)); + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_condition(), binding_for_while, parent_)); + + return Status::OK(); +} + Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { return param_bindings_.ForEachBinding( [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, @@ -430,6 +572,20 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( return Status::OK(); } +void DynamicDimensionInference::CopyMapping(HloInstruction* from, + HloInstruction* to) { + auto iter = per_hlo_dynamic_dimensions_.find(from); + if (iter != per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = + GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim, + dynamic_size); + } + } +} + /* static */ StatusOr DynamicDimensionInference::Run( HloModule* module) { @@ -439,6 +595,20 @@ StatusOr DynamicDimensionInference::Run( return inference; } +string DynamicDimensionInference::ToString() const { + std::vector pieces; + pieces.push_back("DynamicDimensionInference: "); + for (const auto& mapping : dynamic_mapping_) { + const DynamicDimension& dynamic_dimension = mapping.first; + pieces.push_back(absl::StrFormat( + " -- instruction %s at %s has dim %lld as dynamic" + " dimension, which is represented by instruction %s", + dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(), + dynamic_dimension.dim, mapping.second->ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + DynamicDimensionInference::DynamicDimensionInference(HloModule* module) : module_(module) {} diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 164d15bf111a92e3da957f609b54ee0662ef18b1..d0f2998328f3028ccbd5b33690a514371a03b5a1 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -88,6 +88,11 @@ class DynamicDimensionInference { iter.first->second.emplace(DynamicDimension{inst, index, dim}); } + // Copies the internal mapping from instruction `from` to instruction `to`. + // This is useful when an instruction is replaced by the other during the + // inferencing process. + void CopyMapping(HloInstruction* from, HloInstruction* to); + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in // module_. Status AnalyzeDynamicDimensions(); @@ -101,6 +106,8 @@ class DynamicDimensionInference { using DynamicMapping = absl::flat_hash_map; DynamicMapping dynamic_mapping_; + // A convenient mapping from an hlo to the set of dynamic dimensions that it + // holds. using PerHloDynamicDimensions = absl::flat_hash_map>; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index b42e67b4bbcf731d89dd8af9e46b405235a92d8a..36456e5649ad53b3996d41d8659feaf6205b48b3 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -43,7 +43,6 @@ class DynamicDimensionInferenceTest : public HloTestBase { } Status RunInference() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, DynamicDimensionInference::Run(module_.get())); @@ -68,8 +67,8 @@ class DynamicDimensionInferenceTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {}), "lhs")); auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {}), "rhs")); - embedded_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs)); + embedded_builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe)); return module_->AddEmbeddedComputation(embedded_builder.Build()); } @@ -88,6 +87,8 @@ TEST_F(DynamicDimensionInferenceTest, ParamTest) { HloInstruction::CreateParameter(1, scalar_shape_, "param")); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + // Set up dynamic parameter binding. TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( DynamicParameterBinding::DynamicParameter{1, {}}, @@ -112,6 +113,7 @@ TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { DynamicParameterBinding::DynamicParameter{0, {1}}, DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), op::GetTupleElement(param, 1)); @@ -137,6 +139,7 @@ TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { DynamicParameterBinding::DynamicParameter{0, {1}}, DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), op::GetTupleElement(param, 1)); @@ -167,6 +170,7 @@ TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); } @@ -197,6 +201,7 @@ TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); } @@ -228,6 +233,7 @@ TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 2})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); @@ -271,6 +277,7 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) { DynamicParameterBinding::DynamicParameter{2, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); @@ -319,6 +326,7 @@ TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { DynamicParameterBinding::DynamicParameter{2, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); @@ -356,6 +364,7 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) { DynamicParameterBinding::DynamicParameter{3, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 2})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); @@ -386,6 +395,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 3})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); @@ -415,6 +425,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); Status status = RunInference(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); } @@ -439,12 +450,103 @@ TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); } +TEST_F(DynamicDimensionInferenceTest, WhileTest) { + // Test the ability to trace into while loops. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); + + // Body: + // + // Param + // | | + // GTE1 GTE2 + // | | + // ADD + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto gte_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, body_param, 0)); + auto gte_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, body_param, 1)); + auto add = body_builder.AddInstruction( + HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1)); + body_builder.AddInstruction(HloInstruction::CreateTuple({add, add})); + + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + // Entry: + // + // Param + // | + // While + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, tuple_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + // Test that dynamic dimension inference does the right thing. A lambda is + // used here since we want to test twice by running inference again + // (idempotency). + auto test_dynamic_dimension = [&]() { + HloInstruction* while_hlo = nullptr; + // The while hlo has been replaced, find the new one. + for (HloInstruction* inst : module_->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kWhile) { + while_hlo = inst; + } + } + ASSERT_NE(while_hlo, nullptr); + // The original while shape has 2 parameters. With dynamic size passed in + // as an extra parameter, the tuple should have 3 elements. + EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3); + HloInstruction* add = nullptr; + for (HloInstruction* inst : while_hlo->while_body()->instructions()) { + if (inst->opcode() == HloOpcode::kAdd) { + add = inst; + } + } + EXPECT_NE(add, nullptr); + EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param); + }; + + TF_ASSERT_OK(RunInference()); + test_dynamic_dimension(); + TF_ASSERT_OK(RunInference()); + test_dynamic_dimension(); +} + TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { // Test the ability to trace reduce window batch dimensions. auto builder = HloComputation::Builder(TestName()); @@ -490,6 +592,7 @@ TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); } @@ -543,9 +646,32 @@ TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{2, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param); } +TEST_F(DynamicDimensionInferenceTest, SliceTest) { + auto builder = HloComputation::Builder(TestName()); + + auto data_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto* slice = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0}, + /*limit_indices=*/{5, 7}, /*strides=*/{1, 1})); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 4db280f817141bd52e3a5b9564600a618f81aeac..9e1efa44299609803f168832360eaa2f3e579fe5 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -67,12 +67,18 @@ StatusOr ChooseIdentityValue(HloInstruction* inst) { case HloOpcode::kPad: { return inst->mutable_operand(1); } + + case HloOpcode::kSelectAndScatter: { + return inst->mutable_operand(2); + } case HloOpcode::kParameter: case HloOpcode::kGetDimensionSize: case HloOpcode::kReshape: case HloOpcode::kTuple: case HloOpcode::kAllReduce: case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: return nullptr; default: return UnimplementedStrCat("Unimplimented padding for instruction: ", @@ -80,6 +86,22 @@ StatusOr ChooseIdentityValue(HloInstruction* inst) { } } +bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num, + int64 dimension) { + if ((inst->opcode() == HloOpcode::kReduceWindow || + inst->opcode() == HloOpcode::kSelectAndScatter) && + operand_num == 0 && inst->window().dimensions(dimension).size() == 1) { + return true; + } + + if (operand_num == 0 && inst->opcode() == HloOpcode::kConvolution && + inst->convolution_dimension_numbers().input_batch_dimension() == + dimension) { + return true; + } + return false; +} + } // namespace StatusOr DynamicPadder::Run(HloModule* module) { @@ -105,6 +127,11 @@ StatusOr DynamicPadder::Run(HloModule* module) { } VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @" << dim; + + if (ShouldSkipPadOnOperand(inst, operand_num, dim)) { + continue; + } + TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, ChooseIdentityValue(inst)); if (identity_value == nullptr) { @@ -133,9 +160,10 @@ StatusOr DynamicPadder::Run(HloModule* module) { HloInstruction* broadcasted_effective_size = computation->AddInstruction(HloInstruction::CreateBroadcast( mask_shape, dynamic_size, {})); - HloInstruction* pred = computation->AddInstruction( - HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, - broadcasted_effective_size)); + HloInstruction* pred = + computation->AddInstruction(HloInstruction::CreateCompare( + pred_shape, iota, broadcasted_effective_size, + ComparisonDirection::kLt)); HloInstruction* broadcasted_identity_value = computation->AddInstruction(HloInstruction::CreateBroadcast( diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 55a11286e4596d87c330315322cae704fc5cd707..2963deaa317e1b04346046bde35a6025abe0924e 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -41,10 +43,7 @@ class DynamicPadderTest : public HloTestBase { DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } StatusOr RunPadder() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before padder"); - DynamicPadder padder; - return padder.Run(module_.get()); } @@ -133,19 +132,84 @@ TEST_F(DynamicPadderTest, ConvolutionTest) { module_->AddEntryComputation(builder.Build()); + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(conv->operand(0)); +} + +TEST_F(DynamicPadderTest, ConvolutionNoPad) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding for non-contracting dimension. TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( DynamicParameterBinding::DynamicParameter{2, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); - // Set up binding for contracting dimensions. + TF_ASSERT_OK(RunPadder().status()); + + EXPECT_THAT(conv->operand(0), op::Parameter()); +} + +TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {4, 5}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5}); + + auto input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "input")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + TF_ASSERT_OK_AND_ASSIGN(Window window, ParseWindow("size=2x1 pad=0_0x0_0")); + auto output = builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_shape, input, init, window, GetScalarAddComputation())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( - DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); TF_ASSERT_OK(RunPadder().status()); - ExpectPadded(conv->operand(0)); + EXPECT_THAT(output->operand(0), op::Parameter()); } } // namespace diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc index 5549cccfa86f9445ae0aa68748fde2b131ee5a5e..7f0ae692f7414dbdcccda8b287c9059bcf920df1 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -29,7 +29,8 @@ Status DynamicParameterBinding::Bind( } absl::optional -DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { +DynamicParameterBinding::GetBinding( + const DynamicDimension& dynamic_dimension) const { auto param_iter = bindings_.find(dynamic_dimension); if (param_iter == bindings_.end()) { return absl::nullopt; @@ -111,7 +112,8 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { return ForEachBinding([&](const DynamicParameter& dynamic_parameter, const DynamicDimension& dynamic_dimension) -> Status { - TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 && + dynamic_parameter.parameter_num < entry->num_parameters()); TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); TF_RET_CHECK(ShapeUtil::IndexIsValid( entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h index dd474d8eed1b2c30ddb8f624a864198c74eacaba..57af2c43d3c65f7340e6a9f04e5abbf052ebceea 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -89,7 +89,7 @@ class DynamicParameterBinding { // // Returns nullopt if the binding is not set. absl::optional GetBinding( - const DynamicDimension& dynamic_dimension); + const DynamicDimension& dynamic_dimension) const; using BindingFn = std::functionToProto(); + TF_ASSERT_OK_AND_ASSIGN(*binding, + DynamicParameterBinding::CreateFromProto(proto)); + } +}; TEST_F(DynamicParameterBindingTest, SimpleBinding) { // 'b' is a dynamic shape; 'a' represents the real size of b's first @@ -56,15 +64,20 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, - /*parameter_index=*/{}, - /*dimension=*/0}); - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBinding) { @@ -89,16 +102,21 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { @@ -127,26 +145,35 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 1})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - - absl::optional param2 = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - EXPECT_TRUE(param2); - EXPECT_EQ(param2->parameter_num, 0); - EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); - - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + + test(binding); + + SerializeAndDeserialize(&binding); + + // Test the binding again after deserialization. + test(binding); } } // namespace diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 727e0bfa52d45b6f8c67d7d04613e4865f18a53c..53513fa52268dc00de75f644ac3ed77648238337 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -49,7 +49,6 @@ limitations under the License. namespace xla { using absl::StrCat; -using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; @@ -208,10 +207,8 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) { - if (op->opcode() == HloOpcode::kCopy) { - return operand_value; - } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || - op->operand(0)->shape().element_type() == PRED) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -423,6 +420,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kTanh: return EmitTanh(op->shape().element_type(), operand_value); + case HloOpcode::kSqrt: + return EmitSqrt(op->shape().element_type(), operand_value); + case HloOpcode::kRsqrt: + return EmitRsqrt(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -436,18 +437,18 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( {operand_value}, {operand_value->getType()}, b_); case HloOpcode::kRoundNearestAfz: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, - {operand_value}, - {operand_value->getType()}, b_); + return EmitRoundNearestAfz(op->shape().element_type(), operand_value); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = FCmpOEQ(operand_value, zero); - auto olt = FCmpOLT(operand_value, zero); - return Select(oeq, zero, - Select(olt, llvm::ConstantFP::get(type, -1.0), - llvm::ConstantFP::get(type, 1.0))); + auto ne0_i1 = FCmpONE(operand_value, zero); + auto ne0_float = UIToFP(ne0_i1, type); + llvm::Value* result = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {ne0_float, operand_value}, + {operand_value->getType()}, b_); + auto is_nan = FCmpUNO(operand_value, operand_value); + result = Select(is_nan, operand_value, result); + return result; } case HloOpcode::kIsFinite: { // abs(x) o!= inf, this works because the comparison returns false if @@ -653,6 +654,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), FDiv(EmitExtractImag(operand_value), cplx_abs))); } + case HloOpcode::kSqrt: { + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + auto c = llvm::ConstantFP::get(a->getType(), 0.5); + auto d = llvm::ConstantFP::get(b->getType(), 0.0); + return EmitComplexPower(op, a, b, c, d); + } + case HloOpcode::kRsqrt: { + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + auto c = llvm::ConstantFP::get(a->getType(), -0.5); + auto d = llvm::ConstantFP::get(b->getType(), 0.0); + return EmitComplexPower(op, a, b, c, d); + } case HloOpcode::kNegate: return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), FNeg(EmitExtractImag(operand_value))); @@ -703,25 +718,28 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // We use ordered comparisons for everything except kNe, where we use an // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. - case HloOpcode::kEq: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, - rhs_value, b_); - case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, - rhs_value, b_); - case HloOpcode::kLt: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, - rhs_value, b_); - case HloOpcode::kGt: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, - rhs_value, b_); - case HloOpcode::kLe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, - rhs_value, b_); - case HloOpcode::kGe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, - rhs_value, b_); - + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, + rhs_value, b_); + case ComparisonDirection::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, + rhs_value, b_); + case ComparisonDirection::kGt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kGe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, + rhs_value, b_); + } + } case HloOpcode::kMaximum: return EmitFloatMax(lhs_value, rhs_value); case HloOpcode::kMinimum: @@ -736,6 +754,43 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } } +// (a+bi)^(c+di) = +// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), +// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) +StatusOr ElementalIrEmitter::EmitComplexPower( + const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c, + llvm::Value* d) { + PrimitiveType component_type = + primitive_util::ComplexComponentType(op->shape().element_type()); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); + auto zero = llvm::ConstantFP::get(a->getType(), 0); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto one = llvm::ConstantFP::get(a->getType(), 1); + auto half_c = FMul(one_half, c); + + TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, + EmitPow(component_type, aa_p_bb, half_c)); + + auto neg_d = FNeg(d); + TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs)); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + return Select( + And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), + EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), + EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); +} + StatusOr ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { @@ -786,58 +841,34 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // We use ordered comparisons for everything except kNe, where we use an // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. - case HloOpcode::kEq: - return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); - case HloOpcode::kNe: - return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); - + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); + case ComparisonDirection::kNe: + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); + default: + return Unimplemented( + "complex comparison '%s'", + ComparisonDirectionToString(op->comparison_direction())); + } + } case HloOpcode::kPower: { - // (a+bi)^(c+di) = - // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), - // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) - PrimitiveType component_type = - primitive_util::ComplexComponentType(op->shape().element_type()); auto a = EmitExtractReal(lhs_value); auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); - auto zero = llvm::ConstantFP::get(a->getType(), 0); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto one = llvm::ConstantFP::get(a->getType(), 1); - auto half_c = FMul(one_half, c); - - TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, - EmitPow(component_type, aa_p_bb, half_c)); - - auto neg_d = FNeg(d); - TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, - EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = FMul(one_half, d); - auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); - TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); - TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - // 0^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see - // Branch Cuts for Complex Elementary Functions or Much Ado About - // Nothing's Sign Bit, W. Kahan, Section 10. - return Select( - And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), - EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), - EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); + return EmitComplexPower(op, a, b, c, d); } default: return Unimplemented("binary complex op '%s'", @@ -855,6 +886,9 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } +// TODO(b/123355973): We have an implementation of erfinv in math.cc. We +// shouldn't have two implementations, especially since this one isn't testable +// (it's only observable via a normally-distributed RNG). StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) { if (prim_type != F16 && prim_type != F32 && prim_type != F64) { @@ -1047,6 +1081,18 @@ StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, return Select(x_is_small, for_small_x, for_large_x); } +StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value}, + {value->getType()}, b_); +} + +StatusOr ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) { + TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value)); + return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -1088,6 +1134,12 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, return Select(x_is_small, for_small_x, for_large_x); } +StatusOr ElementalIrEmitter::EmitRoundNearestAfz( + PrimitiveType /*prim_type*/, llvm::Value* value) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, {value}, + {value->getType()}, b_); +} + StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -1235,28 +1287,32 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); - case HloOpcode::kEq: - return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, - rhs_value, b_); - case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, - rhs_value, b_); - case HloOpcode::kLt: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, - lhs_value, rhs_value, b_); - case HloOpcode::kGt: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, - lhs_value, rhs_value, b_); - case HloOpcode::kLe: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, - lhs_value, rhs_value, b_); - case HloOpcode::kGe: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, - lhs_value, rhs_value, b_); + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, + rhs_value, b_); + case ComparisonDirection::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, + lhs_value, rhs_value, b_); + case ComparisonDirection::kGt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, + lhs_value, rhs_value, b_); + case ComparisonDirection::kLe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, + lhs_value, rhs_value, b_); + case ComparisonDirection::kGe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, + lhs_value, rhs_value, b_); + } + } case HloOpcode::kMinimum: return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: @@ -1309,46 +1365,6 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, lhs_value, rhs_value); } -llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( - const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) { - CHECK(hlo.IsElementwise()) - << "HLO " << hlo.ToString() << " is not elementwise."; - - const Shape& operand_shape = hlo.operand(operand_no)->shape(); - // If the operand is scalar, the source index is always {}. - if (ShapeUtil::IsScalar(operand_shape)) { - return llvm_ir::IrArray::Index(target_index.GetType()); - } - - // If no implicit broadcast is needed for this operand, returns the target - // index as the source index. - // - // `IrArray::Index` may contain a physical linear which we can propagate to - // our operand only if our layouts match. "only if" is a bit strong since - // e.g. we can still forward the linear index if the operand shape is - // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases - // are probably not worth handling here for now. - if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) && - LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) { - return target_index; - } - - // If implicit broadcast is needed, the source dimensions that are broadcast - // have index 0. - CHECK_EQ(operand_shape.rank(), hlo.shape().rank()); - llvm_ir::IrArray::Index source_index(target_index.GetType()); - for (int64 i = 0; i < hlo.shape().rank(); ++i) { - if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { - source_index.push_back(target_index[i]); - } else { - CHECK_EQ(1, operand_shape.dimensions(i)); - source_index.push_back(target_index.GetConstantWithIndexType(0)); - } - } - return source_index; -} - StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, @@ -1362,26 +1378,69 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); llvm::Type* raw_value_ty = raw_value->getType(); - // Convert raw integer to float in range [0, 1) if the element is a float. + // If we're generating a floating-point value, convert the raw integer R (i.e. + // `raw_value`) to a float in the range [0, 1). + // + // The basic approach is to choose a significand and exponent such that the + // significand is uniformly distributed and the exponent is distributed, well, + // exponentially (it's more likely to be close to 0 than far from 0). + // + // An easy way to do this is to say that the significand is the first S bits + // of R, and the exponent is determined by the number of trailing zeroes in R, + // exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1; + // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.) + // + // This results in a small bias. Namely, if R has enough trailing zeroes, the + // significand and exponent will "overlap". As a concrete example, consider + // + // 20 X's 12 zeroes + // R = 0bXXXXXXXXXXXXXXXXXXXX000000000000 + // + // Here the exponent is 2^-13 because R has 12 trailing zeroes. The + // significand is made up of the first 23 most-significant bits of R, which we + // observe contain 3 zeroes. This is biased because any random value with + // exponent 2^-12 will have a significand which ends in `000`. + // + // For f32s, this problem occurs only when there are more than 32-23 = 9 + // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the + // probability of a large bias (i.e. many trailing 0s in the significand) is + // exponentially low. So we deem this acceptable. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); - CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - // Perform the division using the float type with the same number of bits - // as the raw value to avoid overflow. - if (raw_value_size_in_bits == 32) { - elem_value = UIToFP(elem_value, b_->getFloatTy()); - elem_value = FDiv(elem_value, - llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); - } else { - elem_value = UIToFP(elem_value, b_->getDoubleTy()); - elem_value = FDiv( - elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); - } - - if (elem_ir_ty != elem_value->getType()) { - elem_value = FPTrunc(elem_value, elem_ir_ty); - } + const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics(); + const int bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics)); + + // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the + // implicit "1." at the beginning of the significand. + const int significand_bits = + llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1; + + llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()}, + {raw_value->getType()}, b_); + llvm::Value* significand = LShr(raw_value, bits - significand_bits); + + // Exponent bias is -127 for f32, meaning that if the exponent is E and the + // significand is S, then the value of the number is 2^(E - 127) * (1.S). + // + // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as + // E = 126 - cttz. + // + // For f64, this is all the same, except the bias is -1023. + // + // In IEEE floating point, the absolute value of the exponent bias equals + // the value of the largest possible exponent. + const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics); + llvm::Value* exponent = + Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz); + + // Now just slot everything into place! The `Trunc` is here because + // raw_value may be larger than our float destination. + elem_value = + BitCast(Trunc(Or(Shl(exponent, significand_bits), significand), + b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())), + elem_ir_ty); } // Convert the value for the requested distribution. @@ -1611,14 +1670,11 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); + operand_to_generator.at(hlo->operand(1))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); + operand_to_generator.at(hlo->operand(2))(index)); return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, on_false_value); } @@ -1628,14 +1684,11 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); + operand_to_generator.at(hlo->operand(1))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); + operand_to_generator.at(hlo->operand(2))(index)); PrimitiveType prim_type = hlo->shape().element_type(); if (primitive_util::IsFloatingPointType(prim_type)) { return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); @@ -1668,8 +1721,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( exit_block = llvm_ir::CreateBasicBlock( /*insert_before=*/nullptr, IrName(hlo, "merge"), b_); } else { - exit_block = init_block->splitBasicBlock(b_->GetInsertPoint(), - AsStringRef(IrName(hlo, "merge"))); + exit_block = + init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge")); init_block->getTerminator()->eraseFromParent(); } @@ -1715,37 +1768,40 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_); source_index_phis[operand_id] = PHI(source_index.GetType(), operand_usage_count[operand_id]); - auto operand_index = source_index; - operand_index[concat_dim] = source_index_phis[operand_id]; + std::vector operand_multi_index = source_index.multidim(); + operand_multi_index[concat_dim] = source_index_phis[operand_id]; // Create the terminator of the block before calling operand generators, // because they require non-degenerate basic blocks. b_->SetInsertPoint(llvm::BranchInst::Create( exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id])); + llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(), + source_index.GetType()); TF_ASSIGN_OR_RETURN(llvm::Value * value, operand_to_generator.at(operand)(operand_index)); output->addIncoming(value, b_->GetInsertBlock()); b_->SetInsertPoint(init_block, saved_insert_point); } + std::vector source_multi_index = source_index.multidim(); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); - auto concat_dim_size = - llvm::ConstantInt::get(source_index[concat_dim]->getType(), - operand->shape().dimensions(concat_dim)); + auto concat_dim_size = source_index.GetConstantWithIndexType( + operand->shape().dimensions(concat_dim)); int64 operand_id = to_unique_operand_id[operand]; - source_index_phis[operand_id]->addIncoming(source_index[concat_dim], + source_index_phis[operand_id]->addIncoming(source_multi_index[concat_dim], b_->GetInsertBlock()); - CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), + CondBr(ICmpULT(source_multi_index[concat_dim], concat_dim_size), emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); + source_multi_index[concat_dim] = + Sub(source_multi_index[concat_dim], concat_dim_size); } Unreachable(); @@ -1762,23 +1818,15 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const int64 rank = input_hlo->shape().rank(); // Use the same index type for all tensor accesses in the same kernel. llvm::Type* index_type = index.GetType(); - llvm_ir::IrArray::Index slice_start_index(index_type, rank); + std::vector slice_start_multi_index(rank); for (int64 i = 0; i < rank; ++i) { auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - // TODO(b/118437727): Remove the R1 path. - llvm::Value* start_index_value; - if (hlo->operand(1)->shape().rank() == 1) { - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); - } else { - llvm_ir::IrArray::Index zero_index(index_type); - TF_ASSIGN_OR_RETURN( - start_index_value, - operand_to_generator.at(hlo->operand(1 + i))(zero_index)); - } + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -1793,17 +1841,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), is_signed); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; + start_index_value->setName(IrName(hlo, StrCat("start_idx", i))); + slice_start_multi_index[i] = start_index_value; } - llvm_ir::IrArray::Index input_index(index_type, rank); + std::vector input_multi_index(rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = Add(slice_start_index[i], index[i]); + input_multi_index[i] = Add(slice_start_multi_index[i], index[i]); } + llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(), + index_type); return operand_to_generator.at(input_hlo)(input_index); } @@ -1825,7 +1874,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to // generate. - IrArray::Index operand_index(index_type); + std::vector operand_multi_index; // First copy in the window indices to operand_index. Also collect a mapping // from operand dimension to output window dimension. Elided window dimensions @@ -1834,26 +1883,29 @@ StatusOr ElementalIrEmitter::EmitElementalGather( for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { - operand_index.push_back(index.GetConstantWithIndexType(0)); + operand_multi_index.push_back(index.GetConstantWithIndexType(0)); } else { int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); operand_to_output_dim[i] = output_window_dim; - operand_index.push_back(index[output_window_dim]); + operand_multi_index.push_back(index[output_window_dim]); } } // This is the index of the index vector in the start_indices tensor. - IrArray::Index gather_index_index(index_type); + std::vector gather_index_index_components; { - std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { - gather_index_index.push_back(index[i]); + gather_index_index_components.push_back(index[i]); } } - if (gather_index_index.size() != indices_shape.dimensions_size()) { - gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + if (gather_index_index_components.size() != + indices_shape.dimensions_size()) { + gather_index_index_components.insert( + gather_index_index_components.begin() + + dim_numbers.index_vector_dim(), + nullptr); } } @@ -1881,11 +1933,14 @@ StatusOr ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = - Add(operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_multi_index[operand_dim] = + Add(operand_multi_index[operand_dim], + gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + IrArray::Index gather_index_index(gather_index_index_components, + indices_shape, index_type); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_operand_index(gather_dim_component, 0); @@ -1893,13 +1948,16 @@ StatusOr ElementalIrEmitter::EmitElementalGather( int64 index_vector_size = indices_shape.dimensions(dim_numbers.index_vector_dim()); for (int64 i = 0; i < index_vector_size; i++) { - gather_index_index[dim_numbers.index_vector_dim()] = + gather_index_index_components[dim_numbers.index_vector_dim()] = index.GetConstantWithIndexType(i); + IrArray::Index gather_index_index(gather_index_index_components, + indices_shape, index_type); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_operand_index(gather_dim_component, i); } } + IrArray::Index operand_index(operand_multi_index, operand_shape, index_type); return operand_generator(operand_index); } @@ -1912,8 +1970,8 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. const int64 rank = input_hlo->shape().rank(); - llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); - llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); + std::vector slice_start_multi_index(rank); + std::vector slice_limit_multi_index(rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = b_->getTrue(); @@ -1924,18 +1982,10 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( return llvm::ConstantInt::get(index_type, c); }; - llvm::Value* start_index_value; - // TODO(b/118437727): Remove the R1 path. - if (hlo->operand(2)->shape().rank() == 1) { - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(start_index_value, - operand_to_generator.at(hlo->operand(2))(dim_index)); - } else { - llvm_ir::IrArray::Index zero_index(index_type); - TF_ASSIGN_OR_RETURN( - start_index_value, - operand_to_generator.at(hlo->operand(2 + i))(zero_index)); - } + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) @@ -1952,16 +2002,16 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), is_signed); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; - slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + start_index_value->setName(IrName(hlo, StrCat("start_idx", i))); + slice_start_multi_index[i] = start_index_value; + slice_limit_multi_index[i] = + Add(slice_start_multi_index[i], update_dim_size); slice_intersection = - And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]), "slice_intersection"); slice_intersection = - And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]), "slice_intersection"); } @@ -1977,10 +2027,12 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, b_); // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(index.GetType(), rank); + std::vector update_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = Sub(index[i], slice_start_index[i]); + update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]); } + llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(), + index.GetType()); TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); Store(true_value, ret_value_addr); @@ -1999,27 +2051,28 @@ StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& padded_index) { - auto index = padded_index; + std::vector multi_index = padded_index.multidim(); llvm::Value* in_bounds = b_->getTrue(); - for (size_t i = 0; i < index.size(); ++i) { + for (size_t i = 0; i < multi_index.size(); ++i) { auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); + return padded_index.GetConstantWithIndexType(n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); + multi_index[i] = + Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)), + "in_bounds"); in_bounds = - And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); - in_bounds = And( - in_bounds, - ICmpEQ( - index_typed_const(0), - URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = - SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + And(in_bounds, + ICmpEQ(index_typed_const(0), + URem(multi_index[i], + index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + multi_index[i] = + SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1)); in_bounds = And(in_bounds, - ICmpSLT(index[i], + ICmpSLT(multi_index[i], index_typed_const(hlo->operand(0)->shape().dimensions(i))), "in_bounds"); } @@ -2035,6 +2088,8 @@ StatusOr ElementalIrEmitter::EmitElementalPad( llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); SetToFirstInsertPoint(if_data.true_block, b_); + llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(), + padded_index.GetType()); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); Store(operand_value, ret_value_addr); @@ -2094,17 +2149,27 @@ StatusOr ElementalIrEmitter::EmitElementalDot( // Given an output index [a,b,c,d,e] in the result, we compute: // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - IrArray::Index lhs_index(index_type), rhs_index(index_type); - + std::vector lhs_multi_index, rhs_multi_index; for (int64 i = 0; i < lhs_dims - 1; i++) { - lhs_index.push_back(dot_result_index[i]); + lhs_multi_index.push_back(dot_result_index[i]); } - lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); - - for (int64 i = 0; i < rhs_dims - 1; i++) { - rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); + lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim, + inner_loop->GetIndVarValue()); + IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(), + index_type); + + int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size(); + for (int64 i = 0; i < num_batch_dims; i++) { + rhs_multi_index.push_back( + dot_result_index[dim_numbers.rhs_batch_dimensions(i)]); + } + for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) { + rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]); } - rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); + rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim, + inner_loop->GetIndVarValue()); + IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(), + index_type); llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); @@ -2144,7 +2209,6 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kClz: case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -2156,30 +2220,26 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: case HloOpcode::kAnd: case HloOpcode::kAtan2: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kPower: @@ -2193,11 +2253,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* lhs = hlo->operand(0); const HloInstruction* rhs = hlo->operand(1); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, - operand_to_generator.at(lhs)( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(lhs)(index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, - operand_to_generator.at(rhs)( - ElementwiseSourceIndex(index, *hlo, 1))); + operand_to_generator.at(rhs)(index)); return EmitBinaryOp(hlo, lhs_value, rhs_value); }; case HloOpcode::kSelect: @@ -2214,8 +2272,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); return EmitReducePrecision(hlo, operand_value); }; case HloOpcode::kConcatenate: @@ -2228,13 +2285,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& target_index) -> StatusOr { const HloInstruction* operand = hlo->operand(0); - auto source_index = target_index; + std::vector source_multi_index = target_index.multidim(); for (int64 dim : hlo->dimensions()) { - source_index[dim] = - Sub(llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType( + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } + llvm_ir::IrArray::Index source_index( + source_multi_index, operand->shape(), target_index.GetType()); return operand_to_generator.at(operand)(source_index); }; case HloOpcode::kBroadcast: @@ -2308,7 +2366,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { IrArray::Index sliced_index = index.SourceIndexOfSlice( - /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), + /*operand_shape=*/hlo->operand(0)->shape(), + /*starts=*/hlo->slice_starts(), /*strides=*/hlo->slice_strides(), /*builder=*/b_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; @@ -2345,6 +2404,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return operand_to_generator.at(operand)( index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_)); }; + case HloOpcode::kCopy: + return [hlo, &operand_to_generator]( + const IrArray::Index& target_index) -> StatusOr { + IrArray::Index source_index(target_index.multidim(), + hlo->operand(0)->shape(), + target_index.GetType()); + TF_ASSIGN_OR_RETURN( + llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(source_index)); + return operand_value; + }; case HloOpcode::kTranspose: return [this, hlo, &operand_to_generator](const IrArray::Index& target_index) { @@ -2366,6 +2436,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( -> StatusOr { return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; + case HloOpcode::kReplicaId: + return [this, hlo](const IrArray::Index&) -> StatusOr { + if (hlo_module_config_.replica_count() != 1) { + return Unimplemented("Replication is not implemented on CPU/GPU."); + } + llvm::Type* type = llvm_ir::PrimitiveTypeToIrType( + hlo->shape().element_type(), module_); + return llvm::ConstantInt::getNullValue(type); + }; default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d3e2acaabd4f602171def70ccd3d4fd5adce0d0d..6b1c85b7e5f45037a84b0fc97570dc61bdcd3313 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -119,6 +119,12 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value); + virtual StatusOr EmitSqrt(PrimitiveType prim_type, + llvm::Value* value); + + virtual StatusOr EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitLog1p(PrimitiveType prim_type, llvm::Value* value); @@ -140,6 +146,9 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value); + virtual StatusOr EmitRoundNearestAfz(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x); @@ -150,15 +159,6 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* imag); - // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and - // the target array index, computes the source array index of its - // `operand_no`-th operand. - // - // Precondition: `hlo` is an elementwise op. - llvm_ir::IrArray::Index ElementwiseSourceIndex( - const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no); - // Identifier of the thread unique among all threads on the device virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } @@ -211,13 +211,21 @@ class ElementalIrEmitter : public IrBuilderMixin { const HloModuleConfig& hlo_module_config_; private: + // Computes the complex power function, returns (a + i*b)^(c + i*d). + StatusOr EmitComplexPower(const HloInstruction* op, + llvm::Value* a, llvm::Value* b, + llvm::Value* c, llvm::Value* d); + // Returns a ElementGenerator for an RNG HloInstruction using the Philox // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator); + // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. + // + // Precondition: raw_value has at least as many bits as hlo's element type. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 852f34e06df35242b13110ae4411b8c969c26019..ac18346faa120f2d08ad6eba437266198aa6aff7 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -32,7 +32,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text, config)); + ParseAndReturnVerifiedModule(hlo_text, config)); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); } }; @@ -60,5 +60,38 @@ ENTRY main { Literal rhs = LiteralUtil::CreateR3({{{3}, {4}}}); RunTest(hlo_text, {&lhs, &rhs}); } + +XLA_TEST_F(ElementalIrEmitterExecutionTest, BatchDot) { + const char* hlo_text = R"( +HloModule BatchDot + +fused_computation.1 { + param_0 = f64[1,1,8]{2,1,0} parameter(0) + r.1 = f64[2,4]{1,0} reshape(param_0) + param_1 = f64[1,2,2,2,1]{4,3,2,1,0} parameter(1) + r.2 = f64[2,4,1]{2,1,0} reshape(param_1) + ROOT dot = f64[2,1]{1,0} dot(r.1, r.2), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} +} + +ENTRY resampler_Resampler.49 { + p0 = f64[1,1,8]{2,1,0} parameter(0) + p1 = f64[1,2,2,2,1]{4,3,2,1,0} parameter(1) + ROOT f = f64[2,1]{1,0} fusion(p0, p1), kind=kLoop, calls=fused_computation.1 +} +)"; + + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + // Disable the layout assignment pass because it would throw away the layouts + // in the fusion computation, but not recreate them. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{4e-3, 4e-3})); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 10b8c01ff1383658fcfb2271c177ba54347f985a..7b60c983b3093f1e8f91cacef1c45b50e0d9ac38 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" - namespace xla { StatusOr> Executable::ExecuteOnStreams( @@ -138,8 +138,6 @@ StatusOr Executable::ExecuteOnStreamWrapper( XLA_LOG_LINES( tensorflow::INFO, profile_ptr->ToString(stream->parent()->GetDeviceDescription())); - hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr.get()); } return return_value; @@ -147,37 +145,4 @@ StatusOr Executable::ExecuteOnStreamWrapper( int64 Executable::SizeInBytes() { return -1; } -Status Executable::DumpHloSnapshot() { - TF_RET_CHECK(dumping_snapshot()); - TF_RET_CHECK(hlo_snapshot_->has_hlo() && - hlo_snapshot_->hlo().has_hlo_module()); - const string& directory_path = - module_config().debug_options().xla_dump_executions_to(); - const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = - absl::StrFormat("computation_%d__%s__execution_%d", module.id(), - module.entry_computation_name(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); -} - -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const HloSnapshot& hlo_session) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(hlo_session, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index b34bca55a48b113c325dbf28c03f7a0f5b71f658..a08ec181d494cb575b291d3aabbec6c1e3d56226 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -184,11 +184,6 @@ class Executable { } bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } - Status DumpHloSnapshot(); - - // Dump hlo snapshot to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const HloSnapshot& hlo_session); protected: mutable tensorflow::mutex mutex_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 85409b330b11537158059dcce8c2a96c98d38f30..f16a4485550a4262be8089c7d6c7c8252830dc1b 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -26,7 +26,7 @@ namespace xla { namespace { -// Helper to replace the called computation at a while-, call-, or +// Helper to replace the called computation at a while-, call-, case-, or // conditional-instruction. This function replaces exactly one instance of // 'computation' with 'new_computation' even if 'instruction' calls // 'computation' more than once. @@ -49,11 +49,14 @@ void ReplaceCalledComputation(HloInstruction* instruction, break; } case HloOpcode::kConditional: { - if (computation == instruction->true_computation()) { - instruction->set_true_computation(new_computation); - } else { - CHECK_EQ(computation, instruction->false_computation()); - instruction->set_false_computation(new_computation); + for (int b = 0; b < instruction->branch_count(); ++b) { + if (b == instruction->branch_count() - 1) { + CHECK_EQ(computation, instruction->branch_computation(b)); + } + if (computation == instruction->branch_computation(b)) { + instruction->set_branch_computation(b, new_computation); + break; + } } break; } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8eeb930b48165a2e3c622581e05cb5f7063fa1fa..ef35311b08b8ada098b3dd1f3c70692b8f148add 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -81,8 +81,9 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + zero, ComparisonDirection::kGt)); return builder.Build(); } @@ -158,9 +159,9 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kEq, param0, false_constant)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), param0, false_constant, + ComparisonDirection::kEq)); cond_computation = module->AddEmbeddedComputation(builder.Build()); } diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 01cef499665c050d4453382289168276028e1d26..1838f65e6eae3108b8c5cbd43e221140c3777a2b 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -112,6 +112,14 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( int64 operand_rank) { HloComputation* computation = index_vector->parent(); const Shape& index_shape = index_vector->shape(); + + if (operand_rank == 0) { + // This is Gather from a scalar. So, the index vector in operand space must + // be a zero-sized vector. + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); @@ -153,10 +161,9 @@ static StatusOr> GatherLoopBody( dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); HloInstruction* index_vector; @@ -222,7 +229,7 @@ static StatusOr> GatherLoopBody( {operand, start_indices, updated_accumulator}}; } -static StatusOr CreateGatherLoopAccumulatorInitValue( +static HloInstruction* CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { @@ -297,7 +304,7 @@ static StatusOr PermuteBatchAndOffsetDims( // [3,1] out of operand into an accumulator of shape [4,3,1]. We then // reshape this result to [2,2,3] and finally transpose it to [2,3,2]. -StatusOr GatherExpander::ExpandGather( +StatusOr GatherExpander::ExpandInstruction( HloInstruction* gather_instr) { CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); @@ -332,12 +339,10 @@ StatusOr GatherExpander::ExpandGather( CHECK_EQ(gather_loop_trip_count, canonical_start_indices->shape().dimensions(0)); - TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_init, - CreateGatherLoopAccumulatorInitValue( - computation, output_shape.element_type(), - gather_instr->gather_slice_sizes(), gather_loop_trip_count, - gather_instr->gather_dimension_numbers())); + HloInstruction* accumulator_init = CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_slice_sizes(), gather_loop_trip_count, + gather_instr->gather_dimension_numbers()); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( @@ -364,25 +369,11 @@ StatusOr GatherExpander::ExpandGather( output_rank); } -StatusOr GatherExpander::Run(HloModule* module) { - auto is_nontrivial_gather = [](HloInstruction* inst) { - return inst->opcode() == HloOpcode::kGather && - // Avoid expanding gather ops that produce zero sized tensors, - // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::IsZeroElementArray(inst->shape()); - }; - - std::vector gather_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - absl::c_copy_if(computation->instructions(), - std::back_inserter(gather_instrs), is_nontrivial_gather); - } - - for (HloInstruction* inst : gather_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandGather(inst)); - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); - } - - return !gather_instrs.empty(); +bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::IsZeroElementArray(inst->shape()); } + } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 8af9c6b71fbc391bf7c0e9809e979b65135a6df3..5625a37cb46ca5b70f69d86bc424f6512bfb293f 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -16,20 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloModulePass { +class GatherExpander : public OpExpanderPass { public: absl::string_view name() const override { return "gather_expander"; } - StatusOr Run(HloModule* module) override; protected: - StatusOr ExpandGather(HloInstruction* gather_instr); + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* gather_inst) override; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 7d450f4b53cdea209f2ef10ba785be6ec3b8bf8d..d6a7ec90b59adf72d1e9b447e419cfc8d79fcf4e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -58,7 +57,8 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - MutableBorrowingLiteral literal, std::function done) { + MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* /*transfer_metadata*/) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); @@ -98,7 +98,8 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal( Status GenericTransferManager::TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) { + const ShapedBuffer& device_buffer, + const TransferMetadata* /*transfer_metadata*/) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " << ShapeUtil::HumanString(shape) diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 86c8b1c145a25149a25e7b272babc5c858d476af..acfd8dd64c1a907977bd8d893c9e9dc87425eab5 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -40,14 +40,15 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice(se::Stream* stream, - const ShapedBuffer& device_buffer, - MutableBorrowingLiteral literal, - std::function done) override; + void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* transfer_metadata) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index dc17aa4426236f54e5f03c28634278d45f462158..53cb8c4f49e67d5cc2c37a8c8b7b95d470bce919 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -7,7 +7,7 @@ load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "if_cuda") licenses(["notice"]) # Apache 2.0 @@ -156,7 +156,6 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ - ":backend_configs", ":buffer_allocations", ":cudnn_conv_runner", ":elemental_ir_emitter", @@ -164,8 +163,10 @@ cc_library( ":gpu_executable", ":hlo_to_ir_bindings", ":ir_emission_utils", + ":nccl_all_reduce_thunk", ":parallel_loop_emitter", ":partition_assignment", + ":thunk", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -179,6 +180,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", @@ -287,9 +289,44 @@ cc_library( ], ) +cc_library( + name = "thunk", + srcs = ["thunk.cc"], + hdrs = ["thunk.h"], + deps = [ + ":buffer_allocations", + ":hlo_execution_profiler", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +tf_cuda_library( + name = "nccl_all_reduce_thunk", + srcs = ["nccl_all_reduce_thunk.cc"], + hdrs = ["nccl_all_reduce_thunk.h"], + deps = [ + ":buffer_allocations", + ":hlo_execution_profiler", + ":thunk", + "@com_google_absl//absl/synchronization", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/cuda:cuda_activation", + "//tensorflow/stream_executor/cuda:cuda_gpu_executor", + ] + if_cuda([ + "@local_config_nccl//:nccl", + ]), +) + cc_library( name = "gpu_executable", srcs = [ + "cholesky_thunk.cc", "conditional_thunk.cc", "convolution_thunk.cc", "copy_thunk.cc", @@ -303,12 +340,13 @@ cc_library( "memset_thunk.cc", "outfeed_thunk.cc", "sequential_thunk.cc", - "thunk.cc", "thunk_schedule.cc", + "triangular_solve_thunk.cc", "tuple_thunk.cc", "while_thunk.cc", ], hdrs = [ + "cholesky_thunk.h", "conditional_thunk.h", "convolution_thunk.h", "copy_thunk.h", @@ -322,20 +360,23 @@ cc_library( "memset_thunk.h", "outfeed_thunk.h", "sequential_thunk.h", - "thunk.h", "thunk_schedule.h", + "triangular_solve_thunk.h", "tuple_thunk.h", "while_thunk.h", ], deps = [ ":buffer_allocations", ":cudnn_conv_runner", + ":cusolver_context", ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", + ":nccl_all_reduce_thunk", # fixdeps: keep ":outfeed_manager", ":partition_assignment", ":stream_assignment", + ":thunk", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", @@ -364,6 +405,10 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "//tensorflow/stream_executor:blas", + "//tensorflow/stream_executor:device_memory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -399,16 +444,21 @@ cc_library( ":backend_configs", ":buffer_comparator", ":cudnn_conv_runner", + ":gpu_autotuning_proto", ":gpu_executable", ":ir_emission_utils", + ":scratch_allocator", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:lib", + "//tensorflow/core:logger", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/util/proto:proto_utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -416,6 +466,18 @@ cc_library( ], ) +cc_library( + name = "scratch_allocator", + srcs = ["scratch_allocator.cc"], + hdrs = ["scratch_allocator.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "cudnn_conv_runner", srcs = ["cudnn_conv_runner.cc"], @@ -472,6 +534,43 @@ tf_cc_test( ], ) +cc_library( + name = "cusolver_context", + srcs = ["cusolver_context.cc"], + hdrs = ["cusolver_context.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:blas", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cusolver", + ], +) + +cc_library( + name = "cusolver_rewriter", + srcs = ["cusolver_rewriter.cc"], + hdrs = ["cusolver_rewriter.h"], + deps = [ + ":cusolver_context", + ":ir_emission_utils", + ":scratch_allocator", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:blas", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], @@ -484,6 +583,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -555,6 +655,44 @@ cc_library( ], ) +cc_library( + name = "gpu_sanitize_constant_names", + srcs = ["gpu_sanitize_constant_names.cc"], + hdrs = ["gpu_sanitize_constant_names.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_sanitize_constant_names_test", + srcs = ["gpu_sanitize_constant_names_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_sanitize_constant_names", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "fusion_merger", srcs = ["fusion_merger.cc"], @@ -567,6 +705,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", @@ -640,7 +779,8 @@ cc_library( srcs = ["gpu_transfer_manager.cc"], hdrs = ["gpu_transfer_manager.h"], deps = [ - ":gpu_compiler", + ":infeed_manager", + ":nvptx_compiler", ":outfeed_manager", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -653,7 +793,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", @@ -663,15 +802,17 @@ cc_library( ) cc_library( - name = "gpu_compiler", + name = "nvptx_compiler", srcs = ["nvptx_compiler.cc"], hdrs = ["nvptx_compiler.h"], deps = [ + ":cudnn_batchnorm_rewriter", ":cudnn_conv_algorithm_picker", ":cudnn_conv_pad_for_tensor_cores", ":cudnn_conv_padding_legalization", ":cudnn_conv_rewriter", ":cudnn_fused_conv_rewriter", + ":cusolver_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -679,6 +820,7 @@ cc_library( ":gpu_hlo_schedule", ":gpu_hlo_support_checker", ":gpu_layout_assignment", + ":gpu_sanitize_constant_names", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -700,6 +842,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -719,12 +862,13 @@ cc_library( "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:sort_simplifier", + "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service:while_loop_trip_count_annotator", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", - "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:cuda_libdevice_path", @@ -960,6 +1104,7 @@ cc_library( hdrs = ["gpu_fusible.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", ], ) @@ -998,8 +1143,8 @@ tf_cc_test( srcs = ["cudnn_fused_conv_rewriter_test.cc"], tags = tf_cuda_tests_tags(), deps = [ + ":ir_emission_utils", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", @@ -1042,3 +1187,13 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) + +xla_proto_library( + name = "gpu_autotuning_proto", + srcs = ["gpu_autotuning.proto"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:autotuning_proto_cc", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..7daef16cb62338cfa5b027136ecd4262288eec8d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc @@ -0,0 +1,119 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/device_memory.h" + +namespace xla { +namespace gpu { + +CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, + BufferAllocation::Slice a_buffer, + BufferAllocation::Slice workspace_buffer, + BufferAllocation::Slice info_buffer, + PrimitiveType type, int64 batch_size, int64 n, + const HloInstruction* hlo) + : Thunk(Kind::kCholesky, hlo), + uplo_(options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper), + a_buffer_(a_buffer), + workspace_buffer_(workspace_buffer), + info_buffer_(info_buffer), + type_(type), + batch_size_(batch_size), + a_batch_stride_(n * n * + ShapeUtil::ByteSizeOfPrimitiveType( + hlo->operand(0)->shape().element_type())), + n_(n) {} + +Status CholeskyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(3) << "type=" << PrimitiveType_Name(type_) + << " uplo=" << se::blas::UpperLowerString(uplo_) + << " batch_size=" << batch_size_ << " n=" << n_ + << " a=" << a_buffer_.ToString() + << " workspace=" << workspace_buffer_.ToString() + << " info=" << info_buffer_.ToString(); + + CusolverContext* context; + { + tensorflow::mutex_lock lock(mu_); + auto result = contexts_.emplace(stream, CusolverContext()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + CusolverContext::Create(stream)); + } + context = &result.first->second; + } + + char* a_base = static_cast( + buffer_allocations.GetDeviceAddress(a_buffer_).opaque()); + int* info_base = static_cast( + buffer_allocations.GetDeviceAddress(info_buffer_).opaque()); + se::DeviceMemoryBase workspace_data = + buffer_allocations.GetDeviceAddress(workspace_buffer_); + for (int64 i = 0; i < batch_size_; ++i) { + se::DeviceMemoryBase a_data = + se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_); + se::DeviceMemory info_data( + se::DeviceMemoryBase(info_base + i, sizeof(int))); + switch (type_) { + case F32: { + TF_RETURN_IF_ERROR( + context->Potrf(uplo_, n_, se::DeviceMemory(a_data), n_, + info_data, se::DeviceMemory(workspace_data))); + break; + } + case F64: { + TF_RETURN_IF_ERROR(context->Potrf( + uplo_, n_, se::DeviceMemory(a_data), n_, info_data, + se::DeviceMemory(workspace_data))); + break; + } + case C64: { + TF_RETURN_IF_ERROR(context->Potrf( + uplo_, n_, se::DeviceMemory>(a_data), n_, + info_data, se::DeviceMemory>(workspace_data))); + break; + } + case C128: { + TF_RETURN_IF_ERROR(context->Potrf( + uplo_, n_, se::DeviceMemory>(a_data), n_, + info_data, se::DeviceMemory>(workspace_data))); + break; + } + default: + return InvalidArgument("Invalid type for cholesky %s", + PrimitiveType_Name(type_)); + } + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..cde245a7e8bc0909059d4643cae3de138bddcdec --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_ + +#include "absl/base/thread_annotations.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a Cholesky +// decomposition (LAPACK potrf). It is generated by IrEmitter. +// +// Thread-compatible. +class CholeskyThunk : public Thunk { + public: + static StatusOr ScratchBufferSize(int64 n); + CholeskyThunk(const CholeskyOptions& options, + BufferAllocation::Slice a_buffer, + BufferAllocation::Slice workspace_buffer, + BufferAllocation::Slice info_buffer, + PrimitiveType type, + int64 batch_size, int64 n, const HloInstruction* hlo); + + CholeskyThunk(const CholeskyThunk&) = delete; + CholeskyThunk& operator=(const CholeskyThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + se::blas::UpperLower uplo_; + + const BufferAllocation::Slice a_buffer_; + const BufferAllocation::Slice workspace_buffer_; + const BufferAllocation::Slice info_buffer_; + + const PrimitiveType type_; + const int64 batch_size_; + const int64 a_batch_stride_; + const int64 n_; + + tensorflow::mutex mu_; + absl::flat_hash_map contexts_ GUARDED_BY(mu_); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 9ed523998bf07567133fdac0e40b12b8ce4ea3b0..ea6392498264f25d53bec2309bfdf7bdcf6a2a2e 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -24,25 +24,35 @@ namespace xla { namespace gpu { ConditionalThunk::ConditionalThunk( - const BufferAllocation::Slice& predicate_buffer_index, - const BufferAllocation::Slice& true_operand_buffer_index, - const BufferAllocation::Slice& false_operand_buffer_index, - ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence, + const BufferAllocation::Slice& branch_index_buffer_index, + absl::Span branch_operand_buffer_indexes, + std::vector branch_thunk_sequences, const HloInstruction* hlo) : Thunk(Kind::kConditional, hlo), - predicate_buffer_index_(predicate_buffer_index), - true_operand_buffer_index_(true_operand_buffer_index), - false_operand_buffer_index_(false_operand_buffer_index), - // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_ - // constructors because these SequentialThunks are logically "part of" - // this ConditionalThunk, and shouldn't be profiled separately from it. - true_thunk_(std::move(true_thunk_sequence), nullptr), - false_thunk_(std::move(false_thunk_sequence), nullptr) {} + branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED), + branch_index_buffer_index_(branch_index_buffer_index), + branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), + branch_operand_buffer_indexes.end()) { + // Pass nullptr as the HloInstruction* to the branch_thunks_ + // constructors because these SequentialThunks are logically "part of" + // this ConditionalThunk, and shouldn't be profiled separately from it. + branch_thunks_.reserve(branch_thunk_sequences.size()); + for (auto& branch_thunk_sequence : branch_thunk_sequences) { + branch_thunks_.emplace_back( + new SequentialThunk(std::move(branch_thunk_sequence), nullptr)); + } +} Status ConditionalThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { - TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); - TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); + if (branch_index_is_bool_) { + TF_RET_CHECK(branch_thunks_.size() == 2); + } else { + TF_RET_CHECK(!branch_thunks_.empty()); + } + for (auto& branch_thunk : branch_thunks_) { + TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor)); + } return Status::OK(); } @@ -51,31 +61,38 @@ Status ConditionalThunk::ExecuteOnStream( HloExecutionProfiler* profiler) { auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); // Copy the predicate value from device. - bool predicate; - se::DeviceMemoryBase predicate_address = - buffer_allocations.GetDeviceAddress(predicate_buffer_index_); - stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool)); + int32 branch_index = -1; + bool pred = false; + se::DeviceMemoryBase branch_index_address = + buffer_allocations.GetDeviceAddress(branch_index_buffer_index_); + if (branch_index_is_bool_) { + stream->ThenMemcpy(&pred, branch_index_address, sizeof(bool)); + } else { + stream->ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); + } Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("Failed to retrieve predicate value on stream %p: %s.", - stream, block_status.error_message()); + return InternalError( + "Failed to retrieve branch_index value on stream %p: %s.", stream, + block_status.error_message()); } - - // Execute the true or the false computation depending on the value of the - // predicate. - if (predicate) { - profiler->StartHloComputation(); - TF_RETURN_IF_ERROR( - true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); - profiler->FinishHloComputation(hlo_instruction()->true_computation()); + if (branch_index_is_bool_) { + branch_index = pred ? 0 : 1; } else { - profiler->StartHloComputation(); - TF_RETURN_IF_ERROR( - false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); - profiler->FinishHloComputation(hlo_instruction()->false_computation()); + // Handle default scenario for branch_index not in [0, num_branches). + if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) { + branch_index = hlo_instruction()->branch_count() - 1; + } } + // Execute the branch computation corresponding to the value of branch_index. + profiler->StartHloComputation(); + TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream( + buffer_allocations, stream, profiler)); + profiler->FinishHloComputation( + hlo_instruction()->branch_computation(branch_index)); + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index aef24342c9fe182eb54b1c2beff840a76e7b8115..c0093ca6397e636bee953ddf0af8c48caaaadae0 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -38,12 +42,11 @@ namespace gpu { // false computation share the same allocation. class ConditionalThunk : public Thunk { public: - ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index, - const BufferAllocation::Slice& true_operand_buffer_index, - const BufferAllocation::Slice& false_operand_buffer_index, - ThunkSequence true_thunk_sequence, - ThunkSequence false_thunk_sequence, - const HloInstruction* hlo); + ConditionalThunk( + const BufferAllocation::Slice& branch_index_buffer_index, + absl::Span branch_operand_buffer_indexes, + std::vector branch_thunk_sequences, + const HloInstruction* hlo); ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; @@ -55,11 +58,10 @@ class ConditionalThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - BufferAllocation::Slice predicate_buffer_index_; - BufferAllocation::Slice true_operand_buffer_index_; - BufferAllocation::Slice false_operand_buffer_index_; - SequentialThunk true_thunk_; - SequentialThunk false_thunk_; + const bool branch_index_is_bool_; + BufferAllocation::Slice branch_index_buffer_index_; + std::vector branch_operand_buffer_indexes_; + std::vector> branch_thunks_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index 60289506524759580dbb9b82147c78c4ce1cb25e..2cceb0422d08ff7951308b0727941f5437785447 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -188,13 +188,8 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { computation_->AddInstruction(HloInstruction::CreateBroadcast( batch_norm->operand(3)->shape(), epsilon, {})))); HloInstruction* inverse_stddev = - computation_->AddInstruction(HloInstruction::CreateBinary( - var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, - computation_->AddInstruction(HloInstruction::CreateBroadcast( - var_plus_epsilon->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(-.5))), - {})))); + computation_->AddInstruction(HloInstruction::CreateUnary( + var_plus_epsilon->shape(), HloOpcode::kRsqrt, var_plus_epsilon)); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 309b0aca64954e64509d731dce28ce9d8da4ee43..02eb191cf58a2e7723de2c37bdcf3875b350be9a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -14,17 +14,23 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" +#include "google/protobuf/any.pb.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/proto/proto_utils.h" namespace xla { namespace gpu { @@ -32,49 +38,8 @@ namespace { using absl::optional; using se::DeviceMemoryBase; -using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; - -class ScratchAllocator : public se::ScratchAllocator { - public: - ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - - int64 GetMemoryLimitInBytes(se::Stream* stream) override { - return 1LL << 32; // 4GB. TODO(jlebar): Tune this? - } - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - StatusOr> AllocateBytes(se::Stream* stream, - int64 byte_size) override; - - private: - const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - -StatusOr> ScratchAllocator::AllocateBytes( - se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - absl::StrFormat( - "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false)); - total_allocated_bytes_ += byte_size; - - se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); - allocated_buffers_.push_back(std::move(allocated_buffer)); - return se::DeviceMemory(buffer_addr); -} +using tensorflow::AutotuneResult; std::vector GetAlgorithms(CudnnConvKind kind, se::StreamExecutor* stream_exec) { @@ -132,6 +97,31 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { return tensorflow::mutex_lock{it->second}; } +tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + tensorflow::CudnnVersion cudnn_version; + if (auto* dnn = stream_executor->AsDnn()) { + StatusOr version_or = dnn->GetVersion(); + if (version_or.ok()) { + const auto& version = version_or.ValueOrDie(); + cudnn_version.set_major(version.major_version()); + cudnn_version.set_minor(version.minor_version()); + cudnn_version.set_patch(version.patch()); + } + } + return cudnn_version; +} + +tensorflow::ComputeCapability GetComputeCapability( + se::StreamExecutor* stream_executor) { + tensorflow::ComputeCapability cc; + int cc_major, cc_minor; + stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + cc.set_major(cc_major); + cc.set_minor(cc_minor); + return cc; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -145,8 +135,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr -CudnnConvAlgorithmPicker::PickBestAlgorithm( +StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( const HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. @@ -233,8 +222,6 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm( &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); initialize_buffer(result_buffer); - se::dnn::ProfileResult best_result; - int64 best_result_bytes_used = 0; TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); @@ -244,6 +231,7 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm( // this algorithm considered correct, though. optional first_algorithm; TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + std::vector profile_results; for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; @@ -254,73 +242,111 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm( RunConvOptions options; options.profile_result = &profile_result; options.algo_override = alg; - bool launch_ok = + Status launch_status = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, options) - .ok(); - - if (launch_ok && profile_result.is_valid()) { - const bool crash_on_checking_failure = - instr->GetModule() - ->config() - .debug_options() - .xla_gpu_crash_on_verification_failures(); - if (comparator.has_value()) { - StatusOr result = comparator->CompareEqual( - se::DeviceMemory(result_buffer)); - if (!result.ok()) { - LOG(ERROR) << "Unable to compare " - << AlgorithmToString(*first_algorithm) << " against " - << AlgorithmToString(alg) << " for " << instr->ToString() - << ": " << result.status(); - CHECK(!crash_on_checking_failure); - } else if (!result.ValueOrDie()) { - LOG(ERROR) << "Results mismatch between different convolution " - "algorithms. This is likely a bug in convolution, or " - "an excessive loss of precision in convolution. " - << instr->ToString() << " for " - << AlgorithmToString(*first_algorithm) << " vs " - << AlgorithmToString(alg); - CHECK(!crash_on_checking_failure); - } - } else if (cross_check_enabled) { - auto comp = F16BufferComparator::Create( - se::DeviceMemory(result_buffer), compiler_, allocator, - &stream); - if (comp.ok()) { - comparator.emplace(comp.ConsumeValueOrDie()); - first_algorithm.emplace(alg); - } else { - LOG(ERROR) << "Fail to initialize buffer comparator: " - << comp.status() << ", instruction: " << instr->ToString(); - CHECK(!crash_on_checking_failure); - } + &scratch_allocator, &stream, options); + + if (!launch_status.ok()) { + continue; + } + + if (!profile_result.is_valid()) { + continue; + } + + profile_results.emplace_back(); + AutotuneResult& result = profile_results.back(); + result.mutable_conv()->set_algorithm(alg.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled()); + + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + result.mutable_success()->set_scratch_bytes(scratch_bytes_used); + *result.mutable_success()->mutable_run_time() = + tensorflow::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + + if (comparator.has_value()) { + StatusOr compare_result = comparator->CompareEqual( + se::DeviceMemory(result_buffer)); + if (!compare_result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << compare_result.status(); + CHECK(!crash_on_checking_failure); + } else if (!compare_result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + auto* failure = result.mutable_reference_conv(); + failure->set_algorithm(first_algorithm->algo_id()); + failure->set_tensor_ops_enabled(first_algorithm->tensor_ops_enabled()); } - int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) - << " succeeded, taking " << profile_result.elapsed_time_in_ms() - << "ms and using " << NumBytesToString(scratch_bytes_used) - << " of scratch (Best result: " - << best_result.elapsed_time_in_ms() << "ms, " - << NumBytesToString(best_result_bytes_used) << " of scratch)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - best_result_bytes_used = scratch_bytes_used; + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(result_buffer), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " << comp.status() + << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; } } - if (best_result.is_valid()) { - VLOG(2) << "Best algorithm for " << instr->ToString() << ": " - << AlgorithmToString(best_result.algorithm()) << ", takes " - << best_result.elapsed_time_in_ms() << "ms, and uses " - << best_result_bytes_used << "B of scratch memory."; - return AutotuneResult{best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used, - absl::Milliseconds(best_result.elapsed_time_in_ms())}; + + // Log the autotuning result. + { + tensorflow::AutotuningLog log; + { + ConvInstructionLog instr_log; + *instr_log.mutable_instruction() = instr->ToProto(); + for (const auto* op : instr->operands()) { + *instr_log.add_operand_shapes() = op->shape().ToProto(); + } + log.mutable_instr()->PackFrom(instr_log); + } + for (const auto& profile : profile_results) { + *log.add_results() = profile; + } + *log.mutable_compute_capability() = GetComputeCapability(stream_exec_); + *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); + VLOG(2) << "Autotuning result:\n" << log.DebugString(); + tensorflow::Logger::Singleton()->LogProto(log); + } + + auto* profile_results_end = profile_results.data() + profile_results.size(); + + const AutotuneResult* best_result = std::min_element( + profile_results.data(), profile_results_end, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + // The successful one should have a smaller key, since we are doing + // min_element. If they are both unsuccessful, keep the earlier one in + // the vector by comparing pointers. + return std::make_tuple(!lhs.has_success(), + tensorflow::proto_utils::FromDurationProto( + lhs.success().run_time()), + &lhs) < + std::make_tuple(!rhs.has_success(), + tensorflow::proto_utils::FromDurationProto( + rhs.success().run_time()), + &rhs); + }); + + if (best_result != profile_results_end && best_result->has_success()) { + return *best_result; } return InternalError( @@ -341,22 +367,23 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( } auto best_algo = std::move(best_algo_or).ValueOrDie(); - VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm - << " and " << NumBytesToString(best_algo.scratch_bytes) + VLOG(1) << "Setting cudnn conv to use algorithm " + << best_algo.conv().algorithm() << " and " + << NumBytesToString(best_algo.success().scratch_bytes()) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled(); // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); Shape new_call_shape = ShapeUtil::MakeTupleShape( {instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); + ShapeUtil::MakeShape(U8, {best_algo.success().scratch_bytes()})}); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, instr->backend_config()); - backend_config.set_algorithm(best_algo.algorithm); - backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); + backend_config.set_algorithm(best_algo.conv().algorithm()); + backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled()); HloInstruction* new_call = computation->AddInstruction( instr->CloneWithNewOperands(new_call_shape, instr->operands())); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 4991db0948589e479a202f4082d96df275f6e088..6ab9c7a9ecec98c9a70ea0578cdb3eb4f1d3c12d 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/protobuf/autotuning.pb.h" namespace xla { namespace gpu { @@ -47,16 +48,9 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr Run(HloModule* module) override; private: - struct AutotuneResult { - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - absl::Duration runtime; - }; - StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm( + StatusOr PickBestAlgorithm( const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..923b7bc452870f47505711e8abd4ce236be7815a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h" + +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { + +namespace { + +// Type traits to get CUDA complex types from std::complex. +template +struct CUDAComplexT { + typedef T type; +}; +template <> +struct CUDAComplexT> { + typedef cuComplex type; +}; +template <> +struct CUDAComplexT> { + typedef cuDoubleComplex type; +}; + +template +inline typename CUDAComplexT::type* ToDevicePointer(se::DeviceMemory p) { + return static_cast::type*>(p.opaque()); +} + +cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) { + switch (uplo) { + case se::blas::UpperLower::kUpper: + return CUBLAS_FILL_MODE_UPPER; + case se::blas::UpperLower::kLower: + return CUBLAS_FILL_MODE_LOWER; + default: + LOG(FATAL) << "Invalid value of blas::UpperLower."; + } +} + +// Converts a cuSolver status to a Status. +Status CusolverStatusToStatus(cusolverStatus_t status) { + switch (status) { + case CUSOLVER_STATUS_SUCCESS: + return Status::OK(); + case CUSOLVER_STATUS_NOT_INITIALIZED: + return FailedPrecondition("cuSolver has not been initialized"); + case CUSOLVER_STATUS_ALLOC_FAILED: + return ResourceExhausted("cuSolver allocation failed"); + case CUSOLVER_STATUS_INVALID_VALUE: + return InvalidArgument("cuSolver invalid value error"); + case CUSOLVER_STATUS_ARCH_MISMATCH: + return FailedPrecondition("cuSolver architecture mismatch error"); + case CUSOLVER_STATUS_MAPPING_ERROR: + return Unknown("cuSolver mapping error"); + case CUSOLVER_STATUS_EXECUTION_FAILED: + return Unknown("cuSolver execution failed"); + case CUSOLVER_STATUS_INTERNAL_ERROR: + return Internal("cuSolver internal error"); + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return Unimplemented("cuSolver matrix type not supported error"); + case CUSOLVER_STATUS_NOT_SUPPORTED: + return Unimplemented("cuSolver not supported error"); + case CUSOLVER_STATUS_ZERO_PIVOT: + return InvalidArgument("cuSolver zero pivot error"); + case CUSOLVER_STATUS_INVALID_LICENSE: + return FailedPrecondition("cuSolver invalid license error"); + default: + return Unknown("Unknown cuSolver error"); + } +} + +} // namespace + +StatusOr CusolverContext::Create(se::Stream* stream) { + cusolverDnHandle_t handle; + TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle))); + CusolverContext context(stream, handle); + + // StreamExecutor really should just expose the Cuda stream to clients... + const cudaStream_t* cuda_stream = + CHECK_NOTNULL(reinterpret_cast( + stream->implementation()->GpuStreamMemberHack())); + TF_RETURN_IF_ERROR( + CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream))); + + return std::move(context); +} + +CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle) + : stream_(stream), handle_(handle) {} + +CusolverContext::CusolverContext(CusolverContext&& other) { + handle_ = other.handle_; + stream_ = other.stream_; + other.handle_ = nullptr; + other.stream_ = nullptr; +} + +CusolverContext& CusolverContext::operator=(CusolverContext&& other) { + std::swap(handle_, other.handle_); + std::swap(stream_, other.stream_); + return *this; +} + +CusolverContext::~CusolverContext() { + if (handle_) { + Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_)); + if (!status.ok()) { + LOG(ERROR) << "cusolverDnDestroy failed: " << status; + } + } +} + +#define CALL_LAPACK_TYPES(m) \ + m(float, S) m(double, D) m(std::complex, C) m(std::complex, Z) + +#define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method + +#define POTRF_BUFFER_SIZE_INSTANCE(T, type_prefix) \ + StatusOr CusolverContext::PotrfBufferSize( \ + se::blas::UpperLower uplo, int n, se::DeviceMemory A, int lda) { \ + int size = -1; \ + TF_RETURN_IF_ERROR(CusolverStatusToStatus(DN_SOLVER_FN( \ + potrf_bufferSize, type_prefix)(handle(), CUDABlasUpperLower(uplo), n, \ + ToDevicePointer(A), lda, &size))); \ + return size; \ + } + +CALL_LAPACK_TYPES(POTRF_BUFFER_SIZE_INSTANCE); + +#define POTRF_INSTANCE(T, type_prefix) \ + Status CusolverContext::Potrf( \ + se::blas::UpperLower uplo, int n, se::DeviceMemory A, int lda, \ + se::DeviceMemory lapack_info, se::DeviceMemory workspace) { \ + return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \ + handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda, \ + ToDevicePointer(workspace), workspace.ElementCount(), \ + ToDevicePointer(lapack_info))); \ + } + +CALL_LAPACK_TYPES(POTRF_INSTANCE); + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h new file mode 100644 index 0000000000000000000000000000000000000000..fdd89c3a8d599e2291b60abcd67e267a96d3ac8f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ + +#include + +#include "cuda/include/cublas_v2.h" +#include "cuda/include/cusolverDn.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" + +namespace xla { +namespace gpu { + +class CusolverContext { + public: + static StatusOr Create(se::Stream* stream); + CusolverContext() = default; + ~CusolverContext(); + + CusolverContext(const CusolverContext&) = delete; + CusolverContext(CusolverContext&&); + CusolverContext& operator=(const CusolverContext&) = delete; + CusolverContext& operator=(CusolverContext&&); + + se::Stream* stream() const { return stream_; } + cusolverDnHandle_t handle() const { return handle_; } + + // Computes the Cholesky factorization A = L * L^T for a single matrix. + // Returns Status::OK() if the kernel was launched successfully. See: + // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf + Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, + int lda, se::DeviceMemory dev_lapack_info, + se::DeviceMemory workspace); + Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, + int lda, se::DeviceMemory dev_lapack_info, + se::DeviceMemory workspace); + Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory> dev_A, int lda, + se::DeviceMemory dev_lapack_info, + se::DeviceMemory> workspace); + Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory> dev_A, int lda, + se::DeviceMemory dev_lapack_info, + se::DeviceMemory> workspace); + + // Returns the size of the `workspace` required by Potrf, in number of + // elements of size T. + StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, + se::DeviceMemory dev_A, int lda); + StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, + se::DeviceMemory dev_A, int lda); + StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, + se::DeviceMemory> dev_A, + int lda); + StatusOr PotrfBufferSize(se::blas::UpperLower uplo, int n, + se::DeviceMemory> dev_A, + int lda); + + private: + CusolverContext(se::Stream* stream, cusolverDnHandle_t handle); + + se::Stream* stream_ = nullptr; + cusolverDnHandle_t handle_ = nullptr; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..7861eb1ef04d4fa5ba5690ee388b77a3f354f88e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc @@ -0,0 +1,216 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" + +namespace xla { +namespace gpu { + +namespace { + +void SetFortranLayout(Shape* shape) { + LayoutUtil::SetToDefaultLayout(shape); + int n = shape->mutable_layout()->minor_to_major_size(); + CHECK_GE(n, 2); + std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0), + shape->mutable_layout()->mutable_minor_to_major()->at(1)); +} + +StatusOr CreateCholesky(CusolverContext* context, + ScratchAllocator* allocator, + HloInstruction* operand, + const CholeskyOptions& options, + const OpMetadata& metadata) { + HloComputation* computation = operand->parent(); + + Shape a_shape = operand->shape(); + int ndim = a_shape.dimensions_size(); + CHECK_GE(ndim, 2); + int64 n = a_shape.dimensions(ndim - 1); + + int64 batch_size = std::accumulate(a_shape.dimensions().begin(), + a_shape.dimensions().end() - 2, int64{1}, + [](int64 a, int64 b) { return a * b; }); + + // Find the workspace size. + se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper; + int64 workspace_size; // Number of elements of size a_shape.element_type() + switch (a_shape.element_type()) { + case F32: { + TF_ASSIGN_OR_RETURN(auto a, + allocator->Allocate(context->stream(), n * n)); + TF_ASSIGN_OR_RETURN(workspace_size, + context->PotrfBufferSize(uplo, n, a, n)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN( + auto a, allocator->Allocate(context->stream(), n * n)); + TF_ASSIGN_OR_RETURN(workspace_size, + context->PotrfBufferSize(uplo, n, a, n)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate>( + context->stream(), n * n)); + TF_ASSIGN_OR_RETURN(workspace_size, + context->PotrfBufferSize(uplo, n, a, n)); + break; + } + case C128: { + TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate>( + context->stream(), n * n)); + TF_ASSIGN_OR_RETURN(workspace_size, + context->PotrfBufferSize(uplo, n, a, n)); + break; + } + default: + return InvalidArgument("Invalid type for cholesky decomposition: %s", + a_shape.ToString()); + } + + // TODO(phawkins): Ideally we would relax this constraint. What we actually + // want is that: + // a) the batch dimensions are major, in no particular order. + // b) the two minor dimensions are in fortran (column-major) order, + + SetFortranLayout(&a_shape); + + // This call returns a tuple of (cholesky_result, workspace, info) where: + // * cholesky_result is the result of the Cholesky decomposition, + // * workspace is temporary scratch memory used by cuSolver. + // * info contains the Potrf success/failure status. + // Currently we have no meaningful way to report an error, so we simply + // discard the success/failure information. Obviously this is suboptimal. + Shape call_shape = ShapeUtil::MakeTupleShape( + {a_shape, + ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}), + ShapeUtil::MakeShape(S32, {batch_size})}); + + HloInstruction* custom_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape})); + custom_call->set_metadata(metadata); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(options)); + return custom_call; +} + +} // namespace + +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(CusolverContext* context, + ScratchAllocator* allocator, + HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kCholesky) { + return false; + } + + TF_ASSIGN_OR_RETURN( + HloInstruction * custom_call, + CreateCholesky(context, allocator, instruction->mutable_operand(0), + instruction->cholesky_options(), instruction->metadata())); + + VLOG(1) << "Replacing " << instruction->ToString() << " with " + << custom_call->ToString(); + + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out + // the conv result and replace `conv` with it. + TF_RETURN_IF_ERROR(instruction->parent()->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateGetTupleElement(instruction->shape(), + custom_call, 0))); + return true; +} + +// Rewrites the convolutions in the given computation into calls to cudnn. +// Returns true if it made any changes. +StatusOr CusolverRewriter::RunOnComputation(HloComputation* computation) { + std::vector cusolver_calls; + for (auto* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kCholesky) { + cusolver_calls.push_back(hlo); + } + } + + if (cusolver_calls.empty()) { + return false; + } + + // Create a stream for us to do our work on. We don't really need to do any + // work, just allocate memory, but that's the cuSolver API. + se::Stream stream{stream_exec_}; + stream.Init(); + const auto device_ordinal = stream_exec_->device_ordinal(); + + // allocator either points to this->allocator_ or, if that's null, to a + // StreamExecutorMemoryAllocator for stream_exec_. + DeviceMemoryAllocator* allocator; + absl::optional se_allocator; + if (allocator_ != nullptr) { + allocator = allocator_; + } else { + se_allocator.emplace(stream_exec_->platform(), + absl::Span({stream_exec_})); + allocator = &*se_allocator; + } + ScratchAllocator scratch_allocator(device_ordinal, allocator); + + TF_ASSIGN_OR_RETURN(CusolverContext context, + CusolverContext::Create(&stream)); + + bool changed = false; + for (HloInstruction* instruction : cusolver_calls) { + TF_ASSIGN_OR_RETURN( + bool result, + RunOnInstruction(&context, &scratch_allocator, instruction)); + changed |= result; + } + return changed; +} + +CusolverRewriter::CusolverRewriter(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) + : stream_exec_(stream_exec), allocator_(allocator) {} + +StatusOr CusolverRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..c82233188f7de1e188876f13465f7face76a0a8b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver. +class CusolverRewriter : public HloModulePass { + public: + CusolverRewriter(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator); + absl::string_view name() const override { return "cusolver-rewriter"; } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation* computation); + + se::StreamExecutor* stream_exec_; // never null + DeviceMemoryAllocator* allocator_; // may be null +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2ab754a471070d5f90a3eaebd0600ff180d2fe5d..551f7d773aad0356d9e31c4f952908592936f2b0 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/IR/DerivedTypes.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" @@ -191,39 +192,6 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); - llvm::Type* llvm_ty = lhs_value->getType(); - - auto make_sqrt = [&, this]() -> StatusOr { - // NVPTX has four relevant square root instructions: - // sqrt.approx{.ftz}.f32 - // sqrt.rn{.ftz}.f32 - // sqrt.rn.f64 - // rsqrt.approx.f64 - // We rely on LLVM's NVPTX backend to pick the right one based on our - // fast-math options. (If fast-math is enabled, llvm may compute the 64-bit - // sqrt from the rsqrt approximation.) - return EmitLlvmIntrinsicMathCall("llvm.sqrt", {lhs_value}, {lhs_input_type}, - output_type); - }; - - const HloInstruction* rhs = op->operand(1); - if (IsFPLiteralWithValue(rhs, .5)) { - VLOG(10) << "emitting pow(A, .5) as sqrt(A): " << op->ToString(); - return make_sqrt(); - } - - if (IsFPLiteralWithValue(rhs, -.5)) { - VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); - // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX - // rsqrt.approx instruction. - // - // TODO(jlebar): Does this happen with fastmath disabled? If not, should - // we force-enable it? - TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); - } - - VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); return EmitLibdeviceMathCall("__nv_pow", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, output_type); } @@ -270,6 +238,16 @@ StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, prim_type); } +StatusOr GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) { + return EmitLibdeviceMathCall("__nv_sqrt", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) { + return EmitLibdeviceMathCall("__nv_rsqrt", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -293,6 +271,16 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return FPCast(fast_tanh, value->getType()); } +StatusOr GpuElementalIrEmitter::EmitRoundNearestAfz( + PrimitiveType prim_type, llvm::Value* value) { + // Use libdevice __nv_round instead of llvm.round. This is to workaround a + // bug in the PTX backend, which implements llvm.round with PTX cvt.rni. + // When the llvm.round is fixed, we may still want to use __nv_round here as + // expanding the non-trivial implementation early while inlining allows better + // optimizations. + return EmitLibdeviceMathCall("__nv_round", {value}, {prim_type}, prim_type); +} + llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type, @@ -308,9 +296,11 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( false); // No variadic arguments. // Declares the callee if it is not declared already. - llvm::Function* callee = llvm::cast( - b_->GetInsertBlock()->getModule()->getOrInsertFunction( - llvm_ir::AsStringRef(callee_name), callee_type)); + llvm::Function* callee = llvm::dyn_cast( + b_->GetInsertBlock() + ->getModule() + ->getOrInsertFunction(callee_name, callee_type) + .getCallee()); for (auto attribute : attributes) { callee->addFnAttr(attribute); @@ -395,12 +385,12 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); - IrArray::Index input_index(index_type, index.size()); + std::vector input_multi_index(index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = NSWSub( + input_multi_index[i] = NSWSub( NSWAdd(stridden_index, NSWMul(window_index[i], index_typed_const( @@ -409,24 +399,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( // We need to verify that we are not in the dilated base area. llvm::Value* dilation_condition = ICmpEQ( - SRem(input_index[i], + SRem(input_multi_index[i], index_typed_const(window.dimensions(i).base_dilation())), index_typed_const(0)); in_bounds = And(in_bounds, dilation_condition); // Apply base dilation to the index. - input_index[i] = - SDiv(input_index[i], + input_multi_index[i] = + SDiv(input_multi_index[i], index_typed_const(window.dimensions(i).base_dilation())); - // We must check whether 0 ≤ input_index[i] < bound, as otherwise - // we are in the pad and so can skip the computation. This + // We must check whether 0 ≤ input_multi_index[i] < bound, as + // otherwise we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison - // input_index[i] < bound, as a negative value wraps to a large + // input_multi_index[i] < bound, as a negative value wraps to a large // positive value. in_bounds = And(in_bounds, - ICmpULT(input_index[i], + ICmpULT(input_multi_index[i], index_typed_const(operand->shape().dimensions(i)))); } @@ -435,6 +425,8 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(if_data.true_block, b_); // We are not in pad, so do the computation. + IrArray::Index input_index(input_multi_index, operand->shape(), + index_type); TF_ASSIGN_OR_RETURN(llvm::Value * input_value, operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( @@ -446,7 +438,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( return Load(accum_ptr); }; case HloOpcode::kReduce: - // TODO(b/112040122): This should be supported. + // TODO(b/118332391): This should be supported. CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { @@ -461,19 +453,22 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( b()->CreateStore(init_value, accum_ptr); llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( - operand->shape(), hlo->dimensions(), "reduction_dim"); + std::vector input_multi_index = + loops.AddLoopsForShapeOnDimensions( + operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { - // Here only input_index[hlo->dimensions()] are non-null, so we must - // set the rest. + // Here only input_multi_index[hlo->dimensions()] are non-null, so we + // must set the rest. size_t j = 0; - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = output_index[j++]; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = output_index[j++]; } } CHECK_EQ(output_index.size(), j); } + llvm_ir::IrArray::Index input_index( + input_multi_index, hlo->operand(0)->shape(), index_type); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index e8b56a39ce58b6aab35c1c977553c7ff7e753273..e9d08177ad979871890a32374657d8479c0cf669 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -76,6 +76,12 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExpm1(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) override; + + StatusOr EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) override; @@ -85,6 +91,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr EmitRoundNearestAfz(PrimitiveType prim_type, + llvm::Value* value) override; + llvm::Value* EmitThreadId() override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 91930eccdff94bb2fc85636f3a4b2d661c618d87..0649f42e54e552bded9d4f5a7f9c01c90b0e46fa 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -62,7 +63,7 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Iterate through 'instructions' accumulating byte sizes of each instruction // shape. For each 'instruction' in 'instructions', if all users of - // 'instruction' are Slice instructions, accumuates the byte sizes of each + // 'instruction' are Slice instructions, accumulates the byte sizes of each // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { @@ -95,27 +96,6 @@ double CalculateBytesReadByFusionInstruction(HloInstruction* fusion) { return bytes; } -// Returns the flops to bytes transferred ratio of instruction 'fusion'. -double CalculateFlopsToBytesRatio(HloInstruction* fusion) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Calculate total bytes transferred in/out. - double bytes = CalculateBytesReadByFusionInstruction(fusion); - // Add bytes written to root instructions buffer. - if (fusion->IsMultiOutputFusion()) { - for (auto& operand : fusion->fused_expression_root()->operands()) { - bytes += ShapeUtil::ByteSizeOf(operand->shape()); - } - } else { - bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); - } - // Calculate flops for all fused instructions. Use a null shape size function - // because we don't care about bytes accessed by the ops. - HloCostAnalysis analysis([](const Shape& shape) { return 0; }); - TF_CHECK_OK(fusion->fused_expression_root()->Accept(&analysis)); - // Return flops / bytes. - return bytes > 0.0 ? analysis.flop_count() / bytes : analysis.flop_count(); -} - // Returns bytes transferred by instruction 'fusion', including the bytes // that would be read by all users. double GetCurrentBytesTransferred(HloInstruction* fusion) { @@ -169,8 +149,8 @@ class FusionInstructionMerger { int num_fail_not_loop_fusion_ = 0; int num_fail_merge_all_users_ = 0; int num_fail_expensive_fused_instruction_ = 0; - int num_fail_flops_to_byte_ratio_ = 0; int num_fail_net_bytes_transferred_ratio_ = 0; + int num_fail_inefficient_fusion_emitter_ = 0; TF_DISALLOW_COPY_AND_ASSIGN(FusionInstructionMerger); }; @@ -190,15 +170,13 @@ Status FusionInstructionMerger::Run() { << " not_loop_fusion: " << num_fail_not_loop_fusion_ << " merge_all_users: " << num_fail_merge_all_users_ << " expensive_instruction: " << num_fail_expensive_fused_instruction_ - << " flops_to_byte_ratio: " << num_fail_flops_to_byte_ratio_ << " net_bytes_transferred: " << num_fail_net_bytes_transferred_ratio_ - << " }"; + << " inefficient_fusion_emitter: " + << num_fail_inefficient_fusion_emitter_ << " }"; return Status::OK(); } Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { - VLOG(3) << "FusionInstructionMerger ENTRY fusion: " << fusion->name() - << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion); ++total_visited_; // Skip 'fusion' instruction if there are no users into which we can merge. if (fusion->users().empty()) { @@ -256,15 +234,6 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - // Skip 'fusion' instruction if its flops to bytes transferred ratio - // exceeds the threshold value. - if (CalculateFlopsToBytesRatio(fusion) > - FusionMerger::GetThresholdFlopsToBytesRatio()) { - VLOG(3) << "Not merging " << fusion->name() - << ": flops-to-bytes ratio is not favorable."; - ++num_fail_flops_to_byte_ratio_; - return Status::OK(); - } // Skip 'fusion' instruction if merging it into all users would result in a // net increase in bytes transferred (currently allowing the net bytes // transferred to be exceeded up to ~10% in exhange for eliminating the @@ -280,6 +249,23 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++num_fail_net_bytes_transferred_ratio_; return Status::OK(); } + + // Skip 'fusion' instruction if merging it into at least one of the users + // would cause too much code duplication because of inefficiencies in the + // fusion emitter. + // TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary + // fusion nodes. + if (absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) { + return FusedIrEmitter::IsFusedIrEmitterInefficient(/*consumer=*/user, + /*producer=*/fusion); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Contains one or more users where fusing would cause " + "inefficiencies in the fusion emitter."; + ++num_fail_inefficient_fusion_emitter_; + return Status::OK(); + } + // Merge fused instructions from 'fusion' into each user. std::vector users = fusion->users(); for (HloInstruction* user : users) { @@ -288,7 +274,6 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { } ++total_merged_; VLOG(2) << "Merged fusion instruction: " << fusion->name() - << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " << absl::StrJoin(users, ", ", diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index f19996edfe3dd923aa686a19621ce28a4aed5a45..a49d68002f8de5bb5640731f3cd31572593ee837 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -37,8 +37,6 @@ class FusionMerger : public HloModulePass { absl::string_view name() const override { return "fusion merger"; } StatusOr Run(HloModule* module) override; - - static double GetThresholdFlopsToBytesRatio() { return 1.0; } }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 7cc869ed9e89688d6ea06428a7bade3ebe55ea23..31b73fd250cc812807149f6d1028761cf1f35ebf 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -99,62 +99,6 @@ ENTRY MergeSharedFusionInstruction.Computation0 { EXPECT_EQ(7, operand2->fused_instruction_count()); } -// Tests that we do not merge a fusion instruction that above flops to bytes -// threshold. -// -// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. -TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = ParseHloString(R"( -HloModule FlopsToBytesRatioThresholdExceeded - -comp.2 { - state.param_1.1 = (f32[4]{0}, f32[4]{0}) parameter(0) - get-tuple-element.3 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 - get-tuple-element.4 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 - multiply.29 = f32[4]{0} multiply(get-tuple-element.3, get-tuple-element.4) - multiply.30 = f32[4]{0} multiply(get-tuple-element.3, multiply.29) - multiply.31 = f32[4]{0} multiply(get-tuple-element.3, multiply.30) - multiply.32 = f32[4]{0} multiply(get-tuple-element.3, multiply.31) - multiply.33 = f32[4]{0} multiply(get-tuple-element.3, multiply.32) - multiply.34 = f32[4]{0} multiply(get-tuple-element.3, multiply.33) - multiply.35 = f32[4]{0} multiply(get-tuple-element.3, multiply.34) - multiply.36 = f32[4]{0} multiply(get-tuple-element.3, multiply.35) - multiply.37 = f32[4]{0} multiply(get-tuple-element.3, multiply.36) - multiply.38 = f32[4]{0} multiply(get-tuple-element.3, multiply.37) - multiply.39 = f32[4]{0} multiply(get-tuple-element.3, multiply.38) - multiply.40 = f32[4]{0} multiply(get-tuple-element.3, multiply.39) - ROOT multiply.41 = f32[4]{0} multiply(get-tuple-element.3, multiply.40) -} - -comp.1 { - multiply.12.param_1.1 = f32[4]{0} parameter(1) - constant.param_1.3 = f32[4]{0} parameter(0) - add.3 = f32[4]{0} add(multiply.12.param_1.1, constant.param_1.3) - ROOT multiply.16 = f32[4]{0} multiply(add.3, constant.param_1.3) -} - -comp { - multiply.12.param_1 = f32[4]{0} parameter(1) - constant.param_1.1 = f32[4]{0} parameter(0) - multiply.15 = f32[4]{0} multiply(multiply.12.param_1, constant.param_1.1) - ROOT add.2 = f32[4]{0} add(multiply.15, constant.param_1.1) -} - -ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { - constant = f32[4]{0} constant({1, 1, 1, 1}) - state = (f32[4]{0}, f32[4]{0}) parameter(0) - fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 - fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 - fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp - ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) -})") - .ValueOrDie(); - // Run fusion merger pass, which should detect that the flops/bytes of the - // shared fusion instruction exceeds the threshold ratio, and therefore - // cannot be merged with other fusion instructions. - EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); -} - // Tests that threshold for bytes transferred if merged is exceeded. // // Fusion2 is not merged because it exceeds the threshold bytes transferred. @@ -319,6 +263,62 @@ TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } +// TODO(b/119692968): Remove this test once fusion emitter is fixed. +TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) { + auto module = ParseHloString(R"( + HloModule m + + %fused_computation (param_0.10: f32[6]) -> f32[1] { + %param_0.10 = f32[6]{0} parameter(0) + %add.7 = f32[6]{0} add(%param_0.10, %param_0.10) + %slice.21 = f32[5]{0} slice(%add.7), slice={[0:5]} + %slice.18 = f32[5]{0} slice(%add.7), slice={[1:6]} + %add.5 = f32[5]{0} add(%slice.21, %slice.18) + %slice.15 = f32[4]{0} slice(%add.5), slice={[0:4]} + %slice.12 = f32[4]{0} slice(%add.5), slice={[1:5]} + %add.4 = f32[4]{0} add(%slice.15, %slice.12) + %slice.9 = f32[3]{0} slice(%add.4), slice={[0:3]} + %slice.6 = f32[3]{0} slice(%add.4), slice={[1:4]} + %add.2 = f32[3]{0} add(%slice.9, %slice.6) + %slice.3 = f32[2]{0} slice(%add.2), slice={[0:2]} + %slice.2 = f32[2]{0} slice(%add.2), slice={[1:3]} + %add.1 = f32[2]{0} add(%slice.3, %slice.2) + %slice.1 = f32[1]{0} slice(%add.1), slice={[0:1]} + %slice.0 = f32[1]{0} slice(%add.1), slice={[1:2]} + ROOT %add.0 = f32[1]{0} add(%slice.1, %slice.0) + } + + %fused_computation.1 (param_0.21: f32[11], param_1.21: f32[11]) -> f32[6] { + %param_0.21 = f32[11]{0} parameter(0) + %param_1.21 = f32[11]{0} parameter(1) + %add.16 = f32[11]{0} add(%param_0.21, %param_1.21) + %slice.51 = f32[10]{0} slice(%add.16), slice={[0:10]} + %slice.48 = f32[10]{0} slice(%add.16), slice={[1:11]} + %add.14 = f32[10]{0} add(%slice.51, %slice.48) + %slice.45 = f32[9]{0} slice(%add.14), slice={[0:9]} + %slice.42 = f32[9]{0} slice(%add.14), slice={[1:10]} + %add.13 = f32[9]{0} add(%slice.45, %slice.42) + %slice.39 = f32[8]{0} slice(%add.13), slice={[0:8]} + %slice.36 = f32[8]{0} slice(%add.13), slice={[1:9]} + %add.11 = f32[8]{0} add(%slice.39, %slice.36) + %slice.33 = f32[7]{0} slice(%add.11), slice={[0:7]} + %slice.30 = f32[7]{0} slice(%add.11), slice={[1:8]} + %add.10 = f32[7]{0} add(%slice.33, %slice.30) + %slice.27 = f32[6]{0} slice(%add.10), slice={[0:6]} + %slice.24 = f32[6]{0} slice(%add.10), slice={[1:7]} + ROOT %add.8 = f32[6]{0} add(%slice.27, %slice.24) + } + + ENTRY entry { + p0 = f32[11]{0} parameter(0) + p1 = f32[11]{0} parameter(1) + f1 = f32[6]{0} fusion(p0, p1), kind=kLoop, calls=%fused_computation.1 + ROOT f2 = f32[1] fusion(f1), kind=kLoop, calls=%fused_computation + })") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 86c9bc6a345047fb5329af0be45c8981cc427f50..a7053e6a013be3ccf5725cbe003558be77104af1 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -428,7 +428,8 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, scratch_data = scratch_mem->device_memory(); } const MatrixDescriptor scratch_descriptor( - scratch_data, false, output_num_cols, output_num_rows, batch_size); + scratch_data, false, output_matrix.num_rows, output_matrix.num_cols, + batch_size); StatusOr best_algorithm = GetGemmAutotuneFn( element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto new file mode 100644 index 0000000000000000000000000000000000000000..ec4f6e9c91331f0142ed5434949871b5edc27462 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto @@ -0,0 +1,13 @@ +// This is used for convolution logging. Also see +// tensorflow/core/protobuf/autotuing.h +syntax = "proto3"; + +package xla.gpu; + +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +message ConvInstructionLog { + xla.HloInstructionProto instruction = 1; + repeated xla.ShapeProto operand_shapes = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 842ba2fdcd31a451cec1be543e102e0a46077f38..0cbd92a9553fca3cf73fdef30ad6a58c5c82b923 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -15,7 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include +#include + #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" namespace xla { namespace gpu { @@ -131,5 +138,51 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, get_loop_shape(instr_2)); } +bool IsInputFusibleScatter(const HloInstruction& instr) { + if (instr.opcode() == HloOpcode::kScatter || + (instr.opcode() == HloOpcode::kFusion && + instr.fusion_kind() == HloInstruction::FusionKind::kInput && + instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) { + return true; + } + return false; +} + +bool IsInputFusible(const HloInstruction& instr) { + // Input fusion only handles non-elemental reduction and scatter operations. + return IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr); +} + +bool IsLoopFusible(const HloInstruction& instr) { + // Don't fuse get-tuple-element on GPU: We can, but it's slower than not + // fusing. We never generate kernels for unfused GTEs. Instead, if an + // unfused GTE is an input to a kernel (including a fusion kernel), we + // compute the address of the GTE at the top of the kernel. Often we know the + // address of the GTE result statically, so we can do this without chasing any + // pointers. + return (instr.IsElementwise() && instr.operand_count() > 0) || + instr.opcode() == HloOpcode::kBitcast || + instr.opcode() == HloOpcode::kBroadcast || + instr.opcode() == HloOpcode::kConcatenate || + instr.opcode() == HloOpcode::kDynamicSlice || + instr.opcode() == HloOpcode::kDynamicUpdateSlice || + (instr.opcode() == HloOpcode::kFusion && + instr.fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr.opcode() == HloOpcode::kGather || + instr.opcode() == HloOpcode::kIota || + instr.opcode() == HloOpcode::kPad || + (instr.opcode() == HloOpcode::kReduce && + !IsReductionToVector(instr)) || + instr.opcode() == HloOpcode::kReduceWindow || + instr.opcode() == HloOpcode::kReshape || + instr.opcode() == HloOpcode::kReverse || + instr.opcode() == HloOpcode::kSlice || + instr.opcode() == HloOpcode::kTranspose; +} + +bool IsFusible(const HloInstruction& instr) { + return IsInputFusible(instr) || IsLoopFusible(instr); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e9d7ba1c4cfa865532a0d06c2ed883a2fea4e2cd..3a59c74f64b5486c52e9cdac43f343d70f0d2558 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -24,6 +24,15 @@ limitations under the License. namespace xla { namespace gpu { +// Whether 'instr' can occur inside fusions, i.e. whether it is a candidate +// for being fused. Note that further restrictions apply, e.g. Scatter must +// be the root of an input fusion. +bool IsFusible(const HloInstruction& instr); + +bool IsInputFusible(const HloInstruction& instr); + +bool IsLoopFusible(const HloInstruction& instr); + // The code emitted for reduce-rooted input fusions (EmitReductionToVector) // suffers from poor data locality if the layouts of input parameters differ. In // such situtations it is better not to fuse. Only input params with @@ -46,9 +55,13 @@ bool IsReduceInputFusion(const HloInstruction& instr); // is either an unfused reduction-to-vector op or a reduce input fusion. bool IsInputFusibleReduction(const HloInstruction& instr); +// Whether `instr` is fusible as root of a scatter input fusions, i.e. `instr` +// is either an unfused scatter op or a scatter input fusion. +bool IsInputFusibleScatter(const HloInstruction& instr); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. -// This function works for both, sibling and producer-conumser multi-output +// This function works for both, sibling and producer-consumer multi-output // fusion. // So far, multi-output fusion is supported for loop fusions and reduce // input fusions only. It is up to the caller to ensure the instructions diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index 15d4ee206ce8debcb8a5dbc6ec65d29ba257d302..ee64b3a7596170a2bbf0a430cd6c7de5cebf2da1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -62,7 +62,7 @@ TEST_F(GpuFusibleTest, copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) } fused_reduce { @@ -122,7 +122,7 @@ TEST_F(GpuFusibleTest, p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast) + greater-than = pred[128,1024,32,32]{3,2,1,0} compare(p1.1, broadcast), direction=GT select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast) ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select) } @@ -507,7 +507,7 @@ TEST_F(GpuFusibleTest, p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f32[2,2,2]{2,1,0} parameter(0) ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4268fb2c7a813b3b53e4cd48746028a7b369f28e..4765f67c4b17e97419182e341573f75ad3d6ac30 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 58bdd4209a2315cdb7d29e920faded4d1a6a5876..a6d80f0b6dddb3d8d0fd00c639e11c71da6a9f09 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -240,6 +240,32 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(keys_layout, *output_buffer)); } + } else if (instruction->opcode() == HloOpcode::kTriangularSolve) { + // TODO(phawkins): Ideally we would relax this constraint. What we + // actually want is that: + // a) the batch dimensions are major, in no particular order. + // b) the two minor dimensions are in fortran (column-major) order, + // although for the 'a' argument we could potentially accept row-major + // order and fold the transpose into the operator. + auto set_fortran_layout = [](Shape* shape) { + LayoutUtil::SetToDefaultLayout(shape); + int n = shape->mutable_layout()->minor_to_major_size(); + CHECK_GE(n, 2); + std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0), + shape->mutable_layout()->mutable_minor_to_major()->at(1)); + }; + Shape op0_shape = instruction->operand(0)->shape(); + Shape op1_shape = instruction->operand(1)->shape(); + Shape output_shape = instruction->shape(); + set_fortran_layout(&op0_shape); + set_fortran_layout(&op1_shape); + set_fortran_layout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 29756d27260b0f41b2dd4b649ea9b1610ff90268..3630c3e38c59c8a2557befe38f3415d6bab1ee38 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -368,12 +368,21 @@ TEST_F(LayoutAssignmentTest, DotLayout) { TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + ENTRY sort { keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), - dimensions={1} + dimensions={1}, to_apply=compare })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e38ceca18de30e0e1fa75a7a4bd865e000b7d22 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace gpu { + +StatusOr GpuSanitizeConstantNames::Run(HloModule* module) { + bool changed = false; + + NameUniquer instr_name_uniquer(/*separator=*/"_"); + // Collect the names used for the non-constant HLO instructions.+ + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kConstant) { + continue; + } + + const string& old_name = instr->name(); + instr->UniquifyName(&instr_name_uniquer); + CHECK_EQ(old_name, instr->name()); + } + } + + // Sanitize the names for the constant HLO instructions and make them unique. + // This is not merged into the above loop because we don't want this pass to + // change the names of non-constant instructions, that is, if a constant HLO + // conflicts with a non-constant HLO, we change the name of the constant HLO + // even though the non-constant HLO comes after in the HLO module. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() != HloOpcode::kConstant) { + continue; + } + string sanitized_name = llvm_ir::SanitizeConstantName(*instr); + instr->SetAndSanitizeName(sanitized_name); + instr->UniquifyName(&instr_name_uniquer); + changed = true; + } + } + + return changed; +} // namespace gpu + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h similarity index 56% rename from tensorflow/compiler/xla/service/implicit_broadcast_remover.h rename to tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h index 9c48b7db613b049536c76237b4cfebbbc47448f3..8d583d047e25698e86032020b7fc20df87f5ab68 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,30 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ - -#include +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { +namespace gpu { -// Pass which replaces all implicit broadcasts with their equivalent sequence of -// explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloModulePass { +// Sanitizes HLO instruction names for the GPU backend. Currently, it only +// replaces . and - in the HLO constant instruction names with _ to please the +// LLVM PTX backend. +class GpuSanitizeConstantNames : public HloModulePass { public: - ImplicitBroadcastRemover() {} - ~ImplicitBroadcastRemover() override {} - - absl::string_view name() const override { - return "implicit-broadcast-remover"; - } + absl::string_view name() const override { return "sanitize-constant-names"; } StatusOr Run(HloModule* module) override; }; +} // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5adee8cc61f18f356406d8c089dd43565957739 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using SanitizeConstantNamesTest = HloTestBase; + +TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal-to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); +} + +TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal.to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); +} + +TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) { + const char *const kHloString = R"( + HloModule BufferSanitizedName + ENTRY kernelEntry { + equal.to = s32[2]{0} constant({42, 73}) + equal-to = s32[2]{0} constant({67, 3}) + ROOT equal_to = s32[2]{0} add(equal.to, equal-to) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"), op::Constant()); + EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"), op::Constant()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 8c6a6914792a96ab517fa5f20ff2215e4785490e..e593f535642e15f28a4a1c1f321881ba3c694548 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 69aaaceca112364a4fd562f6a5eff1629fd3fc54..3c50c2b1d8e185975958ba08527aeb21686050e1 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -113,10 +113,9 @@ void HloToIrBindings::EmitBasePointersForHlos( BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), index); } else if (slice.allocation()->is_constant()) { - llvm::Value* global_for_constant = - module_->getGlobalVariable(llvm_ir::AsStringRef( - llvm_ir::ConstantBufferAllocationToGlobalName( - *slice.allocation()))); + llvm::Value* global_for_constant = module_->getGlobalVariable( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation())); BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); @@ -136,11 +135,11 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); + EmitGetTupleElement(gte->operand(0), base_ptr), b_); } // Returns true if `value` has a name that should not be changed. @@ -166,11 +165,10 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); } if (!HasMeaningfulName(ir_value)) { - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + ir_value->setName(llvm_ir::IrName(&hlo, "raw")); } if (!HasMeaningfulName(typed_ir_value)) { - typed_ir_value->setName( - llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed")); } return typed_ir_value; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 6151dd8ff4c92bb81bd756c68cc9377633c8c9d5..62f625defc3e3eb3b1dd01e0992281edc810454d 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -28,31 +29,6 @@ namespace gpu { namespace { -bool IsFusible(const HloInstruction& hlo) { - // Don't fuse get-tuple-element on GPU: We can, but it's slower than not - // fusing. We never generate kernels for unfused GTEs. Instead, if an - // unfused GTE is an input to a kernel (including a fusion kernel), we - // compute the address of the GTE at the top of the kernel. Often we know the - // address of the GTE result statically, so we can do this without chasing any - // pointers. - return (hlo.IsElementwise() && hlo.operand_count() > 0) || - hlo.opcode() == HloOpcode::kBitcast || - hlo.opcode() == HloOpcode::kBroadcast || - hlo.opcode() == HloOpcode::kConcatenate || - hlo.opcode() == HloOpcode::kDynamicSlice || - hlo.opcode() == HloOpcode::kDynamicUpdateSlice || - hlo.opcode() == HloOpcode::kFusion || - hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || - hlo.opcode() == HloOpcode::kReduce || - hlo.opcode() == HloOpcode::kReduceWindow || - hlo.opcode() == HloOpcode::kReshape || - hlo.opcode() == HloOpcode::kReverse || - hlo.opcode() == HloOpcode::kScatter || - hlo.opcode() == HloOpcode::kSlice || - hlo.opcode() == HloOpcode::kTranspose; -} - bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { if (constant->opcode() != HloOpcode::kConstant || !ShapeUtil::IsScalar(constant->shape())) { @@ -138,8 +114,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; } -bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, - int64 operand_index) { +bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, + int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha @@ -275,29 +251,29 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } - - // We put this check last because it's potentially expensive. - return !FusionWouldBeTooLarge(consumer, producer); + return true; } -bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, - int64 operand_index) { - const HloInstruction* producer = consumer->operand(operand_index); - // The IR emitter has limited support for non-loop fusions with multi output - // at present. - // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { +bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, + int64 operand_index) { + if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) { return false; } - // Multi-output fusion requires instructions with compatible shapes. - if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { + auto producer = consumer->operand(operand_index); + // The following checks are potentially expensive. + if (FusionWouldBeTooLarge(consumer, producer)) { return false; } - // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for - // multi-output fusion. In particular, do not check whether an instruction is - // expensive to duplicate, since this doesn't matter here. - return GpuInstructionFusion::ShouldFuse(consumer, operand_index); + // Also check that our emitter can handle the fusion node. We currently can + // have exponential time/memory requirements for emitting certain fusion + // kernels, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once we have fixed our fusion emitter. + return !FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer); +} + +bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + return false; } HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index c91f6343a69268ca687004dbe0ffbb863271a95c..2f8f40b4b5ef4f0d203f8d476ebfd21032c27b62 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -49,6 +49,12 @@ class GpuInstructionFusion : public InstructionFusion { HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; + + private: + // This method is called by ShouldFuse() to do all the computationally + // inexpensive checks whether we should fuse the operand into 'consumer'. + bool ShouldFuseInexpensiveChecks(HloInstruction* consumer, + int64 operand_index); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 688604cd36e5a45debf855aacd29d05ecda92341..a05ab86cf77a134a1fc387d93cb482aa1ff5345b 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -506,202 +506,11 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { })") .ValueOrDie(); - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT( - fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); -} - -TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { - // tanh --> add --> tuple - // \---------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - tanh = f32[4,3]{1,0} tanh(p0) - add = f32[4,3]{1,0} add(tanh, p1) - ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) - })") - .ValueOrDie(); - - // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + // Multi-output fusion is disabled here and performed in the + // GpuMultiOutputFusion pass instead. ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion2) { - // sub --> add1 --\--------\ - // \----------> add2 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(sub, add1) - ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Add()), - op::Add(op::Subtract(), op::Parameter()))); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion3) { - // sub --> add1 ----\--------\ - // \ --> add2 --> add3 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - p3 = f32[4,3]{1,0} parameter(3) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(p2, sub) - add3 = f32[4,3]{1,0} add(add1, add2) - ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Add(), op::Add()), - op::Add(op::Parameter(), op::Subtract()))); -} - -TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { - // sub --> mul ---\ - // \--> call --> add --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - c = f32[] constant(42) - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - sub = f32[4,3]{1,0} subtract(p0, p1) - mul = f32[4,3]{1,0} multiply(sub, c) - call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" - add = f32[4,3]{1,0} add(mul, call) - ROOT tuple = (f32[4,3]{1,0}) tuple(add) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - // Visit instructions in post order to detect cycles. - // TODO(tjoerg): Add cycle detection to the HloVerifier. - class DummyVisitor : public DfsHloVisitorWithDefault { - public: - DummyVisitor() {} - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - } visitor; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - // Accept will return a FailedPrecondition when a cycle is detected. - EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); - } -} - -TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { - // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) - // \-------------------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[2,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[2,3]{1,0} parameter(2) - sub = f32[2,3]{1,0} subtract(p0, p2) - add = f32[4,3]{1,0} add(sub, p1) - ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) - })") - .ValueOrDie(); - - // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` - // have incompatible shapes, expect that no multi-output fusion happens. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { - auto module = ParseHloString(R"( - HloModule test_module - - add_computation { - add_lhs = f32[] parameter(0) - add_rhs = f32[] parameter(1) - ROOT add_root = f32[] add(add_lhs, add_rhs) - } - - fused_computation { - p1 = f32[10] parameter(0) - zero = f32[] constant(0) - ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, - to_apply=add_computation - } - - ENTRY entry { - p0 = f32[10] parameter(0) - mul = f32[10] multiply(p0, p0) - fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation - ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) - })") - .ValueOrDie(); - - // Multi-output fusion is not supported for non-loop fusions at present. Since - // `fused_computation` is a input fusion, expect no multi-output fusion to - // happen. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); + .ValueOrDie()); } TEST_F(InstructionFusionTest, FuseScalarConstant) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 82bdd677d96d3d0826bb4127b32d074eb632b1a3..6b9cbdd94b334ab7a4f61a4e3e43250ed9648cd0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -143,6 +142,16 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { target == kCudnnConvBiasActivationForwardCallTarget; } +const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky"; + +bool IsCustomCallToCusolver(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const auto& target = hlo.custom_call_target(); + return target == kCusolverCholeskyCallTarget; +} + bool ImplementedAsLibraryCall(const HloInstruction& hlo) { return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || IsCustomCallToDnnConvolution(hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index ebf4d926b7a280e10b09a2532caba7ad6ab3ceb2..f1a7aabb4db57b6818b29bdde73d87f0706f2827 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -131,6 +131,19 @@ extern const char* const kCudnnConvBiasActivationForwardCallTarget; // kConvolution opcode. bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); +// Returns true if `hlo` will be implemented as a call to a cuSolver routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCusolver... constants, but returns *false* for HLOs with +// say, a kCholesky opcode. +bool IsCustomCallToCusolver(const HloInstruction& hlo); + +// Cholesky decomposition. Takes a (batched) matrix as input, and returns a +// tuple of (result, workspace, info), where result is the result of the +// Cholesky decomposition, workspace is scratch space for cuSolver, and info +// is a success/failure code per batch element. +extern const char* const kCusolverCholeskyCallTarget; + // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 0007a9a8a3369d8ac010640127e1561615a6d813..f04e8241e5767da539964e93ebf48a79a2b5024c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { get_tuple_element->shape(), get_tuple_element->tuple_index(), // TODO(b/26344050): tighten the alignment here // based on the real element type. - /*alignment=*/1, GetBasePointer(*operand), &b_, module_)); + /*alignment=*/1, GetBasePointer(*operand), &b_)); return Status::OK(); } @@ -144,7 +144,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_); return Status::OK(); } @@ -434,7 +434,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), GetIrArray(*pred, *tuple_select), GetBasePointer(*on_true), GetBasePointer(*on_false), - &b_, module_); + &b_); return Status::OK(); } @@ -492,8 +492,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); result = InsertValue(result, value.first, {0}); result = InsertValue(result, value.second, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { result = FMul(lhs_value, rhs_value); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + result = Mul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -525,16 +528,18 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_); - llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( - lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( - rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + std::vector lhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + std::vector rhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // We don't have to iterate over the batch dimensions in both arrays, simplify // the loop nest of the rhs. for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); - rhs_index[i] = lhs_index[i]; + rhs_multi_index[i] = lhs_multi_index[i]; } // Create the reduction loop which does the sum of products reduction. @@ -545,8 +550,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // The final entry in the rhs and lhs indexes is the indvar of the reduction // loop. - lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); - rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); // For computing the sum of products we alloca a single location to store the // dot product result as we accumulate it within the reduction loop. After the @@ -571,7 +576,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); b_.SetInsertPoint( &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); + llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(), + b_.getInt64Ty()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); + llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(), + b_.getInt64Ty()); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; @@ -583,9 +592,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum_imag = Imag(accum, &b_); llvm::Value* imag_sum = FAdd(accum_imag, value.second); updated_accum = InsertValue(updated_accum, imag_sum, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { llvm::Value* product = FMul(lhs_element, rhs_element); updated_accum = FAdd(accum, product); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + llvm::Value* product = Mul(lhs_element, rhs_element); + updated_accum = Add(accum, product); } Store(updated_accum, accum_address); @@ -593,20 +606,22 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // address. The index into the target address is the concatenation of the rhs // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. - llvm_ir::IrArray::Index target_index(index_type); + std::vector target_multi_index; for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { - target_index.push_back(lhs_index[dimension]); + target_multi_index.push_back(lhs_index[dimension]); } } // Skip over the batch dimensions to not have them in the index twice. for (size_t dimension = dnums.lhs_batch_dimensions_size(); dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { - target_index.push_back(rhs_index[dimension]); + target_multi_index.push_back(rhs_index[dimension]); } } SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); + llvm_ir::IrArray::Index target_index(target_multi_index, + target_array.GetShape(), index_type); target_array.EmitWriteArrayElement( target_index, Load(accum_address), // The value written to the target array. @@ -647,7 +662,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support variadic reduce. + // TODO(b/118332391): Support variadic reduce. if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on GPU"); } @@ -671,7 +686,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Value*s are placed for each dimension in dimensions, and all the rest // are nullptrs. llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const llvm_ir::IrArray::Index reduced_dims_index = + std::vector input_multi_index = loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); @@ -682,17 +697,18 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // filled in. We fill in the rest of the dimensions with induction // Value*s taken from 'index' which iterates over the target array. // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; } } CHECK(index.end() == it); // Apply the reduction function to the loaded value. + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + b_.getInt64Ty()); llvm::Value* input_address = GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 66c65f69758e5a2f4420935279835eaf086fea45..8c02416eef452c932e2adeebf0da7ff245f87447 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -77,10 +77,10 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::Function* function = llvm::Function::Create( function_type, // The function type. llvm::GlobalValue::InternalLinkage, // The linkage type. - llvm_ir::AsStringRef(ir_emitter_context_->name_uniquer()->GetUniqueName( + ir_emitter_context_->name_uniquer()->GetUniqueName( llvm_ir::SanitizeFunctionName( - nested_computation.name()))), // The name of the function. - ir_emitter_context_->llvm_module()); // The parent LLVM module. + nested_computation.name())), // The name of the function. + ir_emitter_context_->llvm_module()); // The parent LLVM module. for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size(); ++arg_no) { int64 arg_size = argument_dereferenceable_bytes[arg_no]; @@ -123,7 +123,7 @@ Status IrEmitterNested::EmitTargetElementLoop( ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 294a454931b5cfa368bf094c428a1e942f4556b8..07038607bdc783238eb2f349bf636e420e6ae20b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -38,8 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -55,11 +55,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -74,6 +76,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -102,6 +105,8 @@ using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; +namespace m = match; + // If a dimensions is smaller than this, untiled transposition may be more // efficient. const int64 kMinDimensionToTransposeTiled = 16; @@ -226,7 +231,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); } else { - fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index()))); + fn_arg->setName(StrCat("alloc", alloc->index())); } } @@ -476,6 +481,51 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return Status::OK(); } + if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) { + TF_ASSIGN_OR_RETURN(CholeskyOptions options, + custom_call->backend_config()); + + const Shape& shape = custom_call->operand(0)->shape(); + int ndim = shape.dimensions_size(); + CHECK_GE(ndim, 2); + int64 n = shape.dimensions(ndim - 1); + + const auto& dims = shape.dimensions(); + int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1}, + [](int64 a, int64 b) { return a * b; }); + + auto operand_buffer = GetAllocationSlice(*custom_call->operand(0)); + + const auto& assn = ir_emitter_context_->buffer_assignment(); + auto a_buffer = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto workspace_buffer = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + auto info_buffer = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); + + std::vector> thunks; + + if (operand_buffer != a_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/a_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call)); + } + + thunks.push_back(absl::make_unique( + options, a_buffer, workspace_buffer, info_buffer, + custom_call->operand(0)->shape().element_type(), batch_size, n, + custom_call)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), custom_call)); + } + + return Status::OK(); + } + return IrEmitter::HandleCustomCall(custom_call); } @@ -487,6 +537,41 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { return Status::OK(); } +Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { + auto has_fortran_layout = [](const Layout& layout) { + int n = layout.minor_to_major_size(); + return layout.minor_to_major(0) == n - 2 && + layout.minor_to_major(1) == n - 1; + }; + TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->shape().layout())); + + std::vector> thunks; + + // Triangular solve is in-place on 'b', so copy 'b' to the output if they + // aren't the same buffer. + auto operand_buffer = GetAllocationSlice(*hlo->operand(1)); + auto destination_buffer = GetAllocationSlice(*hlo); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); + } + + thunks.push_back(BuildTriangularSolveThunk(hlo)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), hlo)); + } + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { @@ -546,7 +631,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { - // TODO(b/112040122): Support variadic reduce. + // TODO(b/118332391): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } return EmitReductionToVector(fusion); @@ -635,7 +720,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support multi-output reduce. + // TODO(b/118332391): Support multi-output reduce. if (!reduce->shape().IsArray()) { return Unimplemented("Multi-output reduce is not supported on GPU"); } @@ -778,16 +863,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - IrArray::Index operand_index(index_type, source_index.size()); + std::vector operand_multi_index(source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = + operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), index_typed_constant(window.dimensions(i).padding_low())); llvm::Value* index_condition = ICmpULT( - operand_index[i], + operand_multi_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = And(in_bounds_condition, index_condition); } @@ -812,6 +897,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); + IrArray::Index operand_index(operand_multi_index, operand->shape(), + index_type); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); Store(operand_data, selected_value_address); @@ -822,7 +909,6 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // potentially update the selected value and index with the currently // visiting operand. llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( @@ -854,15 +940,18 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); - IrArray::Index selected_index(operand_index.GetType()); + std::vector selected_multi_index; for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = InBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(Load(selected_index_address_slot)); + selected_multi_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) .EmitArrayElementAddress(source_index, &b_); + IrArray::Index selected_index(selected_multi_index, + select_and_scatter->shape(), + operand_index.GetType()); llvm::Value* output_value_address = GetIrArray(*select_and_scatter, *select_and_scatter) .EmitArrayElementAddress(selected_index, &b_); @@ -891,13 +980,12 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - // TODO(b/112163966): Move trip count computation earlier in the pipeline. - if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { - AddThunkToThunkSequence(BuildForThunk(xla_while, *loop_trip_count)); - VLOG(3) << "Built ForThunk for while: " << xla_while->name(); + auto config = xla_while->backend_config(); + if (config.ok() && config.ValueOrDie().has_known_trip_count()) { + AddThunkToThunkSequence( + BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n())); } else { AddThunkToThunkSequence(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } @@ -959,18 +1047,18 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { BuildKernelThunk(scatter, /*implements_whole_instruction=*/thunks.empty())); - TF_RETURN_IF_ERROR( - EmitScatter(thunks.back().get(), scatter, - /*scatter_indices_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*scatter_indices, *scatter) - .EmitReadArrayElement(index, &b_, "scatter_index"); - }, - /*updates_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*updates, *scatter) - .EmitReadArrayElement(index, &b_, "update"); - })); + TF_RETURN_IF_ERROR(EmitScatter( + thunks.back().get(), scatter, + /*scatter_indices_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*scatter_indices, *scatter) + .EmitReadArrayElement(index, &b_, "scatter_index"); + }, + /*updates_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*updates, *scatter) + .EmitReadArrayElement(index, &b_, "update"); + })); // Elide the sequential thunk if there's no copy. if (thunks.size() == 1) { @@ -1041,16 +1129,20 @@ Status IrEmitterUnnested::EmitScatter( // Now load the indices corresponding to the current window from // scatter_indices. - llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, - index.GetType()); - raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + std::vector raw_scatter_index_multidim = + input_scatter_multidim; + raw_scatter_index_multidim.insert( + raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(), + nullptr); llvm::Value* is_in_bounds = b_.getTrue(); for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); i != e; ++i) { // Our index is stored along index_vector_dim, insert that into the lookup // index into scatter_indices. - raw_scatter_index_index[dim_numbers.index_vector_dim()] = - raw_scatter_index_index.GetConstantWithIndexType(i); + raw_scatter_index_multidim[dim_numbers.index_vector_dim()] = + index.GetConstantWithIndexType(i); + llvm_ir::IrArray::Index raw_scatter_index_index( + raw_scatter_index_multidim, scatter_indices_shape, index.GetType()); int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); TF_ASSIGN_OR_RETURN( @@ -1118,17 +1210,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; Shape keys_shape = sort->operand(0)->shape(); int64 dimension_to_sort = sort->dimensions(0); - // In case there is a 'values' parameter that is a iota, we take note and use - // it later to ensure a stable sort. Otherwise, we don't guarantee a stable - // sort. - int64 iota_values_parameter_index = -1; for (int64 i = 0; i < sort->operand_count(); ++i) { - if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && - ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && - Cast(sort->operand(i))->iota_dimension() == - dimension_to_sort) { - iota_values_parameter_index = i; - } ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the @@ -1241,25 +1323,23 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); - IrArray keys_array; std::vector values_arrays; - values_arrays.reserve(sort->operand_count() - 1); + values_arrays.reserve(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - if (i == 0) { - keys_array = GetIrArray(*sort, *sort, shape_index); - } else { - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); - } + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); } return llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, - iota_values_parameter_index, IrName(sort), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, - kTileSize); + kTileSize, + [&](absl::Span operands, llvm::Value* output) { + return EmitCallToNestedComputation(*sort->to_apply(), operands, + output); + }); }; std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { @@ -1296,11 +1376,55 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { return IrEmitter::HandleTupleSelect(tuple_select); } +namespace { + +bool IsScalarAddComputation(HloComputation* computation) { + return Match(computation->root_instruction(), + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)) + .WithShape(m::Shape().IsEffectiveScalar())); +} + +} // namespace + Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { + VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count() + << "; operand count: " << crs->operand_count() + << "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled(); + + // Note the replica_count == 1 case is handled via device-to-device copy + // below. + bool should_use_nccl_thunk = + hlo_module_config_.replica_count() > 1 && + crs->IsCrossReplicaAllReduce() && + crs->operand_count() == 1 && // One array to reduce. + crs->operand(0)->shape().element_type() == F32 && + // Check the computation is a summation. + IsScalarAddComputation(crs->to_apply()); + + if (should_use_nccl_thunk) { + CHECK(crs->operand(0)->shape().IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); + AddThunkToThunkSequence(absl::make_unique( + /*replica_count=*/hlo_module_config_.replica_count(), + /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()), + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), crs)); + return Status::OK(); + } + if (hlo_module_config_.replica_count() != 1) { - // TODO(b/33011107): Support nontrivial cross replica sum on GPU. - return Unimplemented( - "AllReduce with >1 replica is not implemented on GPU."); + // TODO(b/33011107): Support more AllReduce configurations on GPU. + string message = absl::StrFormat( + "Requested AllReduce not implemented on GPU; replica_count: %d; " + "operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d", + hlo_module_config_.replica_count(), crs->operand_count(), + crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled()); + if (crs->operand_count() > 0) { + absl::StrAppendFormat( + &message, "; first operand array element-type: %s", + PrimitiveType_Name(crs->operand(0)->shape().element_type())); + } + return Unimplemented("%s", message); } // CRS with one operand and one replica is simply the identity function. @@ -1543,8 +1667,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::Value* loc; if (slice.allocation()->is_constant()) { loc = ir_emitter_context_->llvm_module()->getGlobalVariable( - llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName( - *slice.allocation()))); + llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation())); CHECK_NE(loc, nullptr); } else { loc = InBoundsGEP(kernel_args.at(slice.allocation()), @@ -1573,7 +1696,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( } return absl::make_unique( - non_constant_buffers, llvm_ir::AsString(kernel->getName()), + non_constant_buffers, kernel->getName(), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -1758,6 +1881,29 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } +std::unique_ptr IrEmitterUnnested::BuildTriangularSolveThunk( + const HloInstruction* inst) { + const HloInstruction* a = inst->operand(0); + const HloInstruction* b = inst->operand(1); + int64 m = b->shape().dimensions(b->shape().rank() - 2); + int64 n = b->shape().dimensions(b->shape().rank() - 1); + int64 batch_size = std::accumulate( + b->shape().dimensions().begin(), b->shape().dimensions().end() - 2, + int64{1}, [](int64 a, int64 b) { return a * b; }); + int64 elem_size = + ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type()); + int64 a_batch_stride = inst->triangular_solve_options().left_side() + ? m * m * elem_size + : n * n * elem_size; + int64 b_batch_stride = m * n * elem_size; + return absl::make_unique( + inst->triangular_solve_options(), + /*a_input_buffer=*/GetAllocationSlice(*a), + /*b_input_buffer=*/GetAllocationSlice(*inst), + inst->shape().element_type(), batch_size, m, n, a_batch_stride, + b_batch_stride, inst); +} + StatusOr> IrEmitterUnnested::BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); @@ -1931,41 +2077,32 @@ Status CheckWhileBuffersShareAllocation( // Checks that the buffers used in a conditional instruction are shared with the // operands and result as follows: // * The result buffer of the conditional should share the allocation with the -// result buffers of the true and false computations. -// * The buffer of operand 1 should share the allocation with the buffer of -// the parameter 0 instruction of the true computation. -// * The buffer of operand 2 should share the allocation with the buffer of -// the parameter 0 instruction of the false computation. +// result buffers of each branch computation. +// * The buffer of operand b+1 should share the allocation with the buffer of +// the parameter 0 instruction of the b'th computation. Status CheckConditionalBuffersShareAllocation( const HloInstruction* conditional, const BufferAssignment& buffer_assignment) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( conditional->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { - TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( - conditional, conditional->true_computation()->root_instruction(), - index, buffer_assignment)); - TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( - conditional, conditional->false_computation()->root_instruction(), - index, buffer_assignment)); + for (auto branch_computation : conditional->branch_computations()) { + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + conditional, branch_computation->root_instruction(), index, + buffer_assignment)); + } return Status::OK(); })); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - conditional->operand(1)->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { - return CheckHloBuffersShareAllocation( - conditional->operand(1), - conditional->true_computation()->parameter_instruction(0), index, - buffer_assignment); - })); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - conditional->operand(2)->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { - return CheckHloBuffersShareAllocation( - conditional->operand(2), - conditional->false_computation()->parameter_instruction(0), index, - buffer_assignment); - })); + for (int j = 0; j < conditional->branch_count(); ++j) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->operand(j + 1)->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + return CheckHloBuffersShareAllocation( + conditional->operand(j + 1), + conditional->branch_computation(j)->parameter_instruction(0), + index, buffer_assignment); + })); + } return Status::OK(); } @@ -2018,22 +2155,20 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( TF_CHECK_OK(CheckConditionalBuffersShareAllocation( hlo, ir_emitter_context_->buffer_assignment())); - HloComputation* true_computation = hlo->true_computation(); - IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, - ir_emitter_context_); - TF_CHECK_OK(true_computation->Accept(&ir_emitter_true)); - - HloComputation* false_computation = hlo->false_computation(); - IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, - ir_emitter_context_); - TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); + std::vector branch_operands; + std::vector branch_thunks; + for (int j = 0; j < hlo->branch_count(); ++j) { + branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); + HloComputation* branch_computation = hlo->branch_computation(j); + IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation, + ir_emitter_context_); + TF_CHECK_OK(branch_computation->Accept(&ir_emitter)); + branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence())); + } return absl::make_unique( - GetAllocationSlice(*hlo->operand(0)), - GetAllocationSlice(*hlo->operand(1)), - GetAllocationSlice(*hlo->operand(2)), - std::move(*ir_emitter_true.ConsumeThunkSequence()), - std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo); + GetAllocationSlice(*hlo->operand(0)), branch_operands, + std::move(branch_thunks), hlo); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( @@ -2066,7 +2201,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_); }); // For multioutput fusion, we need to emit each operand and the root. @@ -2133,7 +2268,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } - int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, @@ -2782,13 +2916,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( index, GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); - int num_partial_results = reduction_info->GetNumberOfPartialResults(); - if (num_partial_results > 1) { - // Clear the linear index field of the IrArray::Index to enable the use of - // GetElementPointer with array types. This enables the vectorization of - // the computation for different partial results. - input_index.ClearLinearIndex(); - } absl::Span partial_reduction_result_addresses = reduction_info->GetPartialResultAddresses(); absl::Span reduction_input_addresses = @@ -2967,12 +3094,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { - KernelSupportLibrary{&b_}.If( - "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), - ConstructIrArrayForOutputs(*unnested_hlo), &b_, - module_); - }); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_); + }); } // For each tiled parameter, cast its input IrArray to the corresponding @@ -3710,8 +3835,7 @@ Status IrEmitterUnnested::EmitConstantGlobals() { global_type, /*isConstant=*/should_emit_initializer, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/initializer, - llvm_ir::AsStringRef( - llvm_ir::ConstantBufferAllocationToGlobalName(allocation))); + llvm_ir::ConstantBufferAllocationToGlobalName(allocation)); global_for_const->setAlignment(kConstantBufferAlignBytes); ir_emitter_context_->llvm_module()->getGlobalList().push_back( global_for_const); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 21b842bb2cd63ac454f85556df20ae5877cecbe1..9890ce122dfdc7444d769b6eb695a7c0932408c3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -176,6 +176,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; + Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* after_all) override; @@ -319,6 +320,12 @@ class IrEmitterUnnested : public IrEmitter { // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); + // Returns a CholeskyThunk that calls cuSolver to implement `inst`. + std::unique_ptr BuildCholeskyThunk(const HloInstruction* inst); + + // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. + std::unique_ptr BuildTriangularSolveThunk(const HloInstruction* inst); + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); @@ -352,9 +359,9 @@ class IrEmitterUnnested : public IrEmitter { std::unique_ptr BuildForThunk(const HloInstruction* hlo, const int64 loop_limit); - // Returns a ConditionalThunk that executes the thunk sequence for - // 'true_computation' or 'false_computation' depending on the value of the - // predicate in the given conditional instruction. + // Returns a ConditionalThunk which executes the thunk sequence for the + // 'branch_computation' corresponding to the predicate/branch_index of the + // given conditional instruction. std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); Status Postprocess(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 153aab97d9eb971734c5ea95564895631bc2a9fa..275a638e8c02d8f841790d39cac0cba175d31fab 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -140,10 +140,9 @@ static string GetSmName(std::pair compute_capability) { // Convenience function for producing a name of a temporary compilation product // from the input filename. -string MakeNameForTempProduct(const std::string& input_filename, +string MakeNameForTempProduct(absl::string_view input_filename, absl::string_view extension) { - return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( - llvm_ir::AsString(input_filename))), + return ReplaceFilenameExtension(tensorflow::io::Basename(input_filename), extension); } @@ -254,11 +253,8 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { llvm::buffer_ostream pstream(stream); // The extension is stripped by IrDumpingPassManager, so we need to // get creative to add a suffix. - string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension( - absl::string_view(tensorflow::io::Basename(module_id)), - "-nvptx.dummy"), + MakeNameForTempProduct(module->getModuleIdentifier(), "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -336,7 +332,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // If the module has no functions or globals, there's nothing to compile. Just // return an empty string. if (module->empty() && module->global_empty()) { - VLOG(2) << "Module '" << llvm_ir::AsString(module->getName()) + VLOG(2) << "Module '" << module->getName().str() << "' is empty. Skipping compilation."; return string(); } @@ -492,11 +488,10 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - tensorflow::tracing::ScopedActivity activity( - "Compiling IR", llvm_ir::AsString(module->getName()), - /*is_expensive=*/true); - XLA_SCOPED_LOGGING_TIMER("Compile module " + - llvm_ir::AsString(module->getName())); + tensorflow::tracing::ScopedActivity activity("Compiling IR", + module->getName().str(), + /*is_expensive=*/true); + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); TF_ASSIGN_OR_RETURN( ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, libdevice_dir_path)); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 40b87b16a195564c9b98497f79a70f1db0539d87..4b78d48210a5d65713d3ff63ef943335d0ddd4db 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -437,7 +437,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f32[2,2,2]{2,1,0} parameter(0) ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) } @@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest, p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f16[2,2,2]{2,1,0} parameter(0) ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) } @@ -548,7 +548,7 @@ TEST_F(MultiOutputFusionTest, copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) } fused_reduce { diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..3051db3af4ae4380e4a38f50ad8ebc89642e645f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -0,0 +1,356 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" + +#include "tensorflow/compiler/xla/util.h" + +#if GOOGLE_CUDA +#include "absl/synchronization/blocking_counter.h" +#include "third_party/nccl/nccl.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#endif + +namespace xla { +namespace gpu { + +/* static */ bool NcclAllReduceThunk::NcclIsEnabled() { +#if GOOGLE_CUDA + return true; +#else + return false; +#endif +} + +#if GOOGLE_CUDA +namespace { + +// GPU-replica-driving host threads (i.e. the threads that call +// GpuExecutable::Execute) build up this structure to describe their +// participating replica, and then call to +// GlobalRendezvousManager::SubmitParticipant. +struct ParticipantData { + // Number of replicas particiating in the AllReduce. + int64 replica_count; + + int64 element_count; + int64 device_ordinal; + int64 generation_counter; + + // TODO(b/125951860): We should vet that we're buffer allocating such that + // source_buffer == destination_buffer if that avoids a NCCL copy (will depend + // on how well the NCCL in-place implementation performs vs the out-of-place + // implementation). + se::DeviceMemoryBase source_data; + se::DeviceMemoryBase destination_data; + se::Stream* stream; + + NcclAllReduceThunk* originator; + + string ToString() const { + return absl::StrFormat( + "ParticipantData{replica_count=%d, element_count=%d, " + "device_ordinal=%d, generation_counter=%d, stream=%p, originator=%p}", + replica_count, element_count, device_ordinal, generation_counter, + stream, originator); + } +}; + +// Class that gets instantiated as a singleton in GetGlobalRendezvous() to +// coordinate participating threads in performing an AllReduce operation. +// +// This manager is responsible for establishing communication channels and +// ultimately enqueueing the NCCL library operation onto the participating +// streams. +class GlobalRendezvousManager { + public: + // The GpuExecutable-executing threads call this in order to a) establish the + // all-reduce rendezvous and b) enqueue the AllReduce operation on the caller + // thread's associated stream (given in "participant"). + // + // Implementation note: since the rendezvous we're creating here is global, we + // try to be paranoid about the fact that the *correct* one is happening. In + // an ideal world we'd have some StreamExecutor se::Platform level construct + // that we could use for cross-device networking primitives (e.g. via a + // NetworkSupport interface) that could be shared between TensorFlow and XLA, + // but this is a reasonable stopgap measure to get multi-GPU-replica up and + // running properly for single-host, single-concurrent-XLA-module usage. + Status SubmitParticipant(ParticipantData participant); + + // Returns the current generation number of AllReduce operations. + // (Currently one AllReduce operation occurs per generation.) + int64 GetCurrentGeneration() { + tensorflow::mutex_lock lock(mutex_); + return current_generation_; + } + + private: + // Called by the primary thread to set up the communication links. + // + // TODO(b/125951860): This performs lots of (presumably) unnecessary host-side + // synchronization so that we can be paranoid about semantics in the earliest + // implementation. In the limit we should only need to synchronize host + // replica threads when the "number of replicas" or "participating device + // ordinals" change, to set up a new NCCL "communication" context, at which + // point we can enqueue onto device streams without host synchronization in + // our code -- this will likely be helpful for "lots of little AllReduce" + // cases. + Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Called when all necessary participants are present, the functionality + // that's implemented by all executing threads lives in here. + Status DoAllReduce(ParticipantData data, ncclComm_t comm); + + // Puts all state back into a "reset" state for the next generation of + // AllReduce requests. + void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + for (ncclComm_t& comm : comms_) { + ncclCommDestroy(comm); + } + comms_.clear(); + participants_.clear(); + current_generation_++; + initialized_ = false; + done_ = absl::nullopt; + } + + tensorflow::mutex mutex_; + tensorflow::condition_variable all_participants_present_; + tensorflow::condition_variable deinitialized_; + + // Communication handles that correspond to the participants below. + std::vector comms_ GUARDED_BY(mutex_); + + Status initialize_status_ GUARDED_BY(mutex_); + std::vector participants_ GUARDED_BY(mutex_); + int64 current_generation_ GUARDED_BY(mutex_) = 0; + bool initialized_ GUARDED_BY(mutex_) = false; + + // The participating threads wait for this to count down in order to know we + // can begin the teardown process. + absl::optional done_; +}; + +Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { + auto all_participants_present = [this, &participant]() + EXCLUSIVE_LOCKS_REQUIRED(mutex_) -> bool { + return participants_.size() >= participant.replica_count; + }; + + // We remember the participant index at which we are inserted and use that + // same index for referring to auxiliary metadata (e.g. the ncclComm_t handle + // index) below. + int64 index; + + { + tensorflow::mutex_lock lock(mutex_); + + // Spot check for consistent replica counts among submitting threads. + if (!participants_.empty() && + (participants_.back().replica_count != participant.replica_count || + participants_.back().originator != participant.originator)) { + return InvalidArgument( + "Running two XLA modules with AllReduces in parallel is not " + "supported. It is possible this is due to a bug where were try to " + "run two different AllReduces from the same module at once. " + "(Attempted a rendezvous with a different replica count from other " + "participants; existing: %s; submitted: %s)", + participants_.back().ToString(), participant.ToString()); + } + index = participants_.size(); + participants_.push_back(participant); + + if (all_participants_present()) { + all_participants_present_.notify_all(); + } + } + + // We pull into our thread a) the communication handle and b) whether we're + // the "primary" thread for this rendezvous -- the "primary" thread has some + // additional responsibilities for setup/teardown. + ncclComm_t comm; + bool primary; + + { + tensorflow::mutex_lock lock(mutex_); + while (!all_participants_present()) { + // Once all the participants have arrived, all participating threads will + // cross this barrier, though only (the first) one will be the "primary". + all_participants_present_.wait(lock); + } + + // Somebody will be the first -- that thread has some additional + // responsibilities. + primary = !initialized_; + + CHECK_EQ(participant.generation_counter, current_generation_); + + // Bump the generation counter so the other threads know we've completed the + // global rendezvous and have set up the AllReduce. + if (primary) { + VLOG(3) << "Primary initializing accounting data."; + initialized_ = true; + done_.emplace(participant.replica_count); + initialize_status_ = InitializeCommunicationChannels(); + VLOG(3) << "Done initializing communication channels; status: " + << initialize_status_; + if (!initialize_status_.ok()) { + DeinitializeGeneration(); + } + } + + if (!initialize_status_.ok()) { + // TODO(b/125951860): If this fails once, it will fail forever. + return initialize_status_; + } + + comm = comms_[index]; + + // Drop the lock at the end of scope so other participants may enter. + } + + VLOG(3) << "Performing all reduce from device ordinal: " + << participant.device_ordinal; + + Status all_reduce_status = DoAllReduce(participant, comm); + + VLOG(3) << "Waiting for all participants to complete enqueue."; + + done_->DecrementCount(); + + if (primary) { + // Primary thread clears out the AllReduce state when everybody is done to + // make it clean-slate for any subsequent AllReduce request (e.g. number of + // replicas may change in the next request). + // + // Note surrounding TODOs for only reinitializing this when the replica + // count / participants actually change -- lots of "playing it safe" + // happening in this first cut. + done_->Wait(); + VLOG(3) << "All participants completed enqueue."; + VLOG(3) << "Primary thread clearing."; + tensorflow::mutex_lock lock(mutex_); + DeinitializeGeneration(); + VLOG(3) << "Generation is now: " << current_generation_; + deinitialized_.notify_all(); + } else { + VLOG(3) << "Waiting to deinitialize."; + tensorflow::mutex_lock lock(mutex_); + while (initialized_) { + deinitialized_.wait(lock); + } + } + + VLOG(3) << "Returning status: " << all_reduce_status; + return all_reduce_status; +} + +Status GlobalRendezvousManager::InitializeCommunicationChannels() { + std::vector ordinals; + for (ParticipantData& data : participants_) { + ordinals.push_back(data.device_ordinal); + } + comms_.resize(ordinals.size()); + VLOG(3) << "Participants: " << participants_.size() + << "; initializing comms."; + ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(), + /*devlist=*/ordinals.data()); + if (result != ncclSuccess) { + comms_.clear(); + return InternalError( + "Failed to initialize NCCL communication channels for %d participants: " + "%s", + participants_.size(), ncclGetErrorString(result)); + } + return Status::OK(); +} + +Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, + ncclComm_t comm) { + se::StreamExecutor* executor = participant.stream->parent(); + se::cuda::ScopedActivateExecutorContext scoped_context(executor); + cudaStream_t* cu_stream = reinterpret_cast( + participant.stream->implementation()->GpuStreamMemberHack()); + VLOG(3) << "Using stream pointer: " << cu_stream + << " on device: " << participant.device_ordinal; + void* send_buffer = participant.source_data.opaque(); + void* recv_buffer = participant.destination_data.opaque(); + ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer, + /*count=*/participant.element_count, + /*datatype=*/ncclFloat, + /*op=*/ncclSum, + /*comm=*/comm, + /*stream=*/*cu_stream); + TF_RET_CHECK(ncclSuccess == result) + << "Failed to perform all-reduce: " << ncclGetErrorString(result); + + VLOG(3) << "Done performing all reduce for ordinal: " + << participant.device_ordinal; + + return Status::OK(); +} + +static GlobalRendezvousManager* GetGlobalRendezvous() { + static auto* manager = new GlobalRendezvousManager; + return manager; +} + +} // namespace + +Status NcclAllReduceThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto* global_rendezvous = GetGlobalRendezvous(); + + ParticipantData participant; + participant.replica_count = replica_count_; + participant.element_count = element_count_; + participant.device_ordinal = stream->parent()->device_ordinal(); + participant.generation_counter = global_rendezvous->GetCurrentGeneration(); + participant.source_data = buffer_allocations.GetDeviceAddress(source_buffer_); + participant.destination_data = + buffer_allocations.GetDeviceAddress(destination_buffer_); + participant.stream = stream; + participant.originator = this; + + return GetGlobalRendezvous()->SubmitParticipant(std::move(participant)); +} +#else + +Status NcclAllReduceThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + return Unimplemented( + "NCCL support is not available: this binary was not built with a CUDA " + "compiler, which is necessary to build the NCCL source library."); +} + +#endif // GOOGLE_CUDA + +NcclAllReduceThunk::NcclAllReduceThunk( + int64 replica_count, int64 element_count, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* all_reduce) + : Thunk(Thunk::kNcclAllReduce, all_reduce), + replica_count_(replica_count), + element_count_(element_count), + source_buffer_(source_buffer), + destination_buffer_(destination_buffer) {} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..1a8d1356c0023e2c7f49c3731693e10beba54a6d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. +class NcclAllReduceThunk : public Thunk { + public: + // Returns whether NCCL operations appear possible to perform; e.g. if we + // haven't done a build with the CUDA compiler enabled, we can't compile the + // NCCL header, and thus this will be false. + // + // When this is false, the ExecuteOnStream() call will simply return a status + // error. + static bool NcclIsEnabled(); + + // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial + // implementation is simply F32 summation. + NcclAllReduceThunk(int64 replica_count, int64 element_count, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* all_reduce); + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const int64 replica_count_; + const int64 element_count_; + const BufferAllocation::Slice source_buffer_; + const BufferAllocation::Slice destination_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 48f718b514cc9809d4100627f85af7aa05445d36..39cb71c09133164f299f56f755d31bd74ebc3c5b 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -53,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -81,10 +84,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -117,6 +122,9 @@ std::vector GetCudaRootCandidates( const HloModuleConfig& hlo_module_config) { std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + // "." is our last resort, even though it probably won't work. + potential_cuda_roots.push_back("."); + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has // highest priority. string xla_gpu_cuda_data_dir = @@ -128,9 +136,23 @@ std::vector GetCudaRootCandidates( return potential_cuda_roots; } +void PrintCantFindCudaMessage(absl::string_view msg, + const HloModuleConfig& hlo_module_config) { + LOG(WARNING) << msg; + LOG(WARNING) << "Searched in the following directories:"; + for (const auto& dir : GetCudaRootCandidates(hlo_module_config)) { + LOG(WARNING) << " " << dir; + } + LOG(WARNING) + << "You can choose the search directory by setting xla_gpu_cuda_data_dir " + "in HloModule's DebugOptions. For most apps, setting the environment " + "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; +} + // Returns the directory containing nvvm libdevice files. string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { - for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + const auto& candidate_dirs = GetCudaRootCandidates(hlo_module_config); + for (const string& cuda_root : candidate_dirs) { string libdevice_dir = tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); VLOG(2) << "Looking for libdevice at " << libdevice_dir; @@ -139,8 +161,14 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { return libdevice_dir; } } - LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; - // Last resort: maybe in the current folder. + PrintCantFindCudaMessage( + "Can't find directory containing CUDA libevice. This may result in " + "compilation or runtime failures, if the program we try to run uses " + "routines from libdevice.", + hlo_module_config); + + // GetCudaRotCandidates always inclues ".", but but if everything fails, we + // return it anyway. Better than returning the empty string. return "."; } @@ -171,6 +199,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. @@ -200,7 +230,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddPass(); AlgebraicSimplifierOptions options; - options.set_enable_permutation_sort_replacement(true); pass.AddPass(options); pass.AddPass(); pass.AddPass(); @@ -221,15 +250,27 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + + // Run WhileLoopTripCountAnnotator at the end of the simplification + // pipeline, before layout assignment and fusion. This pass does some + // pattern-matching on while bodies/conditions, and this is where the HLO is + // "nicest". + // + // It's important that we don't make semantic changes (e.g. unrolling) to + // any `while` loops after this point, because otherwise the trip-count + // annotations added by this pass may not be correct after the + // modifications. + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (CudnnConvPaddingLegalization). + // (CudnnConvPaddingLegalization). Also expand cuSolver calls. HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(stream_exec, device_allocator); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -273,7 +314,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // duplicate or NOPs, so remove them with algebraic simplification and CSE. AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); - options.set_enable_permutation_sort_replacement(true); pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. @@ -303,6 +343,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass(stream_exec, device_allocator, compiler); + // Clean up new_tuple described above. pipeline.AddPass(); @@ -377,6 +418,7 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } @@ -577,9 +619,6 @@ StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { // We dump the post-optimization HLO in RunBackend so no need to dump it here. - VLOG(3) << "*** HLO Before Optimization"; - XLA_VLOG_LINES(3, module->ToString()); - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); @@ -633,19 +672,11 @@ StatusOr> NVPTXCompiler::RunBackend( [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); - // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() - // include headers, so no need for us to print them ourselves. - XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); - XLA_VLOG_LINES(2, buffer_assignment->ToString()); - VLOG(3) << "*** HLO After Optimization"; - XLA_VLOG_LINES(3, module->ToString()); - const string xla_dump_optimized_hlo_proto_to = - module->config().debug_options().xla_dump_optimized_hlo_proto_to(); - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *buffer_assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "buffer_assignment", + buffer_assignment->ToString()); } + DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), &stream_exec->GetDeviceDescription(), @@ -663,26 +694,16 @@ StatusOr> NVPTXCompiler::RunBackend( } if (user_pre_optimization_hook_) { - TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); + user_pre_optimization_hook_(llvm_module); } string ir_module_string_before_opt; const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - if (VLOG_IS_ON(3) || embed_ir_in_executable) { + if (embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); - VLOG(3) << "LLVM module before optimizations:"; - XLA_VLOG_LINES(3, ir_module_string_before_opt); } - const string& ir_dump_directory = - module->config().debug_options().xla_dump_ir_to(); - - if (!ir_dump_directory.empty()) { - TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( - /*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/module->name(), llvm_module, - /*optimized=*/false)); - } + llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false); { XLA_SCOPED_LOGGING_TIMER( @@ -696,7 +717,7 @@ StatusOr> NVPTXCompiler::RunBackend( << "Invalid LLVM IR before optimizations:\n" << err_stream.str() << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_ir_to to get the IR. "; + "Rerun with --xla_dump_to to get the IR. "; } string libdevice_dir; @@ -729,35 +750,14 @@ StatusOr> NVPTXCompiler::RunBackend( module->config(), libdevice_dir)); } - if (!ir_dump_directory.empty()) { - TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( - /*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/module->name(), llvm_module, - /*optimized=*/true)); - } + llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/true); if (user_post_optimization_hook_) { - TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); + user_post_optimization_hook_(llvm_module); } - VLOG(3) << "LLVM module after optimizations:"; - XLA_VLOG_LINES(3, llvm_ir::DumpModuleToString(llvm_module)); - VLOG(3) << "PTX:"; - XLA_VLOG_LINES(3, ptx); - // Write PTX to IR dump directory, if IR dumping was requested. - if (!ir_dump_directory.empty()) { - const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, absl::StrCat(module->name(), ".ptx")); - auto status = [&] { - auto* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); - TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx)); - return Status::OK(); - }(); - if (!status.ok()) { - LOG(WARNING) << "Couldn't dump PTX for module " << module->name() - << " to " << ptx_outfile << ": " << status; - } + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "ptx", ptx); } const std::vector cubin = @@ -766,20 +766,27 @@ StatusOr> NVPTXCompiler::RunBackend( auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); - VLOG(3) << "Printing the thunk schedule..."; - XLA_VLOG_LINES(3, thunk_schedule->ToString()); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "thunk_schedule", + thunk_schedule->ToString()); + } std::unique_ptr profile_index_map; std::unique_ptr profile_printer; - if (module->config().hlo_profiling_enabled()) { + if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = absl::make_unique(*module); - profile_printer = CreateHloProfilePrinterData( - *profile_index_map, cost_analysis, entry_computation->name()); + VLOG(1) << "HLO memory read+written: " + << tensorflow::strings::HumanReadableNumBytes( + cost_analysis.bytes_accessed()); + if (module->config().hlo_profiling_enabled()) { + profile_index_map = absl::make_unique(*module); + profile_printer = CreateHloProfilePrinterData( + *profile_index_map, cost_analysis, entry_computation->name()); + } } auto* gpu_executable = new GpuExecutable( @@ -843,10 +850,11 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( log_warning = !warning_done.exchange(true); } if (log_warning) { - LOG(WARNING) - << "Failed to compile ptx to cubin. Will attempt to let " - "GPU driver compile the ptx. " - << maybe_cubin.status(); + PrintCantFindCudaMessage( + "Can't find ptxas binary. Will back to the GPU driver " + "for PTX -> sass compilation. This is OK so long as you don't " + "see a warning below about an out-of-date driver version.", + hlo_module_config); } // We're going to use the driver to JIT our PTX->SASS, so warn if diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 8154d75d23a6d49153ccb6824402aff73f365617..cb012649200c6386d3ae25d088aa3b16bd40be82 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index bfed4f5230dfe37bca48560ce83a2dd82c8950a4..10bc82488ff56135f4585e62c2f71c11a359e542 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -41,7 +41,7 @@ std::ostream& operator<<(std::ostream& out, int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc) { int64 threads_per_block = device_desc.threads_per_block_limit(); - if (threads_per_block == 0) { + if (threads_per_block <= 0) { static std::atomic log_count{0}; if (log_count.fetch_add(1) < 8) { LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " @@ -71,18 +71,17 @@ LaunchDimensions CalculateLaunchDimensions( num_elements = num_elements / unroll_factor; // Since we don't do any inter-warp communication, we're free to choose any - // block size we want, subject to hardware constraints. We choose the - // smallest block size that allows the GPU to reach full occupancy (assuming - // the kernel uses sufficiently few registers). This gives us max performance - // when the kernel uses few registers, and lets us scale down gracefully as - // the kernel uses more registers. + // block size we want, subject to hardware constraints. We choose the largest + // block size allowed, as empirically, this is a performance win on almost + // (but not all) benchmarks. // - // Specifically, we choose the number of threads per block such that + // My guess is that using a larger block size encourages ptxas to decrease + // per-thread register usage, thus allowing for higher occupancy, but I + // haven't verified this. // - // * = - + // TODO(jlebar): Investigate this further, and tune this heuristic so we can + // run faster on the few benchmarks where smaller block size helps. int64 threads_per_block = ThreadsPerBlockLimit(device_desc); - if (num_elements < threads_per_block) { threads_per_block = num_elements; VLOG(2) << "Update # of threads per block to the element count (" diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc b/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..197367e81687eeddea8778267075e66ef1819341 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" + +namespace xla { +namespace gpu { + +StatusOr> ScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); + total_allocated_bytes_ += byte_size; + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h b/tensorflow/compiler/xla/service/gpu/scratch_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..620c7e78912eb7d9730bae02aab8f85b5fd2c096 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/scratch_allocator.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ + +#include + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; + + template + StatusOr> Allocate(se::Stream* stream, + int64 num_elements) { + TF_ASSIGN_OR_RETURN(se::DeviceMemory bytes, + AllocateBytes(stream, num_elements * sizeof(T))); + return se::DeviceMemory(bytes); + } + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index a1ed8499040359fe7265a7317b0577a990a2234c..d33e9cf714ee3810b1fb2fa8c05c3ed399d27bfb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index 6814be779e0b02c38e3bc7008f036b845d88cb6f..963716e70500e1e4eac930e13937ae547ace0c1e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -48,8 +48,9 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { HloInstruction::CreateParameter(0, param_shape, "x")); HloInstruction* param_y = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "y")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {5, 7, 2}), param_x, param_y, + ComparisonDirection::kGe)); auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -73,7 +74,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { x = f32[5,7,2]{2,1,0} parameter(0) y = f32[5,14]{1,0} parameter(1) reshape = f32[5,7,2]{2,1,0} reshape(y) - ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape) + ROOT gte = pred[5,7,2]{2,1,0} compare(x, reshape), direction=GE })", config) .ValueOrDie(); @@ -98,7 +99,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { y = f32[14]{0} parameter(1) reshape = f32[7,2]{1,0} reshape(y) broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2} - ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast) + ROOT gte = pred[5,7,2]{2,1,0} compare(x, broadcast), direction=GE })", config) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index f91a22d482bc8bc046977870a7a4d18ca1acde68..06b06a5b1ee1fb9996be3ebe326893c4160a7e29 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc index c78605cebbc671272b8df9faf0e0cc54be2f5b1c..f43e05904dd6ffdc37c77a08d5ddc49bf08c8feb 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -18,42 +18,52 @@ limitations under the License. namespace xla { namespace gpu { -std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { +absl::string_view ThunkKindToString(Thunk::Kind kind) { switch (kind) { + case Thunk::kCholesky: + return "kCholesky"; case Thunk::kConditional: - return os << "kConditional"; + return "kConditional"; case Thunk::kConvolution: - return os << "kConvolution"; + return "kConvolution"; case Thunk::kCopy: - return os << "kCopy"; + return "kCopy"; case Thunk::kCudnnBatchNormBackward: - return os << "kCudnnBatchNormBackward"; + return "kCudnnBatchNormBackward"; case Thunk::kCudnnBatchNormForwardInference: - return os << "kCudnnBatchNormForwardInference"; + return "kCudnnBatchNormForwardInference"; case Thunk::kCudnnBatchNormForwardTraining: - return os << "kCudnnBatchNormForwardTraining"; + return "kCudnnBatchNormForwardTraining"; + case Thunk::kNcclAllReduce: + return "kNcclAllReduce"; case Thunk::kFft: - return os << "kFft"; + return "kFft"; case Thunk::kGemm: - return os << "kGemm"; + return "kGemm"; case Thunk::kInfeed: - return os << "kInfeed"; + return "kInfeed"; case Thunk::kKernel: - return os << "kKernel"; + return "kKernel"; case Thunk::kMemset32BitValue: - return os << "kMemset32BitValue"; + return "kMemset32BitValue"; case Thunk::kMemzero: - return os << "kMemzero"; + return "kMemzero"; case Thunk::kOutfeed: - return os << "kOutfeed"; + return "kOutfeed"; case Thunk::kSequential: - return os << "kSequential"; + return "kSequential"; + case Thunk::kTriangularSolve: + return "kTriangularSolve"; case Thunk::kTuple: - return os << "kTuple"; + return "kTuple"; case Thunk::kWhile: - return os << "kWhile"; + return "kWhile"; } } +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { + return os << ThunkKindToString(kind); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index e68bee035a029178844282995429eaa960cc4817..56d1176ff4ed3deede4006d16fe8f8a4c66a2a92 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -42,12 +42,14 @@ class GpuExecutable; class Thunk { public: enum Kind { + kCholesky, kConditional, kConvolution, kCopy, kCudnnBatchNormBackward, kCudnnBatchNormForwardInference, kCudnnBatchNormForwardTraining, + kNcclAllReduce, kFft, kGemm, kInfeed, @@ -56,6 +58,7 @@ class Thunk { kMemzero, kOutfeed, kSequential, + kTriangularSolve, kTuple, kWhile, }; @@ -103,6 +106,7 @@ class Thunk { // A sequence of thunks. using ThunkSequence = std::vector>; +absl::string_view ThunkKindToString(Thunk::Kind); std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 25bad67bab9375559c431466571c62acd0452b01..daa5f33e5604c434aaab7fa454fc3e89d68599d5 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" @@ -144,11 +147,32 @@ const std::list& ThunkSchedule::DependsOn( } string ThunkSchedule::ToString() const { + if (thunk_total_order_.empty()) { + return "No thunks."; + } + + const Thunk* thunk_with_longest_kind = *absl::c_max_element( + thunk_total_order_, [](const Thunk* a, const Thunk* b) { + return ThunkKindToString(a->kind()).length() < + ThunkKindToString(b->kind()).length(); + }); + int64 max_thunk_kind_len = + ThunkKindToString(thunk_with_longest_kind->kind()).length(); + string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); + // Write out the thunk kind, padded out to max_thunk_kind_len. + absl::string_view kind_str = ThunkKindToString(thunk->kind()); + absl::StrAppend(&result, kind_str, + string(max_thunk_kind_len - kind_str.length(), ' '), "\t"); + if (thunk->hlo_instruction() != nullptr) { + absl::StrAppend(&result, thunk->hlo_instruction()->ToString()); + } else { + absl::StrAppend(&result, "(no HloInstruction)"); + } + absl::StrAppend(&result, "\n"); } - absl::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "\nDependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..5200a2af412979c7e38d95c5a9bd5bc2ab64f086 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/device_memory.h" + +namespace xla { +namespace gpu { + +TriangularSolveThunk::TriangularSolveThunk( + const TriangularSolveOptions& options, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, PrimitiveType type, + int64 batch_size, int64 m, int64 n, int64 a_batch_stride, + int64 b_batch_stride, const HloInstruction* hlo) + : Thunk(Kind::kTriangularSolve, hlo), + uplo_(options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper), + side_(options.left_side() ? se::blas::Side::kLeft + : se::blas::Side::kRight), + unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit + : se::blas::Diagonal::kNonUnit), + a_buffer_(a_buffer), + b_buffer_(b_buffer), + type_(type), + batch_size_(batch_size), + m_(m), + n_(n), + a_batch_stride_(a_batch_stride), + b_batch_stride_(b_batch_stride) { + transpose_a_ = [&] { + switch (options.transpose_a()) { + case TriangularSolveOptions::NO_TRANSPOSE: + return se::blas::Transpose::kNoTranspose; + case TriangularSolveOptions::TRANSPOSE: + return se::blas::Transpose::kTranspose; + case TriangularSolveOptions::ADJOINT: + return se::blas::Transpose::kConjugateTranspose; + default: + LOG(ERROR) << "Invalid triangular solve transpose value " + << options.transpose_a(); + return se::blas::Transpose::kNoTranspose; + } + }(); +} + +Status TriangularSolveThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_) + << " side=" << se::blas::SideString(side_) + << " diagonal=" << se::blas::DiagonalString(unit_diagonal_) + << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_ + << " a_batch_stride=" << a_batch_stride_ + << " b_batch_stride=" << b_batch_stride_; + + const int lda = side_ == se::blas::Side::kLeft ? m_ : n_; + const int ldb = m_; + + char* a_base = static_cast( + buffer_allocations.GetDeviceAddress(a_buffer_).opaque()); + char* b_base = static_cast( + buffer_allocations.GetDeviceAddress(b_buffer_).opaque()); + for (int64 i = 0; i < batch_size_; ++i) { + bool launch_ok; + se::DeviceMemoryBase a_data = + se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_); + se::DeviceMemoryBase b_data = + se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_); + switch (type_) { + case F32: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, + unit_diagonal_, m_, n_, /*alpha=*/1.0f, + se::DeviceMemory(a_data), lda, + &b_data_typed, ldb) + .ok(); + break; + } + case F64: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, + unit_diagonal_, m_, n_, /*alpha=*/1.0, + se::DeviceMemory(a_data), lda, + &b_data_typed, ldb) + .ok(); + break; + } + case C64: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = + stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_, + n_, /*alpha=*/1.0f, + se::DeviceMemory>(a_data), + lda, &b_data_typed, ldb) + .ok(); + break; + } + case C128: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = + stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_, + n_, /*alpha=*/1.0, + se::DeviceMemory>(a_data), + lda, &b_data_typed, ldb) + .ok(); + break; + } + default: + return InvalidArgument("Invalid type for triangular solve %d", type_); + } + if (!launch_ok) { + return InternalError("Unable to launch triangular solve for thunk %p", + this); + } + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..c947162ea32f197f808d099859eadbbc55a65ab1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a triangular +// solve (BlasTrsm). It is generated by IrEmitter. +// +// Thread-compatible. +class TriangularSolveThunk : public Thunk { + public: + TriangularSolveThunk(const TriangularSolveOptions& options, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, + PrimitiveType type, int64 batch_size, int64 m, int64 n, + int64 a_batch_stride, int64 b_batch_stride, + const HloInstruction* hlo); + + TriangularSolveThunk(const TriangularSolveThunk&) = delete; + TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const se::blas::UpperLower uplo_; + const se::blas::Side side_; + const se::blas::Diagonal unit_diagonal_; + se::blas::Transpose transpose_a_; + + const BufferAllocation::Slice a_buffer_; + const BufferAllocation::Slice b_buffer_; + + const PrimitiveType type_; + const int64 batch_size_; + const int64 m_; + const int64 n_; + const int64 a_batch_stride_; + const int64 b_batch_stride_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc index c552c2925497f1c4808d74a615d35cdbeeba1858..bbbcc2dbb0f71d08462a1aad6d97e7fd07b2a1fb 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 2dce7749bbd8da2673ae607eee3d731d9917e8fe..64a5fe5fdd2ebb5430767973c2ba3d1bf498455c 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -44,9 +44,9 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); - builder.AddInstruction( - HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, - induction_variable, limit_const)); + builder.AddInstruction(HloInstruction::CreateCompare( + condition_result_shape_, induction_variable, limit_const, + ComparisonDirection::kLt)); return builder.Build(); } @@ -106,24 +106,6 @@ class WhileTransformerTest : public HloTestBase { return while_hlo; } - void RunFusionPasses() { - // Run standard fusion passes. - TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .status()); - TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .status()); - } - - void RunCopyInsertionPass() { - HloVerifier verifier(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); - TF_ASSERT_OK(verifier.Run(module_.get()).status()); - CopyInsertion copy_insertion; - TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); - } - Shape GetLoopStateShape(const int64 ind_var_tuple_index) { if (ind_var_tuple_index == 0) { return ShapeUtil::MakeTupleShape( @@ -146,10 +128,6 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); auto while_hlo = BuildWhileInstruction(condition, body, 0, 0); - // Run HLO Optimization passes. - RunFusionPasses(); - RunCopyInsertionPass(); - auto result = ComputeWhileLoopTripCount(while_hlo); ASSERT_TRUE(result); EXPECT_EQ(10, *result); @@ -161,10 +139,6 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1)); auto while_hlo = BuildWhileInstruction(condition, body, 1, 0); - // Run HLO Optimization passes. - RunFusionPasses(); - RunCopyInsertionPass(); - auto result = ComputeWhileLoopTripCount(while_hlo); ASSERT_TRUE(result); EXPECT_EQ(10, *result); @@ -176,10 +150,6 @@ TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); auto while_hlo = BuildWhileInstruction(condition, body, 0, 10); - // Run HLO Optimization passes. - RunFusionPasses(); - RunCopyInsertionPass(); - auto result = ComputeWhileLoopTripCount(while_hlo); ASSERT_TRUE(result); EXPECT_EQ(0, *result); @@ -191,10 +161,6 @@ TEST_F(WhileTransformerTest, InvalidLoopIncrement) { module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1)); auto while_hlo = BuildWhileInstruction(condition, body, 0, 0); - // Run HLO Optimization passes. - RunFusionPasses(); - RunCopyInsertionPass(); - auto result = ComputeWhileLoopTripCount(while_hlo); ASSERT_FALSE(result); } diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc deleted file mode 100644 index ef70b688778df5115e2b5fe572d253a6948d076f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Example HLO graph which demonstrates Graphviz dumper for HLO -// computations. When run, pushes the example DOT graph to the Graphviz service -// and prints the URL. Useful for seeing effect of changes to the graph -// generation code. - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// Adds a computation to the given HLO module which adds a scalar constant to -// its parameter and returns the result. -HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); - auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "x_value")); - auto half = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); - builder.AddInstruction(HloInstruction::CreateBinary( - half->shape(), HloOpcode::kAdd, x_value, half)); - return module->AddEmbeddedComputation(builder.Build()); -} - -// Adds a computation to the given HLO module which sums its two parameters and -// returns the result. -HloComputation* ScalarSumComputation(HloModule* module) { - auto builder = HloComputation::Builder("add"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "lhs")); - auto rhs = builder.AddInstruction( - HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "rhs")); - builder.AddInstruction( - HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); - return module->AddEmbeddedComputation(builder.Build()); -} - -// Adds a computation to the given HLO module which forwards its argument to a -// kCall instruction which then calls the given computation. -HloComputation* CallForwardingComputation(HloComputation* computation, - HloModule* module) { - auto builder = HloComputation::Builder("call_forward"); - auto arg = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "arg")); - builder.AddInstruction( - HloInstruction::CreateCall(arg->shape(), {arg}, computation)); - return module->AddEmbeddedComputation(builder.Build()); -} - -// Create a large, arbitrary computation with many different kinds of -// instructions. Sets the computation as the entry to an HLO module and returns -// the module. -std::unique_ptr MakeBigGraph() { - HloModuleConfig config; - auto module = absl::make_unique("BigGraph", config); - - auto builder = HloComputation::Builder("TestBigGraphvizGraph"); - - // Shapes used in the computation. - auto mshape = ShapeUtil::MakeShape(F32, {3, 5}); - auto vshape = ShapeUtil::MakeShape(F32, {3}); - auto sshape = ShapeUtil::MakeShape(F32, {3}); - - // Create a set of parameter instructions. - auto param_v0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, vshape, "foo")); - auto param_v1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, vshape, "bar")); - auto param_v2 = - builder.AddInstruction(HloInstruction::CreateParameter(2, vshape, "baz")); - auto param_s = - builder.AddInstruction(HloInstruction::CreateParameter(3, sshape, "qux")); - auto param_m = - builder.AddInstruction(HloInstruction::CreateParameter(4, mshape, "zzz")); - - // Add an arbitrary expression of different instructions. - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); - auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( - vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfig precision_config; - precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfig::DEFAULT); - auto dot = builder.AddInstruction(HloInstruction::CreateDot( - vshape, clamp, param_v0, dot_dnums, precision_config)); - auto tuple = builder.AddInstruction( - HloInstruction::CreateTuple({dot, param_s, clamp})); - auto scalar = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(sshape, tuple, 2)); - auto add_one = AddScalarConstantComputation(1.0, module.get()); - auto rng = builder.AddInstruction( - HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto add_computation = ScalarSumComputation(module.get()); - builder.AddInstruction( - HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); - auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(sshape, {scalar}, add_one)); - auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(sshape, {map1}, add_one)); - auto map3 = builder.AddInstruction( - HloInstruction::CreateMap(sshape, {map2}, add_one)); - - // Create a fusion instruction containing the chain of map instructions. - auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( - sshape, HloInstruction::FusionKind::kLoop, map3)); - fusion->FuseInstruction(map2); - fusion->FuseInstruction(map1); - - // Add a random trace instruction. - builder.AddInstruction(HloInstruction::CreateTrace("trace", dot)); - - // Add a call instruction will calls the call-forwarding computation to call - // another computation. - auto call_computation = CallForwardingComputation(add_one, module.get()); - builder.AddInstruction( - HloInstruction::CreateCall(fusion->shape(), {fusion}, call_computation)); - - module->AddEntryComputation(builder.Build()); - return module; -} - -} // namespace -} // namespace xla - -int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - - auto module = xla::MakeBigGraph(); - - printf("Graph URL: %s\n", xla::hlo_graph_dumper::DumpGraph( - *module->entry_computation(), - "Example computation", xla::DebugOptions()) - .c_str()); - return 0; -} diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index dbbf43082f2c1d21f5ef42f53804bf0969903a58..3e0631aeb4aa374cb5748650e1c7529e26e10b34 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -158,7 +158,7 @@ class HeapSimulator { void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const BufferValue* buffer, const HloInstruction* instruction, - const BufferValue* shared_with_canonical); + const BufferValue* share_with_canonical); // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, // in which case we are calculating the same allocs/frees twice in the diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index dc40b9446ad1bffcb757543e52fc9ab20de6d52e..2f1628038204833c1aa1061e81878c07d7f7529d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -54,8 +54,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_data, ComparisonDirection::kLt)); HloComputation* cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -113,7 +113,8 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} // %reshape = f32[] reshape(f32[1]{0} %slice) // %constant = f32[] constant(0) - // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant), + // direction=NE // } // ENTRY %SubcomputationAccounting () -> f32[2,4] { @@ -143,9 +144,9 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); HloInstruction* zero = cond_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - HloInstruction* cond_comparison = - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + HloInstruction* cond_comparison = cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape, + zero, ComparisonDirection::kNe)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); // param - 1 @@ -703,8 +704,8 @@ TEST_F(HeapSimulatorTest, WholeModule) { HloInstruction* cond_data = cond_builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_data, ComparisonDirection::kLt)); HloComputation* cond_computation = tracker.module()->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 263b42a29dbb0dbc0fb6eca7968674ff242f45ed..54ee92943cc6eef1d5961ab9cdc529bab15083d7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -29,12 +29,13 @@ limitations under the License. syntax = "proto3"; package xla; + import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 59 +// Next ID: 64 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -145,6 +146,9 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + // Comparison direction only used for kCompare. + string comparison_direction = 63; + // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_slice_sizes = 34; @@ -175,6 +179,9 @@ message HloInstructionProto { // partners. bool is_host_transfer = 47; + // Whether this Sort instruction should be stable. + bool is_stable = 60; + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. @@ -193,6 +200,15 @@ message HloInstructionProto { // operand. bool constrain_layout = 56; repeated xla.ShapeProto operand_shapes_with_layout = 57; + + // Options for TriangularSolve + xla.TriangularSolveOptions triangular_solve_options = 59; + + // Options for Cholesky + xla.CholeskyOptions cholesky_options = 62; + + // Describes how parameters behave with regards to replicas. + xla.ParameterReplication parameter_replication = 61; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index e511f1951c5dd07ebb64fa38fd5b7f6a0e87b429..7d02f4b3d756df9d1fcbddfa85df2a41a62d9169 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -293,7 +293,7 @@ class BufferValueMap { VLOG(3) << " value @ " << position << " is root of " << callsite.instruction()->name() - << "; true/false branch roots must share buffer among them : " + << "; branch computation roots must share buffer among them : " << cond_value.ToShortString(); aliased_buffers->push_back(GetBufferForValue(cond_value)); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index b6dbf07959c541bceaa8eda5a0101503970ee832..e344fbc54a8a72c6195d1d8590f0f1c56428f641 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -48,7 +47,6 @@ class HloAliasAnalysisTest : public HloTestBase { // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); analysis_ = HloAliasAnalysis::Run(module_.get(), /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); @@ -126,6 +124,7 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -160,6 +159,7 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -203,6 +203,7 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({param0, param1, param0})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -237,6 +238,8 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); @@ -281,6 +284,8 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); @@ -370,6 +375,8 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); @@ -421,6 +428,7 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -462,6 +470,7 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -547,6 +556,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -647,6 +657,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { FlattenCallGraph flattener; TF_ASSERT_OK(flattener.Run(module_.get()).status()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -738,6 +749,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -811,6 +823,7 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -872,6 +885,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -960,6 +974,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { HloInstruction::CreateWhile(tuple_shape, condition, body, select)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -997,6 +1012,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { scalar_shape_, HloOpcode::kBitcast, constant)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1017,6 +1033,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1056,6 +1073,7 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while})); HloComputation* entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index f9b64d12ae83139efa21ca67e565908bd78f9780..48a51d302bbf054d904c54ab933d87fc910d0714 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -124,6 +124,24 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +HloInstruction* HloComputation::AddEntryComputationParameter( + std::unique_ptr instruction) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + CHECK_EQ(instruction->parameter_number(), num_parameters()); + CHECK(parent()->entry_computation() == this); + + HloModuleConfig config = parent()->config(); + config.mutable_entry_computation_layout()->add_parameter_layout( + ShapeLayout(instruction->shape())); + parent()->set_config(config); + + instruction->set_parent(this); + param_instructions_.push_back(instruction.get()); + AddInstructionInternal(std::move(instruction)); + + return instructions_.back().get(); +} + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -296,7 +314,7 @@ void ComputeComputationPostOrder(HloComputation* computation, } // namespace void HloComputation::ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_group, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const { std::vector dfs_stack; @@ -320,66 +338,75 @@ void HloComputation::ComputeInstructionPostOrder( visited->insert({current, kVisiting}); - // Add the operands to the stack in reverse order so the first operand is - // processed first. This will produce a more natural ordering and a nicer - // result for things like HLO stringification. - const auto& operands = current->operands(); - for (int64 i = operands.size() - 1; i >= 0; --i) { - dfs_stack.emplace_back(operands[i]); - } - - for (HloInstruction* op : current->control_predecessors()) { - dfs_stack.emplace_back(op); - } - - // Add inputs for send->recv_done dependencies and all-reduce - // dependencies. - switch (current->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(current->channel_id()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } - break; + const auto get_channel_id = + [](HloInstruction* inst) -> absl::optional { + switch (inst->opcode()) { + case HloOpcode::kRecvDone: + return inst->channel_id(); + case HloOpcode::kAllReduce: + return inst->all_reduce_id(); + default: + return absl::nullopt; } - case HloOpcode::kAllReduce: { - auto all_reduce_id = current->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } + }; + + // When adding a predecessor to the dfs_stack, we need to also add its + // associated channel dependencies. + const auto add_dfs_stack = [&](HloInstruction* inst) { + auto channel_id = get_channel_id(inst); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + dfs_stack.emplace_back(cinst); } - break; + } else { + dfs_stack.emplace_back(inst); } - default: - break; + }; + + const auto add_predecessors = [&](HloInstruction* inst) { + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for things like HLO stringification. + const auto& operands = inst->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + add_dfs_stack(operands[i]); + } + + for (HloInstruction* op : inst->control_predecessors()) { + add_dfs_stack(op); + } + }; + + // If the current instruction is a channel instruction, add the dependencies + // from all associated instructions of the channel. + auto channel_id = get_channel_id(current); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + add_predecessors(cinst); + } + } else { + add_predecessors(current); } } } -HloComputation::ChannelDependencyMap +HloComputation::ChannelDependencyGroup HloComputation::ComputeChannelDependencies() const { - ChannelDependencyMap channel_dependency_map; + ChannelDependencyGroup channel_dependency_group; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { - case HloOpcode::kSend: { - channel_dependency_map[instruction->channel_id()].push_back( + case HloOpcode::kSend: + case HloOpcode::kRecvDone: + channel_dependency_group[instruction->channel_id()].push_back( instruction.get()); break; - } case HloOpcode::kAllReduce: { auto all_reduce_id = instruction->all_reduce_id(); if (all_reduce_id) { - auto& dependencies = channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(instruction->operands(), - std::back_inserter(dependencies)); - absl::c_copy(instruction->control_predecessors(), - std::back_inserter(dependencies)); + channel_dependency_group[all_reduce_id.value()].push_back( + instruction.get()); } break; } @@ -387,11 +414,11 @@ HloComputation::ComputeChannelDependencies() const { break; } } - return channel_dependency_map; + return channel_dependency_group; } std::vector HloComputation::MakeInstructionPostOrder() const { - auto channel_dependency_map = ComputeChannelDependencies(); + auto channel_dependency_group = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; @@ -404,7 +431,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(channel_dependency_map, &post_order, + ComputeInstructionPostOrder(channel_dependency_group, &post_order, instruction.get(), &visited); } } @@ -695,21 +722,34 @@ bool HloComputation::operator==(const HloComputation& other) const { } absl::flat_hash_set> visited; - std::function eq = - [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { - // If are visited but not identical, the recursion should have - // been aborted. So, if are visited at this point, they must be - // identical. - if (visited.contains(std::make_pair(a, b))) { - return true; - } - visited.emplace(a, b); - return a->Identical( - *b, eq, [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }); - }; - return eq(root_instruction(), other.root_instruction()); + std::vector> worklist; + + worklist.push_back({root_instruction(), other.root_instruction()}); + + while (!worklist.empty()) { + auto pair = worklist.back(); + worklist.pop_back(); + + if (visited.contains(pair)) { + continue; + } + visited.emplace(pair); + // TODO(b/123082518): Avoid recursively invoking == becasue it may + // cause a stack overflow with deeply nested subcomputations. + bool identical_ignoring_operands = pair.first->Identical( + *pair.second, + [](const HloInstruction*, const HloInstruction*) { return true; }, + [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }); + if (!identical_ignoring_operands) { + return false; + } + for (size_t i = 0; i < pair.first->operands().size(); ++i) { + worklist.push_back({pair.first->operand(i), pair.second->operand(i)}); + } + } + return true; } Status HloComputation::ReplaceWithNewInstruction( @@ -844,15 +884,15 @@ Status HloComputation::Accept( std::unique_ptr HloComputation::Clone( const string& suffix, HloCloneContext* context) { return CloneWithReplacements( - /*replacements=*/std::unordered_map>(), + /*replacements=*/absl::flat_hash_map>(), /*extra_parameters=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, @@ -863,7 +903,7 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, std::pair> r2, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); @@ -876,7 +916,7 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r2, std::pair> r3, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); @@ -886,7 +926,7 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( } std::unique_ptr HloComputation::CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, absl::Span extra_parameters, HloCloneContext* context, const string& suffix) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index e6a1eb89cfdb474f79c184ea0eb77dba8ccd5f03..a48cfa1f1b22ffd748fe9fe3ddb7f36d8d0dee4d 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -118,11 +117,20 @@ class HloComputation { // instruction. Status RemoveUnusedParameters(); - // Add new parameter instruction to the computation. + // Adds a new parameter instruction to a fusion computation. + // // This should be a new parameter. Instruction will be appended to parameters // and inserted to the instruction list. HloInstruction* AddParameter(std::unique_ptr instruction); + // Adds a new parameter instruction to the entry computation and update + // the parent module config to reflect the change. + // + // This should be a new parameter. Instruction will be appended to parameters + // and inserted to the instruction list. + HloInstruction* AddEntryComputationParameter( + std::unique_ptr instruction); + // Remove an instruction from the computation. The instruction must have no // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); @@ -329,7 +337,8 @@ class HloComputation { // All relevant instructions are cloned, *including* unique_ptr in the // `replacements` map. std::unique_ptr CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, absl::Span extra_parameters = {}, HloCloneContext* context = nullptr, const string& suffix = "clone"); @@ -369,13 +378,13 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // all-reduce the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = + // Returns a map from channel-id to the group of instructions associated with + // the channel. These instructions will be considered as a single node for + // dependency purposes. Send and RecvDone are in the group, and AllReduces + // with the same channel id are in the group. + using ChannelDependencyGroup = absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; + ChannelDependencyGroup ComputeChannelDependencies() const; // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. @@ -391,6 +400,10 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // Clear the unique ID of the computation so that it can be re-assigned, such + // as for the purpose of compacting the unique IDs. + void ClearUniqueIdInternal() { unique_id_ = -1; } + // The id of this computation should be unique within the module. void SetUniqueId(int64 id) { CHECK_EQ(unique_id_, -1); @@ -434,7 +447,7 @@ class HloComputation { enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_map, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 251c7bbec418d8c3e8b27277160e608840726996..3fa6f80b1b9d604bcf299e8e6694f852cca8e765 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -17,14 +17,16 @@ limitations under the License. #include #include -#include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,6 +39,7 @@ namespace xla { namespace { namespace m = match; +namespace op = xla::testing::opcode_matchers; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -506,12 +509,13 @@ TEST_F(HloComputationTest, CloneWithReplacements) { HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); auto param2 = builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); - auto lt = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto lt = builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + param1, ComparisonDirection::kLt)); auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(param2, HloInstruction::CreateParameter(2, r0s32, "p.1")); @@ -645,5 +649,57 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } +std::unique_ptr MakeAddNComputation(int n) { + auto builder = HloComputation::Builder("add_n"); + auto result = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "x_value")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + for (int i = 0; i < n; ++i) { + result = builder.AddInstruction(HloInstruction::CreateBinary( + one->shape(), HloOpcode::kAdd, result, one)); + } + return builder.Build(); +} + +TEST_F(HloComputationTest, DeepEquality) { + auto computation_a = MakeAddNComputation(200000); + auto computation_b = MakeAddNComputation(200000); + EXPECT_TRUE(*computation_a == *computation_b); + + auto computation_c = MakeAddNComputation(199999); + EXPECT_FALSE(*computation_a == *computation_c); + EXPECT_FALSE(*computation_c == *computation_b); +} + +// Tests that cross-module AllReduce instructions are ordered before all their +// predecessors and after all their successors. +TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) { + const char* const hlo_string = R"( +HloModule Module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param = f32[128] parameter(0), sharding={maximal device=0} + crs0 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=0} + crs1 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=1} + add = f32[128] add(crs0, crs0), sharding={maximal device=0} + ROOT t = (f32[128], f32[128]) tuple(add, crs1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(), + ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(), + op::Add(), op::Tuple())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index c4b4fa62ddcb46b8ac46567da5ab32a6a1f4914c..13b1c82709523fc98b02551d14bc9a9cdacc5fc1 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -91,9 +91,10 @@ Status HloCostAnalysis::HandleElementwiseOp( auto opcode = hlo_instruction->opcode(); // We treat transcendental operations separately since one transcendental // operation can correspond to several floating point ops. - if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower || - opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin || - opcode == HloOpcode::kCos) { + if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || + opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || + opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || + opcode == HloOpcode::kSin || opcode == HloOpcode::kCos) { current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from @@ -237,24 +238,17 @@ Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); + const Shape& dot_shape = dot->shape(); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = - lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); - // First divide by reduction width before multiplying by rhs elements to avoid - // overflow. - int64 fma_count; - if (reduction_width == 0) { - fma_count = 0; - } else { - fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) * - ShapeUtil::ElementsIn(rhs_shape); + int64 reduction_width = 1; + for (auto dim : dnums.lhs_contracting_dimensions()) { + reduction_width *= lhs_shape.dimensions(dim); } - - // We count an FMA operation as 2 floating point operations. - current_properties_[kFlopsKey] = kFmaFlops * fma_count; + // Each output elment requires reduction_width FMA operations. + current_properties_[kFlopsKey] = + kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width; return Status::OK(); } @@ -531,7 +525,8 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { } const int64 fma_count = (input_feature / convolution->feature_group_count()) * - output_feature * batch * + output_feature * + (batch / convolution->batch_group_count()) * Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); @@ -552,6 +547,32 @@ Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { return Status::OK(); } +Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) { + float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; + bytes_accessed += GetShapeSize(hlo->operand(1)->shape()); + current_properties_[kBytesAccessedKey] = bytes_accessed; + + const Shape& a_shape = hlo->operand(0)->shape(); + const Shape& b_shape = hlo->operand(1)->shape(); + // Estimate as batch * mn^2 / 2 flops. + int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); + elems *= ShapeUtil::ElementsIn(b_shape); + current_properties_[kFlopsKey] = kFmaFlops * elems; + return Status::OK(); +} + +Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) { + float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; + current_properties_[kBytesAccessedKey] = bytes_accessed; + + const Shape& a_shape = hlo->operand(0)->shape(); + // Estimate as batch * n^3 / 3 flops. + int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); + elems *= ShapeUtil::ElementsIn(a_shape); + current_properties_[kFlopsKey] = elems / 3; + return Status::OK(); +} + Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. @@ -577,6 +598,10 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume @@ -659,19 +684,22 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { } Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { - // Compute the cost of the true and false computations and take the maximum - // from those for each property. + // Compute the cost of the branch computations and take the maximum from those + // for each property. TF_ASSIGN_OR_RETURN( - const Properties true_computation_properties, - ProcessUnnestedSubcomputation(conditional->true_computation())); - TF_ASSIGN_OR_RETURN( - const Properties false_computation_properties, - ProcessUnnestedSubcomputation(conditional->false_computation())); - current_properties_ = true_computation_properties; - for (const auto& property : false_computation_properties) { - if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { - current_properties_[property.first] = - std::max(current_properties_[property.first], property.second); + const Properties branch0_computation_properties, + ProcessUnnestedSubcomputation(conditional->branch_computation(0))); + current_properties_ = branch0_computation_properties; + for (int j = 1; j < conditional->branch_count(); ++j) { + TF_ASSIGN_OR_RETURN( + const Properties branch_computation_properties, + ProcessUnnestedSubcomputation(conditional->branch_computation(j))); + for (const auto& property : branch_computation_properties) { + if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, + property)) { + auto& current_property = current_properties_[property.first]; + current_property = std::max(current_property, property.second); + } } } current_should_compute_bottleneck_time_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index b52305626dd67336eb31098d086ad357f12d96c7..4480554de507f20b5d44b87a19e58236252bad1d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,9 +71,12 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; + Status HandleTriangularSolve(const HloInstruction* hlo) override; + Status HandleCholesky(const HloInstruction* hlo) override; Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; + Status HandleReplicaId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 82dd57d3c656ef7a61f6ab9e110f44ef551fac30..4d42770ba784ba15fae9518b40a75d8a2f038e66 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/compiler/xla/statusor.h" @@ -157,6 +158,87 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30)); } +TEST_F(HloCostAnalysisTest, DotGeneral) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral2) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral3) { + XlaBuilder builder("matrix_multiply"); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + DotDimensionNumbers dnums; + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30)); +} + TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index bb5d21c654c73da257d53e4f8486b2e83019b534..d9c5f7c66de03a50f6566092ab274a6bb99d4229 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -17,10 +17,15 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -37,6 +42,18 @@ StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); } +StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs)); + return computation->AddInstruction( + HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction)); +} + StatusOr MakePadHlo(HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { @@ -156,9 +173,9 @@ StatusOr MakeDynamicUpdateSliceHlo( dynamic_update_slice_shape, operand, update, scalar_start_indices)); } -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds) { +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -268,6 +285,29 @@ StatusOr MakeSelectHlo(HloInstruction* pred, select_shape, HloOpcode::kSelect, pred, on_true, on_false)); } +StatusOr MakeSortHlo( + const Shape& sort_shape, absl::Span operands, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, + HloModule* module) { + CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; + HloComputation* compare_computation; + XlaBuilder b("Sort.Compare"); + std::vector operand_types(operands.size()); + for (int64 i = 0; i < operands.size(); ++i) { + operand_types[i] = operands[i]->shape().element_type(); + } + XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module); + compare_computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + return builder->AddInstruction(HloInstruction::CreateSort( + sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -394,9 +434,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, return MakePadHlo(operand, zero, padding_config); } -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions) { +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 8e5ddbbd503a501bd493aec43a2ccd4db883ef0c..f163112f7ff54bd525f24c235bd8b0d195c33a5f 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -32,6 +32,12 @@ namespace xla { StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs); +// Creates a compare HLO instruction and adds it to the computation containing +// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs); + // Creates a pad HLO instruction and adds it to the computation containing // `operand` and `padding_value` (`operand` and `padding_value` must be in the // same computation). @@ -82,9 +88,9 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds); +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -123,6 +129,15 @@ StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false); +// Creates a Sort HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. Also creates a +// default compare sub-computation which sorts the first operand into ascending +// order. 'is_stable' specifies whether the sorting should be stable. +StatusOr MakeSortHlo( + const Shape& sort_shape, absl::Span operands, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, + HloModule* module); + // Creates an R1 Constant HLO instruction of the given PrimitiveType with the // given values and adds it to the given computation. template @@ -198,9 +213,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // Broadcasts a zero value of type `element_type` into a tensor with element // type `element_type` and dimension bounds `broadcast_dimensions`. The // broadcast instruction is emitted into `computation`. -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions); +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 3715e12b4e2baf7bc2149237457c16c3919c5083..6025e6a77941369f75ebaa98bdf0979669b3a03c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -191,9 +191,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), S32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), S32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; @@ -211,9 +210,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), F32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), F32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index e602107cbe64320a8e8e740168cb294ec6be9667..849cac278ee379122ba1ff9fade3bf003969b8a7 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3144a84805454488f417391f40ed6b9e9facc752..06a832743864896c8feca3c807dfc4faa6445a03 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -24,8 +24,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -35,48 +37,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { -namespace { - -// We have this pattern in dynamaic update slice fusion, which should be -// supported: -// -// Parameters: p0, p1 -// Fusion -// ds = DynamicSlice(p0, p1) -// ROOT DynamicUpdateslice(p0, ds, p1) -// -// In this case, we should be able to reuse p0 and output, although p0 has -// multiple uses. -bool MultiDynamicSliceUseShareSameIndices(absl::Span uses) { - if (uses.empty()) { - return false; - } - const HloInstruction* indices = nullptr; - for (HloUse use : uses) { - auto user = use.instruction; - if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { - if (indices == nullptr) { - indices = user->operand(2); - } else if (indices != user->operand(2)) { - return false; - } - if (use.operand_number != 0) { - return false; - } - } else if (user->opcode() == HloOpcode::kDynamicSlice) { - if (indices == nullptr) { - indices = user->operand(1); - } else if (indices != user->operand(1)) { - return false; - } - } else { - return false; - } - } - return true; -} - -} // namespace using absl::StrAppend; using absl::StrCat; @@ -414,11 +374,11 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateConditionalValueSet( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); - const InstructionValueSet* const inputs[] = { - &GetInstructionValueSet( - conditional->true_computation()->root_instruction()), - &GetInstructionValueSet( - conditional->false_computation()->root_instruction())}; + std::vector inputs(conditional->branch_count()); + for (int j = 0; j < conditional->branch_count(); ++j) { + inputs[j] = &GetInstructionValueSet( + conditional->branch_computation(j)->root_instruction()); + } if (ssa_form_) { return Phi(conditional, inputs); } else { @@ -546,20 +506,23 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { CHECK_EQ(parameter->parameter_number(), 0); auto conditional = callsite.instruction(); - // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is - // the argument to the true computation and operand 2 is the argument to - // the false computation. + // Conditional has branch_count+1 operands. Operand 0 is the branch_index, + // operands 1 and onward are the arguments to the branch computations. // - // If the parameter belongs to conditional's true computation, then + // If the parameter belongs to conditional's branch 0 computation, then // operand 1 is forwarded to this parameter instruction. If the parameter - // belongs to conditional's false computation, then operand 2 is forwarded - // to this parameter instruction. - if (parameter->parent() == conditional->true_computation()) { - inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); - } else { - CHECK_EQ(parameter->parent(), conditional->false_computation()); - inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); + // belongs to conditional's branch 5 computation, then operand 6 is + // forwarded to this parameter instruction. + bool found_parent = false; + for (int j = 0; j < conditional->branch_count(); ++j) { + if (parameter->parent() == conditional->branch_computation(j)) { + inputs.push_back( + &GetInstructionValueSet(conditional->operand(j + 1))); + found_parent = true; + break; + } } + CHECK(found_parent); need_phi = true; } else { LOG(FATAL) << "CallContext::kSequential computations should only be " @@ -710,19 +673,17 @@ void HloDataflowAnalysis::Propagate() { // parameter(s) of the computation need to be updated. if (user->opcode() == HloOpcode::kConditional) { // If operand 0 is the use of instruction, then no parameters need to be - // updated, since that is the predicate of the conditional. - // If operand 1 is the use of instruction, then the true_computation's - // parameter need to be updated. - // If operand 2 is the use of instruction, then the false_computation's - // parameter need to be updated. + // updated, since that is the branch_index of the conditional. + // If operand n+1 is the use of instruction, then the branch_computation + // n's parameter need to be updated. // - // Note that the same instruction can be used in both operand 1 and - // operand 2. - if (user->operand(1) == instruction) { - add_to_worklist(user->true_computation()->parameter_instruction(0)); - } - if (user->operand(2) == instruction) { - add_to_worklist(user->false_computation()->parameter_instruction(0)); + // Note that the same instruction can be used in multiple branches' + // operands. + for (int j = 0; j < user->branch_count(); ++j) { + if (user->operand(j + 1) == instruction) { + add_to_worklist( + user->branch_computation(j)->parameter_instruction(0)); + } } } else { for (HloComputation* called_computation : user->called_computations()) { @@ -744,8 +705,8 @@ void HloDataflowAnalysis::Propagate() { const CallGraphNode& call_graph_node = call_graph_->GetNode(instruction->parent()); for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if ((callsite.instruction()->opcode() == HloOpcode::kCall) || - (callsite.instruction()->opcode() == HloOpcode::kConditional)) { + if (callsite.instruction()->opcode() == HloOpcode::kCall || + callsite.instruction()->opcode() == HloOpcode::kConditional) { add_to_worklist(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. @@ -983,6 +944,79 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer( return true; } +// Given a fusion whose root is a dynamic-update-slice op, determines whether +// the fusion's output buffer can be shared with the buffer of fusion_param, +// which must be a fused parameter of the fusion. +// +// Preconditions: +// +// - fusion's root is a dynamic-update-slice op. +// - fusion_param is a parameter within the fusion. +// +// fusion_param may point to a subelement of the actual parameter instruction if +// the param is a tuple; i.e. fusion_param->index() need not be the empty list. +// +// Returns true if: +// +// * fusion is a loop or input fusion, AND +// * fusion_param is used by the root of dynamic-update-slice as the "base" of +// the update, i.e. the thing being updated, AND +// * all other uses of fusion_param are dynamic-slices that slice the same +// indices as are overwritten in the dynamic-update-slice. +// +// In the case that there are no other uses of fusion_param (last bullet point +// is vacuously true) it's easy to see why an in-place DUS is safe; this is just +// the "natural" implementation of DUS. If there are other users, in-place DUS +// is safe on the assumption that the thread which writes element i of the +// output will be the only one to read element i of fusion_param (via the +// dynamic-slice ops). +static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion, + const HloValue& fusion_param_value) { + auto* root = + Cast(fusion->fused_expression_root()); + auto* fusion_param = fusion_param_value.instruction(); + CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter); + CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation()); + + // fusion must be a loop or input fusion. + auto kind = fusion->fusion_kind(); + if (kind != HloInstruction::FusionKind::kLoop && + kind != HloInstruction::FusionKind::kInput) { + return false; + } + + // fusion_param must be used by the root as the "base" of the + // dynamic-update-slice. The natural way to check this would be + // + // `if (root->operand(0) != fusion_param)` + // + // but we also have to handle the case where the fusion parameter is + // tuple-shaped and we're considering just one element of that tuple, i.e. + // fusion_param.index() != {}. + if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) { + return use.instruction == root; + }) != 1) { + return false; + } + + // All other uses of fusion_param must be dynamic-slices that slice the same + // indices as are overwritten by the dynamic-update-slice. + for (const HloUse& use : fusion_param_value.uses()) { + auto* user = use.instruction; + if (user == root) { + continue; + } + + // Check that `user` is a dynamic-slice op and has the same slice indices as + // `root`. + auto* ds = DynCast(user); + if (!ds || ds->index_operands() != root->index_operands()) { + return false; + } + } + return true; +} + bool HloDataflowAnalysis::CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index) const { @@ -999,35 +1033,34 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kFusion) { - if (fusion_can_share_buffer_ != nullptr) { - return fusion_can_share_buffer_(user, operand); - } // Get the parameter associated with 'operand'; HloInstruction* fusion_param = user->fused_parameter(user->operand_index(operand)); - const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); - if (MultiDynamicSliceUseShareSameIndices(value.uses())) { - return true; + const HloValue& fusion_param_value = + GetValueDefinedAt(fusion_param, operand_index); + + // TODO(b/80315712): This code is in a bit of a weird intermediate state + // at the moment. The in-place DUS check really needs to be common to all + // backends, so it runs first. Then we run the backend-specific check if + // provided, or go through the target-indepdendent check if not. + // Unfortunately, the notionally "target-independent" path actually contains + // some target-specific code, so we can't run all of it *in addition* to the + // target-specific function, like the interface documentation says. + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value); } + + if (fusion_can_share_buffer_ != nullptr) { + return fusion_can_share_buffer_(user, operand); + } + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput) { - if (user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - if (value.uses().size() == 1) { - const HloUse& use = value.uses()[0]; - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } - return false; - } return AreTransitiveUsesElementwiseOrTuple(fusion_param); } + if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -1049,8 +1082,8 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root (at operand // index 'other_add_operand_index'). - if (value.uses().size() == 1) { - const HloUse& use = value.uses()[0]; + if (fusion_param_value.uses().size() == 1) { + const HloUse& use = fusion_param_value.uses()[0]; return use.instruction == user->fused_expression_root() && use.operand_number == other_add_operand_index; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4a7c4963b7b399e625da907b3810c42df7ee2bd3..f0b18d6fc9554e53fd2f4712fe9787f660d25b8a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -49,7 +50,6 @@ class HloDataflowAnalysisTest : public HloTestBase, // reference to the generated analysis stored in analysis_. const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false) { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); analysis_ = HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); @@ -108,6 +108,7 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -156,6 +157,7 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -211,6 +213,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { auto gte_out = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -266,6 +269,7 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -319,6 +323,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto sub = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kSubtract, call1, call2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -371,6 +376,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -433,6 +439,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -508,6 +515,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -613,6 +621,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -700,6 +709,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -795,6 +805,7 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -850,6 +861,7 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -892,6 +904,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -963,6 +976,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1052,6 +1066,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1094,6 +1109,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { scalar_shape_, HloOpcode::kBitcast, constant)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); { @@ -1130,6 +1146,7 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1163,6 +1180,7 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { HloInstruction::CreateSend(param, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1188,6 +1206,7 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1223,6 +1242,7 @@ TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1260,6 +1280,7 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { vector_shape_, HloOpcode::kAdd, negate, exp)); auto entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); HloSchedule schedule(module_.get()); @@ -1338,6 +1359,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { HloInstruction::CreateWhile(scalar_shape_, condition, body, param)); auto entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); RunAnalysis(ssa_form); @@ -1408,6 +1430,7 @@ TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) { HloInstruction::CreateReverse(vector_shape_, negate, {0})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1439,6 +1462,7 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValues) { vector_shape_, HloOpcode::kAdd, negate, exp)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1478,6 +1502,7 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { vector_shape_, HloOpcode::kAdd, negate, exp)); auto entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); HloSchedule schedule(module_.get()); @@ -1536,6 +1561,7 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { builder.AddInstruction(HloInstruction::CreateBinary( vector_shape_, HloOpcode::kAdd, negate, call)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1588,6 +1614,7 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { scalar_shape_, pred, constant1, true_computation, constant2, false_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); @@ -1681,6 +1708,7 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, false_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); @@ -1815,6 +1843,7 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { scalar_shape_, pred1, tuple_operand, inner_conditional_computation, constant3, computation3)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); @@ -2169,6 +2198,66 @@ TEST_F(CanShareOperandBufferWithUserTest, dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) { + const char* kModule = R"( + HloModule test + + fused_computation { + p0 = f32[10,20,30] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30} + ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2) + } + + ENTRY test { + p0 = f32[10,20,30] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param = module_->entry_computation()->parameter_instruction(0); + + RunAnalysis(); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) { + const char* kModule = R"( + HloModule test + + fused_computation { + p0 = f32[10,20,30] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30} + ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3) + } + + ENTRY test { + p0 = f32[10,20,30] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param = module_->entry_computation()->parameter_instruction(0); + + RunAnalysis(); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2178,8 +2267,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { HloInstruction::CreateParameter(0, in_shape, "param0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + auto result = builder.AddInstruction(HloInstruction::CreateCompare( + out_shape, param0, param1, ComparisonDirection::kEq)); BuildModuleAndRunAnalysis(builder.Build()); @@ -2356,14 +2445,17 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto sort = - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false, + &builder, module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); EXPECT_TRUE( dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); @@ -2371,6 +2463,7 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); Shape values_shape = ShapeUtil::MakeShape(F32, {8}); @@ -2378,11 +2471,14 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { HloInstruction::CreateParameter(0, keys_shape, "keys")); auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); - auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, - {values})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); // The buffer for the keys can be shared with the first tuple entry. EXPECT_TRUE( @@ -2495,8 +2591,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); return builder.Build(); }; diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index b5d72b386f89568cc3066b2e497be98428d1ed0c..d0073237ac2aab565175896e42c4503e26a9966b 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -223,8 +223,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { HloInstruction::CreateParameter(0, shape, "cond_param")); auto constant = cond_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant)); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param, + constant, ComparisonDirection::kLt)); } auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 19b5734825df833fd34d634e4c1630dd75e96c4c..3746fbbda02b09d0660c209d09698359375b9e02 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -27,8 +27,6 @@ namespace { StatusOr RunInternal(HloModule* module, HloDomainIsolator::DomainCreator* creator) { - hlo_graph_dumper::MaybeDumpHloModule(*module, "Before Domain Isolator"); - int64 added_domains = 0; for (HloComputation* computation : module->computations()) { // Walk in post order and place all the required kDomain instructions. @@ -56,9 +54,6 @@ StatusOr RunInternal(HloModule* module, } } VLOG(3) << "Added " << added_domains << " kDomain instructions"; - if (added_domains > 0) { - hlo_graph_dumper::MaybeDumpHloModule(*module, "After Domain Isolator"); - } return added_domains > 0; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc index 67fad0769f5eb5ceca64ebd2aa78c6469f2c813d..4975c3fbb93da266bbb542793953d4e365325d93 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -59,8 +59,6 @@ Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( StatusOr HloDomainRemover::RunContext::Run() { VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); - int64 removed_domains = 0; for (HloComputation* computation : module_->computations()) { // First create the domain instruciton sets. A domain instruction set is @@ -97,9 +95,6 @@ StatusOr HloDomainRemover::RunContext::Run() { } VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" << remover_->kind_ << "' kind"; - if (removed_domains > 0) { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); - } return removed_domains > 0; } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 9b0f2b2a0f4dd5d1d1191e9ab0637cc3034b50da..7d6b86056af3fc2128fe1642bbfa0ca6f9ef1da0 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -127,6 +127,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops where it does not make sense to convert them. if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kBitcastConvert || opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { continue; @@ -145,7 +146,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || - opcode == HloOpcode::kConditional) { + opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) { continue; } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index a3b56a44a0b02923585c1dcb69571479236188a3..4171f738620dbf545e5883b8c26169fae4b93643 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -28,15 +28,7 @@ using ::testing::Eq; using ::testing::Not; using ::testing::ResultOf; -class HloElementTypeConverterTest : public HloTestBase { - public: - std::unique_ptr CreateModuleFromHloString( - const string& hlo_string) { - return HloRunner::CreateModuleFromString(hlo_string, - GetDebugOptionsForTest()) - .ValueOrDie(); - } -}; +using HloElementTypeConverterTest = HloTestBase; TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { const string& hlo_string = R"( @@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { custom_call_target="foo" } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -63,7 +55,7 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { outfeed = token[] outfeed(infeed.data, token0) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) { const string& hlo_string = R"( HloModule NestedTuples ENTRY NestedTuples.v5 { - constant.4 = bf16[] constant(42) constant.2 = f32[2]{0} constant({1, 2}) - constant.3 = bf16[] constant(42) - add = bf16[] add(constant.2, constant.3) - tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add) + constant.3 = bf16[2]{0} constant({42, 42}) + add = bf16[2]{0} add(constant.2, constant.3) + tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add) constant.5 = bf16[2]{0} constant({22, 44}) - ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5) + ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -135,7 +126,7 @@ ENTRY main { ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -161,7 +152,7 @@ ENTRY main { ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); @@ -185,5 +176,19 @@ ENTRY main { EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); } +TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) { + const string& hlo_string = R"( + HloModule test + + ENTRY test { + p = bf16[] parameter(0) + ROOT c = u16[] bitcast-convert(p) + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + HloElementTypeConverter converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get())); + EXPECT_FALSE(converted); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 7589c992025d84355a7b19adeecfa93822196d5e..71c745149f1774ad83f4d1a41b67105255dddde9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include +#include #include #include -#include #include #include "absl/algorithm/container.h" @@ -29,11 +29,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -56,43 +56,40 @@ namespace xla { namespace { template -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el != rhs_el; }; break; - case HloOpcode::kGe: + case ComparisonDirection::kGe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el >= rhs_el; }; break; - case HloOpcode::kGt: + case ComparisonDirection::kGt: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el > rhs_el; }; break; - case HloOpcode::kLe: + case ComparisonDirection::kLe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el <= rhs_el; }; break; - case HloOpcode::kLt: + case ComparisonDirection::kLt: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el < rhs_el; }; break; - default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); } Literal result(shape); @@ -106,24 +103,25 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, + ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](complex64 lhs_el, complex64 rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](complex64 lhs_el, complex64 rhs_el) { return lhs_el != rhs_el; }; break; default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); + LOG(FATAL) << "unhandled direction for conversion to Comparison: " + << ComparisonDirectionToString(direction); } Literal result(shape); @@ -137,24 +135,25 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, + ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](complex128 lhs_el, complex128 rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](complex128 lhs_el, complex128 rhs_el) { return lhs_el != rhs_el; }; break; default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); + LOG(FATAL) << "unhandled direction for conversion to Comparison: " + << ComparisonDirectionToString(direction); } Literal result(shape); @@ -491,15 +490,52 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { auto operand = is_finite->operand(0); - if (!ShapeUtil::ElementIsFloating(operand->shape())) { - return InvalidArgument( - "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type())); - } + auto elem_ty = operand->shape().element_type(); + switch (elem_ty) { + case PRED: + case TUPLE: + case OPAQUE: + case TOKEN: + case S8: + case S16: + case S32: + case S64: + case U8: + case U16: + case U32: + case U64: + case C64: + case C128: + // Explicitly enumerate all types in this switch so that when we add a new + // type, we'll get a compile error here. + case PRIMITIVE_TYPE_INVALID: + case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: + case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: + return InvalidArgument( + "expected element type in shape to be floating point, but " + "got: %s", + PrimitiveType_Name(elem_ty)); - switch (operand->shape().element_type()) { - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](Eigen::half elem_operand) { + return std::isfinite(static_cast(elem_operand)); + }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case BF16: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](bfloat16 elem_operand) { + return std::isfinite(static_cast(elem_operand)); + }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } case F32: { auto result_or = ElementWiseUnaryOpImpl( is_finite, @@ -516,9 +552,6 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); break; } - default: - LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: " - << PrimitiveType_Name(operand->shape().element_type()); } return Status::OK(); @@ -542,7 +575,7 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { break; } case C128: { - auto result_or = ElementWiseUnaryOpImpl( + auto result_or = ElementWiseUnaryOpImpl( real, [](complex128 elem_operand) { return std::real(elem_operand); }, GetEvaluatedLiteralFor(operand)); TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); @@ -637,20 +670,11 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) { } Status HloEvaluator::HandleCompare(HloInstruction* compare) { - HloOpcode opcode = compare->opcode(); + ComparisonDirection direction = compare->comparison_direction(); auto lhs = compare->operand(0); auto rhs = compare->operand(1); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()), - ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape())); - } + DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -662,76 +686,76 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { case PRED: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case U8: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U16: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S8: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case S16: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case F16: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case F32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case F64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case C64: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case C128: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; default: @@ -1192,8 +1216,8 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { HloEvaluator embedded_evaluator; embedded_evaluator.set_dynamic_dimension_inference( dynamic_dimension_inference_); - Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(Literal result, + embedded_evaluator.Evaluate(*computation, arg_literals)); evaluated_[call] = std::move(result); return Status::OK(); @@ -1227,37 +1251,35 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { HloEvaluator embedded_evaluator; embedded_evaluator.set_dynamic_dimension_inference( dynamic_dimension_inference_); - Literal result = - embedded_evaluator.Evaluate(*readded_computation, arg_literals) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate( + *readded_computation, arg_literals)); evaluated_[fusion] = std::move(result); return Status::OK(); } Status HloEvaluator::HandleConditional(HloInstruction* conditional) { - const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); - const auto& true_computation_arg = - GetEvaluatedLiteralFor(conditional->operand(1)); - const auto& false_computation_arg = - GetEvaluatedLiteralFor(conditional->operand(2)); - - auto* true_computation = conditional->true_computation(); - auto* false_computation = conditional->false_computation(); + const auto& branch_index_literal = + GetEvaluatedLiteralFor(conditional->operand(0)); + int branch_index; + if (conditional->operand(0)->shape().element_type() == PRED) { + branch_index = branch_index_literal.Get({}) ? 0 : 1; + } else { + branch_index = branch_index_literal.Get({}); + if (branch_index < 0 || branch_index >= conditional->branch_count()) { + branch_index = conditional->branch_count() - 1; + } + } + const auto& branch_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(1 + branch_index)); HloEvaluator embedded_evaluator; embedded_evaluator.set_dynamic_dimension_inference( dynamic_dimension_inference_); - Literal result; - if (pred.Get({})) { - result = - embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg}) - .ConsumeValueOrDie(); - } else { - result = embedded_evaluator - .Evaluate(*false_computation, {&false_computation_arg}) - .ConsumeValueOrDie(); - } + TF_ASSIGN_OR_RETURN(Literal result, + embedded_evaluator.Evaluate( + *conditional->branch_computation(branch_index), + {&branch_computation_arg})); evaluated_[conditional] = std::move(result); return Status::OK(); @@ -1327,168 +1349,212 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } -// Key-value sort is a special snowflake: it's templated on two different -// element types, one for the keys, and one for the values. Jump through some -// hoops to make this work. namespace { -template -StatusOr EvaluateSortInternal(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - auto rank = keys_literal.shape().rank(); - TF_RET_CHECK( - ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) - << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; - // We need to sort an array of keys and an array of values, where the - // sorted order of the values is determined by the keys. The simplest(?) - // way to do this is to go to an array-of-pairs representation, sort the - // array using the keys, and then go back to pair-of-arrays. - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - - if (rank == 0) { - // Nothing to sort. - return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); +template +Literal ExtractLiteralFromIndexPositions(const Literal& from, + absl::Span indices, + bool extract_as_scalar) { + if (extract_as_scalar) { + return LiteralUtil::CreateR0(from.Get({indices[0]})); + } + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); +} + +StatusOr ExtractFromIndexPositions(const Literal& from, + absl::Span indices, + bool extract_as_scalar = false) { + if (extract_as_scalar) { + CHECK_EQ(indices.size(), 1); + } + PrimitiveType type = from.shape().element_type(); + switch (type) { + case PRED: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U8: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S8: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case BF16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + default: + return InvalidArgument("Unsupported type for Sort: %s", + PrimitiveType_Name(type)); } +} +} // namespace - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); +Status HloEvaluator::HandleSort(HloInstruction* sort) { + TF_RET_CHECK(sort->operand_count() >= 1) + << "Expected at least 1 operand for sort"; + for (int64 i = 1; i < sort->operand_count(); ++i) { + TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(i)->shape())) + << "All Sort operands must have the same dimensions"; + } + + if (VLOG_IS_ON(3)) { + for (int64 i = 0; i < sort->operand_count(); ++i) { + VLOG(3) << "HandleSort operand " << i << " literal: " + << GetEvaluatedLiteralFor(sort->operand(i)).ToString(); + } + } + Shape key_shape = sort->operand(0)->shape(); + auto rank = key_shape.rank(); + std::vector result_literals; + result_literals.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + result_literals.emplace_back(sort->operand(i)->shape()); + } std::vector zero_base(rank, 0); std::vector increment(rank, 1); int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + int64 sort_dim_elements = key_shape.dimensions(sort_dim); increment[sort_dim] = sort_dim_elements; + HloEvaluator embedded_evaluator(max_loop_iterations_); // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys_literal.shape(), zero_base, - AsInt64Slice(keys_literal.shape().dimensions()), increment, + key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment, [&](absl::Span indices) -> StatusOr { - // Extract a slice from the keys and values literals that correspond to + // Extract a slice from each operand literal that corresponds to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); absl::c_for_each(limit_indices, [](int64& index) { ++index; }); limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto keys_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& keys_data = keys_to_sort.data(); - TF_ASSIGN_OR_RETURN(auto values_to_sort, - values_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& values_data = values_to_sort.data(); - using kv_pair = std::pair; - std::vector key_value_vector; - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back( - std::make_pair(keys_data[i], values_data[i])); + std::vector literals_to_sort; + literals_to_sort.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(auto literal_to_sort, + GetEvaluatedLiteralFor(sort->operand(i)) + .Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + literals_to_sort.push_back(std::move(literal_to_sort)); } - std::stable_sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - // We use a InlinedVector here because we need to convert it to an - // absl::Span later, and this would not work with std::vector. - absl::InlinedVector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); + std::vector indices_to_sort(sort_dim_elements); + std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); + Status compare_status = Status::OK(); + auto comparator = [sort, &compare_status, &embedded_evaluator, + &literals_to_sort](int64 a, int64 b) { + std::vector literals; + literals.reserve(2 * sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a}, + /*extract_as_scalar=*/true); + if (!lhs.ok()) { + compare_status = lhs.status(); + return false; + } + literals.push_back(std::move(lhs.ValueOrDie())); + auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b}, + /*extract_as_scalar=*/true); + if (!rhs.ok()) { + compare_status = rhs.status(); + return false; + } + literals.push_back(std::move(rhs.ValueOrDie())); + } + std::vector literal_ptrs; + absl::c_transform(literals, std::back_inserter(literal_ptrs), + [](const Literal& literal) { return &literal; }); + + auto computed_result = + embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs); + // Clear visit states so that we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + if (!computed_result.ok()) { + compare_status = computed_result.status(); + return false; + } + return computed_result.ValueOrDie().Get({}); + }; + if (Cast(sort)->is_stable()) { + std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(), + comparator); + } else { + std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator); + } + if (!compare_status.ok()) { + return compare_status; } - Literal sorted_keys(ShapeUtil::MakeShape( - keys_literal.shape().element_type(), {sort_dim_elements})); - sorted_keys.PopulateR1(absl::Span(result_keys)); - Literal sorted_values(ShapeUtil::MakeShape( - values_literal.shape().element_type(), {sort_dim_elements})); - sorted_values.PopulateR1(absl::Span(result_values)); std::vector slice_dimensions(rank, 1); slice_dimensions[sort_dim] = sort_dim_elements; std::vector start_indices(rank, 0); - TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, - sorted_keys.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys_reshaped, start_indices, indices, slice_dimensions)); - TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, - sorted_values.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values_reshaped, start_indices, indices, slice_dimensions)); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN( + Literal sorted_literal, + ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort)); + TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped, + sorted_literal.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom( + sorted_literal_reshaped, start_indices, indices, + slice_dimensions)); + } return true; })); - Literal result_tuple; - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); - VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); - return std::move(result_tuple); -} - -template -StatusOr EvaluateSortCurried(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (values_literal.shape().element_type()) { - case PRED: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case F32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case U32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case S32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case BF16: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} - -StatusOr EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (sort->operand(0)->shape().element_type()) { - case F32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case U32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case S32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case BF16: - return EvaluateSortCurried(sort, keys_literal, values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} -} // namespace - -Status HloEvaluator::HandleSort(HloInstruction* sort) { - if (!sort->shape().IsTuple()) { - return DefaultAction(sort); + if (sort->operand_count() == 1) { + evaluated_[sort] = std::move(result_literals[0]); } else { - // This is a really stupid work-around for the fact it's hard to support a - // multi-value sort directly, due to the fact we need to template the - // evaluation function on all of the value types. - std::vector sort_results_backing; - for (int64 i = 0; i < sort->operand_count(); ++i) { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(i))); - if (!result.ok()) { - return result.status(); - } - sort_results_backing.push_back( - std::move(result.ValueOrDie().DecomposeTuple()[1])); - } - std::vector sort_results; - absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + std::vector literal_ptrs; + absl::c_transform(result_literals, std::back_inserter(literal_ptrs), [](const Literal& literal) { return &literal; }); - evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); - return Status::OK(); + + Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); + + evaluated_[sort] = std::move(result_tuple); } + return Status::OK(); } Status HloEvaluator::HandleReduce(HloInstruction* reduce) { @@ -1507,6 +1573,27 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { } } +Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { + if (!custom_call_handler_) { + // No handler is registered; this means custom-calls are not allowed. + return DefaultAction(custom_call); + } + + // Evaluate input operands so the handler has access to the operand data. + std::vector operands; + operands.reserve(custom_call->operand_count()); + for (const HloInstruction* operand : custom_call->operands()) { + operands.push_back(&GetEvaluatedLiteralFor(operand)); + } + + // Synchronously issue the handler to populate the instruction output literal. + TF_ASSIGN_OR_RETURN( + auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands))); + + evaluated_[custom_call] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return ShapeUtil::ValidateShape(hlo->shape()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index ccb8af4fb07fedb054693b78e8bab49527d38700..357975a131d0c7e63c06e96852468b43d97a37f2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -16,12 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#include #include #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -132,6 +134,23 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Enable the fast path for certain operations like dot or convolution. void set_use_fast_path(bool value) { use_fast_path_ = value; } + // Handles evaluation of a custom-call op. + // Operand literals are provided in |operands| and implementations must + // populate |output| before returning. + using CustomCallHandler = std::function( + HloInstruction* custom_call, absl::Span operands)>; + + // Sets a handler that is called during evaluation for custom-call ops. + // If no handler is defined the default error behavior will occur. The handler + // will be provided evaluated literals for all operands and is expected to + // return an output literal of the appropriate shape. + void set_custom_call_handler( + std::function(HloInstruction* custom_call, + absl::Span operands)> + handler) { + custom_call_handler_ = std::move(handler); + } + // Returns the result of a matrix multiply `lhs x rhs`. static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); @@ -219,6 +238,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + // Unsupported HLOs, note some of them (such as BatchNorm*) are typically // expanded in a semantic-preserving way into other HLOs by adding exanpsion // HLO pass to the HLO optimization pass during compilation, which can then be @@ -310,16 +331,21 @@ class HloEvaluator : public DfsHloVisitorWithDefault { std::vector arg_literals_; // Max loop iterations to execute with no maximum if negative. - int64 max_loop_iterations_; + int64 max_loop_iterations_ = 0; // Module-level seed handle. - uint64 seed_; + uint64 seed_ = 0; // RNG engine. std::minstd_rand0 engine_; // DynamicDimensionInference is used to evaluate GetDimensionSize, which // returns the dynamic dimension size of its operand. - DynamicDimensionInference* dynamic_dimension_inference_; + DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; + + // Optional handler for custom_call ops. + std::function(HloInstruction* custom_call, + absl::Span operands)> + custom_call_handler_; TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index d34fa48efb2e8ec2d35cdbb24441964f3d7c8b92..335859fb52961aa0f88aec69b73bf073409db451 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -58,12 +58,12 @@ class HloEvaluatorTest : public HloTestBase { public: HloEvaluatorTest() : use_bfloat16_(false) {} - Literal Evaluate(absl::Span arg_literals = {}) { + StatusOr Evaluate( + absl::Span arg_literals = {}) { if (use_bfloat16_) { HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_.Evaluate(*m_->entry_computation(), arg_literals) - .ConsumeValueOrDie(); + return evaluator_.Evaluate(*m_->entry_computation(), arg_literals); } // Evaluate function that takes in a local module instead of using m_ @@ -86,7 +86,7 @@ class HloEvaluatorTest : public HloTestBase { b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { @@ -106,7 +106,25 @@ class HloEvaluatorTest : public HloTestBase { HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + + void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0, + Literal src1, Literal src2) { + HloComputation::Builder b(TestName()); + auto operand0 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src0))); + auto operand1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src1))); + auto operand2 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src2))); + b.AddInstruction(HloInstruction::CreateTernary( + expected.shape(), opcode, operand0, operand1, operand2)); + m_->AddEntryComputation(b.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } @@ -145,13 +163,40 @@ TEST_P(HloEvaluatorBf16Test, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +// Verifies that clamping of int64 does not cause loss of precision +TEST_P(HloEvaluatorBf16Test, DoesClampInt64) { + auto ones = [](int bits) { return (int64{1} << bits) - 1; }; + + auto low = + LiteralUtil::CreateR2({{0, ones(54)}, {ones(54), ones(58)}}); + auto value = LiteralUtil::CreateR2({{0, ones(56)}, {0, ones(58)}}); + auto high = LiteralUtil::CreateR2( + {{ones(54), ones(55)}, {ones(56), ones(58)}}); + + Shape shape = low.shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + m_->AddEntryComputation(b.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + + auto expected = + LiteralUtil::CreateR2({{0, ones(55)}, {ones(54), ones(58)}}); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); @@ -166,7 +211,7 @@ TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); @@ -191,7 +236,7 @@ TEST_P(HloEvaluatorBf16Test, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); @@ -254,6 +299,20 @@ TEST_F(HloEvaluatorTest, DoesDivideInt64) { TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } + +TEST_F(HloEvaluatorTest, DoesClampS64) { + auto low = LiteralUtil::CreateR1( + {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL}); + auto value = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL}); + auto high = LiteralUtil::CreateR1( + {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL}); + auto expected = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL}); + TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low), + std::move(value), std::move(high)); +} + TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); @@ -308,6 +367,19 @@ TEST_F(HloEvaluatorTest, DoesNotR2) { {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } + +TEST_F(HloEvaluatorTest, DoesRealC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_real = LiteralUtil::CreateR1({1, -100}); + TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x)); +} + +TEST_F(HloEvaluatorTest, DoesImagC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_imag = LiteralUtil::CreateR1({0, 4}); + TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x)); +} + // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { @@ -332,7 +404,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(args); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args)); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); @@ -356,7 +428,7 @@ TEST_F(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result.EachCell([&](absl::Span indices, NativeT value) { @@ -377,7 +449,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { output_literal.shape(), literal_instruction, {1, 2})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } @@ -396,7 +468,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { /*broadcast_dimensions=*/{})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } @@ -416,7 +488,7 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); @@ -438,7 +510,7 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR1({100, 200}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -458,7 +530,7 @@ TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) { b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } @@ -478,7 +550,7 @@ TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) { b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } @@ -513,7 +585,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); @@ -540,7 +612,7 @@ TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -584,7 +656,7 @@ TEST_P(HloEvaluatorBf16Test, NegativePadding2D) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique>(1, 5); @@ -629,7 +701,7 @@ TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); @@ -668,7 +740,7 @@ TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off auto expected_array = Array2D({ @@ -714,7 +786,7 @@ TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); @@ -758,7 +830,7 @@ TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected_array = Array2D({ {22.f, 28.f}, @@ -800,7 +872,8 @@ TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + float expected_1 = 0; for (float i = 1.0f; i < 7.0f; ++i) { expected_1 += i * i + i; @@ -856,7 +929,7 @@ TEST_P(HloEvaluatorBf16Test, SimpleConv1D) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); @@ -911,7 +984,7 @@ TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -995,7 +1068,7 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1073,7 +1146,7 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1133,7 +1206,7 @@ TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1197,7 +1270,7 @@ TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1269,7 +1342,7 @@ TEST_P(HloEvaluatorBf16Test, /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1341,7 +1414,7 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 1, 8); expected_array.FillWithYX( @@ -1459,7 +1532,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceAdd) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR1({6, 18}); @@ -1511,7 +1584,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{6, 7}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -1563,7 +1636,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{11}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -1620,7 +1693,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -1681,7 +1754,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); std::vector output_dims = {4, 3, 3, 3, 4, 4}; Literal result_literal = @@ -1713,7 +1786,7 @@ TEST_P(HloEvaluatorBf16Test, StridedSlice) { /*strides=*/{2, 3})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {3}, @@ -1749,7 +1822,7 @@ TEST_P(HloEvaluatorBf16Test, DynamicSlice) { HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, @@ -1787,7 +1860,7 @@ TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) { HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, @@ -1826,7 +1899,7 @@ TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) { shape, operand, update, {zero, one})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, @@ -1862,7 +1935,7 @@ TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, @@ -1901,7 +1974,7 @@ TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1939,7 +2012,7 @@ TEST_P(HloEvaluatorBf16Test, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -1979,11 +2052,12 @@ TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloEvaluator evaluator; Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); - auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, ¶m0_literal}, {square, &square_literal}}); - TF_ASSERT_OK(result.status()); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator.EvaluateWithSubstitutions( + add, {{param0, ¶m0_literal}, {square, &square_literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result)); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -2004,11 +2078,11 @@ TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); - auto result = - evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); - TF_ASSERT_OK(result.status()); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -2030,9 +2104,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - Evaluate({&operand, &start_indices}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -2054,9 +2128,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - Evaluate({&operand, &start_indices}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -2078,10 +2152,11 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - Evaluate({&operand, &start_indices}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -2105,9 +2180,9 @@ ENTRY main { {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), result)); } TEST_F(HloEvaluatorTest, @@ -2132,9 +2207,9 @@ ENTRY main { {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -2156,8 +2231,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -2179,9 +2255,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR3({{{8}}, {{5}}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2202,8 +2278,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2226,9 +2303,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{0, 1}, {2, 1}}), result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2257,9 +2334,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2289,9 +2368,11 @@ ENTRY main { Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2321,9 +2402,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2353,9 +2436,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) { @@ -2386,10 +2471,12 @@ ENTRY main { Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); + result, ErrorSpec{0.1, 0.01})); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2419,9 +2506,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2452,9 +2541,11 @@ ENTRY main { Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2489,8 +2580,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, @@ -2526,8 +2618,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2558,8 +2651,9 @@ ENTRY main { Literal updates = LiteralUtil::CreateR2({{10}}); Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2590,8 +2684,9 @@ ENTRY main { Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2619,8 +2714,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{}, {}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - operand, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(operand, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2652,8 +2748,9 @@ ENTRY main { LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); Literal expected = LiteralUtil::CreateR1({10, 61, 32}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { @@ -2776,8 +2873,16 @@ TEST_F(HloEvaluatorTest, DoesCompareBF16) { {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); auto expected = LiteralUtil::CreateR2({{false, true, true}, {false, true, true}}); - TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), - std::move(rhs)); + + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); + b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2, + ComparisonDirection::kGe)); + m_->AddEntryComputation(b.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { @@ -2801,7 +2906,48 @@ ENTRY main { Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) { + // Infeed triggers unimplemented error within HandleCall, and we verify that + // the Evaluator does fail in such case. + const string hlo_text = R"( +HloModule DontFailOnCall + +call { + token0 = token[] after-all() + ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) +} + +ENTRY main { + ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto statusor = Evaluate(); + EXPECT_FALSE(statusor.status().ok()); +} + +TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) { + // Infeed triggers unimplemented error within HandleFusion, and we verify that + // the Evaluator does fail in such case. + const string hlo_text = R"( +HloModule DontFailOnFusion + +fused_computation { + token0 = token[] after-all() + ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) +} + +ENTRY main { + ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto statusor = Evaluate(); + EXPECT_FALSE(statusor.status().ok()); } TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) { @@ -2819,7 +2965,7 @@ ENTRY main { Literal arg = LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, LayoutUtil::MakeLayout({0, 1, 2})); - Literal actual = Evaluate({&arg}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg})); EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } @@ -2841,7 +2987,7 @@ ENTRY main { } TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual = Evaluate({&args[0]}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]})); if (use_bfloat16_) { EXPECT_TRUE( absl::c_equal(args[0].data(), actual.data())); @@ -2867,7 +3013,8 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); - std::vector actual = Evaluate({}).DecomposeTuple(); + TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({})); + std::vector actual = literal.DecomposeTuple(); ASSERT_EQ(actual.size(), 3); uint32 pow30 = uint32{1} << 30; @@ -2907,7 +3054,7 @@ ENTRY main { Literal size_arg = LiteralUtil::CreateR0(3); Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal actual = Evaluate({&size_arg, &data_arg}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg})); EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); } @@ -2979,7 +3126,8 @@ TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual = Evaluate({&args[0]}); + + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]})); EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); } @@ -3000,7 +3148,7 @@ TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual = Evaluate({&args[0]}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]})); EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); } @@ -3022,11 +3170,123 @@ TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual_tuple = Evaluate({&args[0]}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]})); std::vector actual_literals = actual_tuple.DecomposeTuple(); EXPECT_TRUE( absl::c_equal(args[0].data(), actual_literals[0].data())); } +// Tests that custom_calls fail to evaluate when no handler is specified. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_NoHandler + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::UNIMPLEMENTED); +} + +// Tests when a custom_call handler returns an error. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_HandlerError + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + return InternalError("Test error"); + }); + EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::INTERNAL); +} + +// Tests the custom_call handler on calls with many inputs. +// We sum the operands so that we can verify the operand and output literals +// are properly mapped for access. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_ManyInputs + ENTRY kernel_entry { + parameter.0 = u32[1]{0} parameter(0) + parameter.1 = u32[1]{0} parameter(1) + ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode()); + EXPECT_EQ("_my_custom_call", custom_call->custom_call_target()); + EXPECT_EQ(2, custom_call->operand_count()); + EXPECT_EQ(2, operands.size()); + auto output = Literal::CreateFromShape(custom_call->shape()); + auto operand0_data = operands[0]->data(); + auto operand1_data = operands[1]->data(); + auto output_data = output.data(); + output_data[0] = operand0_data[0] + operand1_data[0]; + return output; + }); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]})); + auto arg0_data = args[0].data(); + auto arg1_data = args[1].data(); + std::vector expected_data = {arg0_data[0] + arg1_data[0]}; + EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data())); +} + +TEST_F(HloEvaluatorTest, IsFiniteF16) { + constexpr absl::string_view hlo_text = R"( + HloModule test + + ENTRY IsFiniteTest { + c = f16[6] constant({nan, 7, nan, -1, inf, -inf}) + ROOT is-finite = pred[6] is-finite(c) + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), + ::testing::ElementsAre(false, true, false, true, false, false)); +} + +TEST_F(HloEvaluatorTest, IsFiniteBf16) { + constexpr absl::string_view hlo_text = R"( + HloModule test + + ENTRY IsFiniteTest { + c = bf16[6] constant({nan, 7, nan, -1, inf, -inf}) + ROOT is-finite = pred[6] is-finite(c) + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), + ::testing::ElementsAre(false, true, false, true, false, false)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 95a78408b0ba188487cc53acd5f641d0306cd8af..2d8a578985e8f603d4056bee8619725095ebc7bb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #include +#include #include "absl/algorithm/container.h" #include "absl/base/casts.h" @@ -43,46 +44,6 @@ template using is_complex_t = absl::disjunction, std::is_same>; -// It's UB to use std::sort with std::less, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. -template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - return a < b; -} - -template ::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - bool lhs_is_negative = std::signbit(a); - bool rhs_is_negative = std::signbit(b); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; - } - bool lhs_nan = std::isnan(a); - bool rhs_nan = std::isnan(b); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; - } - return a < b; -} - -template ::value || - std::is_same::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - return SafeLess(static_cast(a), static_cast(b)); -} - // ToArithmeticSafeType(T t): // - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed // integer, and @@ -368,10 +329,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* expm1) { + Status HandleLog1p(HloInstruction* log1p) { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[expm1], - ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + parent_->evaluated_[log1p], + ElementWiseUnaryOp(log1p, [](ElementwiseT elem_operand) { return std::log1p(elem_operand); })); return Status::OK(); @@ -462,9 +423,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleNegate(negate); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { @@ -474,6 +435,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value || + std::is_same::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return std::isnan(elem_operand) + ? elem_operand + : std::copysign( + elem_operand != ElementwiseT(0), + elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -686,6 +664,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleSqrt(HloInstruction* sqrt) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sqrt], + ElementWiseUnaryOp(sqrt, [](ElementwiseT elem_operand) { + return std::sqrt(elem_operand); + })); + return Status::OK(); + } + + Status HandleRsqrt(HloInstruction* rsqrt) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[rsqrt], + ElementWiseUnaryOp(rsqrt, [](ElementwiseT elem_operand) { + return static_cast(1) / std::sqrt(elem_operand); + })); + return Status::OK(); + } + template ::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { @@ -916,9 +911,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleShiftRightLogical(shrl); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + // Special case for integral type due to MSVC's std::isnan being unable to + // handle integral type. + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return static_cast( + std::min(high, std::max(value, low))); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template ::value && + !std::is_integral::value>::type* = + nullptr> Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { @@ -926,7 +941,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(NAN); } return static_cast( - std::fmin(high, std::fmax(value, low))); + std::min(high, std::max(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -1187,7 +1202,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { - if (parent_->use_fast_path_) { + if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 && + parent_->use_fast_path_) { return HandleDot(dot); } return HandleDotSlowPath(dot); @@ -1349,12 +1365,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { static_cast(lhs_literal.Get(lhs_index)) * static_cast(rhs_literal.Get(rhs_index)); - for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) { - int64 value = ++accumulate_index[i]; - if (value != accumulate_index_sizes[i]) { - break; + // If there are no contracting dimension accumulate_index_sizes is + // empty, do not try to count down from -1 to 0 since it is and + // infinite loop. + if (!accumulate_index_sizes.empty()) { + for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) { + int64 value = ++accumulate_index[i]; + if (value != accumulate_index_sizes[i]) { + break; + } + accumulate_index[i] = 0; } - accumulate_index[i] = 0; } } @@ -1656,73 +1677,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleSort(HloInstruction* sort) { - auto keys = sort->operand(0); - TF_RET_CHECK(sort->operand_count() == 1) - << "Typed visitor does not support key-value sort"; - - const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys->shape().dimensions(sort_dim); - int64 rank = keys->shape().rank(); - if (rank == 0) { - // Nothing to sort. - parent_->evaluated_[sort] = keys_literal.Clone(); - return Status::OK(); - } - Literal result_literal(keys_literal.shape()); - std::vector zero_base(rank, 0); - std::vector increment(rank, 1); - increment[sort_dim] = sort_dim_elements; - // Iterate through each dimension except 'sort_dim'. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), - increment, [&](absl::Span indices) -> StatusOr { - // Extract a slice from the literal that corresponds to exactly the - // row in dimension 'sort_dim'. - std::vector limit_indices(indices.begin(), indices.end()); - absl::c_for_each(limit_indices, [](int64& index) { ++index; }); - limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto row_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& row_data = row_to_sort.data(); - - std::vector result_data(row_data.begin(), row_data.end()); - std::stable_sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); - Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), - {sort_dim_elements})); - sorted_row.PopulateR1(absl::Span(result_data)); - std::vector slice_dimensions(rank, 1); - slice_dimensions[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, - sorted_row.Reshape(slice_dimensions)); - std::vector start_indices(rank, 0); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - sorted_row_reshaped, start_indices, indices, slice_dimensions)); - return true; - })); - parent_->evaluated_[sort] = std::move(result_literal); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleSort(HloInstruction* sort) { - return UnsupportedTypeError(sort); - } - Status HandleSort(HloInstruction* sort) override { - return HandleSort(sort); + return UnsupportedTypeError(sort); } Status HandleReduce(HloInstruction* hlo) override { @@ -2752,12 +2708,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& high = parent_->GetEvaluatedLiteralFor(random->operand(1)); - std::uniform_real_distribution generator( - low.Get({}), high.Get({})); - + // std::uniform_real_distribution(a, b) can sometimes return a value + // equal to b. Unclear if this is a spec bug or an implementation bug + // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open + // interval, so we have to re-sample if we get `b` out. + // + // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 + // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 + // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 + auto low_val = low.Get({}); + auto high_val = high.Get({}); + std::uniform_real_distribution generator(low_val, high_val); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span /*indexes*/) { - return generator(parent_->engine_); + while (true) { + NativeT v = generator(parent_->engine_); + if (v != high_val) { + return v; + } + } })); break; } @@ -2891,21 +2860,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { absl::Span start_indices, const Shape& result_shape) { std::vector start; - // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish - // between the cases, this currently assumes there is at least 1 index. That - // is wrong in the general case, because for scalar indices, if the operand - // is scalar, then there are no indices. This problem with resolve itself. - const HloInstruction* first_index = start_indices[0]; - if (first_index->shape().rank() == 1) { - auto start_indices_typed = - parent_->GetEvaluatedLiteralFor(first_index).data(); - start = std::vector(start_indices_typed.begin(), - start_indices_typed.end()); - } else { - for (HloInstruction* index : start_indices) { - start.push_back( - parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); - } + + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); } // Clamp the start indices so the slice is in-bounds w.r.t the operand. @@ -2938,22 +2896,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = operand_literal.Clone(); const auto rank = result.shape().rank(); std::vector start; - // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish - // between the cases, this currently assumes there is at least 1 index. That - // is wrong in the general case, because for scalar indices, if the operand - // is scalar, then there are no indices. This problem with resolve itself. - const HloInstruction* first_index = start_indices[0]; - if (first_index->shape().rank() == 1) { - auto start_indices_typed = - parent_->GetEvaluatedLiteralFor(first_index).data(); - start = std::vector(start_indices_typed.begin(), - start_indices_typed.end()); - } else { - for (HloInstruction* index : start_indices) { - start.push_back( - parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); - } + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); } + // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. for (int64 i = 0; i < rank; ++i) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 4c7f5e9e7dfb12a8cb699bdf397eab21983342a1..116b32f5f4c772b6a9771e6cf9e5095c7c959775 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -38,21 +38,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace hlo_graph_dumper { namespace { using absl::nullopt; @@ -259,14 +259,16 @@ optional MatchTrivialComputation(const HloComputation* computation) { // param0), check that the operation being performed is commutative. if (root->operand(0) == param1) { CHECK_EQ(root->operand(1), param0); - switch (root->opcode()) { - case HloOpcode::kLe: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLt: - return nullopt; - default: - break; + if (root->opcode() == HloOpcode()) { + switch (root->comparison_direction()) { + case ComparisonDirection::kLe: + case ComparisonDirection::kGe: + case ComparisonDirection::kGt: + case ComparisonDirection::kLt: + return nullopt; + default: + break; + } } } @@ -280,18 +282,22 @@ optional MatchTrivialComputation(const HloComputation* computation) { return "min"; case HloOpcode::kMaximum: return "max"; - case HloOpcode::kLe: - return "less-or-equal"; - case HloOpcode::kGe: - return "greater-or-equal"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kNe: - return "not-equal-to"; + case HloOpcode::kCompare: { + switch (root->comparison_direction()) { + case ComparisonDirection::kLe: + return "less-or-equal"; + case ComparisonDirection::kGe: + return "greater-or-equal"; + case ComparisonDirection::kGt: + return "greater-than"; + case ComparisonDirection::kLt: + return "less-than"; + case ComparisonDirection::kEq: + return "equal-to"; + case ComparisonDirection::kNe: + return "not-equal-to"; + } + } default: return nullopt; } @@ -536,7 +542,12 @@ stylesheet=< } } - return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); + // Browsers require that we URI-encode the contents of our data URI. (It + // seems this was a relatively recent change?) In practice, this means that we + // need to escape '#'. + return StrFormat( + fmt, graph_label, + absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}})); } string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } @@ -825,8 +836,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return StrFormat("%s (%s)", constant->literal().ToString(), - ShapeUtil::HumanString(constant->shape())); + return constant->literal().ToString(); } // Otherwise, print e.g. "%constant.42 (s32[100])". @@ -919,27 +929,22 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIota: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: @@ -949,6 +954,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -957,6 +963,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSin: case HloOpcode::kSlice: case HloOpcode::kSort: + case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: // De-emphasize scalar-shaped elementwise ops -- they're generally @@ -1011,6 +1018,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: + case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; @@ -1037,6 +1046,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kReplicaId: return kBrown; case HloOpcode::kCall: case HloOpcode::kConditional: @@ -1248,40 +1258,11 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( return instr; } -class GraphRendererRegistry { - public: - void SetRenderer(std::shared_ptr graph_renderer) { - tensorflow::mutex_lock lock(mu_); - graph_renderer_ = graph_renderer; - } - - std::shared_ptr GetDefaultRenderer() { - tensorflow::mutex_lock lock(mu_); - return graph_renderer_; - } - - static GraphRendererRegistry* Default() { - static GraphRendererRegistry* registry = new GraphRendererRegistry(); - return registry; - } - - private: - tensorflow::mutex mu_; - std::shared_ptr graph_renderer_ GUARDED_BY(mu_); -}; - -} // namespace - -Registrar::Registrar(std::shared_ptr dumper) { - GraphRendererRegistry::Default()->SetRenderer(dumper); -} - -namespace { - // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, - int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter( + const HloInstruction* root, int64 radius, + const absl::flat_hash_set& boundary) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. absl::flat_hash_map nodes; @@ -1297,6 +1278,9 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, if (depth == radius) { continue; } + if (boundary.contains(instr)) { + continue; + } // Traverse into instr's operands. // @@ -1436,157 +1420,7 @@ NodeFilter MakeNodeFromToFilter(const HloInstruction* from, }); } -string SaveGraph(const string& graph, - GraphRendererInterface::GraphKind graph_kind, - const string& dest_path) { - static std::atomic output_num(0); - string file_extension; - switch (graph_kind) { - case GraphRendererInterface::DOT_GRAPH: - file_extension = ".dot"; - break; - case GraphRendererInterface::TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; - } - string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); - auto status = Status::OK(); - auto env = tensorflow::Env::Default(); - if (!env->CreateUniqueFileName(&path, file_extension)) { - status = - Status(tensorflow::error::Code::UNKNOWN, - StrCat("Failed to create temporary file to dump HLO graph: ", - strerror(errno))); - } else { - status = tensorflow::WriteStringToFile(env, path, graph); - } - if (!status.ok()) { - LOG(WARNING) << "Saving HLO graph failed: " << status; - } - return path; -} - -string ExportGraph(const string& graph, - GraphRendererInterface::GraphKind graph_kind, - const DebugOptions& debug_options) { - string path = debug_options.xla_hlo_graph_path(); - if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { - return SaveGraph(graph, graph_kind, path); - } else { - auto graph_renderer = - GraphRendererRegistry::Default()->GetDefaultRenderer(); - CHECK(graph_renderer != nullptr) - << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " - "export to local file system"; - return graph_renderer->RenderGraph(graph, graph_kind, debug_options); - } -} - -} // namespace - -string DumpGraph(const HloComputation& computation, const string& label, - const DebugOptions& debug_options, - const HloExecutionProfile* hlo_execution_profile, - bool show_backend_config) { - GraphRendererInterface::GraphKind graph_kind; - string graph; - if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder(debug_options); - TF_CHECK_OK(builder.AddComputation(computation)); - CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), - &graph)); - graph_kind = GraphRendererInterface::TF_GRAPHDEF; - } else { - graph = - HloDotDumper(&computation, label, debug_options, show_backend_config, - hlo_execution_profile, NodeFilter()) - .Dump(); - graph_kind = GraphRendererInterface::DOT_GRAPH; - } - - string graph_url = ExportGraph(graph, graph_kind, debug_options); - LOG(INFO) << "computation " << computation.name() << " [" << label - << "]: " << graph_url; - return graph_url; -} - -string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_backend_config) { - auto debug_options = node.GetModule()->config().debug_options(); - string label = - StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); - return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); -} - -string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, - int64 max_nodes, bool show_backend_config) { - CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; - auto debug_options = from.GetModule()->config().debug_options(); - - bool hit_limit = false; - NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); - string label; - if (!hit_limit) { - label = StrCat("All paths from ", from.name(), " to ", to.name()); - } else { - label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), - " to ", to.name(), - "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " - "NODES***

"); - } - string graph = - HloDotDumper(from.parent(), label, debug_options, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); - return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); -} - -void DumpText(const HloModule& module, const string& label, - const string& directory_path, bool do_prefix) { - Env* env = Env::Default(); - TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); - string prefix = StrCat(env->NowMicros()); - string filename = - do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); - string path = JoinPath(directory_path, filename); - TF_CHECK_OK(WriteStringToFile( - env, path, - module.ToString(HloPrintOptions().set_print_large_constants(true)))); - LOG(INFO) << "dumping module '" << module.name() << "' to " << path; -} - -string MaybeDumpHloModule(const HloModule& module, const string& label, - const HloExecutionProfile* profile) { - const DebugOptions& debug_options = module.config().debug_options(); - VLOG(2) << "MaybeDumpHloModule called on module " << module.name() - << " with generate_hlo_graph regex \"" - << debug_options.xla_generate_hlo_graph() << "\""; - string graph_url; - if (!debug_options.xla_generate_hlo_graph().empty() && - RE2::PartialMatch(module.name(), - debug_options.xla_generate_hlo_graph())) { - graph_url = - DumpGraph(*module.entry_computation(), label, debug_options, profile); - } - if (!debug_options.xla_log_hlo_text().empty() && - RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { - LOG(INFO) << "HLO for module " << module.name(); - LOG(INFO) << "Label: " << label; - XLA_LOG_LINES(2, module.ToString()); - } - if (!debug_options.xla_generate_hlo_text_to().empty()) { - DumpText(module, label, debug_options.xla_generate_hlo_text_to()); - } - return graph_url; -} - -string WrapDotInHTML(const string& dot) { +string WrapDotInHtml(absl::string_view dot) { static const char html_prefix[] = R"html( @@ -1627,6 +1461,9 @@ string WrapDotInHTML(const string& dot) { var css_data = '' if (results !== null) { css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field. + // CSS inside DOT is URL-escaped, so we must unescape it + // before we can insert it into SVG. + css_data = unescape(css_data); dot_data = data.replace(cssregex, ''); // Remove the stylesheet } @@ -1694,37 +1531,117 @@ string WrapDotInHTML(const string& dot) { )html"; - return html_prefix + dot + html_suffix; + return absl::StrCat(html_prefix, dot, html_suffix); } -string RenderDotAsHTMLFile(const string& dot, - const DebugOptions& debug_options) { - string html = WrapDotInHTML(dot); +tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED); +std::function(absl::string_view)>* url_renderer + GUARDED_BY(url_renderer_mu) = nullptr; - auto env = tensorflow::Env::Default(); - std::vector dirs; - string output_dir = debug_options.xla_hlo_graph_path(); - if (output_dir.empty()) { - env->GetLocalTempDirectories(&dirs); - } else { - dirs.push_back(output_dir); +// Precondition: url_renderer != nullptr. +// +// (We specify this as a precondition rather than checking it in here and +// returning an error because we want to fail quickly when there's no URL +// renderer available, and this function runs only after we've done all the work +// of producing dot for the graph.) +StatusOr WrapDotInFormat(absl::string_view dot, + RenderedGraphFormat format) + EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { + switch (format) { + case RenderedGraphFormat::kUrl: + CHECK(url_renderer != nullptr) + << "Should have checked url_renderer != null before calling."; + return (*url_renderer)(dot); + case RenderedGraphFormat::kHtml: + return WrapDotInHtml(dot); + case RenderedGraphFormat::kDot: + return string(dot); } - // Try each directory, as they might be full, have inappropriate - // permissions or have different problems at times. - string output; - for (const string& dir : dirs) { - string filename = tensorflow::io::JoinPath(dir, "graph-"); - if (env->CreateUniqueFileName(&filename, ".html")) { - output = filename; - break; - } +} + +} // namespace + +void RegisterGraphToURLRenderer( + std::function(absl::string_view)> renderer) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (url_renderer != nullptr) { + LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call " + "wins, but because order of initialization in C++ is " + "nondeterministic, this may not be what you want."; + } + delete url_renderer; + url_renderer = new std::function(absl::string_view)>( + std::move(renderer)); +} + +StatusOr RenderGraph(const HloComputation& computation, + absl::string_view label, + const DebugOptions& debug_options, + RenderedGraphFormat format, + const HloExecutionProfile* hlo_execution_profile, + bool show_backend_config) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { + return Unavailable("Can't render as URL; no URL renderer was registered."); + } + + string rendered_dot = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); + return WrapDotInFormat(rendered_dot, format); +} + +StatusOr RenderNeighborhoodAround( + const HloInstruction& node, int radius, RenderedGraphFormat format, + bool show_backend_config, + const absl::flat_hash_set& boundary) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { + return FailedPrecondition( + "Can't render as URL; no URL renderer was registered."); } - if (output.empty()) { - LOG(FATAL) << "Failed to create unique output file name."; + + string label = + StrCat("Neighborhood of ", radius, " nodes around ", node.name()); + string rendered_dot = + HloDotDumper(node.parent(), label, + node.GetModule()->config().debug_options(), + show_backend_config, /*profile=*/nullptr, + MakeNodeRadiusAroundFilter(&node, radius, boundary)) + .Dump(); + return WrapDotInFormat(rendered_dot, format); +} + +StatusOr RenderAllPathsFromTo(const HloInstruction& from, + const HloInstruction& to, int64 max_nodes, + RenderedGraphFormat format, + bool show_backend_config) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { + return FailedPrecondition( + "Can't render as URL; no URL renderer was registered."); } - TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); - return "file://" + output; + + CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; + auto debug_options = from.GetModule()->config().debug_options(); + + bool hit_limit = false; + NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); + string label; + if (!hit_limit) { + label = StrCat("All paths from ", from.name(), " to ", to.name()); + } else { + label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), + " to ", to.name(), + "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " + "NODES***

"); + } + string rendered_dot = + HloDotDumper(from.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); + return WrapDotInFormat(rendered_dot, format); } -} // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 8e51454ef1cf992386cc7325e32705c08bf7712f..324ac67a6dd565f45bcb32455212ae08c925bc66 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -23,94 +23,76 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" -namespace xla { -namespace hlo_graph_dumper { - -// Abstract interface for classes that render HLO graphs (e.g. DOT graph, -// tensorflow GraphDef). -class GraphRendererInterface { - public: - enum GraphKind { - DOT_GRAPH, - TF_GRAPHDEF, - }; +// This file contains routines for rendering HLO computations into a +// human-readable graphical format. +// +// Fundamentally all graphs are rendered using the DOT language, but they can be +// packaged three different ways: +// +// - as a raw DOT file, which can be rendered using `graphviz`. +// +// - as an HTML file with an embedded DOT file, which can be viewed in a +// browser using a version of graphviz compiled to JavaScript +// +// - as a URL hosted somewhere which somehow embeds the DOT file. +// +// This last option is not implemented by default, but you can add a plugin to +// implement it via RegisterGraphToURLRenderer. +// +// TODO(jlebar): Rename this file to hlo_graph_renderer. - virtual ~GraphRendererInterface() = default; +namespace xla { - // Renders a DOT graph, returning a description of the rendered output - // (e.g., a URL) - virtual string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) = 0; +// Different formats that a graph can be packaged as. +enum class RenderedGraphFormat { + kDot, + kHtml, + kUrl, }; -// Dump the given HLO module if a dump is requested in its debug options. Based -// on the debug options, either a graph dump, a text dump or both may be -// generated. If a graph dump is generated, the description (e.g. an URL) is -// returned; otherwise an empty string is returned. -string MaybeDumpHloModule(const HloModule& module, const string& label, - const HloExecutionProfile* profile = nullptr); - -// Dumps a graph of the computation and returns a description of the rendered -// graph (e.g., a URL) based on the renderer. The "best" renderer in the -// registry is used. -string DumpGraph(const HloComputation& computation, const string& label, - const DebugOptions& debug_options, - const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_backend_config = false); - -// Like DumpGraph, but renders only nodes "near" the given node in the graph. +// Renders an HLO module as a human-readable visual graph. +// +// Note that this only works well for relatively small graphs (no more than a +// few hundred nodes). Beyond that, the dot is usually unrenderable, +// unreadable, or both. To view such graphs, use a tool such as +// interactive_graphviz, which calls RenderNeighborhoodAround to render subsets +// of a graph. +StatusOr RenderGraph( + const HloComputation& computation, absl::string_view label, + const DebugOptions& debug_options, RenderedGraphFormat format, + const HloExecutionProfile* hlo_execution_profile = nullptr, + bool show_backend_config = false); + +// Like RenderGraph, but renders only nodes "near" the given node in the graph. // // The number of nodes dumped is controlled by the radius parameter, which // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. -string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_backend_config = false); - -// Dumps nodes on any of the paths from `from` to `to`. If there are more than -// max_nodes on all paths, restricts to the max_nodes nodes on the shortest +// +// The optional boundary specifies a set of boundary nodes, beyond which nodes +// will be omitted even if they are within the radius. +StatusOr RenderNeighborhoodAround( + const HloInstruction& node, int radius, RenderedGraphFormat format, + bool show_backend_config = false, + const absl::flat_hash_set& boundary = {}); + +// Renders nodes on any of the paths from `from` to `to`. If there are more +// than max_nodes on all paths, restricts to the max_nodes nodes on the shortest // paths. -string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, - int64 max_nodes, bool show_backend_config = false); +StatusOr RenderAllPathsFromTo(const HloInstruction& from, + const HloInstruction& to, int64 max_nodes, + RenderedGraphFormat format, + bool show_backend_config = false); -// Dumps the HloModule::ToString() as a file into the provided directory path -// suffixed with the provided label. +// Registers a function which implements RenderedGraphFormat::kUrl. // -// If do_prefix is true, a timestamp will be prepended onto the label to -// construct a filename in the directory path; otherwise, the label is used -// as the filename directly. -void DumpText(const HloModule& module, const string& label, - const string& directory_path, bool do_prefix = true); - -// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary -// directory or directory specified via --xla_hlo_graph_path. Returns the file -// URI pointing to the file. -string RenderDotAsHTMLFile(const string& dot, - const DebugOptions& debug_options); - -// Graph renderers may be added using a registration mechanism, e.g.: -// XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) -// The renderer with the highest numeric priority value is used. - -#define XLA_REGISTER_GRAPH_RENDERER(factory, ...) \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, __COUNTER__, ##__VA_ARGS__) - -// Internal implementation details below this point. - -// Class that registers a graph renderer. -class Registrar { - public: - Registrar(std::shared_ptr dumper); -}; - -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ - static ::xla::hlo_graph_dumper::Registrar \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)( \ - std::make_shared(), ##__VA_ARGS__) - -// __COUNTER__ must go through another macro to be properly expanded -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_ +// The input to the function is dot, and the output should be a URL or an error. +// +// There can only be one active renderer, and the last call to this function +// wins. +void RegisterGraphToURLRenderer( + std::function(absl::string_view dot)> renderer); -} // namespace hlo_graph_dumper } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 064c53252c0ac4d4e7b93169ad7cbee4807cb963..fa1ff49de876ea21073c09616412e535438c8a02 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -31,24 +32,13 @@ namespace { using absl::StrCat; using ::testing::HasSubstr; +using HloGraphDumperTest = HloTestBase; + string TestName() { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } -class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - return graph; - } - - private: - string last_graph_; -}; - -XLA_REGISTER_GRAPH_RENDERER(DotRenderer); - -TEST(HloGraphDumperTest, NestedFusion) { +TEST_F(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); // Build param0 + param1 + param2 + param3 + param4. @@ -90,8 +80,9 @@ TEST(HloGraphDumperTest, NestedFusion) { {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop); // Generate the graph; all nodes should be present. - string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"", - DebugOptions()); + TF_ASSERT_OK_AND_ASSIGN( + string graph, RenderGraph(*root_computation, /*label=*/"", DebugOptions(), + RenderedGraphFormat::kDot)); for (const HloComputation* computation : {root_computation, // inner_fusion->fused_instructions_computation(), @@ -113,12 +104,13 @@ TEST(HloGraphDumperTest, NestedFusion) { } } ASSERT_NE(inner_sum, nullptr); - EXPECT_THAT( - hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1), - HasSubstr(inner_sum->name())); + TF_ASSERT_OK_AND_ASSIGN(string neighborhood_graph, + RenderNeighborhoodAround(*inner_sum, /*radius=*/1, + RenderedGraphFormat::kDot)); + EXPECT_THAT(neighborhood_graph, HasSubstr(inner_sum->name())); } -TEST(HloGraphDumperTest, Constant) { +TEST_F(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(-42))); @@ -126,13 +118,14 @@ TEST(HloGraphDumperTest, Constant) { HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); - string graph = hlo_graph_dumper::DumpGraph( - *root_computation, /*label=*/"an_empty_graph", DebugOptions()); + TF_ASSERT_OK_AND_ASSIGN( + string graph, RenderGraph(*root_computation, /*label=*/"an_empty_graph", + DebugOptions(), RenderedGraphFormat::kDot)); EXPECT_THAT(graph, HasSubstr("an_empty_graph")); EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } -TEST(HloGraphDumperTest, TupleConstant) { +TEST_F(HloGraphDumperTest, TupleConstant) { Shape tuple_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); HloComputation::Builder b("b"); @@ -144,11 +137,30 @@ TEST(HloGraphDumperTest, TupleConstant) { HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build(gte)); - string graph = hlo_graph_dumper::DumpGraph( - *root_computation, /*label=*/"tuple_constant", DebugOptions()); + TF_ASSERT_OK_AND_ASSIGN( + string graph, RenderGraph(*root_computation, /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot)); EXPECT_THAT(graph, HasSubstr("tuple_constant")); EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); } +TEST_F(HloGraphDumperTest, Compare) { + const char* hlo_string = R"( + HloModule comp + + ENTRY comp { + param.0 = f32[10] parameter(0) + param.1 = f32[10] parameter(1) + ROOT lt = pred[10] compare(param.0, param.1), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + string graph, + RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot)); + EXPECT_THAT(graph, HasSubstr("direction=LT")); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc deleted file mode 100644 index 84c4cf18df69816c611f4eb159ba247320ebc20e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Implementation of an DOT graph renderer that uses Javascript to render DOT to -// SVG in a browser. - -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -class GraphHtmlRenderer : public GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - switch (graph_kind) { - case DOT_GRAPH: - return RenderDotAsHTMLFile(graph, debug_options); - default: - LOG(FATAL) << "Only DOT graphs can be rendered"; - } - } -}; - -XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); - -} // namespace -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index b0b71dece81b561f492767db8c1ccbe3fde442d4..cd13c7a3ac7afe03fb99ed3114bdc6ac0f8ad6a7 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -56,7 +56,8 @@ class HloInputOutputAliasConfig { HloInputOutputAliasConfig() = default; - explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + explicit HloInputOutputAliasConfig(Shape output_shape) + : alias_(output_shape) {} virtual ~HloInputOutputAliasConfig() = default; diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index a46a107723de30176241aae01b268a8c10d991d3..265bfdf7f989b0821a98c1f774cb408b78f348fe 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3c92554ad4ec48686d64c74a00f732a3bfee87bc..fe8a178f80fa3469f193aca467fc1bd9a9c0c6bc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -64,7 +64,35 @@ StatusOr> HloInstruction::CreateFromProto( const absl::flat_hash_map& instruction_map, const absl::flat_hash_map& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); + HloOpcode opcode; + auto opcode_or = StringToHloOpcode(proto.opcode()); + absl::optional comparison_direction; + if (opcode_or.ok()) { + opcode = opcode_or.ConsumeValueOrDie(); + } else { + // Unknown opcode. Try auto-upgrading deprecated "less-than", + // "greater-than", etc opcodes, which are now rolled into the kCompare + // opcode. + if (proto.opcode() == "equal-to") { + comparison_direction = ComparisonDirection::kEq; + } else if (proto.opcode() == "not-equal-to") { + comparison_direction = ComparisonDirection::kNe; + } else if (proto.opcode() == "greater-than-or-equal-to") { + comparison_direction = ComparisonDirection::kGe; + } else if (proto.opcode() == "greater-than") { + comparison_direction = ComparisonDirection::kGt; + } else if (proto.opcode() == "less-than-or-equal-to") { + comparison_direction = ComparisonDirection::kLe; + } else if (proto.opcode() == "less-than") { + comparison_direction = ComparisonDirection::kLt; + } + if (comparison_direction) { + opcode = HloOpcode::kCompare; + } else { + return InvalidArgument("Unknown opcode: %s", proto.opcode()); + } + } + TF_RET_CHECK(proto.has_shape()); std::unique_ptr instruction; @@ -82,6 +110,15 @@ StatusOr> HloInstruction::CreateFromProto( const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; + const auto all_computations = [&computation_map, &proto]() { + std::vector result(proto.called_computation_ids_size()); + std::transform(proto.called_computation_ids().begin(), + proto.called_computation_ids().end(), result.begin(), + [&computation_map](int64 computation_id) { + return computation_map.at(computation_id); + }); + return result; + }; TF_RET_CHECK( absl::c_all_of(proto.operand_ids(), @@ -96,72 +133,73 @@ StatusOr> HloInstruction::CreateFromProto( Shape shape(proto.shape()); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + absl::optional arity = HloOpcodeArity(opcode); + if (arity) { + TF_RET_CHECK(proto.operand_ids_size() == *arity) + << proto.opcode() << " instruction should have " << *arity + << " operands but sees " << proto.operand_ids_size(); + } + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: - TF_RET_CHECK(proto.operand_ids_size() == 3) - << "BatchNormTraining instruction should have 3 operands but sees " - << proto.operand_ids_size(); instruction = CreateBatchNormTraining(shape, operands(0), operands(1), operands(2), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: - TF_RET_CHECK(proto.operand_ids_size() == 5) - << "BatchNormInference instruction should have 5 operands but sees " - << proto.operand_ids_size(); instruction = CreateBatchNormInference( shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: - TF_RET_CHECK(proto.operand_ids_size() == 5) - << "BatchNormGrad instruction should have 5 operands but sees " - << proto.operand_ids_size(); instruction = CreateBatchNormGrad(shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kFft: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Fft instruction should have 1 operand but sees " - << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(shape, operands(0), proto.fft_type(), absl::Span(fft_length)); break; } + case HloOpcode::kCompare: { + // Auto-upgraded from deprecated opcode skips the following. + if (!comparison_direction) { + TF_ASSIGN_OR_RETURN( + comparison_direction, + StringToComparisonDirection(proto.comparison_direction())); + } + instruction = + CreateCompare(shape, operands(0), operands(1), *comparison_direction); + break; + } + case HloOpcode::kTriangularSolve: { + instruction = CreateTriangularSolve(shape, operands(0), operands(1), + proto.triangular_solve_options()); + break; + } + case HloOpcode::kCholesky: { + instruction = + CreateCholesky(shape, operands(0), proto.cholesky_options()); + break; + } case HloOpcode::kSend: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Send instruction should have 2 operand but sees " - << proto.operand_ids_size(); instruction = CreateSend(operands(0), operands(1), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kSendDone: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "SendDone instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateSendDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kRecv: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Recv instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateRecv(shape.tuple_shapes(0), operands(0), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "RecvDone instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateRecvDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kReverse: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Reverse instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateReverse(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); @@ -173,6 +211,26 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateConcatenate(shape, all_operands(), proto.dimensions(0)); break; + case HloOpcode::kConditional: { + TF_RET_CHECK(proto.called_computation_ids_size() > 0) + << "conditional should have at least 1 called computation"; + if (operands(0)->shape().element_type() == PRED) { + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "conditional should have exactly 2 called computations but got " + << proto.called_computation_ids_size(); + } + TF_RET_CHECK(proto.operand_ids_size() == + proto.called_computation_ids_size() + 1) + << "conditional should have one branch_index operand plus one " + "operand per called computation but got " + << proto.operand_ids_size() << " operands for " + << proto.called_computation_ids_size() << " branch computations"; + auto cond_operands = all_operands(); + instruction = + CreateConditional(shape, cond_operands[0], all_computations(), + absl::MakeSpan(cond_operands).subspan(1)); + break; + } case HloOpcode::kReduce: TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) << "Reduce instruction should have an even number of operands but " @@ -201,26 +259,21 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Sort instruction should one called computation but sees " + << proto.called_computation_ids_size(); auto sort_operands = all_operands(); - HloInstruction* keys = sort_operands[0]; - instruction = CreateSort( - shape, proto.dimensions(0), keys, - absl::Span(sort_operands).subspan(1)); + instruction = CreateSort(shape, proto.dimensions(0), all_operands(), + computations(0), proto.is_stable()); break; } case HloOpcode::kTranspose: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Transpose instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateTranspose(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kBroadcast: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Broadcast instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateBroadcast(shape, operands(0), std::vector(proto.dimensions().begin(), @@ -233,9 +286,6 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateMap(shape, all_operands(), computations(0)); break; case HloOpcode::kSlice: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Slice instruction should have 1 operand but sees " - << proto.operand_ids_size(); std::vector slice_starts, slice_limits, slice_strides; for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { @@ -259,9 +309,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kTrace: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Trace instruction should have 1 operand but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); @@ -295,18 +342,16 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kParameter: instruction = CreateParameter(proto.parameter_number(), shape, proto.name()); + if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) { + instruction->set_parameter_replicated_at_leaf_buffers( + proto.parameter_replication().replicated_at_leaf_buffers()); + } break; case HloOpcode::kGetTupleElement: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "GetTupleElement instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateGetTupleElement(shape, operands(0), proto.tuple_index()); break; case HloOpcode::kReducePrecision: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "ReducePrecision instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateReducePrecision( shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; @@ -316,16 +361,10 @@ StatusOr> HloInstruction::CreateFromProto( << "Infeed should have a tuple shape with 2 operands, but has: " << shape; const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0); - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Infeed instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Outfeed instruction should have 2 operands but sees " - << proto.operand_ids_size(); Shape outfeed_shape(proto.outfeed_shape()); TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape)); @@ -359,9 +398,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kCollectivePermute: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "CollectivePermute instruction should have 1 operand but sees " - << proto.operand_ids_size(); std::vector> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { @@ -372,10 +408,11 @@ StatusOr> HloInstruction::CreateFromProto( CreateCollectivePermute(shape, operands(0), source_target_pairs); break; } + case HloOpcode::kReplicaId: { + instruction = CreateReplicaId(); + break; + } case HloOpcode::kConvolution: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Convolution instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); PrecisionConfig precision_config = proto.precision_config(); @@ -389,9 +426,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kReduceWindow: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "ReduceWindow instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); @@ -399,9 +433,6 @@ StatusOr> HloInstruction::CreateFromProto( proto.window(), computations(0)); break; case HloOpcode::kSelectAndScatter: - TF_RET_CHECK(proto.operand_ids_size() == 3) - << "SelectAndScatter instruction should have 3 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 2) << "SelectAndScatter should have 2 called computations but sees " << proto.called_computation_ids_size(); @@ -444,9 +475,6 @@ StatusOr> HloInstruction::CreateFromProto( std::max(static_cast(proto.batch_group_count()), 1LL)); break; case HloOpcode::kPad: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Pad instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_padding_config()); instruction = CreatePad(shape, operands(0), operands(1), proto.padding_config()); @@ -492,9 +520,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kGather: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Gather instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr gather_dimension_numbers = @@ -509,9 +534,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kScatter: { - TF_RET_CHECK(proto.operand_ids_size() == 3) - << "Scatter instruction should have 3 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_scatter_dimension_numbers()) << "Scatter instruction should have ScatterDimensionNumbers set."; TF_RET_CHECK(proto.called_computation_ids_size() == 1) @@ -533,9 +555,6 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kDot: { TF_RET_CHECK(proto.has_dot_dimension_numbers()) << "Dot instruction should have dot_dimension_numbers."; - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Dot instruction should have 2 operands but sees " - << proto.operand_ids_size(); PrecisionConfig precision_config = proto.precision_config(); precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); @@ -545,9 +564,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kDomain: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Domain instruction should have 1 operands but sees " - << proto.operand_ids_size(); std::shared_ptr entry_hlo_sharding; std::shared_ptr exit_hlo_sharding; if (proto.has_domain_entry_sharding()) { @@ -569,7 +585,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kGetDimensionSize: - TF_RET_CHECK(proto.operand_ids_size() == 1); TF_RET_CHECK(proto.dimensions_size() == 1); instruction = CreateGetDimensionSize(shape, operands(0), proto.dimensions(0)); @@ -689,8 +704,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: break; default: @@ -710,15 +727,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -783,6 +794,24 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, fft_length); } +/* static */ std::unique_ptr HloInstruction::CreateCompare( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction) { + return absl::make_unique(shape, lhs, rhs, direction); +} + +/* static */ std::unique_ptr +HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, + HloInstruction* b, + const TriangularSolveOptions& options) { + return absl::make_unique(shape, a, b, options); +} + +/* static */ std::unique_ptr HloInstruction::CreateCholesky( + const Shape& shape, HloInstruction* a, const CholeskyOptions& options) { + return absl::make_unique(shape, a, options); +} + /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, @@ -825,6 +854,11 @@ HloInstruction::CreateCollectivePermute( shape, operand, source_target_pairs); } +/* static */ std::unique_ptr HloInstruction::CreateReplicaId() { + return absl::WrapUnique( + new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {}))); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -932,6 +966,21 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateConditional( + const Shape& shape, HloInstruction* branch_index, + absl::Span branch_computations, + absl::Span branch_computation_args) { + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + instruction->AppendOperand(branch_index); + CHECK_EQ(branch_computations.size(), branch_computation_args.size()); + for (int i = 0; i < branch_computations.size(); ++i) { + instruction->called_computations_.push_back(branch_computations[i]); + instruction->AppendOperand(branch_computation_args[i]); + } + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, absl::Span start_indices, @@ -1141,9 +1190,11 @@ HloInstruction::CreateBroadcastSequence( } /* static */ std::unique_ptr HloInstruction::CreateSort( - const Shape& shape, int64 dimension, HloInstruction* keys, - absl::Span values) { - return absl::make_unique(shape, dimension, keys, values); + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable) { + return absl::make_unique(shape, dimension, operands, + compare, is_stable); } /* static */ std::unique_ptr HloInstruction::CreateFusion( @@ -1299,6 +1350,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kFft: + case HloOpcode::kCompare: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kRecv: @@ -1335,6 +1387,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1355,8 +1409,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1368,12 +1424,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDivide: case HloOpcode::kMultiply: case HloOpcode::kSubtract: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1425,10 +1475,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; case HloOpcode::kConditional: - CHECK_EQ(new_operands.size(), 3); - clone = CreateConditional(shape, new_operands[0], new_operands[1], - true_computation(), new_operands[2], - false_computation()); + CHECK_EQ(new_operands.size(), branch_count() + 1); + clone = CreateConditional(shape, new_operands[0], + absl::MakeSpan(branch_computations()), + new_operands.subspan(1)); break; case HloOpcode::kAfterAll: if (new_operands.empty()) { @@ -1441,6 +1491,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 2); clone = CreateAddDependency(new_operands[0], new_operands[1]); break; + case HloOpcode::kReplicaId: + CHECK_EQ(new_operands.size(), 0); + clone = CreateReplicaId(); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -1685,38 +1739,35 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: case HloOpcode::kXor: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kReshape: + case HloOpcode::kReplicaId: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: @@ -1732,16 +1783,16 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kConditional: - return eq_computations(true_computation(), other.true_computation()) && - eq_computations(false_computation(), other.false_computation()); - - case HloOpcode::kWhile: { - if (eq_computations(while_body(), other.while_body()) && - eq_computations(while_condition(), other.while_condition())) { - return true; + for (int j = 0; j < branch_count(); ++j) { + if (!eq_computations(branch_computation(j), + other.branch_computation(j))) { + return false; + } } - return false; - } + return true; + case HloOpcode::kWhile: + return (eq_computations(while_body(), other.while_body()) && + eq_computations(while_condition(), other.while_condition())); // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. @@ -1749,6 +1800,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kFft: + case HloOpcode::kCompare: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kRecv: @@ -1785,6 +1837,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1839,7 +1893,11 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); + return ReplaceUseWithDifferentShape(user, new_producer); +} +Status HloInstruction::ReplaceUseWithDifferentShape( + HloInstruction* user, HloInstruction* new_producer) { VLOG(3) << "Replacing uses of " << name() << " in " << user->name() << " with " << new_producer->name(); @@ -1935,6 +1993,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduce: case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1954,6 +2013,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kReduce: case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -1996,28 +2056,41 @@ HloInstruction* HloInstruction::while_init() const { HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); + CHECK_EQ(PRED, operand(0)->shape().element_type()); return called_computations_[kTrueComputationIndex]; } HloComputation* HloInstruction::false_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); + CHECK_EQ(PRED, operand(0)->shape().element_type()); return called_computations_[kFalseComputationIndex]; } -void HloInstruction::set_true_computation(HloComputation* true_computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kConditional, opcode_); - called_computations_[kTrueComputationIndex] = true_computation; +const std::vector& HloInstruction::branch_computations() + const { + CHECK(HloOpcode::kConditional == opcode_); + return called_computations_; +} + +int HloInstruction::branch_count() const { + CHECK(HloOpcode::kConditional == opcode_); + return called_computations_.size(); +} + +HloComputation* HloInstruction::branch_computation(int b) const { + CHECK(HloOpcode::kConditional == opcode_); + CHECK_GE(b, 0); + CHECK_LT(b, called_computations_.size()); + return called_computations_[b]; } -void HloInstruction::set_false_computation(HloComputation* false_computation) { +void HloInstruction::set_branch_computation(int b, + HloComputation* computation) { // Don't allow changing the computation for fused instructions so we don't // have to recompute called_instructions for the entire fusion instruction. CHECK(!IsFused()); CHECK_EQ(HloOpcode::kConditional, opcode_); - called_computations_[kFalseComputationIndex] = false_computation; + called_computations_[b] = computation; } string HloInstruction::SignatureString() const { @@ -2064,8 +2137,10 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kNegate: case HloOpcode::kReal: case HloOpcode::kReducePrecision: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: CHECK_EQ(1, operand_count()); return true; @@ -2073,17 +2148,12 @@ bool HloInstruction::IsElementwiseImpl( // Binary elementwise operations, the same as in IsElementwiseBinary(). case HloOpcode::kAdd: case HloOpcode::kAtan2: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -2218,15 +2288,27 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("scatter=", PrintName(scatter()->name(), options))); } else if (opcode() == HloOpcode::kConditional) { - extra.push_back(StrCat("true_computation=", - PrintName(true_computation()->name(), options))); - extra.push_back(StrCat("false_computation=", - PrintName(false_computation()->name(), options))); + if (operand(0)->shape().element_type() == PRED) { + extra.push_back(StrCat("true_computation=", + PrintName(true_computation()->name(), options))); + extra.push_back( + StrCat("false_computation=", + PrintName(false_computation()->name(), options))); + } else { + extra.push_back(StrCat( + "branch_computations={", + StrJoin(branch_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }), + "}")); + } } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kAllReduce || - opcode() == HloOpcode::kScatter) { + opcode() == HloOpcode::kScatter || + opcode() == HloOpcode::kSort) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2252,10 +2334,20 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); break; case HloOpcode::kConditional: - extra.push_back(StrCat("true_computation=\n", - true_computation()->ToString(new_options))); - extra.push_back(StrCat("false_computation=\n", - false_computation()->ToString(new_options))); + if (operand(0)->shape().element_type() == PRED) { + extra.push_back(StrCat("true_computation=\n", + true_computation()->ToString(new_options))); + extra.push_back(StrCat("false_computation=\n", + false_computation()->ToString(new_options))); + } else { + extra.push_back(StrCat( + "branch_computations={\n", + StrJoin(branch_computations(), ",\n", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }), + "\n}")); + } break; case HloOpcode::kCall: case HloOpcode::kMap: @@ -2263,6 +2355,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kReduce: case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2403,12 +2496,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGetTupleElement(this); case HloOpcode::kParameter: return visitor->HandleParameter(this); - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: + case HloOpcode::kCompare: return visitor->HandleCompare(this); case HloOpcode::kComplex: return visitor->HandleComplex(this); @@ -2464,6 +2552,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); + case HloOpcode::kReplicaId: + return visitor->HandleReplicaId(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2498,6 +2588,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCos(this); case HloOpcode::kSin: return visitor->HandleSin(this); + case HloOpcode::kSqrt: + return visitor->HandleSqrt(this); + case HloOpcode::kRsqrt: + return visitor->HandleRsqrt(this); case HloOpcode::kReal: return visitor->HandleReal(this); case HloOpcode::kImag: @@ -2566,6 +2660,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleIota(this); case HloOpcode::kGetDimensionSize: return visitor->HandleGetDimensionSize(this); + case HloOpcode::kTriangularSolve: + return visitor->HandleTriangularSolve(this); + case HloOpcode::kCholesky: + return visitor->HandleCholesky(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2756,11 +2854,6 @@ bool HloInstruction::IsElementwise() const { return IsElementwiseImpl(absl::nullopt); } -bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { - CHECK(IsElementwise()); - return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); -} - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { return IsElementwiseImpl(operand_idx); } @@ -2870,9 +2963,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; default: - return IsElementwise() && !ImplicitlyBroadcastsOperand(i) - ? UseKind::kUse - : UseKind::kReuse; + return IsElementwise() ? UseKind::kUse : UseKind::kReuse; } } @@ -3277,6 +3368,19 @@ int64 HloInstruction::parameter_number() const { return Cast(this)->parameter_number(); } +void HloInstruction::set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + return Cast(this) + ->set_parameter_replicated_at_leaf_buffers( + parameter_replicated_at_leaf_buffers); +} + +const absl::optional>& +HloInstruction::parameter_replicated_at_leaf_buffers() const { + return Cast(this) + ->parameter_replicated_at_leaf_buffers(); +} + int64 HloInstruction::tuple_index() const { return Cast(this)->tuple_index(); } @@ -3433,4 +3537,17 @@ const DomainMetadata& HloInstruction::operand_side_metadata() const { const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } + +ComparisonDirection HloInstruction::comparison_direction() const { + return Cast(this)->direction(); +} + +const TriangularSolveOptions& HloInstruction::triangular_solve_options() const { + return Cast(this)->triangular_solve_options(); +} + +const CholeskyOptions& HloInstruction::cholesky_options() const { + return Cast(this)->cholesky_options(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 2c29b6c243bffccc346af12277dd4fc061250cbe..6f6a1b8505ede690ca3926db79be7bdd1f34c51a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -47,6 +48,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -384,6 +386,14 @@ class HloInstruction { // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. + // + // The parameters to the instruction are interpreted as follows: + // + // - If `distribution` is RNG_UNIFORM, generates a number in range + // [param0, param1). + // + // - If `distribution` is RNG_NORMAL, generates a normally-distributed value + // with mean `param0` and standard deviation `param1`. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, absl::Span parameters); @@ -435,6 +445,18 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + // Creates a compare op, performing the comparison specified in direction. + static std::unique_ptr CreateCompare( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction); + + static std::unique_ptr CreateTriangularSolve( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options); + + static std::unique_ptr CreateCholesky( + const Shape& shape, HloInstruction* a, const CholeskyOptions& options); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( @@ -489,11 +511,14 @@ class HloInstruction { // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor - // conssits of 0(s) in `shape`. + // consists of 0(s) in `shape`. static std::unique_ptr CreateCollectivePermute( const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs); + // Creates an instruction that returns a U32 replica ID. + static std::unique_ptr CreateReplicaId(); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, @@ -595,7 +620,6 @@ class HloInstruction { // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, // ..., inputN.value1) // ... - // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( const Shape& shape, absl::Span operands, absl::Span init_values, @@ -668,10 +692,15 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span dimensions); - // Creates a sort op, with a keys operand, and optional values operands. + // Creates a n-ary sort op with a 'compare' computation which is used for + // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, + // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at + // specific index positions which should be compared, and should return a + // PRED. 'is_stable' specifies whether stable sorting is required. static std::unique_ptr CreateSort( - const Shape& shape, int64 dimension, HloInstruction* keys, - absl::Span values = {}); + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -688,6 +717,11 @@ class HloInstruction { HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation); + static std::unique_ptr CreateConditional( + const Shape& shape, HloInstruction* branch_index, + absl::Span branch_computations, + absl::Span branch_computation_args); + static std::unique_ptr CreateGather( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, @@ -929,6 +963,10 @@ class HloInstruction { // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); + // Same as ReplaceUseWith(), but new_producer can have a different shape. + Status ReplaceUseWithDifferentShape(HloInstruction* user, + HloInstruction* new_producer); + // Replaces the specified operand with new_operand. The old and new operands // must have compatible shapes ignoring floating-point precision. // @@ -1030,14 +1068,23 @@ class HloInstruction { HloInstruction* while_init() const; - // Gets/sets the true and false HloComputation for Conditional. The setters - // should only be called by HloModule or HloComputation methods. + // Gets/sets the true and false HloComputation for Conditional. // - // Precondition: The instruction is a Conditional instruction. + // Precondition: The instruction is a predicated Conditional instruction. HloComputation* true_computation() const; HloComputation* false_computation() const; - void set_true_computation(HloComputation* true_computation); - void set_false_computation(HloComputation* false_computation); + + // Gets the branch HloComputations for Conditional. + // + // Precondition: The instruction is a Conditional instruction. + const std::vector& branch_computations() const; + int branch_count() const; + HloComputation* branch_computation(int b) const; + // Sets a branch HloComputation for Conditional. + // The setter should only be called by HloModule or HloComputation methods. + // + // Precondition: The instruction is a Conditional instruction. + void set_branch_computation(int b, HloComputation* computation); // Returns a string for the signature of this instruction if considered as a // function, e.g. the signature of an F32 add is (F32, F32) -> F32. @@ -1179,10 +1226,8 @@ class HloInstruction { // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, - // after performing necessary implicit broadcast - // (cs/IrArray::EmitArrayElementAddress), to compute the output at index - // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is - // the element at {i_0,i_1,...,i_n}. + // to compute the output at index {i_0,i_1,...,i_n}, the only element required + // from the operand (if any) is the element at {i_0,i_1,...,i_n}. // // Note on performance: when this instruction is kFusion, this method, in the // worst case, scans all fused instructions. We could speed this up by @@ -1198,12 +1243,6 @@ class HloInstruction { // Returns true if this is a cross-replica all-reduce instruction. bool IsCrossReplicaAllReduce() const; - // Returns true if this elementwise instruction implicitly broadcasts operand - // `operand_idx`. - // - // Precondition: this instruction should be an elementwise operation. - bool ImplicitlyBroadcastsOperand(int64 operand_idx) const; - // Returns true if this instruction is binary and elementwise. bool IsElementwiseBinary() const; @@ -1239,6 +1278,10 @@ class HloInstruction { // on the instruction's existing name. void UniquifyName(NameUniquer* name_uniquer); + // Clear the unique ID of the instruction so that it can be re-assigned, such + // as for the purpose of compacting the instruction unique IDs. + void ClearUniqueIdInternal() { unique_id_ = -1; } + // Set the unique id for this instruction to "id" void SetUniqueId(int id) { CHECK_EQ(unique_id_, -1); // Should not be assigned already @@ -1272,6 +1315,9 @@ class HloInstruction { backend_config_ = std::move(config_str); } + bool is_default_config() const { return is_default_config_; } + void set_default_config() { is_default_config_ = true; } + // Returns a string representation of a proto in the format used by // raw_backend_config_string. // @@ -1442,6 +1488,15 @@ class HloInstruction { // Delegates to HloParameterInstruction::parameter_number. int64 parameter_number() const; + // Delegates to + // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers); + + // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. + const absl::optional>& + parameter_replicated_at_leaf_buffers() const; + // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; @@ -1551,6 +1606,15 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloCompareInstruction::direction(). + ComparisonDirection comparison_direction() const; + + // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). + const TriangularSolveOptions& triangular_solve_options() const; + + // Delegates to HloCholeskyInstruction::cholesky_options(). + const CholeskyOptions& cholesky_options() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1717,6 +1781,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // This field is assigned to true when backend_config_ is assigned to + // a default configuration. + bool is_default_config_ = false; + // String identifier for instruction. string name_; @@ -1730,6 +1798,10 @@ class HloInstruction { TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; +// Explicit instantiations in hlo_instruction.cc. +extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); +extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); + string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 35f031f29a7aca8db7ebe2fbcfdcebb7a778d703..85f2ddba8d303a74b8b72c97dd99952a5a57bcb0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1655,7 +1655,7 @@ body (bparam: s32[]) -> s32[] { condition (cparam: s32[]) -> pred[] { xconstant = s32[] constant(5) cparam = s32[] parameter(0) - ROOT greater-than = pred[] greater-than(xconstant, cparam) + ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT } ENTRY entry (param: s32[]) -> s32[] { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 3a0d71dd88b6f16eb5b8492f87ac6a5136584424..41b4ba2138061de7e31037d1f9ae49ec9137fff4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/platform/protobuf.h" namespace xla { namespace { @@ -201,6 +202,164 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } +HloCompareInstruction::HloCompareInstruction(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs, + ComparisonDirection direction) + : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloCompareInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_comparison_direction(ComparisonDirectionToString(direction_)); + return proto; +} + +std::vector HloCompareInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("direction=", ComparisonDirectionToString(direction()))}; +} + +bool HloCompareInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return direction() == casted_other.direction(); +} + +std::unique_ptr HloCompareInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique(shape, new_operands[0], + new_operands[1], direction()); +} + +namespace { + +// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector +// of "key=value" attribute strings generically, using protocol buffer +// reflection. +// +// Currently implements a small subset of cases; feel free to add more as +// needed. +std::vector AttributeProtoToStringVector( + const tensorflow::protobuf::Message& message) { + const tensorflow::protobuf::Reflection* reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + std::vector output; + for (const tensorflow::protobuf::FieldDescriptor* field : fields) { + string s = absl::StrCat(field->name(), "="); + CHECK(!field->is_repeated()) << "Repeated fields aren't implemented"; + switch (field->type()) { + case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: { + bool val = reflection->GetBool(message, field); + absl::StrAppend(&s, val ? "true" : "false"); + break; + } + case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: { + const tensorflow::protobuf::EnumValueDescriptor* evd = + reflection->GetEnum(message, field); + absl::StrAppend(&s, evd->name()); + break; + } + default: + LOG(FATAL) << "Unimplemented field type: " << field->DebugString(); + } + output.push_back(std::move(s)); + } + return output; +} + +} // namespace + +HloTriangularSolveInstruction::HloTriangularSolveInstruction( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options) + : HloInstruction(HloOpcode::kTriangularSolve, shape), + triangular_solve_options_(options) { + AppendOperand(a); + AppendOperand(b); +} + +HloInstructionProto HloTriangularSolveInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_triangular_solve_options() = triangular_solve_options_; + return proto; +} + +std::vector HloTriangularSolveInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return AttributeProtoToStringVector(triangular_solve_options_); +} + +bool HloTriangularSolveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + const auto& options = triangular_solve_options(); + const auto& other_options = casted_other.triangular_solve_options(); + + return options.left_side() == other_options.left_side() && + options.lower() == other_options.lower() && + options.unit_diagonal() == other_options.unit_diagonal() && + options.transpose_a() == other_options.transpose_a(); +} + +std::unique_ptr +HloTriangularSolveInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + shape, new_operands[0], new_operands[1], triangular_solve_options()); +} + +HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape, + HloInstruction* a, + const CholeskyOptions& options) + : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) { + AppendOperand(a); +} + +HloInstructionProto HloCholeskyInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_cholesky_options() = cholesky_options_; + return proto; +} + +std::vector HloCholeskyInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return AttributeProtoToStringVector(cholesky_options_); +} + +bool HloCholeskyInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + const auto& options = cholesky_options(); + const auto& other_options = casted_other.cholesky_options(); + + return options.lower() == other_options.lower(); +} + +std::unique_ptr +HloCholeskyInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique(shape, new_operands[0], + cholesky_options()); +} + HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, int64 channel_id, @@ -383,6 +542,15 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { return proto; } +bool HloAllReduceInstruction::IsNoop() const { + for (auto replica_group : replica_groups()) { + if (replica_group.replica_ids().size() != 1) { + return false; + } + } + return !all_reduce_id(); +} + std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = @@ -600,14 +768,17 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( dimensions(), to_apply()); } -HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, - HloInstruction* keys, - absl::Span values) - : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { - AppendOperand(keys); - for (auto* value : values) { +HloSortInstruction::HloSortInstruction( + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable) + : HloInstruction(HloOpcode::kSort, shape), + dimensions_({dimension}), + is_stable_(is_stable) { + for (auto* value : operands) { AppendOperand(value); } + AppendComputation(compare); } HloInstructionProto HloSortInstruction::ToProto() const { @@ -615,12 +786,18 @@ HloInstructionProto HloSortInstruction::ToProto() const { for (int64 dimension : dimensions_) { proto.add_dimensions(dimension); } + proto.set_is_stable(is_stable()); return proto; } std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; + std::vector attrs; + attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}")); + if (is_stable()) { + attrs.push_back("is_stable=true"); + } + return attrs; } bool HloSortInstruction::IdenticalSlowPath( @@ -628,15 +805,20 @@ bool HloSortInstruction::IdenticalSlowPath( const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return dimensions() == casted_other.dimensions(); + if (dimensions() != casted_other.dimensions()) { + return false; + } + if (is_stable() != casted_other.is_stable()) { + return false; + } + return eq_computations(to_apply(), other.to_apply()); } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - HloInstruction* keys = new_operands[0]; - return absl::make_unique(shape, dimensions(0), keys, - new_operands.subspan(1)); + return absl::make_unique( + shape, dimensions(0), new_operands, to_apply(), is_stable()); } HloTransposeInstruction::HloTransposeInstruction( @@ -1464,9 +1646,30 @@ HloParameterInstruction::HloParameterInstruction(int64 parameter_number, HloInstructionProto HloParameterInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_parameter_number(parameter_number_); + if (parameter_replicated_at_leaf_buffers_) { + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers( + replicated); + } + } return proto; } +std::vector HloParameterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + if (!parameter_replicated_at_leaf_buffers_) { + return result; + } + std::vector buffers_replicated_strs; + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + buffers_replicated_strs.push_back(replicated ? "true" : "false"); + } + result.push_back(StrCat("parameter_replication={", + StrJoin(buffers_replicated_strs, ","), "}")); + return result; +} + string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -1686,6 +1889,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + proto.set_batch_group_count(batch_group_count_); *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1723,6 +1927,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( if (feature_group_count_ != other.feature_group_count()) { return false; } + if (batch_group_count_ != other.batch_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), @@ -1841,6 +2048,7 @@ HloCustomCallInstruction::HloCustomCallInstruction( custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), + batch_group_count_(1), layout_constrained_(false) { for (auto operand : operands) { AppendOperand(operand); @@ -1855,6 +2063,7 @@ HloCustomCallInstruction::HloCustomCallInstruction( custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), + batch_group_count_(1), layout_constrained_(true), operand_shapes_with_layout_(operand_shapes_with_layout.begin(), operand_shapes_with_layout.end()) { @@ -1875,6 +2084,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { proto.set_custom_call_target(custom_call_target_); proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + proto.set_batch_group_count(batch_group_count_); if (layout_constrained()) { proto.set_constrain_layout(true); for (const Shape& shape : operand_shapes_with_layout_) { @@ -1898,6 +2108,9 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + if (batch_group_count_ != 1) { + extra.push_back(StrCat("batch_group_count=", batch_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1941,6 +2154,20 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (feature_group_count_ != casted_other.feature_group_count_) { return false; } + if (batch_group_count_ != casted_other.batch_group_count_) { + return false; + } + if (layout_constrained() != casted_other.layout_constrained()) { + return false; + } + if (layout_constrained()) { + for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) { + if (!ShapeUtil::Equal(operand_shapes_with_layout_[i], + casted_other.operand_shapes_with_layout_[i])) { + return false; + } + } + } return custom_call_target_ == casted_other.custom_call_target_ && opaque_ == casted_other.opaque_; } @@ -1951,6 +2178,10 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { auto cloned = absl::make_unique( shape, new_operands, custom_call_target(), opaque()); + if (layout_constrained()) { + cloned->layout_constrained_ = true; + cloned->operand_shapes_with_layout_ = operand_shapes_with_layout(); + } if (window_ != nullptr) { cloned->set_window(*window_); } @@ -1958,6 +2189,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } cloned->set_feature_group_count(feature_group_count_); + cloned->set_batch_group_count(batch_group_count_); return std::move(cloned); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e6111cfb57581589070b8e34556bdfe8239b4fd3..0bc0db41c0a1751589415cb03289794ba111c5e8 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -131,6 +131,81 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloCompareInstruction : public HloInstruction { + public: + explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + ComparisonDirection direction); + ComparisonDirection direction() const { return direction_; } + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + ComparisonDirection direction_; +}; + +class HloTriangularSolveInstruction : public HloInstruction { + public: + explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, + HloInstruction* b, + const TriangularSolveOptions& options); + const TriangularSolveOptions& triangular_solve_options() const { + return triangular_solve_options_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + TriangularSolveOptions triangular_solve_options_; +}; + +class HloCholeskyInstruction : public HloInstruction { + public: + explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, + const CholeskyOptions& options); + const CholeskyOptions& cholesky_options() const { return cholesky_options_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + CholeskyOptions cholesky_options_; +}; + class HloSendRecvInstruction : public HloInstruction { public: // Returns the channel id associated with the instruction. The id is @@ -253,6 +328,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns true if the AllReduce does no communication, so it's equivalent + // to a mem copy. + bool IsNoop() const; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -414,8 +493,8 @@ class HloReduceInstruction : public HloInstruction { class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, - HloInstruction* keys, - absl::Span values = {}); + absl::Span operands, + HloComputation* compare, bool is_stable); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -428,6 +507,7 @@ class HloSortInstruction : public HloInstruction { HloInstruction* mutable_keys() { return mutable_operand(0); } // Returns the number of value operands. int64 values_count() const { return operand_count() - 1; } + bool is_stable() const { return is_stable_; } private: std::vector ExtraAttributesToStringImpl( @@ -442,6 +522,7 @@ class HloSortInstruction : public HloInstruction { HloCloneContext* context) const override; std::vector dimensions_; + bool is_stable_; }; class HloTransposeInstruction : public HloInstruction { @@ -783,10 +864,28 @@ class HloParameterInstruction : public HloInstruction { explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, const string& name); int64 parameter_number() const { return parameter_number_; } + + // Sets and gets the whether all replicas will receive the same parameter data + // for each leaf buffer in data parallelism. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + CHECK_EQ(ShapeUtil::GetLeafCount(shape()), + parameter_replicated_at_leaf_buffers.size()); + parameter_replicated_at_leaf_buffers_.emplace( + parameter_replicated_at_leaf_buffers.begin(), + parameter_replicated_at_leaf_buffers.end()); + } + const absl::optional>& + parameter_replicated_at_leaf_buffers() const { + return parameter_replicated_at_leaf_buffers_; + } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -800,6 +899,10 @@ class HloParameterInstruction : public HloInstruction { HloCloneContext* context) const override; int64 parameter_number_ = 0; + + // Specifies whether each buffer has the same parameter value on all replicas + // in data parallelism. + absl::optional> parameter_replicated_at_leaf_buffers_; }; class HloGetTupleElementInstruction : public HloInstruction { @@ -899,9 +1002,7 @@ class HloOutfeedInstruction : public HloInstruction { HloInstruction* token_operand, absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. - const Shape& outfeed_shape() const { - return outfeed_shape_; - } + const Shape& outfeed_shape() const { return outfeed_shape_; } // Returns the config for the Outfeed instruction. const string& outfeed_config() const { return outfeed_config_; } // Returns a serialized representation of this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 5e81515134256a3ec4b790b38af3f42f68a79b56..2255383322873a39c7076e0f4f0dd541bc79014d 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" @@ -37,8 +38,8 @@ constexpr int kError = -2; // [a-zA-Z0-9_.-] bool IsIdentifierChar(char c) { - return isalnum(static_cast(c)) || c == '-' || c == '.' || - c == '_'; + return absl::ascii_isalnum(static_cast(c)) || c == '-' || + c == '.' || c == '_'; } } // namespace @@ -105,7 +106,7 @@ TokKind HloLexer::LexToken() { switch (current_char) { default: // [a-zA-Z_] - if (isalpha(static_cast(current_char)) || + if (absl::ascii_isalpha(static_cast(current_char)) || current_char == '_') { return LexIdentifier(); } @@ -152,6 +153,8 @@ TokKind HloLexer::LexToken() { return LexPercent(); case ':': return TokKind::kColon; + case '*': + return TokKind::kAsterisk; case '[': return TokKind::kLsquare; case ']': @@ -211,6 +214,15 @@ TokKind HloLexer::LexToken() { // A lone '/' is an error. return TokKind::kError; } + case '.': + if (PeekCurrentChar() == '.') { + current_ptr_++; + if (PeekCurrentChar() == '.') { + current_ptr_++; + return TokKind::kDots; + } + } + return TokKind::kError; case '"': return LexString(); } @@ -300,7 +312,7 @@ TokKind HloLexer::LexIdentifier() { // name ::= [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexPercent() { const char* name_start = current_ptr_; - if (isalpha(static_cast(PeekCurrentChar())) || + if (absl::ascii_isalpha(static_cast(PeekCurrentChar())) || PeekCurrentChar() == '_') { current_ptr_++; while (IsIdentifierChar(PeekCurrentChar())) { @@ -454,6 +466,8 @@ string TokKindToString(TokKind kind) { return "kComma"; case TokKind::kColon: return "kColon"; + case TokKind::kAsterisk: + return "kAsterisk"; case TokKind::kLsquare: return "kLsquare"; case TokKind::kRsquare: @@ -512,6 +526,8 @@ string TokKindToString(TokKind kind) { return "kInt"; case TokKind::kDecimal: return "kDecimal"; + case TokKind::kDots: + return "kDots"; } } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 94fac3cd8e9da7f273e7e521e21510f5188702e6..383fb4e862b8e32771879d055e663dc821a5c839 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -38,15 +38,17 @@ enum class TokKind { kError, // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * kLsquare, kRsquare, // [ ] kLbrace, kRbrace, // { } kLparen, kRparen, // ( ) + kDots, // ... kArrow, // -> kLeq, // <= @@ -107,7 +109,7 @@ class HloLexer { LOG(FATAL) << "This token does not have string value"; } } - tensorflow::int64 GetInt64Val() const { + int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return token_state_.int64_val; } @@ -170,7 +172,7 @@ class HloLexer { const char* token_start = nullptr; TokKind current_kind; string str_val; - tensorflow::int64 int64_val; + int64 int64_val; double decimal_val; PrimitiveType primitive_type_val; }; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 436cccb1fb9ecf6f4efad772c700c611b28ce628..45d3e9c460e51ac0c1fe613b3765583113d6acc1 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -255,7 +255,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -308,7 +308,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(add.1, constant.2) + ROOT less-than = pred[] compare(add.1, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -360,7 +360,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { loop_var.2 = (s32[], s32[], s32[]) parameter(0) get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 constant.1 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT } ENTRY SimpleLoop { constant.2 = s32[] constant(0) @@ -415,7 +415,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -448,13 +448,13 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } OuterWhileCondition { cond_param.2 = (s32[]) parameter(0) get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 constant.5 = s32[] constant(5) - ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT } OuterWhileBody { body_param.2 = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index d28e79d41ad5d58a8881cfb80d488684af26564f..47ed85be1967f2a8f8f397021a85b806449ca98a 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -89,6 +89,22 @@ bool HloParameterMatcher::MatchAndExplain( return true; } +bool HloComparisonMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->comparison_direction() != direction_) { + *listener << "has wrong comparison direction (got " + << ComparisonDirectionToString( + instruction->comparison_direction()) + << ", want " << ComparisonDirectionToString(direction_) << ")"; + return false; + } + return true; +} + bool HloGetTupleElementMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 67488a6a9a0c9cba7f576f9036c3a0cbe1900fff..756f4d2c6bc65b04edea242ba15e3ee492d8cdb7 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -54,6 +54,21 @@ class HloParameterMatcher : public HloMatcher { int64 parameter_number_; }; +// Custom matcher for comparisons, which accepts a comparison direction. +class HloComparisonMatcher : public HloMatcher { + public: + explicit HloComparisonMatcher( + ComparisonDirection direction, + std::vector<::testing::Matcher> operands) + : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + ComparisonDirection direction_; +}; + // Custom matcher for get-tuple-element instructions, which accepts a tuple // index to match. class HloGetTupleElementMatcher : public HloMatcher { @@ -172,6 +187,7 @@ HLO_MATCHER(BatchNormGrad); HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); +HLO_MATCHER(Compare); HLO_MATCHER(Concatenate); HLO_MATCHER(Conditional); HLO_MATCHER(Constant); @@ -184,28 +200,22 @@ HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); -HLO_MATCHER(Eq); HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); -HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); -HLO_MATCHER(Gt); HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); -HLO_MATCHER(Le); HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); HLO_MATCHER(Xor); -HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); HLO_MATCHER(Minimum); HLO_MATCHER(Multiply); -HLO_MATCHER(Ne); HLO_MATCHER(Negate); HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); @@ -256,6 +266,38 @@ inline ::testing::Matcher Parameter() { new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); } +// Comparison matchers below do not require any additional arguments. +template +inline ::testing::Matcher Eq(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kEq, {operands...})); +} +template +inline ::testing::Matcher Ne(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kNe, {operands...})); +} +template +inline ::testing::Matcher Ge(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kGe, {operands...})); +} +template +inline ::testing::Matcher Gt(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kGt, {operands...})); +} +template +inline ::testing::Matcher Le(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kLe, {operands...})); +} +template +inline ::testing::Matcher Lt(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kLt, {operands...})); +} + // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th // tuple element of operand, while GetTupleElement(operand) matches any GTE // operation on operand, and GetTupleElement() matches any GTE operation at all. diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7961aece541faeb66875885b380158756c503250..549fc603c705d1403da6e96f2ac6703079810623 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -220,5 +220,33 @@ ENTRY DotOperationFusion_TransposeFusion { "rhs_contracting_dimensions (got {0} want {1})"); } +TEST(HloMatchersTest, ComparisonMatcher) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); + auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); + auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kEq); + auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kNe); + auto add = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get()); + auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(), + ComparisonDirection::kLe); + + EXPECT_THAT(eq.get(), op::Compare()); + EXPECT_THAT(eq.get(), op::Eq()); + EXPECT_THAT(ne.get(), op::Compare()); + EXPECT_THAT(ne.get(), op::Ne()); + EXPECT_THAT(le.get(), + op::Compare(op::Parameter(0), + op::Add(op::Parameter(0), op::Parameter(1)))); + EXPECT_THAT(le.get(), op::Le(op::Parameter(0), + op::Add(op::Parameter(0), op::Parameter(1)))); + + EXPECT_THAT(Explain(eq.get(), op::Add()), Eq("")); + EXPECT_THAT(Explain(eq.get(), op::Ne()), + Eq("has wrong comparison direction (got EQ, want NE)")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index d2740bcce26f04c5d7c8b64cfdaea53e3c697855..ca1046856d12b4ae870bc7e267dca34329ada665 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -23,7 +23,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -205,6 +207,18 @@ class ListScheduler { // than not taking subcomputations into account at all. In the future, we may // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { + auto instruction = entry.instruction; + auto opcode = instruction->opcode(); + // To keep the device busy between a host send and send-done, we schedule + // the send done as late as possible. Same for host recv-done. This is a + // hack because packing of computation between channel instructions + // normally happens in the module group scheduler, and the memory scheduler + // only tries to minimize memory. + if ((opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone) && + DynCast(instruction)->is_host_transfer()) { + return INT_MIN; + } + int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { auto buffer = kv->first; @@ -216,7 +230,7 @@ class ListScheduler { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. int64 max_subcomputation_bytes = 0; - for (const auto* c : entry.instruction->called_computations()) { + for (const auto* c : instruction->called_computations()) { auto it = memory_by_computation_.find(c); if (it != memory_by_computation_.end()) { int64 subcomputation_bytes = it->second; @@ -227,9 +241,8 @@ class ListScheduler { } int64 bytes_defined; if (max_subcomputation_bytes > 0 && - (entry.instruction->opcode() == HloOpcode::kWhile || - entry.instruction->opcode() == HloOpcode::kCall || - entry.instruction->opcode() == HloOpcode::kConditional)) { + (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kConditional)) { // The output buffer of while/call/conditional is always aliased with the // output buffer of the root instruction in the body. Don't double count. bytes_defined = max_subcomputation_bytes; diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 7227bfb27c74758d2b79e404afc9eb97a1ca894d..76cc29cbb7848eb424d07abf11a95ffd59e9eed6 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -118,7 +118,7 @@ class HloTrivialScheduler : public HloModulePass { }; // A trivial pass which clears the schedule currently set on the -// HloModule. After this pass runs HloModudle::has_schedule will return false. +// HloModule. After this pass runs HloModule::has_schedule will return false. class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index bc0d7e2bc00eab014f2660c95a51b966642eaee9..200d08c562e0e334cf49e0d3c57caf1d67399387 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -254,8 +254,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0, 0, 0, 0}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param, + zero_vector, ComparisonDirection::kNe)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); // param - 1 diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 258f918f47a313b4b89fb260457b1b119dc16177..edcda8f9a7b974b95a12348577c335a3e6d8fcce 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -158,17 +158,12 @@ void HloModule::ReplaceComputations( break; } case HloOpcode::kConditional: { - HloComputation* new_true_computation = - tensorflow::gtl::FindWithDefault( - replacements, instruction->true_computation(), nullptr); - if (new_true_computation != nullptr) { - instruction->set_true_computation(new_true_computation); - } - HloComputation* new_false_computation = - tensorflow::gtl::FindWithDefault( - replacements, instruction->false_computation(), nullptr); - if (new_false_computation != nullptr) { - instruction->set_false_computation(new_false_computation); + for (int b = 0; b < instruction->branch_count(); ++b) { + HloComputation* new_computation = tensorflow::gtl::FindWithDefault( + replacements, instruction->branch_computation(b), nullptr); + if (new_computation != nullptr) { + instruction->set_branch_computation(b, new_computation); + } } break; } @@ -246,11 +241,39 @@ HloModuleProto HloModule::ToProto() const { return proto; } +Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const { + absl::flat_hash_set computation_names; + absl::flat_hash_set computation_ids; + absl::flat_hash_set instruction_names; + absl::flat_hash_set instruction_ids; + + for (const HloComputation* computation : computations()) { + TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) + << "Computation name is not unique: " << computation->name(); + computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); + + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) + << "Instruction name is not unique: " << instruction->name(); + instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); + } + } + return Status::OK(); +} + /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { VLOG(2) << "CreateFromProto()"; - XLA_VLOG_LINES(2, proto.DebugString()); + XLA_VLOG_LINES(3, proto.DebugString()); // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. @@ -329,28 +352,8 @@ StatusOr> HloModule::CreateFromProto( DynamicParameterBinding::CreateFromProto( proto.dynamic_parameter_binding())); - absl::flat_hash_set computation_names; - absl::flat_hash_set instruction_names; - absl::flat_hash_set computation_ids; - absl::flat_hash_set instruction_ids; - for (HloComputation* computation : module->computations()) { - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) - << "Computation name is not unique: " << computation->name(); - computation_names.insert(computation->name()); - - TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) - << "Computation id is not unique: " << computation->unique_id(); - computation_ids.insert(computation->unique_id()); - for (HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) - << "Instruction name is not unique: " << instruction->name(); - instruction_names.insert(instruction->name()); - - TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) - << "Instruction id is not unique: " << instruction->unique_id(); - instruction_ids.insert(instruction->unique_id()); - } - } + TF_RETURN_IF_ERROR( + module->CheckUniqueNamesAndIdsForComputationsAndInstructions()); if (proto.has_schedule()) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f1310e4b270898a21dbb4f86123edde4ba8993d0..2c63247eea8292f52e95b6171100221336450c13 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -167,6 +167,12 @@ class HloModule { // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } + // Returns the mutable computation for the given index. + HloComputation* mutable_computation(int64 idx) { + CHECK(idx >= 0 && idx < computations_.size()); + return computations_[idx].get(); + } + // Gets the number of instructions in this module. int64 instruction_count() const; @@ -187,6 +193,7 @@ class HloModule { std::vector MakeNonfusionComputations() const; const HloModuleConfig& config() const { return config_; } + void set_config(HloModuleConfig& config) { config_ = config; } // Return a string representation of the module. // @@ -264,6 +271,18 @@ class HloModule { const HloSchedule& schedule() const { return *schedule_; } HloSchedule& schedule() { return *schedule_; } + HloComputation* AddComputationAndUnifyNamesAndIds( + std::unique_ptr computation, bool is_entry) { + computation->ClearUniqueIdInternal(); + for (auto* instruction : computation->instructions()) { + instruction->ClearUniqueIdInternal(); + } + return AddComputationInternal(std::move(computation), is_entry, + /*uniquify_identifiers=*/true); + } + + Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 68c18836eb01484b819e7b7bd26f099dcf56e7ba..cee46fe10a2376555d82d2fcbce756aefaf4b982 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -101,6 +102,20 @@ class HloModuleConfig { return intra_op_parallelism_threads_; } + // Checks if this config has a static device assignment. + bool has_static_device_assignment() const { + return static_device_assignment_.has_value(); + } + + // Getter and setter of the compile-time known device assignment. + const DeviceAssignment& static_device_assignment() const { + CHECK(static_device_assignment_.has_value()); + return *static_device_assignment_; + } + void set_static_device_assignment(const DeviceAssignment& device_assignment) { + static_device_assignment_ = device_assignment; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -117,6 +132,9 @@ class HloModuleConfig { int64 intra_op_parallelism_threads_ = -1; DebugOptions debug_options_; + + // Compile-time known device assignment. + absl::optional static_device_assignment_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index f6e2866204955ac024c2b6f972de449cc3df4c15..84988a9ecb31f3e5058a2c7aa3a44954bd9c9ac9 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -86,7 +86,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -125,7 +125,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { loop_var.2 = (s32[], f32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.3 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT } ENTRY SimpleLoop { constant.4 = s32[] constant(0) @@ -163,7 +163,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -206,7 +206,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { loop_var.2 = (s32[], s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -248,7 +248,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } SimpleLoop.body1 { loop_var.3 = (s32[], s32[3]{0}) parameter(0) @@ -263,7 +263,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { loop_var.4 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 constant.4 = s32[] constant(5) - ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT } ENTRY SimpleLoop { constant.5 = s32[] constant(0) @@ -316,7 +316,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { loop_var.2 = (s32[3]{0}, s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } SimpleLoop.body1 { loop_var.3 = (s32[], s32[3]{0}) parameter(0) @@ -331,7 +331,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { loop_var.4 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 constant.4 = s32[] constant(5) - ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT } ENTRY SimpleLoop { constant.5 = s32[] constant(0) @@ -383,7 +383,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -418,7 +418,7 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { cond_param = (s32[], s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { p0 = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 47734bc55cc00d605f4e318400be88639450343c..bc258a77000d17cdb6b1d1005b6dac70e300e398 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -45,11 +45,8 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { case ComputationKind::kWhileBody: repr += ":WHILE_BODY"; break; - case ComputationKind::kConditionalTrue: - repr += ":CONDITIONAL_TRUE"; - break; - case ComputationKind::kConditionalFalse: - repr += ":CONDITIONAL_FALSE"; + case ComputationKind::kConditionalBranch: + repr += absl::StrCat(":CONDITIONAL_BRANCH_", index_); break; case ComputationKind::kCallFunction: repr += ":CALL"; @@ -307,10 +304,10 @@ Status HloModuleGroupMetadata::RecordInstructions() { tracked_instructions_[hlo->while_body()] = TrackedInstruction(hlo, ComputationKind::kWhileBody); } else if (hlo->opcode() == HloOpcode::kConditional) { - tracked_instructions_[hlo->true_computation()] = - TrackedInstruction(hlo, ComputationKind::kConditionalTrue); - tracked_instructions_[hlo->false_computation()] = - TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + for (int b = 0; b < hlo->branch_count(); ++b) { + tracked_instructions_[hlo->branch_computation(b)] = + TrackedInstruction(hlo, ComputationKind::kConditionalBranch, b); + } } else if (hlo->opcode() == HloOpcode::kCall) { tracked_instructions_[hlo->to_apply()] = TrackedInstruction(hlo, ComputationKind::kCallFunction); @@ -389,9 +386,10 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); - - if (!ContainsKey(companion_set_index_, instruction1) && - !ContainsKey(companion_set_index_, instruction2)) { + if (instruction1 == instruction2) { + return Status::OK(); + } else if (!ContainsKey(companion_set_index_, instruction1) && + !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( absl::make_unique>()); auto companion_set = companion_sets_.back().get(); @@ -419,7 +417,10 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; } - companion_sets_.erase(companion_sets_.begin() + index_to_remove); + // We can't remove the set from the vector because companion_set_index_ + // references sets by their index in this vector, so we reset to nullptr + // instead. + companion_sets_[index_to_remove].reset(nullptr); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 3ed95c10504141139d83eb8679a0b8144b15ad0d..07becfc3638a550b661e2ee0d4f10ac5e836e481 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -67,8 +67,7 @@ class HloModuleGroupMetadata { kInvalid, kWhileCondition, kWhileBody, - kConditionalTrue, - kConditionalFalse, + kConditionalBranch, kCallFunction, }; @@ -80,12 +79,13 @@ class HloModuleGroupMetadata { class TrackedInstruction { public: TrackedInstruction() = default; - TrackedInstruction(HloInstruction* instruction, ComputationKind kind) - : instruction_(instruction), kind_(kind) {} + TrackedInstruction(HloInstruction* instruction, ComputationKind kind, + int index = -1) + : instruction_(instruction), kind_(kind), index_(index) {} bool operator==(const TrackedInstruction& rhs) const { return instruction_->opcode() == rhs.instruction_->opcode() && - kind_ == rhs.kind_; + kind_ == rhs.kind_ && index_ == rhs.index_; } bool operator!=(const TrackedInstruction& rhs) const { return !operator==(rhs); @@ -98,6 +98,7 @@ class HloModuleGroupMetadata { private: HloInstruction* instruction_ = nullptr; ComputationKind kind_ = ComputationKind::kInvalid; + int index_ = -1; }; // Represents a channel and the instructions that form the channel. @@ -173,7 +174,8 @@ class HloModuleGroupMetadata { // Returns the number of modules for devices (excluding the host module). int64 GetDeviceModulesCount() const; - // Returns the companion instructions for the given instruction. + // Returns the companion set for the given instruction, including the + // instruction itself. // // Precondition: IsCompanionWhile(instruction) is true. const std::vector& Companions( diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 91417bd2d9a6ca8a5192a37302e6a91e49a94d77..b2a361f504cb341bea04f50557099e46da3610d4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -207,6 +207,39 @@ std::vector HloModuleGroupUtil::RootInstructions( return roots; } +string HloModuleGroupUtil::CycleToString(HloInstruction* init_instruction) { + std::vector names; + absl::flat_hash_set seen; + + std::function helper = + [&](HloInstruction* instruction) { + if (seen.find(instruction) != seen.end()) { + if (instruction == init_instruction) { + names.push_back(instruction->name()); + return true; + } + return false; + } + seen.insert(instruction); + for (HloInstruction* predecessor : GlobalPredecessors(instruction)) { + bool init_found = helper(predecessor); + if (init_found) { + names.push_back(instruction->name()); + return true; + } + } + return false; + }; + + helper(init_instruction); + std::vector pieces; + pieces.reserve(names.size()); + for (auto name : names) { + pieces.push_back(name); + } + return absl::StrJoin(pieces, " --> "); +} + Status HloModuleGroupUtil::VisitTopologicalOrder( VisitStates* visit_state, const VisitFunction& visit_function, HloInstruction* root) { @@ -269,22 +302,9 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( // a cycle. Generate an error with the list of instructions in the // cycle. if ((*visit_state)[predecessor] == VisitState::kVisiting) { - string cyclic_instructions; - for (const auto& state : *visit_state) { - if (state.second == VisitState::kVisiting) { - absl::StrAppend(&cyclic_instructions, state.first->ToString(), - "\n"); - } - } - // TODO(b/64305524): Improve the error message to print out the - // instructions in a deterministic order that forms the cycle. return FailedPrecondition( - "Cross-computation cycle detected via communicating nodes. The " - "cycle contains the node %s. The cycle is found among the " - "following nodes. Note that the order of the nodes is arbitrary " - "and that the list may include nodes that are not part of the " - "cycle.\n%s", - predecessor->ToString(), cyclic_instructions); + "Cross-computation cycle detected via communicating nodes.\n%s", + CycleToString(predecessor)); } stack.push(predecessor); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index 862666b48c9aa423ba4eeea3052c17fcc1064fd2..d388fe51d0ddf8cebca678d13f2cea96ba8f6114 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -108,6 +108,8 @@ class HloModuleGroupUtil { HloInstruction* instruction, HloReachabilityMap* reachability_map); private: + string CycleToString(HloInstruction* instruction); + const HloModuleGroupMetadata& metadata_; }; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 4551a1c2e259b06818f913cb6a9e782436b7e594..8f459107b3262d6b110258d784e101e52b6f8f0d 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -44,36 +44,29 @@ StatusOr StringToHloOpcode(const string& opcode_name) { return it->second; } -#define CHECK_DEFAULT(property_name, opcode_name) false -#define CHECK_PROPERTY(property_name, opcode_name, value) \ - (value & property_name) -#define RESOLVE(_1, _2, target, ...) target -#define HAS_PROPERTY(property, ...) \ - RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) - bool HloOpcodeIsComparison(HloOpcode opcode) { - switch (opcode) { -#define CASE_IS_COMPARISON(enum_name, ...) \ - case HloOpcode::enum_name: \ - return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); - HLO_OPCODE_LIST(CASE_IS_COMPARISON) -#undef CASE_IS_COMPARISON - } + return opcode == HloOpcode::kCompare; } bool HloOpcodeIsVariadic(HloOpcode opcode) { switch (opcode) { -#define CASE_IS_VARIADIC(enum_name, ...) \ - case HloOpcode::enum_name: \ - return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__); +#define CASE_IS_VARIADIC(enum_name, opcode_name, arity, ...) \ + case HloOpcode::enum_name: \ + return arity == kHloOpcodeIsVariadic; HLO_OPCODE_LIST(CASE_IS_VARIADIC) #undef CASE_IS_VARIADIC } } -#undef HAS_PROPERTY -#undef RESOLVE -#undef CHECK_DEFAULT -#undef CHECK_PROPERTY +absl::optional HloOpcodeArity(HloOpcode opcode) { + switch (opcode) { +#define CASE_ARITY(enum_name, opcode_name, arity, ...) \ + case HloOpcode::enum_name: \ + return arity == kHloOpcodeIsVariadic ? absl::nullopt \ + : absl::make_optional(arity); + HLO_OPCODE_LIST(CASE_ARITY) +#undef CASE_ARITY + } +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 94122ac38ff2a3f7053b19e55f9a400c80ae2134..c5ccd49552a87512547b72b6ae1830e582521125 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -18,8 +18,11 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -30,9 +33,9 @@ namespace xla { // See the XLA documentation for the semantics of each opcode. // // Each entry has the format: -// (enum_name, opcode_name) +// (enum_name, opcode_name, arity) // or -// (enum_name, opcode_name, p1 | p2 | ...) +// (enum_name, opcode_name, arity, p1 | p2 | ...) // // with p1, p2, ... are members of HloOpcodeProperty. They are combined // using bitwise-or. @@ -44,102 +47,102 @@ namespace xla { // - In fully qualified names (HloInstruction::FullyQualifiedName()), to // separate the qualifiers (name of the computation and potentially the // fusion instruction) from the name -#define HLO_OPCODE_LIST(V) \ - V(kAbs, "abs") \ - V(kAdd, "add") \ - V(kAddDependency, "add-dependency") \ - V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ - V(kAllReduce, "all-reduce") \ - V(kAllToAll, "all-to-all") \ - V(kAtan2, "atan2") \ - V(kBatchNormGrad, "batch-norm-grad") \ - V(kBatchNormInference, "batch-norm-inference") \ - V(kBatchNormTraining, "batch-norm-training") \ - V(kBitcast, "bitcast") \ - V(kBitcastConvert, "bitcast-convert") \ - V(kBroadcast, "broadcast") \ - V(kCall, "call", kHloOpcodeIsVariadic) \ - V(kCeil, "ceil") \ - V(kClamp, "clamp") \ - V(kCollectivePermute, "collective-permute") \ - V(kClz, "count-leading-zeros") \ - V(kComplex, "complex") \ - V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ - V(kConditional, "conditional") \ - V(kConstant, "constant") \ - V(kConvert, "convert") \ - V(kConvolution, "convolution") \ - V(kCopy, "copy") \ - V(kCos, "cosine") \ - V(kCustomCall, "custom-call") \ - V(kDivide, "divide") \ - V(kDomain, "domain") \ - V(kDot, "dot") \ - V(kDynamicSlice, "dynamic-slice") \ - V(kDynamicUpdateSlice, "dynamic-update-slice") \ - V(kEq, "equal-to", kHloOpcodeIsComparison) \ - V(kExp, "exponential") \ - V(kExpm1, "exponential-minus-one") \ - V(kFft, "fft") \ - V(kFloor, "floor") \ - V(kFusion, "fusion", kHloOpcodeIsVariadic) \ - V(kGather, "gather") \ - V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kGetDimensionSize, "get-dimension-size") \ - V(kGetTupleElement, "get-tuple-element") \ - V(kGt, "greater-than", kHloOpcodeIsComparison) \ - V(kImag, "imag") \ - V(kInfeed, "infeed") \ - V(kIota, "iota") \ - V(kIsFinite, "is-finite") \ - V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kLog, "log") \ - V(kLog1p, "log-plus-one") \ - V(kAnd, "and") \ - V(kNot, "not") \ - V(kOr, "or") \ - V(kXor, "xor") \ - V(kLt, "less-than", kHloOpcodeIsComparison) \ - V(kMap, "map", kHloOpcodeIsVariadic) \ - V(kMaximum, "maximum") \ - V(kMinimum, "minimum") \ - V(kMultiply, "multiply") \ - V(kNe, "not-equal-to", kHloOpcodeIsComparison) \ - V(kNegate, "negate") \ - V(kOutfeed, "outfeed") \ - V(kPad, "pad") \ - V(kParameter, "parameter") \ - V(kPower, "power") \ - V(kReal, "real") \ - V(kRecv, "recv") \ - V(kRecvDone, "recv-done") \ - V(kReduce, "reduce") \ - V(kReducePrecision, "reduce-precision") \ - V(kReduceWindow, "reduce-window") \ - V(kRemainder, "remainder") \ - V(kReshape, "reshape") \ - V(kReverse, "reverse") \ - V(kRng, "rng") \ - V(kRoundNearestAfz, "round-nearest-afz") \ - V(kScatter, "scatter") \ - V(kSelect, "select") \ - V(kSelectAndScatter, "select-and-scatter") \ - V(kSend, "send") \ - V(kSendDone, "send-done") \ - V(kShiftLeft, "shift-left") \ - V(kShiftRightArithmetic, "shift-right-arithmetic") \ - V(kShiftRightLogical, "shift-right-logical") \ - V(kSign, "sign") \ - V(kSin, "sine") \ - V(kSlice, "slice") \ - V(kSort, "sort") \ - V(kSubtract, "subtract") \ - V(kTanh, "tanh") \ - V(kTrace, "trace") \ - V(kTranspose, "transpose") \ - V(kTuple, "tuple", kHloOpcodeIsVariadic) \ - V(kTupleSelect, "tuple-select") \ - V(kWhile, "while") +#define HLO_OPCODE_LIST(V) \ + V(kAbs, "abs", 1) \ + V(kAdd, "add", 2) \ + V(kAddDependency, "add-dependency", 2) \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \ + V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \ + V(kAtan2, "atan2", 2) \ + V(kBatchNormGrad, "batch-norm-grad", 5) \ + V(kBatchNormInference, "batch-norm-inference", 5) \ + V(kBatchNormTraining, "batch-norm-training", 3) \ + V(kBitcast, "bitcast", 1) \ + V(kBitcastConvert, "bitcast-convert", 1) \ + V(kBroadcast, "broadcast", 1) \ + V(kCall, "call", kHloOpcodeIsVariadic) \ + V(kCeil, "ceil", 1) \ + V(kCholesky, "cholesky", 1) \ + V(kClamp, "clamp", 3) \ + V(kCollectivePermute, "collective-permute", 1) \ + V(kClz, "count-leading-zeros", 1) \ + V(kCompare, "compare", 2) \ + V(kComplex, "complex", 2) \ + V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional", kHloOpcodeIsVariadic) \ + V(kConstant, "constant", 0) \ + V(kConvert, "convert", 1) \ + V(kConvolution, "convolution", 2) \ + V(kCopy, "copy", 1) \ + V(kCos, "cosine", 1) \ + V(kCustomCall, "custom-call", kHloOpcodeIsVariadic) \ + V(kDivide, "divide", 2) \ + V(kDomain, "domain", 1) \ + V(kDot, "dot", 2) \ + V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ + V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ + V(kExp, "exponential", 1) \ + V(kExpm1, "exponential-minus-one", 1) \ + V(kFft, "fft", 1) \ + V(kFloor, "floor", 1) \ + V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGather, "gather", 2) \ + V(kGetDimensionSize, "get-dimension-size", 1) \ + V(kGetTupleElement, "get-tuple-element", 1) \ + V(kImag, "imag", 1) \ + V(kInfeed, "infeed", 1) \ + V(kIota, "iota", 0) \ + V(kIsFinite, "is-finite", 1) \ + V(kLog, "log", 1) \ + V(kLog1p, "log-plus-one", 1) \ + V(kAnd, "and", 2) \ + V(kNot, "not", 1) \ + V(kOr, "or", 2) \ + V(kXor, "xor", 2) \ + V(kMap, "map", kHloOpcodeIsVariadic) \ + V(kMaximum, "maximum", 2) \ + V(kMinimum, "minimum", 2) \ + V(kMultiply, "multiply", 2) \ + V(kNegate, "negate", 1) \ + V(kOutfeed, "outfeed", 2) \ + V(kPad, "pad", 2) \ + V(kParameter, "parameter", 0) \ + V(kPower, "power", 2) \ + V(kReal, "real", 1) \ + V(kRecv, "recv", 1) \ + V(kRecvDone, "recv-done", 1) \ + V(kReduce, "reduce", kHloOpcodeIsVariadic) \ + V(kReducePrecision, "reduce-precision", 1) \ + V(kReduceWindow, "reduce-window", 2) \ + V(kRemainder, "remainder", 2) \ + V(kReplicaId, "replica-id", 0) \ + V(kReshape, "reshape", 1) \ + V(kReverse, "reverse", 1) \ + V(kRng, "rng", kHloOpcodeIsVariadic) \ + V(kRoundNearestAfz, "round-nearest-afz", 1) \ + V(kRsqrt, "rsqrt", 1) \ + V(kScatter, "scatter", 3) \ + V(kSelect, "select", 3) \ + V(kSelectAndScatter, "select-and-scatter", 3) \ + V(kSend, "send", 2) \ + V(kSendDone, "send-done", 1) \ + V(kShiftLeft, "shift-left", 2) \ + V(kShiftRightArithmetic, "shift-right-arithmetic", 2) \ + V(kShiftRightLogical, "shift-right-logical", 2) \ + V(kSign, "sign", 1) \ + V(kSin, "sine", 1) \ + V(kSlice, "slice", 1) \ + V(kSort, "sort", kHloOpcodeIsVariadic) \ + V(kSqrt, "sqrt", 1) \ + V(kSubtract, "subtract", 2) \ + V(kTanh, "tanh", 1) \ + V(kTrace, "trace", 1) \ + V(kTranspose, "transpose", 1) \ + V(kTriangularSolve, "triangular-solve", 2) \ + V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kTupleSelect, "tuple-select", 3) \ + V(kWhile, "while", 1) enum class HloOpcode { #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, @@ -147,12 +150,16 @@ enum class HloOpcode { #undef DECLARE_ENUM }; +// Arity value that denotes that an operator is variadic. +enum { + kHloOpcodeIsVariadic = -1, +}; + // List of properties associated with opcodes. // Properties are defined as increasing powers of two, so that we can use // bitwise-or to combine properties, and bitwise-and to test for them. enum HloOpcodeProperty { kHloOpcodeIsComparison = 1 << 0, - kHloOpcodeIsVariadic = 1 << 1, }; // Returns a string representation of the opcode. @@ -171,6 +178,10 @@ bool HloOpcodeIsComparison(HloOpcode opcode); // Returns true iff the given opcode has variadic operands. bool HloOpcodeIsVariadic(HloOpcode opcode); +// Returns the arity of opcode. If the opcode is variadic, +// returns nullopt. +absl::optional HloOpcodeArity(HloOpcode opcode); + // Returns the number of HloOpcode values. inline const uint32_t HloOpcodeCount() { #define HLO_COUNT_ONE(...) +1 diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 6f3f83f63a05fafaa3f3ddcff8a7cac7cb7b06d5..136e6702b21c05583abaf0e328800cedfd778aa8 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -42,23 +42,27 @@ TEST(HloOpcodeTest, OpcodeProperties) { // Test some properties. switch (opcode) { - case HloOpcode::kEq: - case HloOpcode::kNe: - case HloOpcode::kGt: - case HloOpcode::kLt: - case HloOpcode::kGe: - case HloOpcode::kLe: + case HloOpcode::kCompare: EXPECT_TRUE(HloOpcodeIsComparison(opcode)); break; default: EXPECT_FALSE(HloOpcodeIsComparison(opcode)); } switch (opcode) { + case HloOpcode::kAfterAll: + case HloOpcode::kAllReduce: + case HloOpcode::kAllToAll: case HloOpcode::kCall: case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kFusion: case HloOpcode::kMap: - case HloOpcode::kAfterAll: + case HloOpcode::kReduce: + case HloOpcode::kRng: + case HloOpcode::kSort: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 0cec61c257bb84e467290fb52ec9063a32ed558d..831771fe63b8dd4c276ad3ec05ea90b4d475e7e0 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -66,24 +66,31 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } } - // If the common ancestor is a conditional instruction, even though the true - // and false computations are not really ordered per-se, we define the true - // computation to be ordered before the false one. - // This ensures that buffers can still be shared among the two computations + // If the common ancestor is a conditional instruction, even though the branch + // computations are not really ordered per-se, we define the 0th branch + // computation to be ordered before the 1st one, before the 2nd and so forth. + // This ensures that buffers can still be shared among branch computations // as they will forcibly have disjoint liveness. if (a_ancestor == b_ancestor && - a_ancestor->opcode() == HloOpcode::kConditional) { - const HloComputation* true_computation = a_ancestor->true_computation(); - const HloComputation* false_computation = a_ancestor->false_computation(); - if (call_graph_->InstructionIsNestedIn(a, true_computation) && - call_graph_->InstructionIsNestedIn(b, false_computation)) { + (a_ancestor->opcode() == HloOpcode::kConditional)) { + int a_branch = -1; + int b_branch = -1; + for (int j = 0; j < a_ancestor->branch_count(); ++j) { + if (call_graph_->InstructionIsNestedIn( + a, a_ancestor->branch_computation(j))) { + a_branch = j; + } + if (call_graph_->InstructionIsNestedIn( + b, a_ancestor->branch_computation(j))) { + b_branch = j; + } + } + if (a_branch != -1 && a_branch < b_branch) { return true; } - // If 'b' is the conditional ancestor, and 'a' is within the true or false - // computations, 'a' executes before 'b'. - if (b == a_ancestor && - (call_graph_->InstructionIsNestedIn(a, true_computation) || - call_graph_->InstructionIsNestedIn(a, false_computation))) { + // If 'b' is the conditional ancestor, and 'a' is within a branch + // computation, 'a' executes before 'b'. + if (b == a_ancestor && a_branch != -1) { return true; } } @@ -144,17 +151,17 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { b.defining_instruction()->while_condition()))) { return true; } - // If 'b' is a conditional phi and 'a' is in the true or false computation, - // then 'a' executes before 'b'. + // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a' + // executes before 'b'. if (b.is_phi() && - b.defining_instruction()->opcode() == HloOpcode::kConditional && - (call_graph_->InstructionIsNestedIn( - a.defining_instruction(), - b.defining_instruction()->true_computation()) || - call_graph_->InstructionIsNestedIn( - a.defining_instruction(), - b.defining_instruction()->false_computation()))) { - return true; + b.defining_instruction()->opcode() == HloOpcode::kConditional) { + for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) { + if (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->branch_computation(j))) { + return true; + } + } } return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); } @@ -225,17 +232,14 @@ bool HloOrdering::UseIsBeforeValueDefinition( if (use.instruction->opcode() == HloOpcode::kConditional) { const HloInstruction* conditional = use.instruction; - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - conditional->true_computation())) { - VLOG(4) << " use is conditional " << use.instruction->name() - << " and def is in TRUE computation"; - return true; - } - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - conditional->false_computation())) { - VLOG(4) << " use is conditional " << use.instruction->name() - << " and def is in FALSE computation"; - return true; + for (int j = 0; j < conditional->branch_count(); ++j) { + if (call_graph_->InstructionIsNestedIn( + value.defining_instruction(), + conditional->branch_computation(j))) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in " << j << "th branch computation"; + return true; + } } if (value.defining_instruction() == use.instruction) { VLOG(4) << " use is conditional " << use << " and def is " diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 3ca77e60cd5275c22eb0e338cd5437fc44b49958..8e8b9d663ea2540dec0b1011c32204c52ef6beca 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -306,7 +306,7 @@ condition.v4 { constant.2 = s32[] constant(2) prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0 - ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8) + ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT } fused_computation { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 638396308c2a9c1f20e47f78b594d54f07c0c4e5..8e76a1f262e988d19a0689f4c691844d9f28a559 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include #include "absl/algorithm/container.h" #include "absl/memory/memory.h" @@ -21,17 +22,21 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -44,8 +49,6 @@ using absl::StrCat; using absl::StrFormat; using absl::StrJoin; -const double kF16max = 65504; - // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. HloSchedule ScheduleFromInstructionOrder(HloModule* module) { @@ -60,6 +63,10 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) { return schedule; } +// Some functions accept either a linear index or a multi-dimensional index +// (used for indexing into sparse literals). +using LinearOrMultiIndex = absl::variant>; + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -77,6 +84,7 @@ class HloParser { // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); @@ -102,7 +110,7 @@ class HloParser { // Parse a single instruction worth of text. bool ParseSingleInstruction(HloModule* module); - // ParseXXX returns false if an error occurred. + // Parses a module, returning false if an error occurred. bool ParseHloModule(HloModule* module); bool ParseComputations(HloModule* module); @@ -118,21 +126,30 @@ class HloParser { bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); bool ParseDenseLiteral(Literal* literal, const Shape& shape); bool ParseSparseLiteral(Literal* literal, const Shape& shape); - template - bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); - // Sets the sub-value of literal at the given index to the given value. The - // literal's shape must have the default layout. - bool SetValueInLiteral(tensorflow::int64 value, - tensorflow::int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + // Sets the sub-value of literal at the given linear or sparse index to the + // given value. If the literal is dense, it myst have the default layout. + // + // `loc` should be the source location of the value. + bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index, Literal* literal); - bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index, Literal* literal); + bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index, + Literal* literal); + bool SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, Literal* literal); + // `loc` should be the source location of the value. template - bool SetValueInLiteralHelper(ParsedElemT value, - tensorflow::int64 linear_index, - Literal* literal); + bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, Literal* literal); + + // Checks whether the given value is within the range of LiteralNativeT. + // `loc` should be the source location of the value. + template + bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value); + template + bool CheckParsedValueIsInRange(LocTy loc, std::complex value); bool ParseOperands(std::vector* operands); // Fills parsed operands into 'operands' and expects a certain number of @@ -143,9 +160,9 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; }; // The data parsed for the kDomain instruction. @@ -164,10 +181,13 @@ class HloParser { kBracedInt64List, kBracedInt64ListList, kHloComputation, + kBracedHloComputationList, kFftType, + kComparisonDirection, kWindow, kConvolutionDimensionNumbers, kSharding, + kParameterReplication, kInstructionList, kSliceRanges, kPaddingConfig, @@ -220,6 +240,21 @@ class HloParser { bool ParseAttributeHelper(const std::unordered_map& attrs, std::unordered_set* seen_attrs); + // Parses an attribute string into a protocol buffer `message`. + // Since proto3 has no notion of mandatory fields, `required_attrs` gives the + // set of mandatory attributes. + bool ParseAttributesAsProtoMessage( + const std::unordered_set& required_attrs, + tensorflow::protobuf::Message* message); + + // Parses one attribute. If it has already been seen, return error. Returns + // true and adds to seen_attrs on success. + // + // Do not call this except in ParseAttributesAsProtoMessage. + bool ParseAttributeAsProtoMessageHelper( + tensorflow::protobuf::Message* message, + std::unordered_set* seen_attrs); + // Parses a name and finds the corresponding hlo computation. bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. @@ -232,21 +267,23 @@ class HloParser { bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + bool ParseParameterReplication(ParameterReplication* parameter_replication); // Parses the metadata behind a kDOmain instruction. bool ParseDomain(DomainData* domain); // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); + bool ParseHloComputation(HloComputation** result); + bool ParseHloComputationList(std::vector* result); bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, - std::vector* result); + const TokKind delim, std::vector* result); // 'parse_and_add_item' is an lambda to parse an element in the list and add // the parsed element to the result. It's supposed to capture the result. bool ParseList(const TokKind start, const TokKind end, const TokKind delim, @@ -261,13 +298,16 @@ class HloParser { std::vector* dynamic_dimensions); bool ParseShape(Shape* result); bool ParseLayout(Layout* layout); + bool ParseTiles(std::vector* tiles); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); + bool ParseComparisonDirection(ComparisonDirection* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParsePrecision(PrecisionConfig::Precision* result); - bool ParseInt64(tensorflow::int64* result); + bool ParseInt64(int64* result); bool ParseDouble(double* result); + bool ParseComplex(std::complex* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -627,6 +667,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, std::unordered_map attrs; optional sharding; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional parameter_replication; + attrs["parameter_replication"] = {/*required=*/false, + AttrTy::kParameterReplication, + ¶meter_replication}; optional> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; @@ -640,11 +684,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - tensorflow::int64 parameter_number; + int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || - !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseInt64(¶meter_number)) { + return false; + } + if (parameter_number < 0) { + Error(lexer_.GetLoc(), "parameter number must be >= 0"); + return false; + } + if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") || !ParseAttributes(attrs)) { return false; } @@ -666,7 +716,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { - optional iota_dimension; + optional iota_dimension; attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || @@ -695,8 +745,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -713,12 +765,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kSubtract: case HloOpcode::kAtan2: case HloOpcode::kComplex: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -830,6 +876,14 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } + case HloOpcode::kReplicaId: { + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReplicaId()); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -861,17 +915,21 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; + optional is_stable = false; + attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable}; + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || dimensions->size() != 1) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], - /*values=*/absl::Span(operands).subspan(1))); + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, dimensions->at(0), operands, + to_apply.value(), is_stable.value())); break; } case HloOpcode::kTuple: { @@ -897,7 +955,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -913,7 +971,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -931,7 +989,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -946,7 +1004,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -964,7 +1022,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -1047,7 +1105,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -1059,8 +1117,43 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands[0], *fft_type, *fft_length)); break; } + case HloOpcode::kTriangularSolve: { + TriangularSolveOptions options; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributesAsProtoMessage( + /*required_attrs=*/std::unordered_set(), &options)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateTriangularSolve( + shape, operands[0], operands[1], options)); + break; + } + case HloOpcode::kCompare: { + optional direction; + attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection, + &direction}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCompare( + shape, operands[0], operands[1], *direction)); + break; + } + case HloOpcode::kCholesky: { + CholeskyOptions options; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributesAsProtoMessage( + /*required_attrs=*/std::unordered_set(), &options)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCholesky(shape, operands[0], options)); + break; + } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1072,7 +1165,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -1087,7 +1180,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1103,7 +1196,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1124,7 +1217,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1168,7 +1261,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; LocTy loc = lexer_.GetLoc(); @@ -1207,7 +1300,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1221,7 +1314,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -1237,7 +1330,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -1254,7 +1347,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -1336,8 +1429,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1355,18 +1448,36 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kConditional: { optional true_computation; optional false_computation; - attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, - &true_computation}; - attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, - &false_computation}; - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { + optional> branch_computations; + if (!ParseOperands(&operands)) { + return false; + } + const bool branch_index_is_bool = + operands[0]->shape().element_type() == PRED; + if (branch_index_is_bool) { + attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &true_computation}; + attrs["false_computation"] = { + /*required=*/true, AttrTy::kHloComputation, &false_computation}; + } else { + attrs["branch_computations"] = {/*required=*/true, + AttrTy::kBracedHloComputationList, + &branch_computations}; + } + if (!ParseAttributes(attrs)) { + return false; + } + if (branch_index_is_bool) { + branch_computations.emplace({*true_computation, *false_computation}); + } + if (branch_computations->empty() || + operands.size() != branch_computations->size() + 1) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateConditional( - shape, /*pred=*/operands[0], - /*true_computation_arg=*/operands[1], *true_computation, - /*false_computation_arg=*/operands[2], *false_computation)); + shape, /*branch_index=*/operands[0], + absl::MakeSpan(*branch_computations), + absl::MakeSpan(operands).subspan(1))); break; } case HloOpcode::kCustomCall: { @@ -1375,6 +1486,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional batch_group_count; optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; @@ -1384,6 +1496,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64, + &batch_group_count}; attrs["operand_layout_constraints"] = { /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1439,19 +1553,22 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (feature_group_count.has_value()) { instruction->set_feature_group_count(*feature_group_count); } + if (batch_group_count.has_value()) { + instruction->set_batch_group_count(*batch_group_count); + } break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; optional> operand_precision; @@ -1495,19 +1612,19 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> offset_dims; + optional> offset_dims; attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &offset_dims}; - optional> collapsed_slice_dims; + optional> collapsed_slice_dims; attrs["collapsed_slice_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; - optional> start_index_map; + optional> start_index_map; attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, &start_index_map}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> slice_sizes; + optional> slice_sizes; attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, &slice_sizes}; @@ -1529,17 +1646,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kScatter: { - optional> update_window_dims; + optional> update_window_dims; attrs["update_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; - optional> inserted_window_dims; + optional> inserted_window_dims; attrs["inserted_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; - optional> scatter_dims_to_operand_dims; + optional> scatter_dims_to_operand_dims; attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &scatter_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; @@ -1580,7 +1697,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); case HloOpcode::kGetDimensionSize: - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1605,6 +1722,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); } + if (parameter_replication) { + int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); + const auto& replicated = + parameter_replication->replicated_at_leaf_buffers(); + if (leaf_count != replicated.size()) { + return Error(lexer_.GetLoc(), + StrCat("parameter has ", leaf_count, + " leaf buffers, but parameter_replication has ", + replicated.size(), " elements.")); + } + instruction->set_parameter_replicated_at_leaf_buffers(replicated); + } if (predecessors) { for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); @@ -1669,8 +1798,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1696,7 +1825,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - tensorflow::int64 dim; + int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1708,7 +1837,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - tensorflow::int64 device; + int64 device; if (!ParseInt64(&device)) { return false; } @@ -1752,10 +1881,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - for (tensorflow::int64 dim : tile_assignment_dimensions) { + for (int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (tensorflow::int64 device : devices) { + for (int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1764,6 +1893,32 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// parameter_replication ::= +// '{' ('true' | 'false')* (',' ('true' | 'false'))* '}' +bool HloParser::ParseParameterReplication( + ParameterReplication* parameter_replication) { + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start parameter_replication attribute")) { + return false; + } + + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (lexer_.GetKind() == TokKind::kw_true) { + parameter_replication->add_replicated_at_leaf_buffers(true); + } else if (lexer_.GetKind() == TokKind::kw_false) { + parameter_replication->add_replicated_at_leaf_buffers(false); + } else { + return false; + } + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + + return ParseToken(TokKind::kRbrace, + "expected '}' to end parameter_replication attribute"); +} + // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' // 'exit=' exit_sharding '}' bool HloParser::ParseDomain(DomainData* domain) { @@ -1816,130 +1971,146 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(tensorflow::int64 value, - tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, int64 value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S16: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S32: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S64: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case U8: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U32: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U64: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case PRED: // Bool type literals with rank >= 1 are printed in 0s and 1s. - return SetValueInLiteralHelper(static_cast(value), - linear_index, literal); + return SetValueInLiteralHelper(loc, static_cast(value), index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, double value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case F32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); case F64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); default: LOG(FATAL) << "unknown floating point primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, bool value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case PRED: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); default: LOG(FATAL) << PrimitiveType_Name(shape.element_type()) << " is not PRED type"; } } +bool HloParser::SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case C64: + return SetValueInLiteralHelper>(loc, value, index, + literal); + case C128: + return SetValueInLiteralHelper>(loc, value, index, + literal); + default: + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " is not a complex type type"; + } +} + +template +string StringifyValue(T val) { + return StrCat(val); +} +template <> +string StringifyValue(std::complex val) { + return StrFormat("(%f, %f)", std::real(val), std::imag(val)); +} + template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, - tensorflow::int64 linear_index, +bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, Literal* literal) { - // Check that linear_index is in range. - if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { - return TokenError( - StrCat("trys to set value ", value, " to a literal in shape ", - ShapeUtil::HumanString(literal->shape()), " at linear index ", - linear_index, ", but the index is out of range")); + if (!CheckParsedValueIsInRange(loc, value)) { + return false; } - if (std::isnan(value) || - (std::numeric_limits::has_infinity && - (std::numeric_limits::infinity() == value || - -std::numeric_limits::infinity() == value))) { - // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16 || - literal->shape().element_type() == BF16) { - if (value > kF16max || value < -kF16max) { - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); + // Check that the index is in range and assign into the literal + if (auto* linear_index = absl::get_if(&index)) { + if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) { + return Error(loc, StrCat("trys to set value ", StringifyValue(value), + " to a literal in shape ", + ShapeUtil::HumanString(literal->shape()), + " at linear index ", *linear_index, + ", but the index is out of range")); } - } else if (std::is_unsigned::value) { - CHECK((std::is_same::value || - std::is_same::value)) - << "Unimplemented checking for ParsedElemT"; - - ParsedElemT upper_bound; - if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { - upper_bound = std::numeric_limits::max(); - } else { - upper_bound = - static_cast(std::numeric_limits::max()); + literal->data().at(*linear_index) = + static_cast(value); + } else { + auto* multi_index = absl::get_if>(&index); + CHECK(multi_index != nullptr); + + auto invalid_idx = [&](string msg) { + return Error(loc, StrFormat("Invalid sparse index [%s]. %s", + absl::StrJoin(*multi_index, ", "), msg)); + }; + + const auto& shape = literal->shape(); + if (shape.rank() != multi_index->size()) { + return invalid_idx( + StrFormat("Has rank %d, but constant has shape %s, which has rank %d", + multi_index->size(), shape.ToString(), shape.rank())); } - if (value > upper_bound || value < 0) { - // Value is out of range for LiteralNativeT. - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); - } - } else if (value > static_cast( - std::numeric_limits::max()) || - value < static_cast( - std::numeric_limits::lowest())) { - // Value is out of range for LiteralNativeT. - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); + for (int64 i = 0; i < shape.rank(); ++i) { + auto idx = (*multi_index)[i]; + if (idx < 0) { + return invalid_idx(StrFormat( + "Sub-index value at %d, namely %d, cannot be negative.", i, idx)); + } + if (idx >= shape.dimensions(i)) { + return invalid_idx( + StrFormat("Sub-index at %d, namely %d, doesn't fit within shape " + "dimension %d in %s", + i, idx, shape.dimensions(i), shape.ToString())); + } + } + literal->AppendSparseElement(*multi_index, + static_cast(value)); } - - literal->data().at(linear_index) = - static_cast(value); return true; } @@ -1996,12 +2167,16 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { - const tensorflow::int64 rank = shape.rank(); + // Cast `rank` to int because we call shape.dimensions(int rank) below, and if + // `rank` is an int64, that's an implicit narrowing conversion, which is + // implementation-defined behavior. + const int rank = static_cast(shape.rank()); + // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); - tensorflow::int64 nest_level = 0; - tensorflow::int64 linear_index = 0; + int64 nest_level = 0; + int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -2009,17 +2184,35 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim( - elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), + elems_seen_per_dim.begin() + dim); return StrCat("[", StrJoin(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { + [](string* out, const int64& num_elems) { StrAppend(out, num_elems - 1); }), "]"); }; + + auto add_one_elem_seen = [&] { + if (rank > 0) { + if (nest_level != rank) { + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees %d", rank, nest_level)); + } + elems_seen_per_dim[rank - 1]++; + if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); + } + } + return true; + }; + do { switch (lexer_.GetKind()) { default: @@ -2055,6 +2248,31 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { lexer_.Lex(); break; } + case TokKind::kLparen: { + if (!primitive_util::IsComplexType(shape.element_type())) { + return TokenError( + absl::StrFormat("unexpected '(' in literal. Parens are only " + "valid for complex literals")); + } + + std::complex value; + LocTy loc = lexer_.GetLoc(); + if (!add_one_elem_seen() || !ParseComplex(&value) || + !SetValueInLiteral(loc, value, linear_index++, literal)) { + return false; + } + break; + } + case TokKind::kDots: { + if (nest_level != 1) { + return TokenError(absl::StrFormat( + "expects `...` at nest level 1, but sees it at nest level %d", + nest_level)); + } + elems_seen_per_dim[0] = shape.dimensions(0); + lexer_.Lex(); + break; + } case TokKind::kComma: // Skip. lexer_.Lex(); @@ -2066,23 +2284,11 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { case TokKind::kw_nan: case TokKind::kw_inf: case TokKind::kNegInf: { - if (rank > 0) { - if (nest_level != rank) { - return TokenError( - absl::StrFormat("expects nested array in rank %d, but sees %d", - rank, nest_level)); - } - elems_seen_per_dim[rank - 1]++; - if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError(absl::StrFormat( - "expects %d elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); - } - } + add_one_elem_seen(); if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, + if (!SetValueInLiteral(lexer_.GetLoc(), + lexer_.GetKind() == TokKind::kw_true, linear_index++, literal)) { return false; } @@ -2090,12 +2296,12 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } else if (primitive_util::IsIntegralType(shape.element_type()) || shape.element_type() == PRED) { LocTy loc = lexer_.GetLoc(); - tensorflow::int64 value; + int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal)) { + if (!SetValueInLiteral(loc, value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -2106,7 +2312,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal)) { + if (!SetValueInLiteral(loc, value, linear_index++, literal)) { return false; } } else { @@ -2123,48 +2329,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { - switch (shape.element_type()) { - case PRED: - return ParseSparseLiteralHelper(literal, shape); - case S8: - return ParseSparseLiteralHelper(literal, shape); - case S16: - return ParseSparseLiteralHelper(literal, shape); - case S32: - return ParseSparseLiteralHelper(literal, shape); - case S64: - return ParseSparseLiteralHelper(literal, shape); - case U8: - return ParseSparseLiteralHelper(literal, shape); - case U16: - return ParseSparseLiteralHelper(literal, shape); - case U32: - return ParseSparseLiteralHelper(literal, shape); - case U64: - return ParseSparseLiteralHelper(literal, shape); - case F16: - return ParseSparseLiteralHelper(literal, shape); - case F32: - return ParseSparseLiteralHelper(literal, shape); - case BF16: - return ParseSparseLiteralHelper(literal, shape); - case F64: - return ParseSparseLiteralHelper(literal, shape); - default: - return Error(lexer_.GetLoc(), - StrCat("invalid primitive type for sparse literal: ", - PrimitiveType_Name(shape.element_type()))); - } -} - -template -bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { - std::vector index; - - tensorflow::int64 rank = shape.rank(); - *literal = Literal(shape); - if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { return false; @@ -2176,61 +2341,66 @@ bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { break; } - LocTy index_loc = lexer_.GetLoc(); - index.clear(); + std::vector index; if (lexer_.GetKind() == TokKind::kInt) { - tensorflow::int64 single_index = lexer_.GetInt64Val(); + int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); - if (rank != 1) { - return Error( - index_loc, - StrCat("invalid single-dimensional index for shape with rank ", - rank, ": ", single_index)); - } index.push_back(single_index); } else { if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, &index)) { return false; } - if (index.size() != rank) { - return Error( - index_loc, - StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", StrJoin(index, ", "), "]")); - } } if (!ParseToken(TokKind::kColon, "expects ':' after after the sparse array index and before " "the sparse array value")) { return false; } + LocTy value_loc = lexer_.GetLoc(); - LiteralNativeT value; if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - value = static_cast(lexer_.GetKind() == TokKind::kw_true); + bool value = lexer_.GetKind() == TokKind::kw_true; + if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) { + return false; + } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - tensorflow::int64 value_s64; - if (!ParseInt64(&value_s64)) { + int64 value; + if (!ParseInt64(&value)) { return Error(value_loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - value = static_cast(value_s64); + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { - double value_f64; - if (!ParseDouble(&value_f64)) { + double value; + if (!ParseDouble(&value)) { return Error(value_loc, StrCat("expects floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - value = static_cast(value_f64); + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } + } else if (primitive_util::IsComplexType(shape.element_type())) { + std::complex value; + if (!ParseComplex(&value)) { + return Error(value_loc, + StrCat("expects complex value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } } else { LOG(FATAL) << "Unexpected element type: " << PrimitiveType_Name(shape.element_type()); } + if (lexer_.GetKind() != TokKind::kRbrace && !ParseToken(TokKind::kComma, "expects ',' separator between sparse array elements")) { @@ -2244,14 +2414,114 @@ bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { StrCat("number of sparse elements exceeds maximum for layout: ", ShapeUtil::HumanStringWithLayout(shape))); } - - literal->AppendSparseElement(index, value); } literal->SortSparseElements(); return true; } +// MaxFiniteValue is a type-traits helper used by +// HloParser::CheckParsedValueIsInRange. +template +struct MinMaxFiniteValue { + static T max() { return std::numeric_limits::max(); } + static T min() { return std::numeric_limits::lowest(); } +}; + +template <> +struct MinMaxFiniteValue { + static double max() { + // Sadly this is not constexpr, so this forces `value` to be a method. + return static_cast(Eigen::NumTraits::highest()); + } + static double min() { return -max(); } +}; + +template <> +struct MinMaxFiniteValue { + static double max() { return static_cast(bfloat16::highest()); } + static double min() { return -max(); } +}; + +template +bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { + PrimitiveType literal_ty = + primitive_util::NativeToPrimitiveType(); + if (std::isnan(value) || + (std::numeric_limits::has_infinity && + (std::numeric_limits::infinity() == value || + -std::numeric_limits::infinity() == value))) { + // Skip range checking for non-finite value. + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return Error(loc, StrCat("value ", value, + " is out of range for literal's primitive type ", + PrimitiveType_Name(literal_ty), " namely [0, ", + upper_bound, "].")); + } + } else if (value > MinMaxFiniteValue::max() || + value < MinMaxFiniteValue::min()) { + // Value is out of range for LiteralNativeT. + return Error(loc, StrCat("value ", value, + " is out of range for literal's primitive type ", + PrimitiveType_Name(literal_ty), " namely [", + MinMaxFiniteValue::min(), ", ", + MinMaxFiniteValue::max(), "].")); + } + return true; +} + +template +bool HloParser::CheckParsedValueIsInRange(LocTy loc, + std::complex value) { + // e.g. `float` for std::complex + using LiteralComplexComponentT = + decltype(std::real(std::declval())); + + // We could do simply + // + // return CheckParsedValueIsInRange(std::real(value)) && + // CheckParsedValueIsInRange(std::imag(value)); + // + // but this would give bad error messages on failure. + + auto check_component = [&](absl::string_view name, double v) { + if (std::isnan(v) || v == std::numeric_limits::infinity() || + v == -std::numeric_limits::infinity()) { + // Skip range-checking for non-finite values. + return true; + } + + double min = MinMaxFiniteValue::min(); + double max = MinMaxFiniteValue::max(); + if (v < min || v > max) { + // Value is out of range for LitearlComplexComponentT. + return Error( + loc, + StrCat(name, " part ", v, + " is out of range for literal's primitive type ", + PrimitiveType_Name( + primitive_util::NativeToPrimitiveType()), + ", namely [", min, ", ", max, "].")); + } + return true; + }; + return check_component("real", std::real(value)) && + check_component("imaginary", std::imag(value)); +} + // operands ::= '(' operands1 ')' // operands1 // ::= /*empty*/ @@ -2409,24 +2679,23 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kInt64: { - tensorflow::int64 result; + int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr) - ->emplace(result); + static_cast*>(attr_out_ptr)->emplace(result); return true; } case AttrTy::kInt32: { - tensorflow::int64 result; + int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -2444,20 +2713,21 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kHloComputation: { HloComputation* result = nullptr; - if (lexer_.GetKind() == TokKind::kLbrace) { - // This means it is a nested computation. - if (!ParseInstructionList(&result, /*computation_name=*/"_")) { - return false; - } - } else { - // This means it is a computation name. - if (!ParseComputationName(&result)) { - return false; - } + if (!ParseHloComputation(&result)) { + return false; } static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kBracedHloComputationList: { + std::vector result; + if (!ParseHloComputationList(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kFftType: { FftType result; if (!ParseFftType(&result)) { @@ -2466,6 +2736,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kComparisonDirection: { + ComparisonDirection result; + if (!ParseComparisonDirection(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { @@ -2491,6 +2770,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kParameterReplication: { + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(parameter_replication); + return true; + } case AttrTy::kInstructionList: { std::vector result; if (!ParseInstructionNames(&result)) { @@ -2510,19 +2798,19 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } case AttrTy::kBracedInt64ListList: { - std::vector> result; + std::vector> result; auto parse_and_add_item = [&]() { - std::vector item; + std::vector item; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &item)) { return false; @@ -2534,8 +2822,7 @@ bool HloParser::ParseAttributeHelper( parse_and_add_item)) { return false; } - static_cast>>*>( - attr_out_ptr) + static_cast>>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2610,6 +2897,95 @@ bool HloParser::ParseAttributeHelper( return true; } +// attributes ::= (',' attribute)* +bool HloParser::ParseAttributesAsProtoMessage( + const std::unordered_set& required_attrs, + tensorflow::protobuf::Message* message) { + LocTy loc = lexer_.GetLoc(); + std::unordered_set seen_attrs; + while (EatIfPresent(TokKind::kComma)) { + if (!ParseAttributeAsProtoMessageHelper(message, &seen_attrs)) { + return false; + } + } + // Check that all required attrs were seen. + for (const string& attr : required_attrs) { + if (seen_attrs.find(attr) == seen_attrs.end()) { + return Error(loc, + StrFormat("attribute %s is expected but not seen", attr)); + } + } + return true; +} + +bool HloParser::ParseAttributeAsProtoMessageHelper( + tensorflow::protobuf::Message* message, + std::unordered_set* seen_attrs) { + LocTy loc = lexer_.GetLoc(); + string name; + if (!ParseAttributeName(&name)) { + return Error(loc, "error parsing attributes"); + } + VLOG(1) << "Parsing attribute " << name; + if (!seen_attrs->insert(name).second) { + return Error(loc, StrFormat("attribute %s already exists", name)); + } + const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor(); + const tensorflow::protobuf::FieldDescriptor* fd = + descriptor->FindFieldByName(name); + if (!fd) { + string allowed_attrs = "Allowed attributes: "; + + for (int i = 0; i < descriptor->field_count(); ++i) { + if (i == 0) { + absl::StrAppend(&allowed_attrs, descriptor->field(i)->name()); + } else { + absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name()); + } + } + return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, + allowed_attrs)); + } + const tensorflow::protobuf::Reflection* reflection = message->GetReflection(); + CHECK(!fd->is_repeated()); // Repeated fields not implemented. + bool success = [&] { + switch (fd->type()) { + case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: { + bool result; + if (!ParseBool(&result)) { + return false; + } + reflection->SetBool(message, fd, result); + return true; + } + case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: { + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError( + StrFormat("expects %s type", fd->enum_type()->name())); + } + string val = lexer_.GetStrVal(); + const tensorflow::protobuf::EnumValueDescriptor* evd = + fd->enum_type()->FindValueByName(val); + if (evd == nullptr) { + return TokenError(StrFormat("expects %s type but sees: %s", + fd->enum_type()->name(), val)); + } + reflection->SetEnum(message, fd, evd); + lexer_.Lex(); + return true; + } + default: + LOG(ERROR) << "Unimplemented protocol buffer type " + << fd->DebugString(); + return false; + } + }(); + if (!success) { + return Error(loc, StrFormat("error parsing attribute %s", name)); + } + return true; +} + bool HloParser::ParseComputationName(HloComputation** value) { string name; LocTy loc = lexer_.GetLoc(); @@ -2736,7 +3112,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( absl::string_view rhs = split2[0]; absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs.length(); + const int64 rank = lhs.length(); if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2847,7 +3223,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2894,6 +3270,29 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +bool HloParser::ParseHloComputation(HloComputation** result) { + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + return ParseInstructionList(result, /*computation_name=*/"_"); + } + // This means it is a computation name. + return ParseComputationName(result); +} + +bool HloParser::ParseHloComputationList(std::vector* result) { + auto parse_and_add_item = [&]() { + HloComputation* computation; + if (!ParseHloComputation(&computation)) { + return false; + } + LOG(INFO) << "parsed computation " << computation->name(); + result->push_back(computation); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // shapelist ::= '{' shapes '}' // precision_elements // ::= /*empty*/ @@ -2917,9 +3316,9 @@ bool HloParser::ParseShapeList(std::vector* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; if (!ParseInt64(&i)) { return false; } @@ -2995,7 +3394,7 @@ bool HloParser::ParseParamList() { bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, std::vector* dynamic_dimensions) { auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; bool is_dynamic = false; if (lexer_.GetKind() == TokKind::kLeq) { is_dynamic = true; @@ -3012,22 +3411,108 @@ bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, parse_and_add_item); } -// layout ::= '{' int64_list '}' +// tiles +// ::= /*empty*/ +// ::= 'T' '(' dim_list ')' +// dim_list +// ::= /*empty*/ +// ::= (int64 | '*') (',' (int64 | '*'))* +bool HloParser::ParseTiles(std::vector* tiles) { + auto parse_and_add_tile_dimension = [&]() { + tensorflow::int64 i; + if (ParseInt64(&i)) { + tiles->back().add_dimensions(i); + return true; + } + if (lexer_.GetKind() == TokKind::kAsterisk) { + tiles->back().add_dimensions(Tile::kCombineDimension); + lexer_.Lex(); + return true; + } + return false; + }; + + do { + tiles->push_back(Tile()); + if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma, + parse_and_add_tile_dimension)) { + return false; + } + } while (lexer_.GetKind() == TokKind::kLparen); + return true; +} + +// layout ::= '{' int64_list (':' tiles element_size_in_bits)? '}' +// element_size_in_bits +// ::= /*empty*/ +// ::= 'E' '(' int64 ')' bool HloParser::ParseLayout(Layout* layout) { std::vector minor_to_major; + std::vector tiles; + tensorflow::int64 element_size_in_bits = 0; + auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; if (!ParseInt64(&i)) { return false; } minor_to_major.push_back(i); return true; }; - if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, - parse_and_add_item)) { + + if (!ParseToken(TokKind::kLbrace, + StrCat("expects layout to start with ", + TokKindToString(TokKind::kLbrace)))) { return false; } - *layout = LayoutUtil::MakeLayout(minor_to_major); + if (lexer_.GetKind() != TokKind::kRbrace) { + if (lexer_.GetKind() == TokKind::kInt) { + // Parse minor to major. + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + + if (lexer_.GetKind() == TokKind::kColon) { + lexer_.Lex(); + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") { + lexer_.Lex(); + ParseTiles(&tiles); + } + + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") { + // Parse element size in bits. + lexer_.Lex(); + if (!ParseToken(TokKind::kLparen, + StrCat("expects element size in bits to start with ", + TokKindToString(TokKind::kLparen)))) { + return false; + } + if (!ParseInt64(&element_size_in_bits)) { + return false; + } + if (!ParseToken(TokKind::kRparen, + StrCat("expects element size in bits to end with ", + TokKindToString(TokKind::kRparen)))) { + return false; + } + } + } + } + if (!ParseToken(TokKind::kRbrace, + StrCat("expects layout to end with ", + TokKindToString(TokKind::kRbrace)))) { + return false; + } + + std::vector vec_tiles(tiles.size()); + for (int i = 0; i < tiles.size(); i++) { + vec_tiles[i] = Tile(tiles[i]); + } + *layout = + LayoutUtil::MakeLayout(minor_to_major, vec_tiles, element_size_in_bits); return true; } @@ -3079,7 +3564,7 @@ bool HloParser::ParseShape(Shape* result) { lexer_.Lex(); const string message = "expects a brace-bracketed integer for sparse layout"; - tensorflow::int64 max_sparse_elements; + int64 max_sparse_elements; if (!ParseToken(TokKind::kLbrace, message) || !ParseInt64(&max_sparse_elements) || !ParseToken(TokKind::kRbrace, message)) { @@ -3099,13 +3584,20 @@ bool HloParser::ParseShape(Shape* result) { // // The open brace could either be the start of a computation or the start of a // layout for the f32[123] shape. We consider it the start of a layout if the - // next token after the open brace is a integer + // next token after the open brace is an integer or a colon. if (lexer_.GetKind() == TokKind::kLbrace && - lexer_.LookAhead() == TokKind::kInt) { + (lexer_.LookAhead() == TokKind::kInt || + lexer_.LookAhead() == TokKind::kColon)) { Layout layout; if (!ParseLayout(&layout)) { return false; } + if (layout.minor_to_major_size() != result->rank()) { + return Error( + lexer_.GetLoc(), + StrFormat("Dimensions size is %ld, but minor to major size is %ld.", + result->rank(), layout.minor_to_major_size())); + } *result->mutable_layout() = layout; } return true; @@ -3148,15 +3640,14 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, - std::vector* result) { +bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { - tensorflow::int64 number; + int64 number; if (!ParseInt64(&number)) { return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } @@ -3175,8 +3666,7 @@ bool HloParser::ParseDxD(const string& name, return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad( - std::vector>* pad) { +bool HloParser::ParseWindowPad(std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -3186,7 +3676,7 @@ bool HloParser::ParseWindowPad( } string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { - std::vector low_high; + std::vector low_high; if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -3209,7 +3699,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -3231,7 +3721,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -3283,6 +3773,22 @@ bool HloParser::ParseFftType(FftType* result) { return true; } +bool HloParser::ParseComparisonDirection(ComparisonDirection* result) { + VLOG(1) << "ParseComparisonDirection"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects comparison direction"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToComparisonDirection(val); + if (!status_or_result.ok()) { + return TokenError( + StrFormat("expects comparison direction but sees: %s", val)); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { @@ -3334,7 +3840,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { return true; } -bool HloParser::ParseInt64(tensorflow::int64* result) { +bool HloParser::ParseInt64(int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -3346,9 +3852,18 @@ bool HloParser::ParseInt64(tensorflow::int64* result) { bool HloParser::ParseDouble(double* result) { switch (lexer_.GetKind()) { - case TokKind::kDecimal: - *result = lexer_.GetDecimalVal(); + case TokKind::kDecimal: { + double val = lexer_.GetDecimalVal(); + // If GetDecimalVal returns +/-inf, that means that we overflowed + // `double`. + if (std::isinf(val)) { + return TokenError(StrCat("Constant is out of range for double (+/-", + std::numeric_limits::max(), + ") and so is unparsable.")); + } + *result = val; break; + } case TokKind::kInt: *result = static_cast(lexer_.GetInt64Val()); break; @@ -3368,6 +3883,42 @@ bool HloParser::ParseDouble(double* result) { return true; } +bool HloParser::ParseComplex(std::complex* result) { + if (lexer_.GetKind() != TokKind::kLparen) { + return TokenError("expects '(' before complex number"); + } + lexer_.Lex(); + + double real; + LocTy loc = lexer_.GetLoc(); + if (!ParseDouble(&real)) { + return Error(loc, + "expect floating-point value for real part of complex number"); + } + + if (lexer_.GetKind() != TokKind::kComma) { + return TokenError( + absl::StrFormat("expect comma after real part of complex literal")); + } + lexer_.Lex(); + + double imag; + loc = lexer_.GetLoc(); + if (!ParseDouble(&imag)) { + return Error( + loc, + "expect floating-point value for imaginary part of complex number"); + } + + if (lexer_.GetKind() != TokKind::kRparen) { + return TokenError(absl::StrFormat("expect ')' after complex number")); + } + + *result = std::complex(real, imag); + lexer_.Lex(); + return true; +} + bool HloParser::ParseBool(bool* result) { if (lexer_.GetKind() != TokKind::kw_true && lexer_.GetKind() != TokKind::kw_false) { @@ -3441,6 +3992,21 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr> HloParser::ParseParameterReplicationOnly() { + lexer_.Lex(); + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after parameter replication"); + } + return std::vector( + parameter_replication.replicated_at_leaf_buffers().begin(), + parameter_replication.replicated_at_leaf_buffers().end()); +} + StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; @@ -3556,6 +4122,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr> ParseParameterReplication(absl::string_view str) { + HloParser parser(str); + return parser.ParseParameterReplicationOnly(); +} + StatusOr ParseWindow(absl::string_view str) { HloParser parser(str); return parser.ParseWindowOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 450a54c54c156c2ae27475d145a8e83dc841b431..a96260b4d75e515a4cb23d315444142cae1b9587 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,11 +44,16 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string, -// e.g., "{replicated}". +// Parses sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., +// "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses parameter replication from str. str is supposed to contain the body of +// the parameter replication, i.e. just the rhs of the +// "parameter_replication={...}" attribute string, e.g., "{true, false}". +StatusOr> ParseParameterReplication(absl::string_view str); + // Parses the result of window_util::ToString(const Window&). StatusOr ParseWindow(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 6ba16cc82ac1da2a30610d9dfb56cacc100ae05f..6f4171bca82b1c287cfe1fef16e95f96215ac702 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -63,6 +63,19 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } +)" +}, +// parameter replication +{ +"ParamReplication", +R"(HloModule param_replication_module + +ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) { + %a = f32[] parameter(0), parameter_replication={true} + %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} + ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) +} + )" }, // pred constant @@ -209,7 +222,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} - %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} + %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated} ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } @@ -279,7 +292,7 @@ R"(HloModule WhileWithScalarS32Result_module %condition.v3 (prev.2: s32[]) -> pred[] { %constant.1 = s32[] constant(5) %prev.2 = s32[] parameter(0) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT } ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { @@ -461,7 +474,7 @@ R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -487,7 +500,7 @@ R"(HloModule select_and_scatter_scalar %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -562,6 +575,19 @@ ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] { ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2} } +)" +}, +// Triangular solve +{ +"TriangularSolve", +R"(HloModule TriangularSolve_module + +ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] { + %a.1 = f32[4,4]{1,0} parameter(0) + %b.2 = f32[3,4]{1,0} parameter(1) + ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE +} + )" }, // Dynamic slice @@ -782,7 +808,17 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { R"(HloModule sparse_f32 ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) + ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3}) +} + +)" +}, +{ +"SparseC128", +R"(HloModule sparse_c128 + +ENTRY %sparse () -> c128[2,3,4] { + ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)}) } )" @@ -894,6 +930,58 @@ ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} } +)" +}, +// Parse c64 literal +{ +"ParseC64Literal", +R"(HloModule ParseC64Literal + +ENTRY %ParseC64Literal () -> c64[2] { + ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)}) +} + +)" +}, +// Parse c128 literal +{ +"ParseC128Literal", +R"(HloModule ParseC128Literal + +ENTRY %ParseC128Literal () -> c128[2] { + ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)}) +} + +)" +}, +// Indexed Conditional +{ +"IndexedConditional", +R"(HloModule indexed_conditional + +%Negate (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + ROOT %negate = f32[] negate(f32[] %x) +} + +%Identity (y: f32[]) -> f32[] { + %y = f32[] parameter(0) + ROOT %copy = f32[] copy(f32[] %y) +} + +%Floor (z: f32[]) -> f32[] { + %z = f32[] parameter(0) + ROOT %floor = f32[] floor(f32[] %z) +} + +ENTRY %Parameters1.v4 () -> f32[] { + %constant = s32[] constant(1) + %constant.1 = f32[] constant(56) + %constant.2 = f32[] constant(12) + %constant.3 = f32[] constant(13) + ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor} +} + )" }, }); @@ -949,7 +1037,7 @@ R"(HloModule TupleReduce max_argmax { value = f32[] parameter(2) prev_max = f32[] parameter(0) - is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + is_next_larger = pred[] compare(value, prev_max), direction=GE max = f32[] select(is_next_larger, value, prev_max) index = s32[] parameter(3) prev_argmax = s32[] parameter(1) @@ -1015,9 +1103,15 @@ ENTRY ReducePrecision { "SortKey", R"(HloModule sort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY Sort { x = f32[1024]{0} parameter(0) - ROOT sorted = f32[1024]{0} sort(x), dimensions={0} + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare } )" @@ -1027,10 +1121,18 @@ ENTRY Sort { "SortKeyValue", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY Sort { keys = f32[1024]{0} parameter(0) values = s32[1024]{0} parameter(1) - ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1040,9 +1142,15 @@ ENTRY Sort { "SortKeyR2", R"(HloModule sort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY Sort { x = f32[1024,16]{0,1} parameter(0) - ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0} + ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare } )" @@ -1052,10 +1160,18 @@ ENTRY Sort { "SortKeyValueR2", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY Sort { keys = f32[1024,16]{0,1} parameter(0) values = s32[1024,16]{0,1} parameter(1) - ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1065,20 +1181,80 @@ ENTRY Sort { "SortManyValues", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.2.lhs = u32[] parameter(4) + p.2.rhs = u32[] parameter(5) + p.3.lhs = f32[] parameter(6) + p.3.rhs = f32[] parameter(7) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY Sort { keys = f32[1024,16]{0,1} parameter(0) values.0 = s32[1024,16]{0,1} parameter(1) values.1 = u32[1024,16]{0,1} parameter(2) values.2 = f32[1024,16]{0,1} parameter(3) - ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare +} + +)" +}, +// Sort (Key) is_stable=true +{ +"SortKeyStable", +R"(HloModule sort + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare } )" }, -// Conditional +// Indexed Conditional { -"Conditional", -R"(HloModule conditional +"IndexedConditional", +R"(HloModule indexed_conditional + +Negate { + x = f32[] parameter(0) + ROOT negate = f32[] negate(x) +} + +Identity { + y = f32[] parameter(0) + ROOT copy = f32[] copy(y) +} + +Floor { + z = f32[] parameter(0) + ROOT floor = f32[] floor(z) +} + +ENTRY Parameters1.v4 { + constant = s32[] constant(1) + constant.1 = f32[] constant(56) + constant.2 = f32[] constant(12) + constant.3 = f32[] constant(13) + ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor} +} + +)" +}, +// Predicated Conditional +{ +"PredicatedConditional", +R"(HloModule pred_conditional Negate { x = f32[] parameter(0) @@ -1248,6 +1424,17 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } +)" +}, +// replica-id +{ +"ReplicaId", +R"(HloModule replica-id + +ENTRY Replica-id { + ROOT replica-id = u32[] replica-id() +} + )" }, // Iota @@ -1277,10 +1464,18 @@ ENTRY Computation { "ScheduledModule", R"(HloModule scheduled_module, is_scheduled=true +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY Sort { keys = f32[1024]{0} parameter(0) values = s32[1024]{0} parameter(1) - ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1314,6 +1509,30 @@ ENTRY AddDependency { ROOT sum = f32[] add(neg, exp) } +)" +}, + +// A module containing constants equal to the min/max values of various data +// types. +{ +"MinMaxValues", +R"(HloModule MinMaxValues + +ENTRY MinMaxValues { + x.s8 = s8[2]{0} constant({-128, 127}) + x.s16 = s16[2]{0} constant({-32768, 32767}) + x.s32 = s32[2]{0} constant({-2147483648, 2147483647}) + x.u8 = u8[2]{0} constant({0, 255}) + x.u16 = u16[2]{0} constant({0, 65535}) + x.u32 = u32[2]{0} constant({0, 4294967295}) + x.f16 = f16[2]{0} constant({-65504, 65504}) + x.bf16 = bf16[2]{0} constant({-3.38953e+38, 3.38953e+38}) + x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38}) + x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308}) + x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)}) + ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)}) +} + )" }, }); @@ -1340,7 +1559,7 @@ class HloParameterizedParserTest protected: // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and - // checks that the it equals the original string. + // checks that it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -1437,7 +1656,7 @@ TEST_F(HloParserTest, WrongOperandsSize) { ENTRY %blabla (x: f32[]) -> pred[] { %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x) + %eq = pred[]{} compare(f32[]{} %x), direction=EQ } )"; @@ -1449,7 +1668,7 @@ TEST_F(HloParserTest, OperandNotFound) { const string original = R"(HloModule operand_not_found: ENTRY %blabla (x: f32[]) -> pred[] { %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) + %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ } )"; auto result = ParseHloString(original); @@ -1543,6 +1762,37 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantBf16NoOverflow) { + // 65505 is in range for bf16. + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = bf16[] constant(-65505) + })"; + EXPECT_EQ(Status::OK(), ParseHloString(original).status()); +} + +TEST_F(HloParserTest, ConstantBf16Overflow) { + // 1e100 is out of range for bf16. + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = bf16[] constant(1e100) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "out of range"); +} + +TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65505}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "is out of range for literal's primitive type F16"); +} + TEST_F(HloParserTest, ConstantUnsignedUnderflow) { const string original = R"( HloModule ConstantUnsignedUnderflow_module @@ -1577,6 +1827,46 @@ TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { EXPECT_NE(Status::OK(), result.status()); } +TEST_F(HloParserTest, ConstantC64Overflow) { + const string original = R"( + HloModule test_module + ENTRY test () -> c64[] { + ROOT c = c64[] constant((1e100, 0)) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantC64Underflow) { + const string original = R"( + HloModule test_module + ENTRY test () -> c64[] { + ROOT c = c64[] constant((0, -1e100)) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantF64Overflow) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f64[] constant(1.8e308) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantF64Underflow) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f64[] constant(-1.8e308) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module @@ -1592,6 +1882,19 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, ShortConstant) { + const string original = R"(HloModule ShortCOnstant_module + +ENTRY %ShortConstant.v4 () -> f32[67,89] { + ROOT %constant.1 = f32[67,89]{1,0} constant({...}) +} + +)"; + auto result = ParseHloString(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); +} + TEST_F(HloParserTest, AttibutesAnyOrder) { const string original = R"(HloModule any_order_module @@ -2074,6 +2377,31 @@ TEST(HloParserSingleOpTest, CanonicalOpWithNested) { text); } +TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) { + const string text = + R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={ +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0) +}, +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0) +}, +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0) +} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + TEST(HloParserSingleOpTest, SingleOpWithNested) { const string text = R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= @@ -2291,6 +2619,46 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } +TEST_F(HloParserTest, OutOfRangeSparseIndex) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[100]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, NegativeSparseIndex) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({-1: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, SparseIndexWithRankTooLarge) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[0, 0]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, SparseIndexWithRankTooSmall) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5, 5]sparse{10} constant({[0]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + TEST_F(HloParserTest, ParseShapeStringR2F32) { string shape_string = "f32[123,456]"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); @@ -2335,6 +2703,60 @@ TEST_F(HloParserTest, ParseShapeStringWithLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) { + // One tile. + string shape_string = "f32[123,456]{0,1:T(2,128)}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Tile with negative dimension size for combining dimensions. + shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = + ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2}, + {Tile({2, Tile::kCombineDimension, 128})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Two tiles. + shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout( + BF16, {123, 456, 789}, {2, 1, 0}, + {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Tile with element size in bits. + shape_string = "pred[123,456]{1,0:T(2,128)E(1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, + {Tile({2, 128})}, 1); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Element size in bits without tile. + shape_string = "pred[123,456]{1,0:E(1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Wrong minor_to_major. + shape_string = "f32[123,456,789]{1:T(2, * , 128)}"; + auto result = ParseShape(shape_string); + ExpectHasSubstr(result.status().error_message(), + "Dimensions size is 3, but minor to major size is 1."); +} + TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { string shape_string = "f32[123,456]sparse{10}"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); @@ -2391,5 +2813,24 @@ TEST_F(HloParserTest, ParseDynamicTuple) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST_F(HloParserTest, NegativeParameterNumber) { + const string hlo_string = "par0 = f32[3,5] parameter(-1)"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter number must be >= 0")); +} + +TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) { + const string hlo_string = + "par0 = (f32[3,5], f32[]) parameter(0), " + "parameter_replication={true,false,true}"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter has 2 leaf buffers, but " + "parameter_replication has 3 elements")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 791b1a97b0b82edf19ff1588fd8d5d996ac0fef4..35dc9c0029f9871334cb500c6b71f0c86ab136d7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -39,9 +40,36 @@ class HloPassFix : public Pass { int64 iteration_count = 0; int64 limit = std::max(static_cast(1000), module->instruction_count()); + VLOG(3) << "Running HloPassFix."; while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly high number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } + } + return changed; + } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = 1000; + for (const HloModule* module : module_group->modules()) { + limit = std::max(limit, module->instruction_count()); + } + VLOG(3) << "Running HloPassFix."; + while (changed_this_iteration) { + TF_ASSIGN_OR_RETURN(changed_this_iteration, + Pass::RunOnModuleGroup(module_group)); + changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == limit) { LOG(ERROR) diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index ae8c08cf1d16ad6738962f3be7c1b5512110b1d1..0ca04cf8c58fd835449df035f9c486825b842942 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -99,30 +100,8 @@ std::vector HloPassPipeline::GetEnabledPasses( void HloPassPipeline::MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, absl::string_view before_pass_name) { - const string& proto_dump_path = - module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!proto_dump_path.empty()) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new absl::flat_hash_map(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; - - const string filename = SanitizeFileName( - absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), - pass_number, name(), after_pass_name)); - - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( - MakeHloProto(module), proto_dump_path, filename)); - } - - const string message = - absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); - hlo_graph_dumper::MaybeDumpHloModule(module, message); - VLOG(3) << "HLO " << message << ":"; - VLOG(3) << module.entry_computation_layout().ToString(); - XLA_VLOG_LINES(3, module.ToString()); + DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name, + module); } void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 0fced7f15bdaf1dbe349e3b0fc6ada68393c6512..af07eb83a5c48d3380bf527ff3292e80bb441698 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -54,7 +54,9 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( } bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector->OrWith(GetBitVector(input)); + if (input != instruction) { + bit_vector->OrWith(GetBitVector(input)); + } } } @@ -77,28 +79,51 @@ std::unique_ptr HloReachabilityMap::Build( const HloComputation* computation) { const auto& all = computation->MakeInstructionPostOrder(); auto result = absl::make_unique(all); - auto channel_dependency_map = computation->ComputeChannelDependencies(); + auto channel_group = computation->ComputeChannelDependencies(); - std::vector inputs; for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); + std::vector inputs; + const auto add_input = [&channel_group, &inputs](HloInstruction* input) { + inputs.push_back(input); + if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) { + auto it = channel_group.find(*input->all_reduce_id()); + if (it != channel_group.end()) { + inputs.insert(inputs.end(), it->second.begin(), it->second.end()); + } + } + }; + + const auto add_dependencies = [&add_input](const HloInstruction* hlo) { + for (HloInstruction* operand : hlo->operands()) { + add_input(operand); + } + for (HloInstruction* predecessor : hlo->control_predecessors()) { + add_input(predecessor); + } + }; + + add_dependencies(hlo); switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(hlo->channel_id()); + if (it != channel_group.end()) { + for (HloInstruction* channel : it->second) { + if (channel->opcode() == HloOpcode::kSend) { + add_input(channel); + } + } } break; } case HloOpcode::kAllReduce: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(all_reduce_id.value()); + if (it != channel_group.end()) { + for (HloInstruction* all_reduce : it->second) { + add_dependencies(all_reduce); + } } } break; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 5c1793933ca2f927d3172de7135a0e583f70d8f9..a175e4643de2ac6ce07ac00da914d7ab7acca541 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -188,7 +188,8 @@ class InstructionList { Item* CreateItem(HloInstruction* inst) { Item* item = new Item; item->instruction = inst; - CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice"; + CHECK(item_map_.insert({inst, item}).second) + << "inserting inst twice " << inst->name(); return item; } @@ -426,11 +427,12 @@ class MemoryUsageTracker { // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, ItemList&& rematerialized_uses) { - CHECK(original_buffer.defining_instruction->placed); - CHECK(!original_buffer.has_indirect_uses); - CHECK(!original_buffer.live_out); + CHECK(original_buffer.defining_instruction->placed) + << original_buffer.defining_instruction->instruction->name(); + CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString(); + CHECK(!original_buffer.live_out) << original_buffer.ToString(); for (Item* use : rematerialized_uses) { - CHECK(!use->placed); + CHECK(!use->placed) << use->instruction->name(); } return NewBuffer(remat_item, original_buffer.size, std::move(rematerialized_uses), /*live_out=*/false, @@ -683,8 +685,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, << ", remat_instruction = " << remat_item->instruction->name(); TF_RET_CHECK(in_progress_item_ != nullptr); - TF_RET_CHECK(original_item->placed); - TF_RET_CHECK(!remat_item->placed); + TF_RET_CHECK(original_item->placed) << original_item->instruction->name(); + TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name(); // Construct the list of buffers used and defined by the rematerialization. remat_item->buffers_used = original_item->buffers_used; @@ -713,7 +715,7 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, ItemList unplaced_users; for (Item* user : old_buffer.users) { if (user->placed) { - CHECK(IsFinished(user)); + CHECK(IsFinished(user)) << user->instruction->name(); placed_users.push_back(user); } else { unplaced_users.push_back(user); @@ -1098,7 +1100,7 @@ StatusOr HloRematerialization::RematerializeComputation( Item* successor_item = instruction_list.GetItem(successor); // Assert to make sure we never remat an operation with control // successor already placed. - CHECK(!successor_item->placed); + CHECK(!successor_item->placed) << successor_item->instruction->name(); place_before.push_back(successor_item); } instruction_list.InsertBeforeInstructions(remat_item, place_before); @@ -1168,7 +1170,7 @@ StatusOr HloRematerialization::RematerializeComputation( // Verify some invariants on the memory tracker. CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto* instruction : computation->instructions()) { - CHECK(memory_tracker.IsPlaced(instruction)); + CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name(); } VLOG(1) << "In computation " << computation->name() << " rematerialized " diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index d7d66ae1c4592723ca991d5ee971fa72cc1af90a..837367745866386921368cc7d49105029aadaf98 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -168,6 +169,35 @@ StatusOr HloRunner::Execute(std::unique_ptr module, /*profile=*/profile); } +StatusOr HloRunner::Execute( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*executable=*/executable.get(), + /*arguments=*/argument_buffers, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr HloRunner::Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return Execute( + /*module=*/std::move(executable), + /*arguments=*/argument_pointers, + /*profile=*/profile); +} + StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, const absl::Span arguments, bool run_hlo_passes, @@ -206,7 +236,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { // Get service run options. @@ -225,7 +255,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { std::vector argument_pointers; @@ -240,14 +270,11 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr> HloRunner::ExecuteReplicated( - std::unique_ptr module, - const ReplicatedExecuteOptions& options) { + std::unique_ptr module, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, bool use_threads) { TF_ASSIGN_OR_RETURN( std::unique_ptr executable, CreateExecutable(std::move(module), options.run_hlo_passes)); - TF_ASSIGN_OR_RETURN( - DeviceAssignment device_assignment, - backend().computation_placer()->AssignDevices(options.num_replicas, 1)); std::vector> streams; std::vector service_run_options; @@ -264,13 +291,13 @@ StatusOr> HloRunner::ExecuteReplicated( std::vector> argument_buffer_slices; int64 index = 0; for (int64 i = 0; i < options.num_replicas; ++i) { - int64 device = device_assignment(i, 0); + int64 device = (*device_assignment)(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( - device, streams.back().get(), &device_assignment)); + device, streams.back().get(), device_assignment)); // Copy arguments to device. for (const Literal* argument : options.arguments) { @@ -300,7 +327,7 @@ StatusOr> HloRunner::ExecuteReplicated( } if (options.infeed != nullptr) { for (int64 i = 0; i < options.num_replicas; ++i) { - int64 device = device_assignment(i, 0); + int64 device = (*device_assignment)(i, 0); pool->Schedule([this, device, &options]() { se::StreamExecutor* executor = backend().stream_executor(device).ValueOrDie(); @@ -318,7 +345,7 @@ StatusOr> HloRunner::ExecuteReplicated( } if (ShapeUtil::IsInitialized(options.outfeed_shape)) { for (int64 i = 0; i < options.num_replicas; ++i) { - int64 device = device_assignment(i, 0); + int64 device = (*device_assignment)(i, 0); pool->Schedule([this, device, &options]() { se::StreamExecutor* executor = backend().stream_executor(device).ValueOrDie(); @@ -340,9 +367,39 @@ StatusOr> HloRunner::ExecuteReplicated( } LOG(INFO) << "Replicated execution started"; - TF_ASSIGN_OR_RETURN(std::vector results, - executable->ExecuteOnStreams(service_run_options, - argument_buffer_slices)); + std::vector results; + if (!use_threads) { + TF_ASSIGN_OR_RETURN(results, + executable->ExecuteOnStreams(service_run_options, + argument_buffer_slices)); + } else { + tensorflow::mutex mutex; + std::vector> thread_results( + options.num_replicas); + { + LOG(INFO) << "Creating thread pool for " << options.num_replicas + << " replicas"; + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "replicas", options.num_replicas); + for (int64 i = 0; i < options.num_replicas; ++i) { + pool.Schedule([&, i] { + auto result = executable->ExecuteOnStream( + &service_run_options[i], argument_buffer_slices[i], nullptr); + tensorflow::mutex_lock lock(mutex); + thread_results[i] = std::move(result); + }); + } + + // Note: the thread pool destructor guarantees it completes all work + // before we leave this scope. + } + for (auto& thread_result : thread_results) { + if (!thread_result.ok()) { + return thread_result.status(); + } + results.push_back(std::move(thread_result).ValueOrDie()); + } + } LOG(INFO) << "Replicated execution terminated"; std::vector exec_results; @@ -356,6 +413,16 @@ StatusOr> HloRunner::ExecuteReplicated( return std::move(exec_results); } +StatusOr> HloRunner::ExecuteReplicated( + std::unique_ptr module, const ReplicatedExecuteOptions& options, + bool use_threads) { + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + return ExecuteReplicated(std::move(module), options, &device_assignment, + use_threads); +} + StatusOr> HloRunner::CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { if (run_hlo_passes) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index bb792cf8c9825ff67ca33bbcf2c3c32b1a0ecb85..0c1ae3a0abbd1be8e92de79c88ca4073ca09f569 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -60,7 +60,7 @@ class HloRunner { // The number of times the infeed literal should be fed to the HLO module. // For a clean exit, this should match the iterations-per-loop parameter // used when generating the HLO module proto (that is usually the main - // while bounary counter). A value higher then iterations-per-loop would + // while boundary counter). A value higher then iterations-per-loop would // lead to infeed threads feeding to a gone computation, while a lower // value would trigger a stuck ExecuteReplicated() call (the computation // will be trying to infeed data which will never come). @@ -124,6 +124,14 @@ class HloRunner { bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + StatusOr Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( @@ -136,13 +144,16 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // In the following two calls, "executable" is not a unique_ptr to allow + // reuse of the Executable. This call may update the profile information in + // *executable. StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); @@ -154,9 +165,19 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. + // + // use_threads indicates whether this replicated computation will be executed + // with a thread-per-replica, vs using an implicitly async call such as + // Executable::ExecuteOnStreams. + StatusOr> ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options, bool use_threads = false); + + // Same as above, but with specified device assignment. StatusOr> ExecuteReplicated( std::unique_ptr module, - const ReplicatedExecuteOptions& options); + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, bool use_threads = false); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 0e56e6f760e35ddcb45c6f58771d78405a09acfe..ecc8dbe6560fdfaa4b7801b9f482b3482f2eb083 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -228,7 +228,7 @@ HloModule UpdateScheduleWithMultipleComputations %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %WhileLoop () -> s32[] { @@ -297,7 +297,7 @@ HloModule UpdateScheduleWithMultipleComputations %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %WhileLoop () -> s32[] { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 37cc146bd7a6f2aef9373bd4afd8572ffac6473c..f1d7e60f2b5a68408f6d428a0ec47fba3c9c4f12 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -96,13 +96,13 @@ string HloSharding::ToString() const { if (replicated_) { return "{replicated}"; - } else if (maximal_) { + } + if (maximal_) { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); - } else { - return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), - "]", StrJoin(tile_assignment_, ","), "}"); } + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", + StrJoin(tile_assignment_, ","), "}"); } bool HloSharding::UsesDevice(int64 device) const { @@ -328,8 +328,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, status = tensorflow::errors::InvalidArgument( StrCat("core ", core, " is not unique in tile assignment")); } + seen_cores.insert(core); } - seen_cores.insert(core); }); if (!status.ok()) { return status; @@ -347,7 +347,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } - // The correct constructor have to be used to create tile maximal shardings. + // The correct constructor has to be used to create tile maximal shardings. if (tile_assignment_.num_elements() == 1) { return tensorflow::errors::InvalidArgument( "Tile assignment only contains a single device. If a replicated " diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 5789ae09988d2a85247c5b8c037a172b3699f3b7..dd57ea83f1cb33aa052facb607bc040d2e708633 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -118,7 +118,7 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; - // Retrieves an histogram of the devices used by the sharding. The returned + // Retrieves a histogram of the devices used by the sharding. The returned // map has the device number as key, and the occurrence count as value. // If a sharding does not have a device, it will not be incuded in the // histogram. The count argument, if not nullptr, will receive the total @@ -260,6 +260,19 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; + // This field is only used if replicated_ is false. If maximal_ is true, then + // the field contains a rank 1 array with a single element, which is the + // device the HLO is assigned to. If maximal_ is false, the field contains an + // array with the same rank as the corresponding HLO. The dimension sizes of + // the array describe the number of ways the HLO is partitioned along each + // dimension. The values of the array specify which device each tile of + // the HLO is assigned to. The index of each value determines which tile it + // takes. + // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is + // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and + // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the + // tile that contains the 2nd half of dimension 1 and the 1st half of + // dimension 3. Array tile_assignment_; // Only non-empty when tuple_ is true. If a tuple is empty then one entry is // present for the root. This is a flattened list of all the leaf shardings in diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 80634677e78e4a35dcb9bf7de018a88122c3c030..9e234e025586ff14f99da73afc5610c627303a36 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -84,7 +84,7 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should fail because of more devices used then `num_device`. + // Test should fail because of more devices used than `num_device`. HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index c1073911ea9dc3811c195e27bcbae9b00929ad17..6c0a1926c414819933cdf4c142702bcf0b65f2f4 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -87,17 +86,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", - module->config().debug_options()); - } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", - module->config().debug_options()); - } EXPECT_EQ(2, module->computation_count()); EXPECT_EQ(x->to_apply(), y->to_apply()); } @@ -126,17 +115,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", - module->config().debug_options()); - } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", - module->config().debug_options()); - } EXPECT_EQ(2, module->computation_count()); EXPECT_EQ(x->to_apply(), y->to_apply()); } @@ -166,17 +145,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", - module->config().debug_options()); - } EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", - module->config().debug_options()); - } EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc deleted file mode 100644 index c1f69db74eafb7743e85f499f2f4828ed0375501..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -LIcensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using absl::StrAppend; -using absl::StrCat; -using tensorflow::GraphDef; -using tensorflow::NodeDef; -using tensorflow::TensorShapeProto; - -string GetOpDefName(const HloInstruction* instruction) { - string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok - name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); - - if (instruction->opcode() == HloOpcode::kFusion) { - string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, absl::string_view(fusion_name).substr(1)); - } - return name; -} - -TensorShapeProto GetTensorShape(const HloInstruction* instruction) { - TensorShapeProto tensor_shape; - const Shape& shape = instruction->shape(); - for (auto dim : shape.dimensions()) { - tensor_shape.add_dim()->set_size(dim); - } - return tensor_shape; -} - -string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } - -void CleanNodeName(string* name) { - name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); - const string chars_to_replace = "<>[]"; - auto pred = [&](char c) { - return absl::c_linear_search(chars_to_replace, c); - }; - std::replace_if(name->begin(), name->end(), pred, '_'); -} - -} // namespace - -HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) - : debug_options_(debug_options) {} - -Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { - VLOG(2) << "Adding computation " << computation.name(); - for (auto embedded : computation.MakeEmbeddedComputationsList()) { - for (auto* instruction : embedded->instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - } - for (auto* instruction : computation.instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - return Status::OK(); -} - -const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } - -const string& HloTfGraphBuilder::GetNodeNameForInstruction( - const HloInstruction* instruction) { - if (ContainsKey(instruction_to_node_name_, instruction)) { - return instruction_to_node_name_[instruction]; - } - auto append = [](string* str, const string& other) { - if (str->empty()) { - *str = other; - } else if (!other.empty()) { - StrAppend(str, "/", other); - } - }; - string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes()) { - auto device = instruction->sharding_unique_device(); - if (device) { - node_name = StrCat("dev", *device); - } - } - // If an instruction is fused, put it in the subgraph of the fusion; - // otherwise, put it in the computation subgraph. - const HloComputation* computation = instruction->parent(); - if (computation->IsFusionComputation()) { - append(&node_name, - GetNodeNameForInstruction(computation->FusionInstruction())); - } else { - append(&node_name, computation->name()); - if (!instruction->metadata().op_name().empty()) { - // Always make computations contain TF ops but not the other way around. - append(&node_name, instruction->metadata().op_name()); - } - } - string instruction_name = instruction->name(); - if (instruction->opcode() == HloOpcode::kParameter) { - StrAppend(&instruction_name, ".", instruction->parameter_number()); - } - append(&node_name, instruction_name); - CleanNodeName(&node_name); - auto ret = - instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); - CHECK(ret.second); - return ret.first->second; -} - -void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, - NodeDef* node_def) const { - auto& attrs = *node_def->mutable_attr(); - - // Set the number of arguments for instructions that have variadic operands. - if (HloOpcodeIsVariadic(instruction->opcode())) { - tensorflow::AttrValue attr_value; - attr_value.set_i(instruction->operands().size()); - attrs["arg_num"] = attr_value; - } - - // Set the node type. - attrs["type"].set_s( - xla::PrimitiveType_Name(instruction->shape().element_type())); - - // Set the framework op (e.g. Tensorflow op) that generated this XLA op. - attrs["tf_op_type"].set_s(instruction->metadata().op_type()); - attrs["tf_op_name"].set_s(instruction->metadata().op_name()); - - // Set the shape of the output tensor. "_output_shapes" is a special attribute - // name used by Tensorboard for shapes of output tensors. - tensorflow::AttrValue shapes; - *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); - attrs["_output_shapes"] = shapes; - - // Set the layout. - if (LayoutUtil::HasLayout(instruction->shape())) { - string layout_string; - if (instruction->shape().IsTuple()) { - // For tuples, emit the full shape because the layout of a tuple is not - // represented in a single Layout field. - layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); - } else { - layout_string = StrCat( - "{", - absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), - "}"); - } - attrs["layout"].set_s(layout_string); - } - - // Set op-specific attributes. - switch (instruction->opcode()) { - case HloOpcode::kConcatenate: - case HloOpcode::kBroadcast: - case HloOpcode::kReduce: - case HloOpcode::kReverse: - case HloOpcode::kTranspose: - for (auto dim : instruction->dimensions()) { - attrs["dims"].mutable_list()->add_i(dim); - } - break; - case HloOpcode::kGetTupleElement: - attrs["index"].set_i(instruction->tuple_index()); - break; - case HloOpcode::kRng: - attrs["dist"].set_s( - RandomDistribution_Name(instruction->random_distribution())); - break; - case HloOpcode::kConstant: - if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s(instruction->literal().GetAsString({})); - } - break; - case HloOpcode::kCustomCall: - attrs["custom_call_target"].set_s(instruction->custom_call_target()); - break; - case HloOpcode::kSend: - case HloOpcode::kRecv: - attrs["channel_id"].set_i(instruction->channel_id()); - break; - default: - break; - } -} - -Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { - if (!visited_instructions_.insert(instruction).second) { - // Skip instructions that have already been added. - return Status::OK(); - } - - NodeDef* node_def = graph_def_.add_node(); - node_def->set_name(GetNodeNameForInstruction(instruction)); - node_def->set_op(GetOpDefName(instruction)); - - auto device = instruction->sharding_unique_device(); - if (device) { - node_def->set_device(GetDeviceName(*device)); - } - SetNodeAttrs(instruction, node_def); - if (instruction->opcode() == HloOpcode::kFusion) { - for (auto* fused_instruction : instruction->fused_instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); - } - } - // Add all edges including control edges. - for (unsigned i = 0; i < instruction->operands().size(); ++i) { - *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); - } - // Called computations are control dependencies. - for (const auto* called_computation : instruction->called_computations()) { - *node_def->add_input() = StrCat( - "^", GetNodeNameForInstruction(called_computation->root_instruction())); - } - return Status::OK(); -} - -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h deleted file mode 100644 index c4876b852e32d34693202f4023aa20ad2b301ffd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" - -namespace xla { -namespace hlo_graph_dumper { - -// This constructs a tensorflow graph for HLO computations. -class HloTfGraphBuilder { - public: - HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); - - // Adds a computation to the graph. - Status AddComputation(const HloComputation& computation); - - const tensorflow::GraphDef& GetGraphDef() const; - - private: - // Gets the node name of an instruction. The node name is hierarchical. For - // example, if an instruction is fused, it will be put in a subgraph of the - // fusion instruction. - const string& GetNodeNameForInstruction(const HloInstruction* instruction); - - void SetNodeAttrs(const HloInstruction* instruction, - tensorflow::NodeDef* node_def) const; - - Status AddInstruction(const HloInstruction* instruction); - - DebugOptions debug_options_; - tensorflow::GraphDef graph_def_; - // This records instructions that have been visited. - std::unordered_set visited_instructions_; - // A cache that maps instruction to the node name. - std::unordered_map instruction_to_node_name_; -}; - -} // namespace hlo_graph_dumper -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc deleted file mode 100644 index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using ::tensorflow::GraphDef; - -class HloTfGraphBuilderTest : public HloTestBase { - protected: - HloTfGraphBuilderTest() {} - HloTfGraphBuilder generator_; - - // Create a computation which takes a scalar and returns its negation. - std::unique_ptr CreateNegateComputation() { - auto builder = HloComputation::Builder("Negate"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - return builder.Build(); - } - - // Creates a computation which calls map with the given computation. - std::unique_ptr CreateMapComputation( - HloComputation *map_computation) { - auto builder = HloComputation::Builder("Map"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map_computation)); - return builder.Build(); - } - Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); -}; - -static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, - const string &attr_name) { - auto attr = node.attr().find(attr_name); - CHECK(attr != node.attr().end()); - return attr->second; -} - -TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { - auto builder = HloComputation::Builder("Concatenate"); - Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, shape, "param1")); - builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - const auto &node = graph_def.node(2); - EXPECT_EQ(node.name(), "Concatenate/concatenate"); - - // Check dimensions. - auto dims_value = GetNodeAttr(node, "dims"); - EXPECT_EQ(dims_value.list().i_size(), 1); - EXPECT_EQ(dims_value.list().i(0), 1); - - // Check shapes. - auto shape_value = GetNodeAttr(node, "_output_shapes"); - EXPECT_EQ(shape_value.list().shape_size(), 1); - EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); -} - -TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { - auto builder = HloComputation::Builder("Const"); - HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - OpMetadata metadata; - metadata.set_op_name("x"); - metadata.set_op_type("y"); - instruction->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 1); - const auto &node = graph_def.node(0); - EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); - EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); -} - -TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { - auto negate_computation = CreateNegateComputation(); - TF_CHECK_OK(generator_.AddComputation(*negate_computation)); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 2); - EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); - EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); - EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); - EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); - EXPECT_EQ(graph_def.node(1).input_size(), 1); - EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); -} - -TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - auto ge = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - OpMetadata metadata; - metadata.set_op_name("x/y"); - metadata.set_op_type("Y"); - ge->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { - // Create computations with a diamond-shaped callgraph. - auto negate_computation = CreateNegateComputation(); - auto map1_computation = CreateMapComputation(negate_computation.get()); - auto map2_computation = CreateMapComputation(negate_computation.get()); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); - auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); - auto computation = builder.Build(); - TF_CHECK_OK(generator_.AddComputation(*computation)); - EXPECT_GT(generator_.GetGraphDef().node_size(), 0); -} - -} // namespace -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 36340d3d78e059aae7ac8b341cf1b87384818ba9..375ae2c477d7a0aea8445d9c237991eee3353a04 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -50,6 +50,7 @@ bool IsCallerInstruction(HloInstruction* hlo) { case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: return true; default: @@ -57,15 +58,6 @@ bool IsCallerInstruction(HloInstruction* hlo) { } } -Status ShapeVerifier::Preprocess(HloInstruction* hlo) { - if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) { - return InternalError( - "Called computations specified for non-caller instruction %s", - hlo->ToString()); - } - return VerifyNotSparse(hlo->shape()); -} - namespace { Status CheckOperandCount(const HloInstruction* hlo, int expected) { @@ -90,6 +82,21 @@ Status CheckParameterCount(const HloInstruction* calling_instruction, } // namespace +Status ShapeVerifier::Preprocess(HloInstruction* hlo) { + if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) { + return InternalError( + "Called computations specified for non-caller instruction %s", + hlo->ToString()); + } + TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape())); + + absl::optional arity = HloOpcodeArity(hlo->opcode()); + if (arity) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); + } + return Status::OK(); +} + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -121,14 +128,12 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { - TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { - TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -139,7 +144,6 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { - TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -148,7 +152,6 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { - TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( @@ -159,7 +162,6 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { } Status ShapeVerifier::HandleFft(HloInstruction* fft) { - TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -167,6 +169,21 @@ Status ShapeVerifier::HandleFft(HloInstruction* fft) { return CheckShape(fft, expected); } +Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferTriangularSolveShape( + hlo->operand(0)->shape(), hlo->operand(1)->shape(), + hlo->triangular_solve_options())); + return CheckShape(hlo, expected); +} + +Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); + TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape( + hlo->operand(0)->shape())); + return CheckShape(hlo, expected); +} + Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { std::vector operand_shapes; for (const HloInstruction* operand : crs->operands()) { @@ -184,14 +201,16 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { + return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); +} + Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { - TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -225,7 +244,6 @@ Status ShapeVerifier::CheckOperandAndParameter( } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -236,7 +254,6 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -312,7 +329,6 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { - TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); @@ -323,13 +339,48 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError("Expected at least 1 operand for %s instruction: %s", HloOpcodeString(sort->opcode()), sort->ToString()); } + HloComputation* compare = sort->to_apply(); + + // Check that the 'compare' computation returns a PRED. + Shape compare_shape = compare->root_instruction()->shape(); + if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { + return InternalError( + "The Sort compare computation shape does not lead to a scalar " + "predicate shape: %s", + StringifyShape(compare_shape)); + } + + // Check that the number of parameters of the 'compare' computation is + // correct. + TF_RETURN_IF_ERROR( + CheckParameterCount(sort, compare, sort->operand_count() * 2)); + + // Verify that the operands of the compare computation have the correct scalar + // shapes. + for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters(); + ++parameter_idx) { + int64 operand_idx = parameter_idx / 2; + Shape expected_scalar_shape = ShapeUtil::MakeShape( + sort->operand(operand_idx)->shape().element_type(), {}); + Shape actual_parameter_shape = + compare->parameter_instruction(parameter_idx)->shape(); + if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape, + actual_parameter_shape)) { + return InternalError( + "Expected the %lld-th parameter of the compare computation of sort " + "to have shape %s, but got %s", + parameter_idx, StringifyShape(expected_scalar_shape), + StringifyShape(actual_parameter_shape)); + } + } + + // Verify that all operand shapes have the same dimensions. for (int64 operand = 1; operand < sort->operand_count(); ++operand) { if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), sort->operand(operand)->shape())) { return InternalError( - "Expected sort to have to have the same dimensions for the keys " - "and the values. Keys shape is: %s\n, Values shape (operand index " - "%lld) is: %s", + "Expected sort to have to have the same dimensions for all operands. " + "First operand shape is: %s\n, shape (operand index %lld) is: %s", StringifyShape(sort->operand(0)->shape()), operand, StringifyShape(sort->operand(operand)->shape())); } @@ -338,7 +389,6 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { - TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); if (!Cast(constant)->HasLiteral()) { return InternalError("Constant is required to have a valid literal: %s", constant->ToString()); @@ -347,8 +397,10 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); + if (!iota->shape().IsArray()) { + return InternalError("Iota does not support non-array result."); + } const int64 rank = iota->shape().rank(); if (rank == 0) { return InternalError("Iota does not support scalars."); @@ -362,13 +414,30 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { - TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), get_tuple_element->tuple_index())); } +namespace { +Status SameElementTypesForOperandsAndToApplyParameters( + const HloInstruction& instruction, int64 num_operands_to_check) { + const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape(); + for (int i = 0; i < num_operands_to_check; ++i) { + const Shape& parameter_shape = to_apply.parameters(i); + const Shape& operand_shape = instruction.operands()[i]->shape(); + if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) { + return InvalidArgument( + "Shape mismatch between to_apply computation" + " parameter and operand %d in %s.", + i, instruction.ToString().c_str()); + } + } + return Status::OK(); +} +} // namespace + Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { if (reduce->operand_count() % 2 != 0) { return InternalError( @@ -380,13 +449,18 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(reduce, ShapeInference::InferReduceShape( - operand_shapes, reduce->dimensions(), - reduce->to_apply()->ComputeProgramShape())); + TF_RETURN_IF_ERROR( + CheckShape(reduce, ShapeInference::InferReduceShape( + operand_shapes, reduce->dimensions(), + reduce->to_apply()->ComputeProgramShape()))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters( + *reduce, reduce->operands().size() - 1); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); // Bitcasts are not allowed to change the element type. if (bitcast->operand(0)->shape().element_type() != bitcast->shape().element_type()) { @@ -399,7 +473,6 @@ Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { - TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); @@ -419,7 +492,6 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { - TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. const Shape& operand_shape = reshape->operand(0)->shape(); TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape)); @@ -429,14 +501,12 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { - TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } @@ -496,7 +566,6 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -538,23 +607,33 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { // arbitrary map dimensions. std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); - return CheckShape(map, ShapeInference::InferMapShape( - operand_shapes, - map->to_apply()->ComputeProgramShape(), map_dims)); + + TF_RETURN_IF_ERROR(CheckShape( + map, + ShapeInference::InferMapShape( + operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters( + *map, map->operands().size()); } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { - TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); - return CheckShape( + TF_RETURN_IF_ERROR(CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( reduce_window->operand(0)->shape(), reduce_window->operand(1)->shape(), reduce_window->window(), - reduce_window->to_apply()->ComputeProgramShape())); + reduce_window->to_apply()->ComputeProgramShape()))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters(*reduce_window, + 1); } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -565,7 +644,6 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { - TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); TF_RETURN_IF_ERROR( CheckParameterCount(xla_while, xla_while->while_body(), 1)); TF_RETURN_IF_ERROR( @@ -589,33 +667,32 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { - TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); - TF_RETURN_IF_ERROR( - CheckParameterCount(conditional, conditional->true_computation(), 1)); - TF_RETURN_IF_ERROR( - CheckParameterCount(conditional, conditional->false_computation(), 1)); - TF_RETURN_IF_ERROR(CheckOperandAndParameter( - conditional, 1, conditional->true_computation(), 0)); - TF_RETURN_IF_ERROR(CheckOperandAndParameter( - conditional, 2, conditional->false_computation(), 0)); - TF_RETURN_IF_ERROR( - CheckShape(conditional, - conditional->true_computation()->root_instruction()->shape())); - TF_RETURN_IF_ERROR(CheckShape( - conditional, - conditional->false_computation()->root_instruction()->shape())); + const int num_branches = conditional->branch_count(); + if (conditional->operand(0)->shape().element_type() == PRED) { + TF_RET_CHECK(num_branches == 2); + } else { + TF_RET_CHECK(num_branches >= 1); + } + TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1)); + for (int j = 0; j < num_branches; ++j) { + TF_RETURN_IF_ERROR(CheckParameterCount( + conditional, conditional->branch_computation(j), 1)); + TF_RETURN_IF_ERROR(CheckOperandAndParameter( + conditional, j + 1, conditional->branch_computation(j), 0)); + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->branch_computation(j)->root_instruction()->shape())); + } return Status::OK(); } Status ShapeVerifier::HandlePad(HloInstruction* pad) { - TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { - TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -623,12 +700,10 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { - TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { - TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -636,7 +711,6 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { - TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -646,7 +720,6 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { - TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -657,7 +730,6 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { - TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -669,7 +741,6 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { - TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -737,7 +808,6 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { - TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -746,7 +816,6 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { - TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -764,7 +833,6 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { } Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) { - TF_RETURN_IF_ERROR(CheckOperandCount(add_dependency, 2)); TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); return CheckShape(add_dependency, add_dependency->operand(0)->shape()); } @@ -846,14 +914,12 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -861,7 +927,6 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -1237,8 +1302,8 @@ Status CheckFusionInstruction(HloInstruction* fusion) { return Status::OK(); } -// Checks that the non-scalar operand shapes are compatible to the output -// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +// Checks that the operand shapes are compatible to the output shape, i.e., +// that there are no implicit broadcasts. Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { @@ -1307,17 +1372,13 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } Status HandleConditional(HloInstruction* conditional) override { - if (conditional->true_computation()->num_parameters() != 1) { - return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %d", - conditional->true_computation()->name(), conditional->ToString(), - conditional->true_computation()->num_parameters()); - } - if (conditional->false_computation()->num_parameters() != 1) { - return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %d", - conditional->false_computation()->name(), conditional->ToString(), - conditional->false_computation()->num_parameters()); + for (int b = 0; b < conditional->branch_count(); ++b) { + if (conditional->branch_computation(b)->num_parameters() != 1) { + return FailedPrecondition( + "Branch computation %s of %s must have 1 parameter insted of %d", + conditional->branch_computation(b)->name(), conditional->ToString(), + conditional->branch_computation(b)->num_parameters()); + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 479905b317d5639ff2cebc4d1044e21b527693f6..d427a1586c3cd1d1abbd6606f33067e36cabad98 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -52,9 +52,12 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; + Status HandleCholesky(HloInstruction* hlo) override; + Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index de0335a5e8f18d1321e9ca6e6cf5057999f9adc4..523890b3c7268c06cdb6aaa67749f26a1cb62855 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -535,5 +535,84 @@ TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, IotaNonArrayResult) { + const char* const hlo_string = R"( + HloModule IotaTupleResult + + ENTRY kernelEntry { + ROOT iota = () iota(), iota_dimension=24 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("does not support non-array result")); +} + +static const char* const kMapOperandComputationMismatchHlo = R"( + HloModule MapOperandComputationMismatch + + Computation { + param0 = f32[] parameter(0) + constant = f32[] constant(1) + ROOT add = f32[] add(param0, constant) + } + + ENTRY kernelEntry { + param = f64[] parameter(0) + ROOT map = f32[] map(param), dimensions={}, to_apply=Computation +})"; + +TEST_F(HloVerifierTest, MapOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kMapOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr( + "Shape mismatch between to_apply computation parameter and operand")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kMapOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +static const char* const kReduceOperandComputationMismatchHlo = R"( + HloModule ReduceOperandComputationMismatch + computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernelEntry { + arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0) + constant = f16[] constant(0) + reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation + })"; + +TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kReduceOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to f32[64]")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kReduceOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc deleted file mode 100644 index ada21345014dac70d61129aaf7bbc7466a7db914..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -namespace { - -// Visitor for removing implicit broadcasts. -class ImplicitBroadcastVisitor : public DfsHloVisitorWithDefault { - public: - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Status::OK(); - } - - Status HandleElementwiseBinary(HloInstruction* hlo) override { - return ReplaceImplicitBroadcastOperands(hlo); - } - - Status HandleClamp(HloInstruction* hlo) override { - // Clamp is the only element-wise ternary operation. - return ReplaceImplicitBroadcastOperands(hlo); - } - - // Returns whether any modification has been made to any visited instruction. - bool changed() const { return changed_; } - - private: - // Iterates through the operands of 'hlo' and replace any operands which are - // implicitly broadcast with the equivalent sequence of broadcast and reshape - // instructions. An operand is considered to be implicitly broadcast if the - // operand shape does have the same dimensions as the shape of 'hlo'. - Status ReplaceImplicitBroadcastOperands(HloInstruction* hlo) { - auto fadd = [hlo](std::unique_ptr x) { - return hlo->parent()->AddInstruction(std::move(x)); - }; - std::vector operands; - bool operands_changed = false; - for (int i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - if (!ShapeUtil::SameDimensions(hlo->shape(), operand->shape())) { - HloInstruction* new_operand = hlo->parent()->AddInstruction( - HloInstruction::CreateBroadcastSequence(hlo->shape(), operand, - fadd)); - operands.push_back(new_operand); - operands_changed = true; - } else { - operands.push_back(operand); - } - } - if (operands_changed) { - // Create a new HLO instruction because the HloInstruction::Replace* - // methods check that the shape does not change with the replacement. - HloInstruction* new_hlo = hlo->parent()->AddInstruction( - hlo->CloneWithNewOperands(hlo->shape(), operands)); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); - changed_ = true; - } - return Status::OK(); - } - - bool changed_ = false; -}; - -} // namespace - -StatusOr ImplicitBroadcastRemover::Run(HloModule* module) { - VLOG(1) << "Removing implicit broadcast from module " << module->name(); - XLA_VLOG_LINES(2, - "Before removing implicit broadcasts:\n" + module->ToString()); - - ImplicitBroadcastVisitor visitor; - for (HloComputation* computation : module->computations()) { - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - } - - if (visitor.changed()) { - // HLO instructions with implicitly broadcast operands are cloned and left - // for dead. Remove them. - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } - - XLA_VLOG_LINES(2, - "After removing implicit broadcasts:\n" + module->ToString()); - - return visitor.changed(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc deleted file mode 100644 index cf6cf897fe11eda01ba6b22119bba34ac2bef8fe..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -class ImplicitBroadcastRemoverTest : public HloTestBase { - protected: - ImplicitBroadcastRemover remover_; -}; - -TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); - auto param0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - auto param1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_FALSE(remover_.Run(m.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Parameter(), op::Parameter())); -} - -TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); - auto param1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - - EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - root = computation->root_instruction(); - - EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); - - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); - auto param0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 4, 1}), "p1")); - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Subtract(op::Parameter(), - op::Broadcast(op::Reshape(op::Parameter())))); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); - auto param1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, - op::Subtract(op::Broadcast(op::Parameter()), op::Parameter())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 4, 1, 8}), "p0")); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 6, 8}), "p1")); - auto param2 = builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(F32, {2, 1, 6, 8}), "p2")); - builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, - param0, param1, param2)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), - op::Broadcast(op::Reshape(op::Parameter())), - op::Broadcast(op::Reshape(op::Parameter())))); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, - TernaryScalarAndDegenerateDimensionBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 4, 6}), "p1")); - auto param2 = - builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2")); - builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, - param0, param1, param2)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), - op::Broadcast(op::Reshape(op::Parameter())), - op::Parameter())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 76bf48870d55e82497ba5f63e9e2e2a322cb330e..c5d32a4b9ad8c708ec0870173fa72320238e8464 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -namespace gtl = ::tensorflow::gtl; namespace { using Analysis = IndexedArrayAnalysis; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 295465c8481bcb7d1385192febe0d09614e393b3..62107b5a88d4e37552fa5a6384700a9291a9c655 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "absl/strings/ascii.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -43,7 +42,7 @@ class IndexedArrayAnalysisTest : public HloTestBase { string result; for (char c : text) { - if (!isspace(c)) { + if (!absl::ascii_isspace(c)) { result.push_back(c); } else if (!result.empty() && result.back() != ' ') { result.push_back(' '); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index b97060535d998e174639dceca5cde517cef01e30..4868cf961aa1bb02e37b2207c4231ca52f3d28a7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -65,6 +65,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: @@ -72,21 +73,15 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kFloor: - case HloOpcode::kGe: case HloOpcode::kGetTupleElement: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kInfeed: case HloOpcode::kIota: case HloOpcode::kIsFinite: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: @@ -95,6 +90,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kPad: case HloOpcode::kReal: case HloOpcode::kReducePrecision: + case HloOpcode::kReplicaId: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: @@ -125,6 +121,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: + case HloOpcode::kCholesky: case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kAllReduce: @@ -150,13 +147,16 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kRsqrt: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kSort: + case HloOpcode::kSqrt: case HloOpcode::kTanh: case HloOpcode::kTrace: + case HloOpcode::kTriangularSolve: case HloOpcode::kWhile: case HloOpcode::kGetDimensionSize: return true; diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 545662543cca40e42b0f0302e14152e5283f9e4f..599489b3785be50ba7a145f298a13d6bb995a1cf 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:cholesky_expander", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:dynamic_index_splitter", @@ -50,7 +51,9 @@ cc_library( "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 4818b2dae0a9951346600a9b2906488c3ef7e06e..a8f8ab4f725d904a529dbd50c1c199972a1c0895 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -20,7 +20,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -34,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -42,10 +45,44 @@ limitations under the License. namespace xla { namespace interpreter { +namespace { + +// Handles custom_call ops during evaluation by routing them through the global +// CPU registry used by other CPU-based backends. +StatusOr HandleEvaluatorCustomCall( + HloInstruction* custom_call, absl::Span operands) { + // Find the target C function in the global registry. + auto* registry = xla::cpu::CustomCallTargetRegistry::Global(); + void* target_fn = registry->Lookup(custom_call->custom_call_target()); + if (!target_fn) { + return NotFound("Custom call target '%s' was not registered", + custom_call->custom_call_target()); + } + + // Populate pointers to operand and output literal data. + std::vector operand_data; + operand_data.reserve(operands.size()); + for (const auto* literal : operands) { + operand_data.push_back(literal->untyped_data()); + } + auto output = Literal::CreateFromShape(custom_call->shape()); + void* output_data = output.untyped_data(); + + // Call the target function matching the C ABI used by the CPU backends. + auto* typed_fn = reinterpret_cast(target_fn); + (*typed_fn)(output_data, operand_data.data()); + + return std::move(output); +} + +} // namespace + Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout); @@ -83,10 +120,12 @@ StatusOr> InterpreterCompiler::RunBackend( // In this case we are using an HloEvaluator at execution time, so we don't // need to compile anything - // Create executable from only the Hlo module. auto evaluator = absl::make_unique(); evaluator->set_use_fast_path( hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); + evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); + + // Create executable from only the Hlo module. std::unique_ptr executable = absl::make_unique(std::move(hlo_module), std::move(evaluator)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 10ff7bb6d46ee3b2cd1228b4b7a49269be8c65d3..039954a1837abfc6d5205a39b448238403d74c02 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -588,48 +587,56 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( body_layout.result_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kConditional) { - // The layout of the true and false computations must match, and must - // be the layout of the kConditional instruction. - TF_RET_CHECK(instruction->operand_count() == 3); - - HloComputation* true_computation = instruction->true_computation(); - HloComputation* false_computation = instruction->false_computation(); - const HloInstruction* true_operand = instruction->operand(1); - const HloInstruction* false_operand = instruction->operand(2); - - TF_RET_CHECK(true_computation->num_parameters() == 1); - TF_RET_CHECK(false_computation->num_parameters() == 1); - ComputationLayout& true_computation_layout = - FindOrDie(computation_layouts_, true_computation); - ComputationLayout& false_computation_layout = - FindOrDie(computation_layouts_, false_computation); - - DCHECK(ShapeUtil::Compatible(true_operand->shape(), - true_computation_layout.parameter_shape(0))); - DCHECK(ShapeUtil::Compatible( - false_operand->shape(), false_computation_layout.parameter_shape(0))); - if (true_computation_layout.result_layout() != - false_computation_layout.result_layout()) { - // We assign layouts in DFS fashion, so the true and false computations - // might have negotiated a different layout. But for the conditional - // instruction POV the layout must match, so we run again on the false - // computation, this time with proper computation layout. - VLOG(2) << "Reset %conditional false computation result layout: " - "false_computation=" - << false_computation->name() - << " conditional=" << instruction->name() << " shape=" - << true_computation_layout.result_layout().ToString(); - *false_computation_layout.mutable_result_layout() = - true_computation_layout.result_layout(); + // Find the conditional branch with the most instructions and force all + // other computations to match that layout. A potentially better decison + // could count the number FLOPs or how constrained the layouts are. + int64 largest_branch = 0; + int64 largest_instruction_count = + instruction->branch_computation(0)->instruction_count(); + for (int j = 1; j < instruction->branch_count(); ++j) { + const int64 instruction_count = + instruction->branch_computation(j)->instruction_count(); + if (instruction_count > largest_instruction_count) { + largest_branch = j; + largest_instruction_count = instruction_count; + } + } + ComputationLayout& best_branch_computation_layout = + FindOrDie(computation_layouts_, + instruction->branch_computation(largest_branch)); + for (int k = 0; k < instruction->branch_count(); ++k) { + // Visit the best branch first. + int j = (k + largest_branch) % instruction->branch_count(); + TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1); + ComputationLayout& branch_computation_layout = + FindOrDie(computation_layouts_, instruction->branch_computation(j)); + + DCHECK(ShapeUtil::Compatible( + instruction->operand(j + 1)->shape(), + branch_computation_layout.parameter_shape(0))); + if (best_branch_computation_layout.result_layout() != + branch_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the largest_branch and current + // branch computations might have negotiated a different layout. But + // for the case instruction POV the layout must match, so we run again + // on the branch j computation, this time with proper computation + // layout. + VLOG(2) << "Reset %conditional branch " << j + << " computation result layout: branch_computation=" + << instruction->branch_computation(j)->name() + << " case=" << instruction->name() << " shape=" + << best_branch_computation_layout.result_layout().ToString(); + *branch_computation_layout.mutable_result_layout() = + best_branch_computation_layout.result_layout(); + } + if (k == 0) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + best_branch_computation_layout.result_shape(), instruction)); + } + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + branch_computation_layout.parameter_shape(0), instruction, j + 1, + /*mandatory=*/true)); } - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - true_computation_layout.result_shape(), instruction)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - true_computation_layout.parameter_shape(0), instruction, 1, - /*mandatory=*/true)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - false_computation_layout.parameter_shape(0), instruction, 2, - /*mandatory=*/true)); } } // Finally set the result layout to match ComputationLayout, if there is one. @@ -699,28 +706,21 @@ Status CheckWhileLayout(HloInstruction* while_inst, Status CheckConditionalLayout( HloInstruction* instruction, - const ComputationLayout& true_computation_layout, - const ComputationLayout& false_computation_layout) { - HloComputation* true_computation = instruction->true_computation(); - HloComputation* false_computation = instruction->false_computation(); - const HloInstruction* true_operand = instruction->operand(1); - const HloInstruction* false_operand = instruction->operand(2); - - TF_RET_CHECK(true_computation_layout.result_layout() == - false_computation_layout.result_layout()); - TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( - instruction->shape())); - TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( - true_computation->root_instruction()->shape())); - TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( - instruction->shape())); - TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( - false_computation->root_instruction()->shape())); - TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( - true_operand->shape())); - TF_RET_CHECK( - false_computation_layout.parameter_layout(0).MatchesLayoutInShape( - false_operand->shape())); + absl::Span branch_computation_layouts) { + for (int j = 0; j < instruction->branch_count(); ++j) { + const HloInstruction* branch_operand = instruction->operand(j + 1); + TF_RET_CHECK(branch_computation_layouts[0].result_layout() == + branch_computation_layouts[j].result_layout()); + TF_RET_CHECK( + branch_computation_layouts[j].result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK( + branch_computation_layouts[j].result_layout().MatchesLayoutInShape( + instruction->branch_computation(j)->root_instruction()->shape())); + TF_RET_CHECK( + branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape( + branch_operand->shape())); + } return Status::OK(); } @@ -937,13 +937,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->while_condition()), FindOrDie(computation_layouts_, instruction->while_body()))); break; - case HloOpcode::kConditional: + case HloOpcode::kConditional: { + std::vector branch_computation_layouts; + for (auto branch_computation : instruction->branch_computations()) { + branch_computation_layouts.emplace_back( + FindOrDie(computation_layouts_, branch_computation)); + } TF_RETURN_IF_ERROR(CheckConditionalLayout( - instruction, - FindOrDie(computation_layouts_, instruction->true_computation()), - FindOrDie(computation_layouts_, - instruction->false_computation()))); + instruction, absl::MakeSpan(branch_computation_layouts))); break; + } default: break; } @@ -1008,7 +1011,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // operations. For similar reasons, if the operand and output have the same // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && - instruction->shape().rank() == 1) { + ShapeUtil::TrueRank(instruction->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } @@ -1019,16 +1022,6 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); - if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return absl::make_unique(operand_shape.layout()); - } - if (operand_shape.rank() == output_shape.rank()) { - *operand_shape.mutable_layout() = output_layout; - if (ShapeUtil::ReshapeIsBitcast(operand_shape, - output_shape_with_layout)) { - return absl::make_unique(output_layout); - } - } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); if (aligned_operand_shape) { @@ -1078,7 +1071,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // reshape is a bitcast when using the same layout. This may avoid copy // operations. For similar reasons, if the operand and output have the same // rank, try to match the outputs's layout to the operand. - if (operand->shape().rank() == 1 && + if (ShapeUtil::TrueRank(operand->shape()) == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; @@ -1090,16 +1083,6 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( Shape output_shape = user->shape(); *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); - if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return absl::make_unique(output_shape.layout()); - } - if (operand->shape().rank() == output_shape.rank()) { - *output_shape.mutable_layout() = operand_layout; - if (ShapeUtil::ReshapeIsBitcast(output_shape, - operand_shape_with_layout)) { - return absl::make_unique(operand_layout); - } - } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); if (aligned_user_shape) { @@ -1585,8 +1568,9 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { // Nop; leave the infeed layout alone. - } else { + } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) { // Other instructions don't have layouts inside of fusion nodes. + // But do not clear layouts for other instructions in custom fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); } } @@ -1932,12 +1916,6 @@ Status LayoutAssignment::PropagateComputationLayouts( StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); - XLA_VLOG_LINES(3, module->ToString()); - if (VLOG_IS_ON(10)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before layout assignment", - module->config().debug_options()); - } TF_RETURN_IF_ERROR(Init()); // Verify computation layout is sane. @@ -1992,13 +1970,6 @@ StatusOr LayoutAssignment::Run(HloModule* module) { entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); - VLOG(3) << "After layout assignment:"; - XLA_VLOG_LINES(3, module->ToString()); - if (VLOG_IS_ON(10)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after layout assignment", - module->config().debug_options()); - } // All layouts are reset then reassigned by this pass. return true; } @@ -2016,6 +1987,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConditional: @@ -2027,24 +1999,18 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: - case HloOpcode::kLt: case HloOpcode::kMap: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: @@ -2057,6 +2023,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRemainder: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: case HloOpcode::kScatter: case HloOpcode::kSelect: case HloOpcode::kSelectAndScatter: @@ -2067,8 +2034,11 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSin: case HloOpcode::kSlice: case HloOpcode::kSort: + case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: + case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: return false; @@ -2094,6 +2064,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduce: + case HloOpcode::kReplicaId: case HloOpcode::kReshape: case HloOpcode::kRng: case HloOpcode::kSend: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index c8cf3c47d380012fdb0206c0d20d67e6a13017ae..efca6be331ee93f69e506ff1240db1ee089924a2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -1084,7 +1084,7 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) counter.1 = s32[] get-tuple-element(tup.1), index=0 five = s32[] constant(5) - ROOT lt = pred[] less-than(counter.1, five) + ROOT lt = pred[] compare(counter.1, five), direction=LT } body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index 182d8edbe30da292f28aeab53be646ce6651839f..afd9f37038387e2ba11d7f3c5e184dc4ad163584 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -37,7 +37,7 @@ class LLVMCompiler : public Compiler { // A callback of this type can be run before and/or after IR-level // optimization to e.g. dump out the generated IR to disk or gather some // statistics. - using ModuleHook = std::function; + using ModuleHook = std::function; void SetPreOptimizationHook(ModuleHook hook) { CHECK(!user_pre_optimization_hook_) diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index c5d59fb28e02ce229967fb3856012d608fb83c5d..e51a394827d1b7c145c54f4202bc6403d0fa517b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -67,9 +67,11 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service/cpu:cpu_options", "//tensorflow/core:lib", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", @@ -161,6 +163,7 @@ cc_library( ":llvm_util", ":loop_emitter", ":tuple_ops", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -169,6 +172,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -236,7 +240,7 @@ cc_library( hdrs = ["kernel_support_library.h"], deps = [ ":llvm_loop", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + ":llvm_util", "@com_google_absl//absl/strings", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index ce3d922ca7a9bdea3a520959a8b8d284bc3e0d64..c915a472707f8e591af6edcb3ce84e837b96016b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -115,7 +115,7 @@ llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( llvm::MDBuilder metadata_builder(domain->getContext()); llvm::MDNode* scope = metadata_builder.createAliasScope( - AsStringRef("buffer: " + buffer_slice.ToString()), domain); + "buffer: " + buffer_slice.ToString(), domain); llvm::MDNode* scope_list = llvm::MDNode::get(domain->getContext(), scope); return scope_list; } @@ -197,7 +197,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( std::vector scopes; for (const BufferAllocation::Slice noalias_slice : buffers) { llvm::MDNode* scope = metadata_builder.createAliasScope( - AsStringRef("buffer: " + noalias_slice.ToString()), domain); + "buffer: " + noalias_slice.ToString(), domain); scopes.push_back(scope); } llvm::MDNode* noalias_list = diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index b6ae4932f5707f1d15af1e09a735a7de2e48fac5..db900856993b3a22f79767b2c98d79e051571194 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -46,7 +46,7 @@ condition { condition.state = f32[] parameter(0) addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget" add = f32[] add(addend, condition.state) - ROOT greater-than = pred[] greater-than(const.100, add) + ROOT greater-than = pred[] compare(const.100, add), direction=GT } ENTRY while3 { diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index bdce4a171b8a58f617f1d56e6cf6db5354846703..1ea5a42b0b398818b0946eaa9e214100007bada4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,14 +41,26 @@ static const HloInstruction& InstrForConstantBufferAllocation( return *const_instr; } -string ConstantBufferAllocationToGlobalName( - const BufferAllocation& allocation) { - string instr_name = InstrForConstantBufferAllocation(allocation).name(); +string SanitizeConstantName(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kConstant); + string instr_name = instr.name(); for (char& c : instr_name) { - if (c == '.') { + // Having a hyphen or a dot in a global variable name can crash the LLVM PTX + // backend. + if (c == '.' || c == '-') { c = '_'; } } + return instr_name; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + const HloInstruction& instr = InstrForConstantBufferAllocation(allocation); + string instr_name = instr.name(); + // Check that names are sanitized and stored in the HLO instructions + // before constant buffer allocation. + DCHECK_EQ(instr_name, SanitizeConstantName(instr)); return absl::StrCat("buffer_for_", instr_name); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h index bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7..03e98a66900095889292cbff9d9924a9abe83ab0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -20,6 +20,10 @@ limitations under the License. namespace xla { namespace llvm_ir { +// Sanitizes the HLO constant instruction name so that it can be used for the +// name of the corresponding constant buffer. In particular, it replaces . and +// - with _. +string SanitizeConstantName(const HloInstruction& instr); // In XLA:GPU we map constant buffer allocations to globals in the generated // LLVM IR. This function gives us the name of the global variable a constant // buffer is mapped to. Not used on XLA:CPU. diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index c66eaec8fb0e4c03f6967fec0cf0ae9661cdf470..4974cb57db3bb85f90382bfeb7794fe414befdbd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -47,29 +47,30 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Read start indices from start_indices_generator. const int64 rank = output_shape.rank(); - IrArray::Index start_index(b->getInt64Ty(), rank); + std::vector start_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(i)); + TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( - start_index[i]->getType(), output_shape.dimensions(i)); + start_multi_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( - start_index[i]->getType(), update_shape.dimensions(i)); + start_multi_index[i]->getType(), update_shape.dimensions(i)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size); - llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); - start_index[i] = + llvm::Value* zero = + llvm::ConstantInt::get(start_multi_index[i]->getType(), 0); + start_multi_index[i] = b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - zero, start_index[i]), - zero, start_index[i]); + zero, start_multi_index[i]), + zero, start_multi_index[i]); - start_index[i] = + start_multi_index[i] = b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - max_bound, start_index[i]), - max_bound, start_index[i]); + max_bound, start_multi_index[i]), + max_bound, start_multi_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { @@ -78,14 +79,16 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // // output_index[dim] = start_index[dim] + update_index[dim] // - IrArray::Index output_index(start_index.GetType(), rank); + std::vector output_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* start_index0 = - b->CreateSExtOrBitCast(start_index[i], update_index[i]->getType()); - output_index[i] = b->CreateAdd(start_index0, update_index[i]); + llvm::Value* start_index0 = b->CreateSExtOrBitCast( + start_multi_index[i], update_index[i]->getType()); + output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. + IrArray::Index output_index(output_multi_index, output_shape, + b->getInt64Ty()); TF_ASSIGN_OR_RETURN(llvm::Value * update_data, update_array_generator(update_index)); output_array.EmitWriteArrayElement(output_index, update_data, b); @@ -113,20 +116,10 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, Shape output_shape = output_array.GetShape(); Shape update_shape = update_array.GetShape(); - IndexGenerator start_indices_generator; - // TODO(b/118437727): Remove the R1 path, and rename the variables. - if (start_indices_array.GetShape().rank() == 1) { - start_indices_generator = [&](int64 index) { - return start_indices_array.EmitReadArrayElement( - IrArray::Index({b->getInt64(index)}), b); - }; - } else { - start_indices_generator = [&](int64 index) { - return operand_arrays[2 + index].EmitReadArrayElement( - IrArray::Index(b->getInt64Ty()), b); - }; - } - + IndexGenerator start_indices_generator = [&](int64 index) { + return operand_arrays[2 + index].EmitReadArrayElement( + IrArray::Index(b->getInt64Ty()), b); + }; ElementGenerator update_array_generator = [&](const IrArray::Index& index) { return update_array.EmitReadArrayElement(index, b); }; @@ -178,21 +171,11 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); - // TODO(b/118437727): Remove the R1 path, and rename the variables. - IndexGenerator start_indices_generator; - if (start_indices->shape().rank() == 1) { - start_indices_generator = [&](int64 index) { - return fused_emitter.GetGenerator(start_indices)( - IrArray::Index({b->getInt64(index)})); - }; - } else { - start_indices_generator = [&](int64 index) { - ElementGenerator element_generator = - fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); - return element_generator(IrArray::Index(b->getInt64Ty())); - }; - } - + IndexGenerator start_indices_generator = [&](int64 index) { + ElementGenerator element_generator = + fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + return element_generator(IrArray::Index(b->getInt64Ty())); + }; bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( update_shape, start_indices_generator, is_signed, update_array_generator, diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index e440f05e2b2f0d4a2a4c7b326b4881183de4d235..cf9df95ad575425404e4f88919078f158ec5c1de 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -15,14 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -58,9 +66,9 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { } VLOG(3) << "The cached generated value can't be reused, because it is in " "a different BB (" - << llvm_ir::AsString(generated_value_bb->getName()) + << generated_value_bb->getName().str() << ") from the current insertion block (" - << llvm_ir::AsString(b_->GetInsertBlock()->getName()) << ")."; + << b_->GetInsertBlock()->getName().str() << ")."; } TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()], @@ -78,8 +86,11 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global = new llvm::GlobalVariable( *b_->GetInsertBlock()->getModule(), initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, /*Name=*/""); + global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global); llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); @@ -110,9 +121,9 @@ Status FusedIrEmitter::HandleGetTupleElement( } // Lookup tuple element pointer. - return llvm_ir::EmitGetTupleElement( - get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, tuple_ptr, b_, module_); + return llvm_ir::EmitGetTupleElement(get_tuple_element->shape(), + get_tuple_element->tuple_index(), + /*alignment=*/1, tuple_ptr, b_); }; if (!get_tuple_element->shape().IsTuple()) { @@ -192,4 +203,101 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator( return indexed_generators_.at(instruction); } +bool FusedIrEmitter::IsFusedIrEmitterInefficient( + const HloInstruction* consumer, const HloInstruction* producer) { + if (consumer->opcode() != HloOpcode::kFusion) { + return false; + } + // Collects for each instruction in the fusion node from which (indirect) + // users newly created index values are passed. Roughly speaking, we reuse + // index values if the shapes are equal when ignoring the element type (we may + // reuse also if the shape change is a bitcast, but we don't consider that + // here). By ignoring potential reuses our estimate whether the fusion emitter + // is inefficient is a bit more conservative than necessary. + absl::flat_hash_map> + indexing_users; + // Stores the number of different index accesses for each instruction in the + // fusion node. The fusion emitter caches access with the same index, so this + // value indicates how many times a specific instruction will be emitted. + absl::flat_hash_map index_usage_count; + index_usage_count[consumer] = 1; + + auto evaluate_fusion_computation = [&indexing_users, &index_usage_count]( + const HloInstruction* fusion) { + auto postorder = + fusion->fused_instructions_computation()->MakeInstructionPostOrder(); + std::reverse(postorder.begin(), postorder.end()); + for (const auto* instruction : postorder) { + if (instruction->opcode() == HloOpcode::kParameter) { + continue; + } + int64& total = index_usage_count[instruction]; + if (indexing_users[instruction].empty()) { + total = index_usage_count[fusion]; + } else { + total = 0; + for (const auto* user : indexing_users[instruction]) { + int64 weight = 1; + // Concatenate is special: the index differs for each operand, so + // in the worst case we have to deal with as many index values as + // the number of operands of Concatenate. By considering the worst + // case, we are more conservative than necessary regarding + // refusing to fuse. + if (user->opcode() == HloOpcode::kConcatenate) { + weight = user->operand_count(); + } + total += index_usage_count[user] * weight; + } + } + for (const auto* operand : instruction->operands()) { + // For simplicity we assume that all shape and layout changing + // operations invalidate index reuse. + if (Shape::Equal().IgnoreElementType()(operand->shape(), + instruction->shape())) { + // If the index is reused, it means the operand gets index values + // from the same set of (indirect) users as 'instruction' itself. + indexing_users[operand].insert(indexing_users[instruction].begin(), + indexing_users[instruction].end()); + } else { + // If the index is not reused, it means 'instruction' computes a + // new index derived from the index it gets. + indexing_users[operand].insert(instruction); + } + } + } + }; + evaluate_fusion_computation(consumer); + + // Also account for the 'producer' if it would be fused. Find the operand it + // corresponds to. + for (int64 operand_num = 0; operand_num < consumer->operand_count(); + ++operand_num) { + if (consumer->operand(operand_num) == producer) { + auto instruction = consumer->fused_parameter(operand_num); + int64& total = index_usage_count[producer]; + total = 0; + for (const auto* user : indexing_users[instruction]) { + total += index_usage_count[user]; + } + break; + } + } + + // If 'producer' is a fusion node as well, also evaluate it. + if (producer->opcode() == HloOpcode::kFusion) { + evaluate_fusion_computation(producer); + } + + // Sum up the total number of emitted ops. + int64 total = 0; + for (const auto& entry : index_usage_count) { + total += entry.second; + } + + // Check that the code duplication has at most a factor of 8 (where 8 is an + // arbitrary constant that seems to work). + return total > 8 * index_usage_count.size(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index e6d52a580c04a920d3f0e8ed6f39c1cae587cf1b..b1aa6d59634463956491b586d84fb6a6945a3fdf 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -91,6 +91,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { tiled_parameter_info_ = info; } + // Evaluates whether fusing 'producer' into 'consumer' might cause exponential + // behavior in FusedIrEmitter. We currently can have exponential time/memory + // requirements for emitting certain fusion kernels, in which case we don't + // want to fuse. + // TODO(b/119692968): Remove this once we have fixed our fusion emitter. + static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer, + const HloInstruction* producer); + protected: // Returns the IrArrays for the fusion instruction operands. llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 8ee07ae8331e986f9d271be5e39065f0d87853b1..7bf8bde9dededbfefed48449f19987798f62c434 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,6 +29,14 @@ limitations under the License. namespace xla { namespace llvm_ir { +IrArray::Index::Index(absl::Span multidim, + llvm::Value* linear, const Shape& shape, + llvm::Type* index_type) + : Index(multidim, shape, index_type) { + CHECK_NE(linear, nullptr); + linear_ = linear; +} + void IrArray::Index::Delinearize(std::vector* multidim, llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) const { @@ -74,36 +82,22 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, } IrArray::Index::Index(absl::Span multidim, - llvm::Value* linear, const Shape& shape) + const Shape& shape, llvm::Type* index_type) : multidim_(multidim.begin(), multidim.end()), - linear_(linear), + linear_(nullptr), layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - if (size()) { - index_type_ = multidim_[0]->getType(); - } else { - CHECK_NE(linear_, nullptr); - index_type_ = linear_->getType(); - } + dims_(shape.dimensions().begin(), shape.dimensions().end()), + index_type_(index_type) { CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); + for (const auto* dim : multidim) { + CHECK_NE(dim, nullptr); + } CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; } -IrArray::Index::Index(absl::Span multidim, - const Shape& shape, llvm::IRBuilder<>* b) - : multidim_(multidim.begin(), multidim.end()), - layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - CHECK_GT(multidim_.size(), 0); - index_type_ = multidim[0]->getType(); - CHECK_NE(index_type_, nullptr); - CHECK_EQ(shape.dimensions_size(), multidim.size()); - CHECK(LayoutUtil::HasLayout(shape)); -} - IrArray::IrArray(llvm::Value* base_ptr, Shape shape) : base_ptr_(base_ptr), shape_(std::move(shape)) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -117,10 +111,10 @@ IrArray::IrArray(llvm::Value* base_ptr, Shape shape) ++depth; } - if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) { + if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString(); + DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString(); } } @@ -178,30 +172,30 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) && LayoutUtil::HasLayout(output_shape) && ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { - return Index(source_multidim_index, linear(), input_shape); + return Index(source_multidim_index, linear(), input_shape, index_type_); } return Index(source_multidim_index, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, absl::Span starts, + const Shape& operand_shape, absl::Span starts, absl::Span strides, llvm::IRBuilder<>* builder) const { - Index source_index(index_type_, multidim_.size()); + std::vector source_multi_index(multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); if (stride != 1) { - source_index[i] = builder->CreateAdd( + source_multi_index[i] = builder->CreateAdd( builder->CreateMul(multidim_[i], llvm::ConstantInt::get(type, stride)), llvm::ConstantInt::get(type, starts[i])); } else { - source_index[i] = builder->CreateAdd( + source_multi_index[i] = builder->CreateAdd( multidim_[i], llvm::ConstantInt::get(type, starts[i])); } } - return source_index; + return Index(source_multi_index, operand_shape, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfTranspose( @@ -214,7 +208,7 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose( if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && LayoutUtil::HasLayout(shape) && ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { - return Index(operand_multidim_index, linear(), operand_shape); + return Index(operand_multidim_index, linear(), operand_shape, index_type_); } return Index(operand_multidim_index); @@ -250,7 +244,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( std::vector multi_index(operand_shape.dimensions_size()); Delinearize(&multi_index, linear_index, operand_shape, builder); - return Index(multi_index, linear_index, operand_shape); + return Index(multi_index, linear_index, operand_shape, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfBroadcast( @@ -303,9 +297,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } if (divisor > 1) { - linear = builder->CreateUDiv( - linear, - IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor)); + linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor)); } if (min_broadcasted_dimension > 0) { int64 mod = 1; @@ -313,11 +305,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( ++i) { mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - linear = builder->CreateURem( - linear, - IrArray::Index(linear->getType()).GetConstantWithIndexType(mod)); + linear = builder->CreateURem(linear, GetConstantWithIndexType(mod)); } - return Index(source_index, linear, operand_shape); + return Index(source_index, linear, operand_shape, index_type_); } llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, @@ -341,20 +331,21 @@ llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, llvm::IRBuilder<>* b, - absl::string_view name) const { - if (ShapeUtil::IsScalar(*shape_)) { + absl::string_view name, + bool use_linear_index) const { + if (ShapeUtil::IsScalar(shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), shape_->rank()); + CHECK_EQ(index.size(), shape_.rank()); - if (index.LinearValidOnShape(*shape_)) { + if (use_linear_index && index.LinearValidOnShape(shape_)) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); return b->CreateInBoundsGEP( b->CreateBitCast(base_ptr_, - PrimitiveTypeToIrType(shape_->element_type(), module) + PrimitiveTypeToIrType(shape_.element_type(), module) ->getPointerTo()), {index.linear()}, llvm_ir::AsStringRef(name)); } @@ -364,7 +355,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // When dimension i is of size 1, LLVM optimization is able to replace // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to // produce better code in some cases. - auto dim = shape_->dimensions(i); + auto dim = shape_.dimensions(i); actual_index.push_back( dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]); } @@ -377,8 +368,8 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, CHECK_GT(index.size(), 0); std::vector gep_indices( 1, llvm::ConstantInt::get(index[0]->getType(), 0)); - for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { - int64 dimension = LayoutUtil::Major(shape_->layout(), i); + for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); gep_indices.push_back(actual_index[dimension]); } return b->CreateInBoundsGEP(base_ptr_, gep_indices, @@ -399,16 +390,20 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - absl::string_view name) const { - llvm::Value* element_address = EmitArrayElementAddress(index, b, name); + absl::string_view name, + bool use_linear_index) const { + llvm::Value* element_address = + EmitArrayElementAddress(index, b, name, use_linear_index); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); return load; } void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* b) const { - llvm::Value* element_address = EmitArrayElementAddress(index, b); + llvm::IRBuilder<>* b, + bool use_linear_index) const { + llvm::Value* element_address = + EmitArrayElementAddress(index, b, "", use_linear_index); llvm::StoreInst* store = b->CreateStore(value, element_address); AnnotateLoadStoreInstructionWithMetadata(store); } @@ -423,18 +418,5 @@ IrArray IrArray::CastToShape(const Shape& new_shape, return new_irarray; } -/* static */ IrArray::Index IrArray::BumpIndex(const Index& index, - int64 which_dimension, - int64 addend, - llvm::IRBuilder<>* b) { - Index new_index = index; - new_index[which_dimension] = b->CreateAdd( - index[which_dimension], - llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", - /*HasNUW=*/true, - /*HasNSW=*/true); - return new_index; -} - } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index b706ebd311cbb706e7e4698b93319e37e664d10a..d0a8b6356d6d76022e801520a1436481672bf678 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -55,17 +55,8 @@ class IrArray { // multidimensional index, which LLVM DCE can delete. class Index { public: - // Constructs an index of rank "size". Each dimension of the index is - // initialized to "value". - explicit Index(size_t size, llvm::Value* value) - : multidim_(size, value), index_type_(value->getType()) { - CHECK_NE(index_type_, nullptr); - } - - // Constructs an index of rank "size". Each dimension of the index is - // initialized to nullptr. - explicit Index(llvm::Type* index_ty, size_t size = 0) - : multidim_(size, nullptr), index_type_(index_ty) { + // Constructs an index for a scalar shape. + explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { CHECK(index_ty->isIntegerTy()); } @@ -77,7 +68,10 @@ class IrArray { if (size() == 0) { index_type_ = index_ty; } else { - index_type_ = (*this)[0]->getType(); + for (const auto* dim : multidim) { + CHECK_NE(dim, nullptr); + } + index_type_ = multidim[0]->getType(); if (index_ty != nullptr) { CHECK_EQ(index_type_, index_ty); } @@ -96,25 +90,20 @@ class IrArray { // Precondition: "shape" has a layout. Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); - // Constructs an index from the given multi-dimensional index and the shape - // that it indexes into. + // Constructs an index from a multi-dimensional index. 'shape' is the shape + // for which the multi-dimensional index is used. 'index_type' is the type + // of the index. // // Precondition: "shape" has a layout. Index(absl::Span multidim, const Shape& shape, - llvm::IRBuilder<>* b); - - // Constructs an index from both a multi-dimensional index and a linear - // index. "shape" has the same meaning as that in the constructor that takes - // only a linear index. - Index(absl::Span multidim, llvm::Value* linear, - const Shape& shape); + llvm::Type* index_type); // Returns an index that adds `addend` to the given `dim` of the object. Index AddOffsetToDim(llvm::Value* addend, int64 dim, llvm::IRBuilder<>* b) const { - IrArray::Index index = *this; - index[dim] = b->CreateAdd(index[dim], addend); - return index; + std::vector multi_index = multidim(); + multi_index[dim] = b->CreateAdd(multi_index[dim], addend); + return Index(multi_index, index_type_); } const std::vector& multidim() const { return multidim_; } @@ -123,43 +112,28 @@ class IrArray { size_t size() const { return multidim().size(); } llvm::Value* operator[](size_t i) const { return multidim()[i]; } - llvm::Value*& operator[](size_t i) { return mutable_multidim()[i]; } - - void push_back(llvm::Value* value) { mutable_multidim().push_back(value); } - void InsertAt(int64 index, llvm::Value* value) { - CHECK_LE(index, size()); - mutable_multidim().insert(mutable_multidim().begin() + index, value); - } - void InsertAt(int64 index, int64 count, llvm::Value* value) { - CHECK_LE(index, size()); - mutable_multidim().insert(mutable_multidim().begin() + index, count, - value); - } - using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - iterator begin() { return mutable_multidim().begin(); } - iterator end() { return mutable_multidim().end(); } - const_iterator begin() const { return multidim().begin(); } const_iterator end() const { return multidim().end(); } - llvm::Value* back() const { return multidim().back(); } - bool LinearValidOnShape(const Shape& a) const; // Given that "this" is the target index of a reshape from `operand_shape` // to `shape`, returns the source index. - Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape, + Index SourceIndexOfReshape(const Shape& output_shape, + const Shape& input_shape, llvm::IRBuilder<>* builder) const; // Returns the index into the source operand from which a slice operation // selects a value to be placed into index "this". The slice is described // by starting indices `starts` and stride values `strides`. // - // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + // Precondition: "this" is an index into a slice whose operand shape is + // `operand_shape`. + Index SourceIndexOfSlice(const Shape& operand_shape, + absl::Span starts, absl::Span strides, llvm::IRBuilder<>* builder) const; @@ -194,14 +168,14 @@ class IrArray { return llvm::ConstantInt::get(index_type_, c); } - void ClearLinearIndex() { linear_ = nullptr; } - private: - // Changing the multi-dimensional index invalidates the linear index. - std::vector& mutable_multidim() { - linear_ = nullptr; - return multidim_; - } + // Constructs an index from both a multi-dimensional index and a linear + // index. 'shape' is the shape on which the index is used. 'index_type' is + // the type of the index. + // + // Precondition: "shape" has a layout. + Index(absl::Span multidim, llvm::Value* linear, + const Shape& shape, llvm::Type* index_type); void Delinearize(std::vector* multidim, llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) const; @@ -242,9 +216,7 @@ class IrArray { llvm::Value* GetBasePointer() const { return base_ptr_; } llvm::Type* GetElementLlvmType() const { return element_type_; } - const Shape& GetShape() const { - return *shape_; - } + const Shape& GetShape() const { return shape_; } // Emit a sequence of instructions to compute the address of the element in // the given array at the given index. Returns the address of the element as @@ -253,7 +225,8 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - absl::string_view name = "") const; + absl::string_view name = "", + bool use_linear_index = true) const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -266,15 +239,23 @@ class IrArray { // // The optional name is useful for debugging when looking at // the emitted LLVM IR. + // 'use_linear_index' can be used to specify whether the linear index (if + // available) or the multi-dimensional index should be used. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - absl::string_view name = "") const; + absl::string_view name = "", + bool use_linear_index = true) const; // Emit IR to write the given value to the array element at the given index. + // 'use_linear_index' can be used to specify whether the linear index (if + // available) or the multi-dimensional index should be used. void EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* b) const; + llvm::IRBuilder<>* b, + bool use_linear_index = true) const; // Returns a new IrArray whose shape is "new_shape" and base pointer is a // bitcast of the base pointer of "this" IrArray. + // 'use_linear_index' can be used to specify whether the linear index (if + // available) or the multi-dimensional index should be used. IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { @@ -318,11 +299,6 @@ class IrArray { const std::map& metadata() const { return metadata_; } - // Bumps the "which_dimension" value within the provided index by the provided - // addend. - static Index BumpIndex(const Index& index, int64 which_dimension, - int64 addend, llvm::IRBuilder<>* b); - private: // Add the specified LLVM IR metadata to loads/stores associated with this // IrArray. @@ -337,7 +313,7 @@ class IrArray { llvm::Type* element_type_; // Shape of the XLA array. - absl::optional shape_; + Shape shape_; // The list of key/value pairs used when attaching metadata to emitted // loads/stores for this array. They keys are the metadata kinds and the diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index cf5083e8c13b9485035923895cec1ad05049c644..02c719502ee7b0a732ae74acec364f89d51ae0c1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -269,6 +269,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpUNE(std::forward(args)...); } + template + llvm::Value* FCmpUNO(Args&&... args) { + return mixin_builder()->CreateFCmpUNO(std::forward(args)...); + } + template llvm::Value* FDiv(Args&&... args) { return mixin_builder()->CreateFDiv(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 5eeb29c478a371dae83251771f2dc4844672d3e9..e1dc7e747654a2539c3216d418226003e89de3b1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -70,7 +70,7 @@ Status KernelSupportLibrary::IfWithStatus( } void KernelSupportLibrary::EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& @@ -101,10 +101,9 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel( auto* function_type = llvm::FunctionType::get(b->getVoidTy(), arg_types, /*isVarArg=*/false); - function = llvm_ir::CreateFunction( - function_type, llvm::GlobalValue::InternalLinkage, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, kernel_name, module); + function = llvm_ir::CreateCpuFunction(function_type, + llvm::GlobalValue::InternalLinkage, + module_config, kernel_name, module); llvm::IRBuilder<>::InsertPointGuard guard(*b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 612b839cfa15711061e1ae53358a72d5220e1801..b66ce6b835e6bdecd606cc9919575c11e32f6a0a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -263,33 +263,33 @@ class KernelSupportLibrary { // in a nullptr llvm::Value* in its position to `kernel_body_generator`. // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function& kernel_body_generator) { - EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2}, - [&](ArgumentVector args) { - kernel_body_generator(args[0], args[1], args[2]); - }); + EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2}, + [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], + args[2]); + }); } static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, b, kernel_name, - {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { + module_config, b, kernel_name, {arg0, arg1, arg2, arg3}, + [&](ArgumentVector args) { kernel_body_generator(args[0], args[1], args[2], args[3]); }); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index cd8dd72cd775d5e0b52f96a2326367da0775e7eb..3c4aee70bca657e4a6232834ad1bfcbfa460571b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -185,15 +185,15 @@ IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( const IrArray::Index& tile_index) { - IrArray::Index elem_index = tile_index; + std::vector elem_multi_index = tile_index.multidim(); for (int i = DimY; i < DimTot; ++i) { - elem_index[i] = + elem_multi_index[i] = b_->CreateMul(tile_index[i], llvm::ConstantInt::get(tile_index[i]->getType(), GetTileSizeForDimension(i)), "tile_origin." + std::to_string(i)); } - return elem_index; + return IrArray::Index(elem_multi_index, tile_index.GetType()); } llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index fe320bbe727111fbc986cc1fbc217feed74d30f1..e08248b9cc46134675311a82f69f02572a495fe4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -74,8 +73,8 @@ void ForLoop::Emit(llvm::IRBuilder<>* b) { // Split the preheader to create an exit basic block. The exit basic block // will contain all instructions at or after insert_point. - exit_bb_ = preheader_bb_->splitBasicBlock( - insert_point, AsStringRef(GetQualifiedName("loop_exit"))); + exit_bb_ = preheader_bb_->splitBasicBlock(insert_point, + GetQualifiedName("loop_exit")); // splitBasicBlock adds an unconditional branch between the split basic // blocks. Remove it. An unconditional branch will be added below from the @@ -95,9 +94,8 @@ void ForLoop::Emit(llvm::IRBuilder<>* b) { llvm::Function* func = preheader_bb_->getParent(); b->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); - llvm::Value* indvar_address = - b->CreateAlloca(start_index_->getType(), nullptr, - AsStringRef(GetQualifiedName("invar_address"))); + llvm::Value* indvar_address = b->CreateAlloca( + start_index_->getType(), nullptr, GetQualifiedName("invar_address")); // Preheader basic block. // Initialize induction variable starting index. Create branch to the header. @@ -111,8 +109,7 @@ void ForLoop::Emit(llvm::IRBuilder<>* b) { // Emit the loop conditional branch. Load and compare indvar with ending // index and jump to loop exit if equal. Jump to body otherwise. b->SetInsertPoint(header_bb_); - indvar_ = - b->CreateLoad(indvar_address, AsStringRef(GetQualifiedName("indvar"))); + indvar_ = b->CreateLoad(indvar_address, GetQualifiedName("indvar")); llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_); b->CreateCondBr(/*Cond=*/exit_cond, /*True=*/exit_bb_, /*False=*/body_bb_); @@ -237,25 +234,26 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { std::vector dimensions(shape.rank()); std::iota(dimensions.begin(), dimensions.end(), 0); - return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); + return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix), + shape, index_type_); } -IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( +std::vector ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, absl::Span dimensions, absl::string_view suffix) { - llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); + std::vector multi_index(shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ llvm_ir::IrName(suffix, absl::StrCat(dimension))); - index[dimension] = loop->GetIndVarValue(); + multi_index[dimension] = loop->GetIndVarValue(); } - return index; + return multi_index; } -IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( +std::vector ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost @@ -271,18 +269,18 @@ IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( // Create loop nest with one for-loop for each dimension of the // output. - llvm_ir::IrArray::Index index = + std::vector multi_index = AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); // Verify every dimension except the 'dimension_to_skip' dimension was set in // the index. - for (size_t dimension = 0; dimension < index.size(); ++dimension) { + for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) { if (dimension == dimension_to_skip) { - DCHECK_EQ(nullptr, index[dimension]); + DCHECK_EQ(nullptr, multi_index[dimension]); } else { - DCHECK_NE(nullptr, index[dimension]); + DCHECK_NE(nullptr, multi_index[dimension]); } } - return index; + return multi_index; } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index ac3bba3c9fd6a9eb4e7822474963fcc5a394baf7..1dbc9745c080e845140e0976b788d8465deaa2b4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -43,6 +43,9 @@ enum class UnrollMode { // A class for constructing a for-loop in LLVM IR. class ForLoop { public: + ForLoop(const ForLoop&) = delete; + ForLoop& operator=(const ForLoop&) = delete; + // Emit a for-loop at the current insert point of the given IRBuilder. // // start_index and end_index are the loop bounds (end_index is not inclusive). @@ -169,18 +172,11 @@ class ForLoop { llvm::Value* indvar_; UnrollMode unroll_mode_; bool prevent_vectorization_; - - TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); }; // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : ForLoopNest(/*name=*/"", b) { - SetIndexType(index_ty); - } - ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) : name_(name), @@ -190,6 +186,8 @@ class ForLoopNest { b_(b) { SetIndexType(index_ty); } + ForLoopNest(const ForLoopNest&) = delete; + ForLoopNest& operator=(const ForLoopNest&) = delete; // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have @@ -241,7 +239,7 @@ class ForLoopNest { // The return value is an index with the induction variables. The // size equals the rank of shape and there is a null for each // dimension that is not in "dimensions". - IrArray::Index AddLoopsForShapeOnDimensions( + std::vector AddLoopsForShapeOnDimensions( const Shape& shape, absl::Span dimensions, absl::string_view suffix); @@ -252,9 +250,9 @@ class ForLoopNest { // dimensions of the index are filled except for 'dimension_to_skip'. // name_suffix is the string to append to the names of LLVM constructs (eg, // basic blocks) constructed by this method. - IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, - int64 dimension_to_skip, - absl::string_view name_suffix); + std::vector EmitOperandArrayLoopNest( + const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been @@ -289,8 +287,6 @@ class ForLoopNest { llvm::IRBuilder<>* b_; llvm::Type* index_type_; - - TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 807296329c07b8e4ac630486a1e1f59e4fdfa009..66219c156674318f074d7342f9f2a8a218b4f7e8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -31,6 +31,8 @@ limitations under the License. #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -58,14 +60,6 @@ llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) { } // namespace -string AsString(const std::string& str) { - return string(str.data(), str.length()); -} - -llvm::StringRef AsStringRef(absl::string_view str) { - return llvm::StringRef(str.data(), str.size()); -} - std::unique_ptr DropConstantInitializers( const llvm::Module& module) { std::unique_ptr cloned_module = CloneModule(module); @@ -81,7 +75,7 @@ string DumpModuleToString(const llvm::Module& module) { llvm::raw_string_ostream ostream(buffer_string); module.print(ostream, nullptr); ostream.flush(); - return AsString(buffer_string); + return buffer_string; } llvm::CallInst* EmitCallToIntrinsic( @@ -248,7 +242,7 @@ StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, return InternalError("Encoded shape size exceeded int32 size limit."); } *shape_size = static_cast(encoded_shape.size()); - return b->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); + return b->CreateGlobalStringPtr(encoded_shape); } StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, @@ -293,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, int alignment) { - llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); + llvm::IRBuilder<>::InsertPointGuard guard(*b); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), function->getEntryBlock().getFirstInsertionPt()); @@ -302,7 +296,6 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, if (alignment != 0) { alloca->setAlignment(alignment); } - b->restoreIP(insert_point); return alloca; } @@ -334,7 +327,7 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); + b->GetInsertPoint(), absl::StrCat(name, "-after")); } // Our basic block should now end with an unconditional branch. Remove it; @@ -507,24 +500,25 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { return ShapeUtil::ByteSizeOf(shape, pointer_size); } -llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { +llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) { llvm::FastMathFlags flags; - if (fast_math_enabled) { - // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, - // AllowReciprocal, AllowContract, and ApproxFunc. - flags.setFast(); + if (!module_config.debug_options().xla_cpu_enable_fast_math()) { + return flags; } - return flags; -} -void SetTargetOptions(bool fast_math_enabled, - llvm::TargetOptions* target_options) { - // In LLVM backend flags, UnsafeFPMath does not explicitly imply - // NoInfs, etc. - target_options->UnsafeFPMath = fast_math_enabled; - target_options->NoInfsFPMath = fast_math_enabled; - target_options->NoNaNsFPMath = fast_math_enabled; - target_options->NoSignedZerosFPMath = fast_math_enabled; + // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal, + // AllowContract, and ApproxFunc. + flags.setFast(); + + if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) { + flags.setNoNaNs(false); + } + + if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) { + flags.setNoInfs(false); + } + + return flags; } std::map MergeMetadata( @@ -575,14 +569,6 @@ std::map MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(absl::string_view prefix) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); - - tensorflow::mutex_lock lock(mu); - return uniquer->GetUniqueName(prefix); -} - static Status CreateAndWriteStringToFile(const string& directory_name, const string& file_name, const string& text) { @@ -596,35 +582,34 @@ static Status CreateAndWriteStringToFile(const string& directory_name, return Status::OK(); } -Status DumpIRToDirectory(const string& directory_name, - const string& hlo_module_name, - const llvm::Module& llvm_module, bool optimized) { +void DumpIrIfEnabled(const HloModule& hlo_module, + const llvm::Module& llvm_module, bool optimized) { + const auto& debug_opts = hlo_module.config().debug_options(); + if (!DumpingEnabledForHloModule(hlo_module)) { + return; + } // We can end up compiling different modules with the same name when using // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. - string unique_and_safe_file_name = GetProcessUniqueIrFileName( - absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); - - string ir_file_name = tensorflow::io::JoinPath( - directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); + string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt"); + DumpToFileInDirOrStdout(hlo_module, absl::StrCat(suffix, ".ll"), + DumpModuleToString(llvm_module)); // For some models the embedded constants can be huge, so also dump the module - // with the constants stripped to get IR that is easier to manipulate. - string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); - - TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( - directory_name, ir_file_name, DumpModuleToString(llvm_module))); - return CreateAndWriteStringToFile( - directory_name, ir_no_constant_initializers_file_name, - DumpModuleToString(*DropConstantInitializers(llvm_module))); + // with the constants stripped to get IR that is easier to manipulate. Skip + // this if we're dumping to stdout; there's no point in duplicating everything + // when writing to the terminal. + if (!DumpingToStdout(debug_opts)) { + DumpToFileInDir(hlo_module, absl::StrCat(suffix, "-noconst.ll"), + DumpModuleToString(*DropConstantInitializers(llvm_module))); + } } -llvm::Function* CreateFunction(llvm::FunctionType* function_type, - llvm::GlobalValue::LinkageTypes linkage, - bool enable_fast_math, bool optimize_for_size, - absl::string_view name, llvm::Module* module) { +llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + const HloModuleConfig& module_config, + absl::string_view name, + llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -634,17 +619,23 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, // created by the JIT compiled code. function->setHasUWTable(); - if (enable_fast_math) { + if (module_config.debug_options().xla_cpu_enable_fast_math()) { function->addFnAttr("unsafe-fp-math", "true"); - function->addFnAttr("no-infs-fp-math", "true"); - function->addFnAttr("no-nans-fp-math", "true"); function->addFnAttr("no-signed-zeros-fp-math", "true"); + + if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) { + function->addFnAttr("no-nans-fp-math", "true"); + } + + if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) { + function->addFnAttr("no-infs-fp-math", "true"); + } } // Add the optize attribute to the function if optimizing for size. This // controls internal behavior of some optimization passes (e.g. loop // unrolling). - if (optimize_for_size) { + if (cpu::options::OptimizeForSizeRequested(module_config)) { function->addFnAttr(llvm::Attribute::OptimizeForSize); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index c604c7c870adf734a29017e6accbd159317a9548..7b7d86364e2786f5989b9fa8c13de257fc1a80e5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -45,14 +45,13 @@ class TargetOptions; namespace xla { namespace llvm_ir { -// Convert a std::string (used by LLVM's interfaces) to string. -string AsString(const std::string& str); - // Convert a absl::string_view to a llvm::StringRef. Note: both // absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(absl::string_view str); +inline llvm::StringRef AsStringRef(absl::string_view str) { + return llvm::StringRef(str.data(), str.size()); +} template llvm::ArrayRef AsArrayRef(const std::vector& vec) { @@ -71,7 +70,7 @@ string DumpToString(const T& entity) { llvm::raw_string_ostream ostream(buffer_string); entity.print(ostream); ostream.flush(); - return AsString(buffer_string); + return buffer_string; } // Dump the given LLVM module to a string. This requires a function distinct @@ -264,12 +263,7 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout); // Gets an llvm::FastMathFlags that reflects the settings in the given // module config. -llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled); - -// Sets values in the given TargetOptions struct according to the given -// compilation options. -void SetTargetOptions(bool fast_math_enabled, - llvm::TargetOptions* target_options); +llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config); // Computes a conservative union of the metadata in "a" and "b". For // aliasing-related metadata, this means the result can be applied to @@ -279,19 +273,19 @@ std::map MergeMetadata( llvm::LLVMContext* context, const std::map& a, const std::map& b); -// Dumps out `llvm_module` to a file in the directory named `directory_name`, -// creating the directory if necessary. A sanitized version of -// `hlo_module_name` is incorporated into the file name. If `optimized` is true -// then a suffix of "-with-opt.ll" is used, else a suffix of "-no-opt.ll" is -// used. -Status DumpIRToDirectory(const string& directory_name, - const string& hlo_module_name, - const llvm::Module& llvm_module, bool optimized); - -llvm::Function* CreateFunction(llvm::FunctionType* function_type, - llvm::GlobalValue::LinkageTypes linkage, - bool enable_fast_math, bool optimize_for_size, - absl::string_view name, llvm::Module* module); +// Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is +// enabled for the given HLO module. +// +// A sanitized version of `hlo_module_name` is incorporated into the file name. +// If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix +// of "-no-opt.ll" is used. +void DumpIrIfEnabled(const HloModule& hlo_module, + const llvm::Module& llvm_module, bool optimized); + +llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + const HloModuleConfig& module_config, + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 0dc120e0b0df47f261435f490a8459b49d989b53..638ff2c6e0b38ec15b8ee4a06098d0442dc1d203 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -99,15 +98,16 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). ForLoopNest loop_nest(loop_name, b_); - IrArray::Index array_index(index_type, shape_.dimensions_size()); + std::vector array_multi_index(shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), /*suffix=*/absl::StrFormat("dim.%d", dimension)); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } + IrArray::Index array_index(array_multi_index, shape_, index_type); // Set IR builder insertion point to the loop body basic block of the // innermost loop. diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 89b6a36f96beedbcb7322e6164ac59221650d3d8..3ac8eb78cb03b216fc487c18c0a8d91caea1bc2a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -45,13 +46,14 @@ namespace llvm_ir { namespace { // Adds the inner comparison loop body where we compare elements. -void EmitCompareLoopBody( - int64 iteration_bound, PrimitiveType key_type, int64 num_values, - int64 iota_values_parameter_index, llvm::Value* element_pair_index, +Status EmitCompareLoopBody( + int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, - std::function read_element, + std::function + element_address, std::function write_element, + const EmitCallToNestedComputationCallback& emit_compare_callback, llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { auto index_typed_constant = [&](int64 value) { return llvm::ConstantInt::get(index_type, value); @@ -108,74 +110,44 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.If("smaller_comparison_index", do_comparison, [&]() { - auto key1 = read_element(0, current_keys_index); - auto key2 = read_element(0, compare_keys_index); - auto compare_key1 = key1; - auto compare_key2 = key2; - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the - // sort has a predictable behavior in the presence of NaNs. Rather - // than using floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the - // obvious order, -0 is ordered before 0, and -NaN and NaN appear at - // the beginning and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; - } - // If key2 < key1 - auto is_smaller_than = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - if (iota_values_parameter_index >= 0) { - auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); - auto key_index1 = - read_element(iota_values_parameter_index, current_keys_index); - auto key_index2 = - read_element(iota_values_parameter_index, compare_keys_index); - auto index_is_smaller_than = - b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); - is_smaller_than = b->CreateOr( - is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() { + std::vector values_to_compare; + for (int i = 0; i < num_values; ++i) { + values_to_compare.push_back(element_address(i, compare_keys_index)); + values_to_compare.push_back(element_address(i, current_keys_index)); } + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer", + b); + TF_RETURN_IF_ERROR( + emit_compare_callback(values_to_compare, compare_return_buffer)); + llvm::Value* result = b->CreateLoad(compare_return_buffer); + + // Check if the 'compare' function returns true. + llvm::Value* is_smaller_than = + b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0), + "boolean_predicate"); ksl.If("is_smaller_than", is_smaller_than, [&]() { - // Swap key1 with key2. - write_element(0, current_keys_index, key2); - write_element(0, compare_keys_index, key1); - for (int64 i = 1; i <= num_values; ++i) { - // Also swap the values. - auto value1 = read_element(i, current_keys_index); - auto value2 = read_element(i, compare_keys_index); - write_element(i, current_keys_index, value2); - write_element(i, compare_keys_index, value1); + for (int64 i = 0; i < num_values; ++i) { + // Swap the values. + auto value1 = b->CreateLoad(values_to_compare[i * 2]); + auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]); + write_element(i, current_keys_index, value1); + write_element(i, compare_keys_index, value2); } }); + return Status::OK(); }); } -void EmitTiledCompareLoop( +Status EmitTiledCompareLoop( const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, - int64 dimension_to_sort_bound, PrimitiveType keys_type, - absl::Span xor_masks, const std::vector& params, - const std::vector& param_shmem_buffers, - int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { + int64 dimension_to_sort_bound, absl::Span xor_masks, + const std::vector& params, + const std::vector& param_shmem_buffers, int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback, + llvm::IRBuilder<>* b) { KernelSupportLibrary ksl(b); llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); @@ -200,7 +172,7 @@ void EmitTiledCompareLoop( [&]() { auto cache_index = b->CreateShl(thread_id, value_one); read_or_write(cache_index, current_keys_index); - // Increment to go the next index position. + // Increment to go to the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. ksl.If("inner_smaller_keys_index", @@ -215,10 +187,12 @@ void EmitTiledCompareLoop( }; // Copy operand tiles from the operand buffers to shared memory. - IrArray::Index keys_index = tiled_keys_index; + std::vector keys_multi_index = tiled_keys_index.multidim(); for (int64 i = 0; i < params.size(); ++i) { copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { - keys_index[dimension_to_sort] = index; + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, params[i].GetShape(), + tiled_keys_index.GetType()); auto value = params[i].EmitReadArrayElement(keys_index, b); b->CreateStore(value, b->CreateGEP(param_shmem_buffers[i], @@ -230,10 +204,18 @@ void EmitTiledCompareLoop( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); // Now emit the bodies of the comparison loops. - auto read_element = [&](int64 operand, llvm::Value* index) { - return b->CreateLoad( + auto element_address = [&](int64 operand, llvm::Value* index) { + auto shared_memory_address = b->CreateGEP(param_shmem_buffers[operand], - {tiled_keys_index.GetConstantWithIndexType(0), index})); + {tiled_keys_index.GetConstantWithIndexType(0), index}); + auto ptr_type = shared_memory_address->getType(); + // We need a generic pointer with address space 0 instead of a pointer to + // shared memory (address space 3) so that we can pass it to the comparison + // computation. + return b->CreateAddrSpaceCast( + shared_memory_address, + llvm::PointerType::get(ptr_type->getPointerElementType(), + /*AddressSpace=*/0)); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { @@ -252,7 +234,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.If( + TF_RETURN_IF_ERROR(ksl.IfWithStatus( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], @@ -260,24 +242,24 @@ void EmitTiledCompareLoop( tiled_keys_index.GetConstantWithIndexType( RoundDownToNearest(dimension_to_sort_bound, tile_size))), [&]() { - EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, - params.size() - 1, iota_values_parameter_index, - element_pair_index, xor_mask, - tiled_keys_index.GetType(), read_element, - write_element, b); + return EmitCompareLoopBody( + dimension_to_sort_bound % tile_size, params.size(), + element_pair_index, xor_mask, tiled_keys_index.GetType(), + element_address, write_element, emit_compare_callback, b); }, [&]() { - EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - iota_values_parameter_index, element_pair_index, - xor_mask, tiled_keys_index.GetType(), - read_element, write_element, b, - /*needs_bounds_checks=*/false); - }); + return EmitCompareLoopBody( + tile_size, params.size(), element_pair_index, xor_mask, + tiled_keys_index.GetType(), element_address, write_element, + emit_compare_callback, b, + /*needs_bounds_checks=*/false); + })); } else { - EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - iota_values_parameter_index, element_pair_index, - xor_mask, tiled_keys_index.GetType(), read_element, - write_element, b, /*needs_bounds_checks=*/false); + TF_RETURN_IF_ERROR(EmitCompareLoopBody( + tile_size, params.size(), element_pair_index, xor_mask, + tiled_keys_index.GetType(), element_address, write_element, + emit_compare_callback, b, + /*needs_bounds_checks=*/false)); } // Wait until all comparisons have happened. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); @@ -286,7 +268,9 @@ void EmitTiledCompareLoop( // Copy the operand tiles back from shared memory to the operand buffers. for (int64 i = 0; i < params.size(); ++i) { copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { - keys_index[dimension_to_sort] = index; + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, params[i].GetShape(), + tiled_keys_index.GetType()); auto value = b->CreateLoad(b->CreateGEP( param_shmem_buffers[i], {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); @@ -301,17 +285,16 @@ void EmitTiledCompareLoop( // same location in shared memory because we have exactly tile_size / 2 many // threads, and the linear index calculated by ParallelLoopEmitter uses // linear_index = blockIdx.x * blockDim.x + threadIdx.x; + return Status::OK(); } } // namespace -Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const std::vector& values_arrays, - int64 iota_values_parameter_index, - absl::string_view name, - absl::Span xor_masks, llvm::IRBuilder<>* b, - const gpu::LaunchDimensions& launch_dimensions, - int64 num_iterations_in_sort_dim, - const int64 tile_size) { +Status EmitSortInPlace( + int64 dimension_to_sort, const std::vector& values_arrays, + absl::string_view name, absl::Span xor_masks, + llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, const int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback) { // Iterate through the keys shape in physical order, but skip the dimension to // sort and make it the innermost loop which is the loop where the comparisons // happen. In the dimension to sort, if we use tiling, we iterate through it @@ -321,7 +304,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // within those 64 elements and are therefore independent of the other // comparisons). - const Shape& keys_shape = keys_array.GetShape(); + const Shape& keys_shape = values_arrays[0].GetShape(); int64 rank = keys_shape.rank(); int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); @@ -338,18 +321,16 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), dimensions_in_iteration_order); - std::vector params(1, keys_array); - params.insert(params.end(), values_arrays.begin(), values_arrays.end()); // Allocate shared memory for the tiled compare loop. - std::vector param_shmem_buffers(params.size(), nullptr); + std::vector param_shmem_buffers(values_arrays.size(), nullptr); if (xor_masks.size() > 1) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - for (int64 i = 0; i < params.size(); ++i) { - llvm::Type* tile_type = - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - params[i].GetShape().element_type(), module), - tile_size); + for (int64 i = 0; i < values_arrays.size(); ++i) { + llvm::Type* tile_type = llvm::ArrayType::get( + llvm_ir::PrimitiveTypeToIrType( + values_arrays[i].GetShape().element_type(), module), + tile_size); param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( module, tile_type, absl::StrCat(name, "_tile_param_", i)); } @@ -371,30 +352,37 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // // This follows the algorithm described on Wikipedia: // https://en.wikipedia.org/wiki/Bitonic_sorter - IrArray::Index keys_index(tiles_index.GetType(), rank); + std::vector keys_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; + keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i]; } if (xor_masks.size() > 1) { - EmitTiledCompareLoop(keys_index, dimension_to_sort, - dimension_to_sort_bound, keys_shape.element_type(), - xor_masks, params, param_shmem_buffers, - iota_values_parameter_index, tile_size, b); + IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(), + tiles_index.GetType()); + TF_RETURN_IF_ERROR(EmitTiledCompareLoop( + keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks, + values_arrays, param_shmem_buffers, tile_size, emit_compare_callback, + b)); } else { - auto read_element = [&](int64 operand, llvm::Value* index) { - keys_index[dimension_to_sort] = index; - return params[operand].EmitReadArrayElement(keys_index, b); + auto element_address = [&](int64 operand, llvm::Value* index) { + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, + values_arrays[operand].GetShape(), + tiles_index.GetType()); + return values_arrays[operand].EmitArrayElementAddress(keys_index, b); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { - keys_index[dimension_to_sort] = index; - params[operand].EmitWriteArrayElement(keys_index, value, b); + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, + values_arrays[operand].GetShape(), + tiles_index.GetType()); + values_arrays[operand].EmitWriteArrayElement(keys_index, value, b); }; - EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), - values_arrays.size(), iota_values_parameter_index, - tiles_index[rank - 1], xor_masks[0], - tiles_index.GetType(), read_element, write_element, - b); + TF_RETURN_IF_ERROR(EmitCompareLoopBody( + dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1], + xor_masks[0], tiles_index.GetType(), element_address, write_element, + emit_compare_callback, b)); } return Status::OK(); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 685f9383acba416f51681270e4037d56abb4b6ea..b9341a34d1f2203db6e02c3df5d607174b6d0f74 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -28,19 +28,18 @@ limitations under the License. namespace xla { namespace llvm_ir { +using EmitCallToNestedComputationCallback = + std::function, llvm::Value*)>; // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' -// dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' -// contains only powers of 2, or values 2^k - 1 (k > 0). If -// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand -// that is a iota and can be used to make the sorting stable. -Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const std::vector& values_arrays, - int64 iota_values_parameter_index, - absl::string_view name, - absl::Span xor_masks, llvm::IRBuilder<>* b, - const gpu::LaunchDimensions& launch_dimensions, - int64 num_iterations_in_sort_dim, int64 tile_size); +// dimension of each array in 'values_arrays'. All other dimensions are kept +// as-is. This implements the inner loop of BitonicSort. It is assumed that +// 'xor_masks' contains only powers of 2, or values 2^k - 1 (k > 0). +Status EmitSortInPlace( + int64 dimension_to_sort, const std::vector& values_arrays, + absl::string_view name, absl::Span xor_masks, + llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index d8d2700e1934fd202d44a1dc60e71a99913d4537..3a4814b1857ec6e1fec3b79ef75d40db4fb99269 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -29,9 +29,14 @@ limitations under the License. namespace xla { namespace llvm_ir { +static llvm::Module* getModuleFromBuilder(llvm::IRBuilder<>* b) { + return b->GetInsertBlock()->getModule(); +} + void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { + llvm::Module* module = getModuleFromBuilder(b); CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -65,7 +70,8 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, } void EmitTuple(const IrArray& tuple, absl::Span operands, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { + llvm::Module* module = getModuleFromBuilder(b); for (size_t i = 0; i < operands.size(); ++i) { auto* store = b->CreateStore( b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), @@ -76,18 +82,19 @@ void EmitTuple(const IrArray& tuple, absl::Span operands, } void EmitTuple(const IrArray& tuple, absl::Span buffers, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { std::vector buffer_ptrs; buffer_ptrs.reserve(buffers.size()); absl::c_transform( buffers, std::back_inserter(buffer_ptrs), [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); - llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b); } llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { + llvm::Module* module = getModuleFromBuilder(b); llvm::Value* element_ptr = b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)}); llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 94340b91d8eeea1ba4681c2e49c0894eab2f6cc0..67d6323aba2f0bc10e8099014a214fc3025893ac 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -61,17 +61,17 @@ namespace llvm_ir { // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. void EmitTuple(const IrArray& tuple, absl::Span operands, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); // Similar to EmitTuple above, except that the output buffers are provided in // the form of IrArray. void EmitTuple(const IrArray& tuple, absl::Span buffers, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction @@ -79,7 +79,7 @@ void EmitTuple(const IrArray& tuple, absl::Span buffers, // Returns an llvm value representing a pointer to the tuple element buffer. llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 600b069ecdbabf6b05e6abb3a6b8d9b1a4b0ecf4..3470fe5b2c34bf832207ed546fad176319446f31 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -110,6 +110,7 @@ ExecutionOptions CreateExecutionOptions( *execution_options.mutable_shape_with_output_layout() = result_shape.ToProto(); } + execution_options.set_num_replicas(build_options.num_replicas()); return execution_options; } diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f6feed29935a1446499559d947dff0a8eefe5d2e..e55b83d17e90bc2ca0053a0421cf80ef6edd5bca 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -28,7 +29,7 @@ namespace { bool IsAllowed(char character) { auto c = static_cast(character); - return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; + return (absl::ascii_isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; } } // namespace @@ -46,7 +47,7 @@ NameUniquer::NameUniquer(const string& separator) { string result = name; char c = static_cast(result[0]); - if (!isalpha(c) && c != '_') { + if (!absl::ascii_isalpha(c) && c != '_') { result[0] = '_'; } for (int i = 1; i < result.length(); i++) { diff --git a/tensorflow/compiler/xla/service/op_expander_pass.cc b/tensorflow/compiler/xla/service/op_expander_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..02c9d4b387b112be39c204d35fe4fa1013ed064c --- /dev/null +++ b/tensorflow/compiler/xla/service/op_expander_pass.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +StatusOr OpExpanderPass::Run(HloModule* module) { + std::vector matching_instructions; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + absl::c_copy_if( + computation->instructions(), std::back_inserter(matching_instructions), + [&](HloInstruction* inst) { return InstructionMatchesPattern(inst); }); + } + + for (HloInstruction* inst : matching_instructions) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, + ExpandInstruction(inst)); + if (expanded_root == nullptr) { + continue; + } + TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + } + + return !matching_instructions.empty(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/op_expander_pass.h b/tensorflow/compiler/xla/service/op_expander_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..276e3d70b8ecd8742e0b277698765063198fe872 --- /dev/null +++ b/tensorflow/compiler/xla/service/op_expander_pass.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass is an abstract superclass for passes that replace operations that +// match a pattern. It is intended to be subclassed, not used directly. +// +// This pass is useful for legalizing HLO instructions that a particular backend +// does not support into other HLO instructions. +class OpExpanderPass : public HloModulePass { + public: + StatusOr Run(HloModule* module) override; + + protected: + // Returns `true` if `instruction` should be expanded by this pass. + virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; + + // Returns a replacement for `instruction`, or nullptr if no replacement is + // neeeded (e.g. only the to_apply subcomputation of the instruction was + // modified). + virtual StatusOr ExpandInstruction( + HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc new file mode 100644 index 0000000000000000000000000000000000000000..701c629add52a217f16877a085b9ef2d096623d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Returns true if the given shape is a non-nested tuple. +bool IsNonNestedTuple(const Shape& shape) { + return shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape); +} + +} // namespace + +StatusOr OptimizeInputOutputBufferAlias::Build( + const Shape& input_shape, const Shape& output_shape, + HloInputOutputAliasConfig* alias_config) { + bool changed = false; + TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(output_shape)); + VLOG(1) << "input_shape:" << input_shape.ToString(); + VLOG(1) << "output_shape:" << output_shape.ToString(); + + // For all buffers defined by the parameter, build a map from the byte + // size to the list of the buffers of that size. + absl::flat_hash_map> size_to_input_index; + ShapeUtil::ForEachSubshape( + input_shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return; + } + int64 bytes = size_func_(subshape); + size_to_input_index[bytes].push(index); + }); + + // For each result buffer shape index, take the first unused parameter + // buffer that matches the size. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + output_shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return Status::OK(); + } + int64 bytes = size_func_(subshape); + + auto it = size_to_input_index.find(bytes); + if (it != size_to_input_index.end() && !it->second.empty()) { + changed = true; + const ShapeIndex& input_index = it->second.front(); + const ShapeIndex& output_index = index; + if (!alias_config->ParameterHasAlias(0, input_index) && + !alias_config->OutputHasAlias(output_index)) { + TF_RETURN_IF_ERROR(alias_config->SetUpAlias( + output_index, 0, input_index, + HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + } + VLOG(3) << "Set up alias from with param index " + << it->second.front().ToString() << ", shape size " << bytes + << " and result subshape " + << ShapeUtil::HumanStringWithLayout(subshape) << " at index " + << index.ToString(); + it->second.pop(); + } + return Status::OK(); + })); + return changed; +} + +StatusOr OptimizeInputOutputBufferAlias::Run(HloModule* module) { + // User buffer alias only work for modules with 1 parameter. + if (module->entry_computation()->num_parameters() != 1) { + return false; + } + + HloInputOutputAliasConfig* alias_config = + &module->input_output_alias_config(); + + return Build(module->entry_computation()->parameter_instruction(0)->shape(), + module->entry_computation()->root_instruction()->shape(), + alias_config); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h new file mode 100644 index 0000000000000000000000000000000000000000..79ce468e975300ed703ae0fd780f4b9d5328a4b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// This pass opportunistically finds input and output buffers that can be +// aliased, and writes the alias config into the HloModule. +// +// The input and the output buffers can be in any shape, and each output buffer +// can alias with an input buffer with the same size. Each input buffer may only +// alias with a single output buffer. For example, for the following parameter +// and the output buffers, +// +// Parameters : { P1(2MiB), P2(4MiB), P3(8MiB), P4(4MiB), P5(4MiB), ... } +// Outputs : { O1(4MiB), O2(2MiB), O3(4MiB), O4(6MiB), O5(4MiB), ... } +// +// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), (O5, P5), .. +class OptimizeInputOutputBufferAlias : public HloModulePass { + using ShapeSizeFunction = std::function; + + public: + OptimizeInputOutputBufferAlias(ShapeSizeFunction size_func) + : size_func_(size_func) {} + ~OptimizeInputOutputBufferAlias() override = default; + + absl::string_view name() const override { + return "optimize_input_output_buffer_alias.h"; + } + + StatusOr Run(HloModule* module) override; + + private: + friend class OptimizeInputOutputBufferAliasTest; + + StatusOr Build(const Shape& input_shape, const Shape& output_shape, + HloInputOutputAliasConfig* alias_config); + ShapeSizeFunction size_func_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..41e90f9b6931619fd9824e2eda25e12e4c7197b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +// Tests that UserBufferAlias properly maps input and output buffer indices of +// various shapes for aliasing. +class OptimizeInputOutputBufferAliasTest : public HloTestBase { + protected: + OptimizeInputOutputBufferAliasTest() { + r1f32_ = ShapeUtil::MakeShape(F32, {4}); + r2f32_ = ShapeUtil::MakeShape(F32, {4, 5}); + r3f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6}); + r4f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + + auto size_func = [](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); + }; + + optimize_pass_ = + absl::make_unique(size_func); + } + + // Returns the number of output indices that aliases with the input. + int64 AliasCount() { + int64 count = 0; + + config_.ForEachAlias( + [&](const ShapeIndex&, const HloInputOutputAliasConfig::Alias&) { + count++; + }); + return count; + } + + bool BuildAliasConfig(const Shape& input_shape, const Shape& output_shape) { + config_ = HloInputOutputAliasConfig(output_shape); + auto changed = optimize_pass_->Build(input_shape, output_shape, &config_); + TF_CHECK_OK(changed.status()); + + return changed.ValueOrDie(); + } + + std::unique_ptr optimize_pass_; + + HloInputOutputAliasConfig config_; + + Shape r1f32_; + Shape r2f32_; + Shape r3f32_; + Shape r4f32_; +}; + +// All shapes are different, so no aliasing is available. +TEST_F(OptimizeInputOutputBufferAliasTest, AllDifferentBufferSizes) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_}); + Shape output = ShapeUtil::MakeTupleShape({r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_FALSE(changed); + EXPECT_EQ(AliasCount(), 0); +} + +// Input and output shapes are equal, so buffers can alias at the same index. +TEST_F(OptimizeInputOutputBufferAliasTest, OrderedNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + EXPECT_EQ(AliasCount(), 4); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{1}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{2}); + EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{3}); +} + +// Only a subset of the tuple element shapes match between the input and the +// output. +TEST_F(OptimizeInputOutputBufferAliasTest, PartialReuseNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r1f32_, r2f32_, r2f32_}); + Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 2); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1}); +} + +// The output shape is reverse of the input shape, but we can still reuse all +// the buffers. +TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape({r4f32_, r3f32_, r2f32_, r1f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 4); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{3}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{2}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1}); + EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{0}); +} + +TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({r1f32_}), r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape( + {r1f32_, ShapeUtil::MakeTupleShape({r3f32_, r2f32_}), r2f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 3); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0, 0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex({1, 1})); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex({1, 0})); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 9e3d1060210790f60243195a1c1dff13f1fc7fc5..ae1df60d350babda12f0dc37aea41e01b8a51561 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -67,6 +67,7 @@ namespace xla { // - WithOneUse: Instruction is used as an operand exactly once. // - WithOneUser: Instruction is used by exactly one other instruction, but // is possibly used more than once as an operand (e.g. multiply(x,x)). +// - WithComparisonDirection: instr has the given direction // // Shape(): // - EqualTo @@ -1671,6 +1672,40 @@ class HloInstructionPatternOneUserImpl } }; +class HloInstructionPatternComparisonDirectionImpl { + public: + explicit constexpr HloInstructionPatternComparisonDirectionImpl( + ComparisonDirection direction) + : direction_(direction) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has comparison direction " + << ComparisonDirectionToString(direction_); + } + + private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kCompare || + inst->comparison_direction() != direction_) { + EXPLAIN << "HloInstruction is not comparison " + << ComparisonDirectionToString(direction_); + return false; + } + return true; + } + + ComparisonDirection direction_; +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -1956,6 +1991,14 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOneUserImpl()); } + // Modifies the pattern to match only if the instruction has the given + // comparison direction. + auto WithComparisonDirection(ComparisonDirection direction) const + -> decltype(this->AppendImpl( + HloInstructionPatternComparisonDirectionImpl(direction))) { + return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction)); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { impl_.DescribeTo(os, indent); } @@ -2053,10 +2096,12 @@ XLA_UNOP_PATTERN(RecvDone) XLA_UNOP_PATTERN(ReducePrecision) XLA_UNOP_PATTERN(Reshape) XLA_UNOP_PATTERN(Reverse) +XLA_UNOP_PATTERN(Rsqrt) XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Slice) +XLA_UNOP_PATTERN(Sqrt) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN @@ -2116,18 +2161,13 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Compare) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) -XLA_BINOP_PATTERN(Ge) -XLA_BINOP_PATTERN(Gt) -XLA_BINOP_PATTERN(Le) -XLA_BINOP_PATTERN(Lt) XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) -XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) @@ -2240,6 +2280,73 @@ XLA_VARIADIC_OP_PATTERN(Reduce); XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); +// Helpers for comparison instructions. +#define XLA_COMPARE_PATTERN(NAME) \ + inline auto NAME()->decltype( \ + Op().WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op().WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } + +#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ + XLA_COMPARE_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs))) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ + } + +XLA_COMMUTATIVE_COMPARE_PATTERN(Eq); +XLA_COMMUTATIVE_COMPARE_PATTERN(Ne); +XLA_COMPARE_PATTERN(Ge); +XLA_COMPARE_PATTERN(Gt); +XLA_COMPARE_PATTERN(Le); +XLA_COMPARE_PATTERN(Lt); + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc index 9ca2fb05c1f7ef093c58237cf21fbc7c813a592a..f51a18b13894d75300c46835fabd82a4ce0699af 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -23,7 +23,6 @@ namespace xla { namespace { namespace m = ::xla::match; -using ::testing::Eq; using ::testing::Not; template diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 5c3c009a68bffbda8642fceedfb724879fbf1530..cbe8c4a2410d3f569933fe86ae0a8056b6b5ed85 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -931,5 +931,48 @@ TEST(PatternMatcherTest, OneUseAndOneUser) { "in p0 = f32[] parameter(0)"); } +TEST(HloMatchersTest, Comparison) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); + auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); + auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kEq); + auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kNe); + auto add = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get()); + auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(), + ComparisonDirection::kLe); + + EXPECT_TRUE(Match(eq.get(), m::Compare())); + EXPECT_TRUE(Match(eq.get(), m::Eq())); + EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1)))); + EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0)))); + EXPECT_TRUE(Match(ne.get(), m::Compare())); + EXPECT_TRUE(Match(ne.get(), m::Ne())); + EXPECT_TRUE(Match( + le.get(), + m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1))))); + EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0), + m::Add(m::Parameter(0), m::Parameter(1))))); + + EXPECT_FALSE(Match(eq.get(), m::Add())); + EXPECT_FALSE(Match(eq.get(), m::Ne())); + EXPECT_FALSE( + Match(le.get(), + m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1))))); + EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0)))); + EXPECT_DESC_AND_EXPLANATION( + eq, m::Ne().WithOneUser(), + "an HloInstruction:\n" + " * with opcode compare AND\n" + " * which has comparison direction NE AND\n" + " * which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction is not comparison NE\n" + "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), " + "direction=EQ"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index a62118df157edf67114ff41befbdce3da129fe93..9e2d740694012b05510fc098048c762b3057da0d 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -29,11 +29,6 @@ limitations under the License. // // Where the instruction must be elementwise, and both reshapes and transposes // are moved. -// -// Most elementwise instructions support implicit broadcast of scalar operands, -// but select is a special-case. The signature is Select(Pred, A, B), and the -// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or -// transposes to a scalar should be cheap, we simply never move them. #include "tensorflow/compiler/xla/service/reshape_mover.h" @@ -64,20 +59,14 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { // // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially // reshapable if *all* instructions in the chain have user_count == 1. And - // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we - // rely on implicit scalar broadcast for scalars to be trivial. In addition, - // these cases make it harder to maintain correctness of the UpdateOperand - // logic below. + // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar. + // In addition, these cases make it harder to maintain correctness of the + // UpdateOperand logic below. // // So don't handle these chains, unless you update the tests and code to deal // with these properly. One idea is to add a pass immediately beforehand that // collapses trivial runs of reshapes / transposes. - // Scalars can operate with any shape. - if (ShapeUtil::IsScalar(instruction->shape())) { - return true; - } - // A constant can trivially reshape the literal it holds. if (instruction->opcode() == HloOpcode::kConstant) { return true; @@ -143,8 +132,8 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { // This function is called once we've decided to sink reshape/transpose operands // across an instruction. It returns an updated `operand` with a shape that -// plays nicely with `new_operand_shape`; either it has the same shape (of the -// correct type), or it is a scalar that may be implicitly broadcast. +// plays nicely with `new_operand_shape`; it has the same shape (of the +// correct type). HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, const Shape& new_operand_shape, HloInstruction* operand) { @@ -221,9 +210,8 @@ StatusOr PerformSinkReshapeOrTranspose( UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { - // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not have easy shape - // changes, so all the fused instructions have the same dimensions. + // Here we already know `instruction` is elementwise, and all the fused + // instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); shape->clear_dimensions(); @@ -287,21 +275,17 @@ bool IsReshapeMoveCandidate(HloInstruction* instruction) { } // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, they may be - // implicitly broadcast, which can confound the movement's - // correctness. + // 0. Have the same dimensions as the output. // // And one of the following: // 1. Are reshapes or transposes that have the same input and // output shapes as all other reshaped or transposed operands. // or - // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars - // that can change shape trivially. + // 2. Are one of kConstant, kRng, broadcast of a scalar value. const HloInstruction* first_reshape_operand = nullptr; for (const HloInstruction* operand : instruction->operands()) { if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " + VLOG(5) << "Operand shape differs from output shape; so preventing " "movement\n\toperand: " << operand->ToString(print_no_metadata) << "\n\tinstruction: " << instruction->ToString(print_no_metadata); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 036c3c36f648daf8963a6b25e300b93c1bdf78d9..e3a3feb86404634f0114f4cb8aa9b1c883e78f95 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -134,6 +134,13 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( int64 operand_rank) { HloComputation* computation = index_vector->parent(); const Shape& index_shape = index_vector->shape(); + + // Scatter of a scalar. Return a zero-sized vector of indices. + if (operand_rank == 0) { + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); @@ -171,12 +178,12 @@ static StatusOr CheckIndexValidity( // Valid range for the index: [0, operand_dims - window_sizes] // Check if the index has any negative values. - TF_ASSIGN_OR_RETURN( - HloInstruction * zero_index, + HloInstruction* zero_index = BroadcastZeros(computation, index->shape().element_type(), - AsInt64Slice(index->shape().dimensions()))); - TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, - MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + AsInt64Slice(index->shape().dimensions())); + TF_ASSIGN_OR_RETURN( + HloInstruction * negative_index_check, + MakeCompareHlo(ComparisonDirection::kLe, zero_index, index)); // Check if the index is OOB w.r.t. the operand dimensions and window sizes. std::vector max_valid_index(operand_dims.size()); @@ -187,9 +194,9 @@ static StatusOr CheckIndexValidity( HloInstruction * max_valid_index_constant, MakeR1ConstantHlo(computation, index->shape().element_type(), max_valid_index)); - TF_ASSIGN_OR_RETURN( - HloInstruction * oob_index_check, - MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check, + MakeCompareHlo(ComparisonDirection::kGe, + max_valid_index_constant, index)); // Combine the results of the two checks above. TF_ASSIGN_OR_RETURN( @@ -222,10 +229,9 @@ static StatusOr> ScatterLoopBody( bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; // Build a vector form of the induction variable of the while loop. - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); // Pick the index to scatter from scatter_indices based on the induction_var // and transform that to an index into the `operand` space. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 32707de700b5926105e15922f67c19e0ed7bd7b8..49c346d87fcbafb1e6780b0207f6250077c1e297 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -28,7 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -63,6 +65,10 @@ namespace { using absl::StrCat; using absl::StrFormat; +// Argument used when calling DumpHloModuleIfEnabled before optimizations are +// performed on an HloModule. +constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations"; + // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments(const absl::Span arguments, se::Stream* stream, TransferManager* transfer_manager, @@ -296,11 +302,16 @@ StatusOr> Service::CreateModuleConfig( computation_layout->mutable_result_layout()->SetToDefaultLayout(); } - config->set_replica_count(options_.number_of_replicas()); if (execution_options != nullptr) { + if (execution_options->num_replicas() > 0) { + config->set_replica_count(execution_options->num_replicas()); + } else { + config->set_replica_count(options_.number_of_replicas()); + } config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { + config->set_replica_count(options_.number_of_replicas()); config->set_debug_options(GetDebugOptionsFromFlags()); } @@ -309,6 +320,15 @@ StatusOr> Service::CreateModuleConfig( config->set_intra_op_parallelism_threads( execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); } + + if (execution_options != nullptr && + execution_options->has_device_assignment()) { + TF_ASSIGN_OR_RETURN( + auto device_assignment, + DeviceAssignment::Deserialize(execution_options->device_assignment())); + config->set_static_device_assignment(*device_assignment); + } + return std::move(config); } @@ -333,21 +353,8 @@ StatusOr>> Service::BuildExecutables( // Dump computation proto state if flag is set. std::vector> hlo_snapshots; for (int64 i = 0; i < module_protos.size(); ++i) { - const string& directory_path = - module_configs[i]->debug_options().xla_dump_computations_to(); - const string& execution_directory_path = - module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && execution_directory_path.empty()) { - continue; - } auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; - if (!directory_path.empty()) { - string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name()); - TF_RETURN_IF_ERROR( - Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - } hlo_snapshots.push_back(std::move(hlo_snapshot)); } @@ -363,6 +370,7 @@ StatusOr>> Service::BuildExecutables( const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); module_group->push_back(std::move(module)); } @@ -372,7 +380,9 @@ StatusOr>> Service::BuildExecutables( std::move(executors), device_allocator)); for (size_t i = 0; i < module_protos.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { + const auto& debug_opts = module_configs[i]->debug_options(); + if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) && + debug_opts.xla_dump_hlo_snapshots()) { executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i])); } } @@ -470,24 +480,6 @@ Service::ExecuteParallelAndRegisterResult( } } - // For every stream that had profiling enabled, obtain and debug-dump the HLO - // profile. - for (auto& index_to_profiled_stream : index_to_profiled_streams) { - int64 device = index_to_profiled_stream.first; - se::Stream* stream = index_to_profiled_stream.second; - Executable* executable = executables[device]; - const HloModule& module = executable->module(); - HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), - &executable->hlo_profile_index_map()); - TF_RETURN_IF_ERROR( - executable->PopulateExecutionProfile(&hlo_profile, stream)); - XLA_LOG_LINES( - tensorflow::INFO, - hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); - hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", - &hlo_profile); - } - if (profile != nullptr) { CHECK(!timers.empty()); std::vector timer_nanoseconds; @@ -746,16 +738,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, } for (int i = 0; i < executable_ptrs.size(); i++) { - if (executable_ptrs[i]->dumping_snapshot()) { + Executable* executable = executable_ptrs[i]; + if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(outputs[i], 0)); TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(all_executors[i][0])); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), - executable_ptrs[i]->hlo_snapshot())); - // Dump out the ith snapshot. - TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot()); + executable->hlo_snapshot())); + DumpHloSnapshotIfEnabled(executable->module(), + *executable->hlo_snapshot()); } } @@ -795,26 +788,9 @@ StatusOr> Service::BuildExecutable( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name()); - // Dump computation proto state if flag is set. - auto hlo_snapshot = absl::make_unique(); - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& execution_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !execution_directory_path.empty()) { - *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; - if (!directory_path.empty()) { - string filename = StrFormat("computation_%d__%s", module_proto.id(), - module_proto.entry_computation_name()); - TF_RETURN_IF_ERROR( - Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - } - } - TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(module_proto, *module_config)); - - TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); + DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -824,7 +800,11 @@ StatusOr> Service::BuildExecutable( backend->compiler()->RunBackend( std::move(module), executor, device_allocator)); - if (!execution_directory_path.empty()) { + const auto& debug_opts = module_config->debug_options(); + if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) && + debug_opts.xla_dump_hlo_snapshots()) { + auto hlo_snapshot = absl::make_unique(); + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; executable->set_hlo_snapshot(std::move(hlo_snapshot)); } @@ -934,7 +914,7 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), executable->hlo_snapshot())); - TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); + DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot()); } VLOG(1) << "successfully completed 'execute' request"; @@ -1156,9 +1136,7 @@ Status Service::GetComputationGraphStats( config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); + DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); // Run HLO analysis to get the computation statistics. HloCostAnalysis analysis( @@ -1197,16 +1175,4 @@ StatusOr> Service::Replicas( return replicas; } -Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { - const string xla_dump_unoptimized_hlo_proto_to = - module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); - if (xla_dump_unoptimized_hlo_proto_to.empty()) { - return Status::OK(); - } - HloProto proto = MakeHloProto(module); - return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, - StrCat(module.name(), ".unoptimized")); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 6e31bea7283d5808adf90b06fe7ef927c2f7fbdc..f127e340b5950ae77bcfa22b638c1d9fc8a2024b 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -53,7 +53,7 @@ class ServiceOptions { ServiceOptions& set_platform(se::Platform* platform); se::Platform* platform() const; - // Set the number of replicas to use when compiling replicated + // Set the default number of replicas to use when compiling replicated // programs. ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; @@ -275,10 +275,6 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - // Dumps the (unoptimized) module given if the corresponding DebugOptions - // field has been set. - Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; - // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. DeviceHandle SingleComputationDeviceHandle() const; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index fad7afefafd86d56f0d60b7ecfa1742d53d9452b..e1536684c066a1e29df6ad956ec1bebcfbccba72 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -167,6 +167,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } std::vector output_dimensions(window.dimensions_size()); + std::vector output_is_dynamic(window.dimensions_size()); for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -196,6 +197,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, window.DebugString()); } + if (base_shape.is_dynamic_dimension(i) && + !window_util::IsTrivialWindowDimension(dim)) { + return Unimplemented( + "Dynamic shape is not supported for non trivial window: %s", + window_util::ToString(window)); + } + const int64 dilated_base = window_util::DilatedBound( ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); const int64 padded_dilated_base = @@ -205,9 +213,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, output_dimensions[i] = window_util::StridedBound( padded_dilated_base, dilated_window, dim.stride()); + output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); } - return ShapeUtil::MakeValidatedShape(element_type, output_dimensions); + return ShapeUtil::MakeValidatedShape(element_type, output_dimensions, + output_is_dynamic); } } // namespace @@ -245,6 +255,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kExpm1: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kRsqrt: + case HloOpcode::kSqrt: case HloOpcode::kTanh: if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { @@ -500,17 +512,33 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, padding_config.ShortDebugString()); } + if (!padding_value_shape.is_static()) { + return InvalidArgument("Dynamic padding value is not supported"); + } + std::vector dimensions(operand_shape.rank()); + std::vector is_dynamic(operand_shape.rank()); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); + if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 && + p.edge_padding_low() != 0 && p.interior_padding() != 0) { + return InvalidArgument( + "Dynamic dimension on padding dimension is not supported."); + } dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * p.interior_padding(); + if (dimensions[i] < 0) { + return InvalidArgument("Padding result in negative size for dimension %d", + i); + } + is_dynamic[i] = operand_shape.is_dynamic_dimension(i); } + return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), - dimensions); + dimensions, is_dynamic); } // Current DotDimensionNumbers Requirements: @@ -620,7 +648,9 @@ Status ValidateDotDimensionNumbers( const int64 rhs_contracting_dimension = dimension_numbers.rhs_contracting_dimensions(i); if (lhs.dimensions(lhs_contracting_dimension) != - rhs.dimensions(rhs_contracting_dimension)) { + rhs.dimensions(rhs_contracting_dimension) || + lhs.is_dynamic_dimension(lhs_contracting_dimension) != + rhs.is_dynamic_dimension(rhs_contracting_dimension)) { return fail("Contracting dimension sizes do not match."); } } @@ -634,7 +664,10 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) || + lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.is_dynamic_dimension( + dimension_numbers.rhs_batch_dimensions(i))) { return fail("Batch dimension sizes must match for lhs/rhs."); } } @@ -645,14 +678,17 @@ Status ValidateDotDimensionNumbers( // Generate the result dimensions in order, rhs dimensions followed by lhs // dimensions except the contracted and batch dimensions. std::vector dimensions; + std::vector is_dynamic; for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { dimensions.push_back(lhs.dimensions(lhs_dim)); + is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim)); } for (int64 i = 0; i < lhs.rank(); i++) { if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), i) && !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); + is_dynamic.push_back(lhs.is_dynamic_dimension(i)); } } for (int64 i = 0; i < rhs.rank(); i++) { @@ -660,10 +696,11 @@ Status ValidateDotDimensionNumbers( i) && !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { dimensions.push_back(rhs.dimensions(i)); + is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } Shape result = ShapeUtil::MakeShape( - ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -681,13 +718,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // dimension). In that case, the output shape has the non-1 dimension size // from the lhs/rhs pair in every index. std::vector output_dimensions(lhs.rank()); + std::vector output_dimensions_is_dynamic(lhs.rank()); for (int64 i = 0; i < lhs.rank(); ++i) { if (lhs.dimensions(i) == rhs.dimensions(i)) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else if (lhs.dimensions(i) == 1) { output_dimensions[i] = rhs.dimensions(i); + output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i); } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", @@ -696,7 +737,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - output_dimensions); + output_dimensions, output_dimensions_is_dynamic); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -775,6 +816,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 small_dimension_size = smaller_shape.dimensions(i); int64 large_dimension_size = larger_shape.dimensions(dimension_to_match); + bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); + bool large_is_dynamic = + larger_shape.is_dynamic_dimension(dimension_to_match); // Dimension sizes must be compatible: match or be degenerate (degenerate // case is handled by degenerate dimension broadcasting which occurs after // InDim broadcasting). @@ -786,6 +830,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); } + if (small_is_dynamic != large_is_dynamic) { + if (small_dimension_size == large_dimension_size || + (small_dimension_size == 1 && !small_is_dynamic) || + (large_dimension_size == 1 && !large_is_dynamic)) { + // Do nothing. It's OK when the size-1 dimension is not static. + } else { + return InvalidArgument( + "Broadcast dimension %d dynamism mismatch: %s and %s.", i, + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); + } + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { @@ -795,6 +851,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } output_shape.set_dimensions(dimension_to_match, small_dimension_size); + output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic); } return output_shape; @@ -924,12 +981,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: { + case HloOpcode::kCompare: { TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); @@ -1227,16 +1279,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ExpectArray(scale_shape, "scale input of batch norm inference")); - TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - Status::OK()); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape)); if (feature_index >= operand_shape.rank()) { return InvalidArgument( @@ -1661,11 +1709,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (batch_group_count > 1 && input_batch % kernel_output_features != 0) { return InvalidArgument( - "Expected output feature dimension (value %d) to be divisible by " - "input_batch (value %d) for batch group count %d; " + "Expected input batch (value %d) to be divisible by output feature " + "dimension size (value %d) for batch group count %d; " "got (%s, %s)\n" "Dimension numbers: {%s}.", - kernel_output_features, input_batch, batch_group_count, + input_batch, kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } @@ -1682,7 +1730,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + if (kernel_output_features % feature_group_count > 0) { + // A depthwise/grouped filter has the shape + // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When + // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape + // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output + // feature count (which is equal to kernel output features) has to be a + // multiple of feature_group_count. return InvalidArgument( "Expected output feature dimension (value %d) to be divisible by " "feature_group_count (value %d); " @@ -1731,8 +1786,33 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } + std::vector is_dynamic(num_dims); + for (int i = 0; i < num_dims; i++) { + if (lhs.is_dynamic_dimension(i)) { + if (i == dnums.input_batch_dimension()) { + is_dynamic[dnums.output_batch_dimension()] = true; + } else if (i == dnums.input_feature_dimension()) { + // Input feature dimension is a contracting dimension, which does not + // affect the output dimension size. So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: lhs shape is %s ", + lhs.ToString()); + } + } + if (rhs.is_dynamic_dimension(i)) { + if (i == dnums.kernel_input_feature_dimension()) { + // Kernel feature dimension does not affect the output dimension size. + // So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: rhs shape is %s ", + rhs.ToString()); + } + } + } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - dimensions); + dimensions, is_dynamic); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1776,6 +1856,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } + if (ShapeUtil::IsZeroElementArray(in)) { + return in; + } Shape result = ShapeUtil::ChangeElementType(in, C64); result.set_dimensions(result.dimensions_size() - 1, fft_length[fft_rank - 1] / 2 + 1); @@ -1817,6 +1900,78 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } +/* static */ StatusOr ShapeInference::InferTriangularSolveShape( + const Shape& a, const Shape& b, const TriangularSolveOptions& options) { + if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) || + a.element_type() != b.element_type()) { + return InvalidArgument( + "Expected element types in shape to be floating or complex and " + "identical for TriangularSolve; got %s and %s.", + PrimitiveType_Name(a.element_type()), + PrimitiveType_Name(b.element_type())); + } + if (a.rank() < 2) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s", + a.ToString()); + } + if (b.rank() != a.rank()) { + return InvalidArgument( + "Arguments to triangular solve must have equal rank; got %s and %s.", + b.ToString(), a.ToString()); + } + if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + return InvalidArgument( + "The two minor dimensions of 'a' must have equal size, got %s.", + a.ToString()); + } + if (a.dimensions(a.rank() - 1) != + b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) { + return InvalidArgument( + "The shared dimension of 'a' and 'b' does not match, got shapes %s and " + "%s", + a.ToString(), b.ToString()); + } + absl::Span a_batch_dims(a.dimensions()); + absl::Span b_batch_dims(b.dimensions()); + a_batch_dims.remove_suffix(2); + b_batch_dims.remove_suffix(2); + if (a_batch_dims != b_batch_dims) { + return InvalidArgument( + "The leading batch dimensions of the arguments to triangular solve " + "must be equal; got %s and %s.", + b.ToString(), a.ToString()); + } + if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) || + options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) { + return InvalidArgument( + "Invalid transpose option value for triangular solve (%d).\n", + options.transpose_a()); + } + return b; +} + +/* static */ StatusOr ShapeInference::InferCholeskyShape( + const Shape& a) { + if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "Cholesky; got %s.", + PrimitiveType_Name(a.element_type())); + } + if (a.rank() < 2) { + return InvalidArgument( + "The 'a' argument to Cholesky must have rank >= 2, got shape %s", + a.ToString()); + } + if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + return InvalidArgument( + "The two minor dimensions of 'a' must have equal size, got %s.", + a.ToString()); + } + return a; +} + /* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { @@ -1904,7 +2059,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( - "All reduced tensors must have the sime dimension. Tensor 0 has " + "All reduced tensors must have the same dimension. Tensor 0 has " "shape %s, Tensor %d has shape %s", ShapeUtil::HumanString(*reduced_args[0]), i, ShapeUtil::HumanString(*reduced_args[i])); @@ -1933,20 +2088,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); std::vector new_dimensions; + std::vector new_is_dynamic; for (int i = 0; i < arg.rank(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); + new_is_dynamic.push_back(arg.is_dynamic_dimension(i)); } } if (ShapeUtil::IsScalar(to_apply.result())) { return ShapeUtil::MakeShape(to_apply.result().element_type(), - new_dimensions); + new_dimensions, new_is_dynamic); } else { std::vector result_subshapes; for (const Shape& subshape : to_apply.result().tuple_shapes()) { - result_subshapes.push_back( - ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + result_subshapes.push_back(ShapeUtil::MakeShape( + subshape.element_type(), new_dimensions, new_is_dynamic)); } return ShapeUtil::MakeTupleShape(result_subshapes); } @@ -2020,6 +2177,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(source_shape), ShapeUtil::HumanString(window_result_shape)); } + return operand_shape; } @@ -2260,8 +2418,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (operand_shape.rank() != number_of_indices) { return InvalidArgument( - "Dynamic update slice start number of dimensions %d must match rank " - "%d of slice input (%s).", + "Dynamic update slice start number of dimensions %d must match " + "rank %d of slice input (%s).", number_of_indices, operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); } @@ -2348,7 +2506,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg)); } - if (index >= arg.tuple_shapes_size()) { + if (index < 0 || index >= arg.tuple_shapes_size()) { return InvalidArgument( "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", @@ -2395,59 +2553,55 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferConditionalShape( - const Shape& predicate, const Shape& true_operand, - const Shape& false_operand, const ProgramShape& true_computation, - const ProgramShape& false_computation) { - if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) { - return InvalidArgument("Predicate must be a boolean; got %s.", - ShapeUtil::HumanString(predicate)); - } - - if (true_computation.parameters_size() != 1) { - return InvalidArgument("true_computation must take 1 argument; got %d.", - true_computation.parameters_size()); - } - if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { - auto true_shape_string = [&]() { - return StrFormat("true_operand: %s; true_computation: %s", - ShapeUtil::HumanString(true_operand), - ShapeUtil::HumanString(true_computation)); - }; - return InvalidArgument( - "true_operand must match the shape of the only parameter of " - "true_computation: got %s.", - true_shape_string()); + const Shape& branch_index, + absl::Span branch_computations, + absl::Span branch_operands) { + if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) && + !ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) { + return InvalidArgument("branch_index must be bool or int32; got %s.", + ShapeUtil::HumanString(branch_index)); + } + if (branch_index.element_type() == PRED) { + TF_RET_CHECK(2 == branch_computations.size()); + } else { + TF_RET_CHECK(!branch_computations.empty()); } + TF_RET_CHECK(branch_computations.size() == branch_operands.size()); - if (false_computation.parameters_size() != 1) { - return InvalidArgument("false_computation must take 1 argument; got %d.", - false_computation.parameters_size()); - } - if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { - auto false_shape_string = [&]() { - return StrFormat("false_operand: %s; false_computation: %s", - ShapeUtil::HumanString(false_operand), - ShapeUtil::HumanString(false_computation)); - }; - return InvalidArgument( - "false_operand must match the shape of the only parameter of " - "false_computation: got %s.", - false_shape_string()); - } - if (!ShapeUtil::Compatible(true_computation.result(), - false_computation.result())) { - auto shape_string = [&]() { - return StrFormat( - "true_computation result: %s; false_computation result: %s.", - ShapeUtil::HumanString(true_computation.result()), - ShapeUtil::HumanString(false_computation.result())); - }; - return InvalidArgument( - "the result of true_computation and false_computation must have the " - "same shape: got %s.", - shape_string()); + for (int j = 0; j < branch_computations.size(); ++j) { + if (branch_computations[j].parameters_size() != 1) { + return InvalidArgument( + "branch computation %d must take 1 argument; got %d.", j, + branch_computations[j].parameters_size()); + } + if (!ShapeUtil::Compatible(branch_computations[j].parameters(0), + branch_operands[j])) { + auto shape_string = [&]() { + return StrFormat("operand: %s; computation: %s", + ShapeUtil::HumanString(branch_operands[j]), + ShapeUtil::HumanString(branch_computations[j])); + }; + return InvalidArgument( + "branch operand %d must match the shape of the only parameter of " + "branch computation %d: got %s.", + j, j, shape_string()); + } + + if (!ShapeUtil::Compatible(branch_computations[0].result(), + branch_computations[j].result())) { + auto shape_string = [&]() { + return StrFormat( + "branch 0 computation result: %s; branch %d computation result: %s", + ShapeUtil::HumanString(branch_computations[0].result()), j, + ShapeUtil::HumanString(branch_computations[j].result())); + }; + return InvalidArgument( + "the result of branch 0 computation and branch %d computation must " + "have the same shape: got %s.", + j, shape_string()); + } } - return true_computation.result(); + return branch_computations[0].result(); } /* static */ StatusOr ShapeInference::InferBroadcastShape( @@ -2497,11 +2651,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, operand_shape.dimensions(i) != 1) { return InvalidArgument( "Input dimension should be either 1 or equal to the output dimension " - "it's broadcasting into; the %lldth operand dimension is %lld, the " + "it is broadcasting into; the %lldth operand dimension is %lld, the " "%lldth output dimension is %lld.", i, operand_shape.dimensions(i), broadcast_dimensions[i], output_shape.dimensions(broadcast_dimensions[i])); } + if (operand_shape.is_dynamic_dimension(i) != + output_shape.is_dynamic_dimension(broadcast_dimensions[i])) { + return InvalidArgument( + "Broadcast input and output dynamism mismatch: %s and %s", + operand_shape.ToString(), output_shape.ToString()); + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) { @@ -2544,6 +2704,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand, inferred_shape); + for (auto& unmodified : unmodified_dims) { + if (operand.is_dynamic_dimension(unmodified.first)) { + inferred_shape.set_dynamic_dimension(unmodified.second, true); + } + } + return inferred_shape; } @@ -2551,11 +2719,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(operand.rank()); - std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != operand.rank() || - !std::is_permutation(dimensions.begin(), dimensions.end(), - indices.begin())) { + if (!IsPermutation(dimensions, operand.rank())) { return InvalidArgument( "Transpose dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", @@ -2617,19 +2781,31 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred)); } - if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + if (Shape::Equal() + .IgnoreElementType() + .IgnoreLayout() + .IgnoreDynamicDimension()(pred, on_true) || ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return ShapeUtil::ChangeElementType( + Shape inferred_shape = ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); - } else { - return InvalidArgument( - "Select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s.", - ShapeUtil::HumanString(pred)); + + // Propagate dynamic dimensions if pred is not a scalar. + if (!ShapeUtil::IsScalar(pred)) { + for (int i = 0; i < inferred_shape.rank(); i++) { + if (pred.is_dynamic_dimension(i)) { + inferred_shape.set_dynamic_dimension(i, true); + } + } + } + return inferred_shape; } + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + "different from the other operands: %s.", + ShapeUtil::HumanString(pred)); } /* static */ StatusOr ShapeInference::InferTupleSelectShape( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 7d39ef38e05abf0a81683c1fb0f3999908b27d23..590a664224e6786bf387494139c66a69a43a5247 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -116,6 +116,13 @@ class ShapeInference { static StatusOr InferFftShape(const Shape& in, FftType fft_type, absl::Span fft_length); + // Infers the shape produced by the given triangular solve operation. + static StatusOr InferTriangularSolveShape( + const Shape& a, const Shape& b, const TriangularSolveOptions& options); + + // Infers the shape produced by the given triangular solve operation. + static StatusOr InferCholeskyShape(const Shape& a); + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferAllReduceShape( @@ -201,11 +208,11 @@ class ShapeInference { const ProgramShape& body, const Shape& init); - // Infers the shape produced by a conditional operation. + // Infers the shape produced by a predicated or indexed conditional operation. static StatusOr InferConditionalShape( - const Shape& predicate, const Shape& true_operand, - const Shape& false_operand, const ProgramShape& true_computation, - const ProgramShape& false_computation); + const Shape& branch_index, + absl::Span branch_computations, + absl::Span branch_operands); // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 26120a06b823c9fddf378991cec434a880fb888d..a9cab3f3e694d3b03fbdf57484ca1b584a0b55bf 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -252,7 +252,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const absl::Span& bcast) { + absl::Span bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -896,6 +896,20 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie())); } +TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); + auto inferredNegative_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, -1); + auto inferred2_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, 2); + ASSERT_FALSE(inferredNegative_status.ok()); + ASSERT_FALSE(inferred2_status.ok()); + EXPECT_THAT(inferredNegative_status.status().error_message(), + HasSubstr("attempt to index out of tuple bounds")); + EXPECT_THAT(inferred2_status.status().error_message(), + HasSubstr("attempt to index out of tuple bounds")); +} + TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = ShapeInference::InferBinaryOpShape( @@ -904,55 +918,10 @@ TEST_F(ShapeInferenceTest, InferPowShape) { ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } -TEST_F(ShapeInferenceTest, InferCompareShapeEq) { +TEST_F(ShapeInferenceTest, InferCompareShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeGe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeGt) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeLe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeLt) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeNe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kCompare, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1467,6 +1436,14 @@ TEST_F(ShapeInferenceTest, Pad) { Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); + + dimension1->set_edge_padding_low(-20); + dimension1->set_edge_padding_high(-10); + auto negative_dimension_size = ShapeInference::InferPadShape( + input_shape, padding_value_shape, padding_config); + ASSERT_FALSE(negative_dimension_size.ok()); + ASSERT_THAT(negative_dimension_size.status().error_message(), + HasSubstr("negative size for dimension 1")); } TEST_F(ShapeInferenceTest, Reverse) { @@ -1550,79 +1527,176 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } -TEST_F(ShapeInferenceTest, Conditional) { +TEST_F(ShapeInferenceTest, Rank1Transpose) { + Shape a_shape = ShapeUtil::MakeShape(F32, {5}); + auto inferred_shape_and_status = + ShapeInference::InferTransposeShape(a_shape, {0}); + EXPECT_IS_OK(inferred_shape_and_status); + Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); + EXPECT_TRUE( + ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); +} + +TEST_F(ShapeInferenceTest, ConditionalPred) { auto inferred_status0 = ShapeInference::InferConditionalShape( - pred_, vector_32_, vector_64_, - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_}); EXPECT_IS_OK(inferred_status0.status()); EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); auto inferred_status1 = ShapeInference::InferConditionalShape( - pred_, matrix_32_48_, vector_32_, - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), - ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); + pred_, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)}, + {matrix_32_48_, vector_32_}); EXPECT_IS_OK(inferred_status1.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); auto inferred_status2 = ShapeInference::InferConditionalShape( - pred_, matrix_32_48_, tuple_f32_v32, - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), - ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); + pred_, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, + {matrix_32_48_, tuple_f32_v32}); EXPECT_IS_OK(inferred_status2.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); auto inferred_status_error0 = ShapeInference::InferConditionalShape( - s32_, vector_32_, vector_64_, - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + f32_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), - HasSubstr("Predicate must be a boolean")); + HasSubstr("must be bool or int32")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( - pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, - ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); + pred_, + {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, + {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); EXPECT_FALSE(inferred_status_error1.ok()); EXPECT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("true_computation must take 1 argument")); + HasSubstr("branch computation 0 must take 1 argument")); auto inferred_status_error2 = ShapeInference::InferConditionalShape( - pred_, vector_32_, vector_64_, - ShapeUtil::MakeProgramShape({vector_64_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + pred_, + {ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error2.ok()); EXPECT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("true_operand must match the shape of the only " - "parameter of true_computation")); + HasSubstr("branch operand 0 must match the shape of the only " + "parameter of branch computation 0")); auto inferred_status_error3 = ShapeInference::InferConditionalShape( - pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), - ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); + pred_, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)}, + {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})}); EXPECT_FALSE(inferred_status_error3.ok()); EXPECT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("false_computation must take 1 argument")); + HasSubstr("branch computation 1 must take 1 argument")); auto inferred_status_error4 = ShapeInference::InferConditionalShape( - pred_, vector_32_, vector_64_, - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, + {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error4.ok()); EXPECT_THAT(inferred_status_error4.status().error_message(), - HasSubstr("false_operand must match the shape of the only " - "parameter of false_computation")); + HasSubstr("branch operand 1 must match the shape of the only " + "parameter of branch computation 1")); auto inferred_status_error5 = ShapeInference::InferConditionalShape( - pred_, vector_32_, vector_64_, - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, + {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error5.ok()); EXPECT_THAT(inferred_status_error5.status().error_message(), - HasSubstr("the result of true_computation and false_computation " - "must have the same shape")); + HasSubstr("the result of branch 0 computation and branch 1 " + "computation must have the same shape")); +} + +TEST_F(ShapeInferenceTest, ConditionalIndexed) { + auto r0s32 = ShapeUtil::MakeShape(S32, {}); + auto inferred_status0 = ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_, vector_64_}); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)}, + {matrix_32_48_, vector_32_, matrix_32_48_}); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); + + auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + auto inferred_status2 = ShapeInference::InferConditionalShape( + r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, + {tuple_f32_v32}); + EXPECT_IS_OK(inferred_status2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_32_, vector_64_}); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("2 == branch_computations.size()")); + + auto inferred_status_error1 = ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, + {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + matrix_32_48_}); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("branch computation 1 must take 1 argument")); + + auto inferred_status_error2 = ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({r0s32}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, + {r0s32, vector_32_, vector_64_}); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("branch operand 2 must match the shape of the only " + "parameter of branch computation 2")); + + auto inferred_status_error3 = ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, + {vector_32_, vector_32_, vector_32_, vector_64_}); + EXPECT_FALSE(inferred_status_error3.ok()); + EXPECT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("the result of branch 0 computation and branch 3 " + "computation must have the same shape")); + + auto inferred_status_error4 = + ShapeInference::InferConditionalShape(r0s32, {}, {}); + EXPECT_FALSE(inferred_status_error4.ok()); + EXPECT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("!branch_computations.empty()")); } TEST_F(ShapeInferenceTest, BadSlice) { diff --git a/tensorflow/compiler/xla/service/sort_simplifier.cc b/tensorflow/compiler/xla/service/sort_simplifier.cc index 4a00e8d7b227f14d462ca53f695189f3f48754ee..122366a0f322a66963b364e1b19629cbd2d9aabe 100644 --- a/tensorflow/compiler/xla/service/sort_simplifier.cc +++ b/tensorflow/compiler/xla/service/sort_simplifier.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/sort_simplifier.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" + +#include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace { @@ -39,8 +42,7 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { return false; } - // Index 0 is the sorting key used by the sort HLO itself. - absl::flat_hash_set used_indices{0}; + absl::flat_hash_set used_indices; for (const HloInstruction* user : sort->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { // Can't analyse users other then get-tuple-element. @@ -49,15 +51,25 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { used_indices.insert(user->tuple_index()); } + // Also note which parameters are used by the comparator computation. + auto comparator = sort->to_apply(); + for (int64 i = 0; i < sort->operand_count() * 2; ++i) { + if (comparator->parameter_instruction(i)->user_count() > 0) { + // operand i corresponds to parameters 2 * i and 2 * i + 1 of the + // computation. + used_indices.insert(i / 2); + } + } + if (used_indices.size() == sort->operand_count()) { // All operands are used. return false; } - std::vector operands{sort->mutable_operand(0)}; - std::vector new_shapes{sort->operand(0)->shape()}; - for (int64 i = 1; i < sort->operand_count(); ++i) { - if (used_indices.count(i)) { + std::vector operands; + std::vector new_shapes; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.contains(i)) { operands.push_back(sort->mutable_operand(i)); new_shapes.push_back(sort->operand(i)->shape()); } @@ -68,6 +80,32 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { : ShapeUtil::MakeTupleShape(new_shapes); HloInstruction* new_sort = computation->AddInstruction( sort->CloneWithNewOperands(new_sort_shape, operands)); + absl::flat_hash_map> + replacements; + int64 parameter_number = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto* old_lhs_parameter = comparator->parameter_instruction(i * 2); + auto* old_rhs_parameter = comparator->parameter_instruction(i * 2 + 1); + if (used_indices.contains(i)) { + Shape scalar_shape = + ShapeUtil::MakeShape(sort->operand(i)->shape().element_type(), {}); + replacements[old_lhs_parameter] = HloInstruction::CreateParameter( + parameter_number, scalar_shape, + absl::StrCat("p.", parameter_number / 2, ".lhs")); + ++parameter_number; + replacements[old_rhs_parameter] = HloInstruction::CreateParameter( + parameter_number, scalar_shape, + absl::StrCat("p.", parameter_number / 2, ".rhs")); + ++parameter_number; + } else { + replacements[old_lhs_parameter] = nullptr; + replacements[old_rhs_parameter] = nullptr; + } + } + HloModule* module = sort->GetModule(); + HloComputation* new_compare = module->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements))); + new_sort->set_to_apply(new_compare); // Map from original get-tuple-element tuple index to new HLO instruction absl::flat_hash_map result_map; @@ -83,7 +121,8 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { } } } else { - result_map[0] = new_sort; + CHECK_EQ(used_indices.size(), 1); + result_map[*used_indices.begin()] = new_sort; } std::vector users(sort->users().begin(), sort->users().end()); diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc index cd05fcf830d32e8bac4f8b260d3dd143ab98ad7b..284d50952776bd21c8131cbec7de3e4b6692ebc5 100644 --- a/tensorflow/compiler/xla/service/sort_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -34,13 +34,21 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { const char* hlo_string = R"( HloModule permutation_sort - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} parameter(1) - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), - dimensions={1} - ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 - })"; + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); @@ -58,17 +66,27 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { const char* hlo_string = R"( HloModule permutation_sort - ENTRY sort_computation { - keys = f32[64,87] parameter(0) - values.0 = s32[64,87] parameter(1) - values.1 = u32[64,87] parameter(2) - sort = (f32[64,87], s32[64,87], u32[64,87]) sort( - keys, values.0, values.1), - dimensions={1} - gte.0 = f32[64,87] get-tuple-element(sort), index=0 - gte.1 = u32[64,87] get-tuple-element(sort), index=2 - ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) - })"; + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.2.lhs = u32[] parameter(4) + p.2.rhs = u32[] parameter(5) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,87] parameter(0) + values.0 = s32[64,87] parameter(1) + values.1 = u32[64,87] parameter(2) + sort = (f32[64,87], s32[64,87], u32[64,87]) sort( + keys, values.0, values.1), + dimensions={1}, to_apply=compare + gte.0 = f32[64,87] get-tuple-element(sort), index=0 + gte.1 = u32[64,87] get-tuple-element(sort), index=2 + ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); @@ -86,17 +104,57 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { const char* hlo_string = R"( HloModule permutation_sort - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} parameter(1) - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - })"; + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}, to_apply=compare + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); SortSimplifier simplifier; EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } + +TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + uint64 num_executions = 0; + do { + num_executions++; + } while (simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(num_executions, 2); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(1)))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae4ce32569a3e3aa56100837621994336445f088 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Looks for a iota operand that can be used as tie breaker in the computation. +// If no matching iota operand is found, a iota operand is added to Sort. The +// comparison computation is adjusted to break ties using the values from the +// iota operand. +StatusOr StableSortExpander::ExpandInstruction( + HloInstruction* instruction) { + auto* sort = Cast(instruction); + HloComputation* computation = sort->parent(); + + HloInstruction* expanded_sort = nullptr; + absl::flat_hash_set used_indices; + int64 iota_index = -1; + for (const HloInstruction* operand : sort->operands()) { + // We can only use the iota operand if it has an iota dimension which is the + // same as the dimension to sort. Also it should have an integral type that + // is large enough for the number of elements in the sort dimension. For + // now, we only allow S32, because we expect to find a S32 iota operand for + // all Sort ops which are created by TopK. + // TODO(b/122298745): Also support other types. + if (operand->opcode() == HloOpcode::kIota && + Cast(operand)->iota_dimension() == + sort->sort_dimension() && + operand->shape().element_type() == S32) { + iota_index = sort->operand_index(operand); + break; + } + } + + // If there is currently no iota operand which we could use for making the + // sort stable, we will have to add a new such operand. + if (iota_index == -1) { + Shape iota_shape = sort->operand(0)->shape(); + // We might need to use S64 if the number of elements in the sort dimension + // is bigger than 2^31 - 1. + // TODO(b/122298745): Handle Sort ops where S32 is too small for the number + // of elements in the sort dimension. + if (iota_shape.dimensions(sort->sort_dimension()) > + std::numeric_limits::max()) { + return Unimplemented( + "Stable sorting of more than 2^31-1 elements is not implemented"); + } + iota_shape.set_element_type(S32); + auto iota = computation->AddInstruction( + HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); + + // Create a new comparator. + auto comparator = sort->to_apply(); + absl::flat_hash_map> + replacements; + std::vector> extra_parameters; + std::vector extra_parameter_ptrs; + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".lhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2 + 1, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".rhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements), + extra_parameter_ptrs))); + + // Replace the original sort op. + std::vector new_operands(sort->operands().begin(), + sort->operands().end()); + new_operands.push_back(iota); + std::vector new_shapes = sort->operand_count() == 1 + ? std::vector{sort->shape()} + : sort->shape().tuple_shapes(); + new_shapes.push_back(iota_shape); + Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, new_operands)); + + // Add a "wrapper" around the new sort op to make sure we have the same + // shape as before. For the rank 1 case, we only need a GetTupleElement, + // otherwise we create a Tuple consisting of GetTupleElements of the new + // sort. + std::vector tuple_elements; + tuple_elements.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + sort->operand(i)->shape(), new_sort, i))); + } + expanded_sort = tuple_elements[0]; + if (tuple_elements.size() > 1) { + expanded_sort = computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + } + sort = Cast(new_sort); + iota_index = sort->operand_count() - 1; + } + + // Modify the computation to break ties using the iota operand. + auto comparator = sort->to_apply(); + std::vector instructions_postorder = + comparator->MakeInstructionPostOrder(); + absl::flat_hash_map replacements; + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second; + }; + HloInstruction* old_root = comparator->root_instruction(); + // The comparison computation gets 2 * n parameters (n being the number of + // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two + // different scalars of operand i of Sort which are to be compared. The + // comparison computation should induce a strict weak order, so if + // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to + // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the + // values to be compared are equivalent, and perform a tie-breaker comparison. + // + // We clone each instruction with at least one operand, but use as new + // operands of the instruction the replacements of the original operands. + // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This + // should make sure that the cloned root instruction gives the result of the + // comparison computation when being called with each scalar pair reversed. + // parameters corresponding to the iota operand. + for (int64 i = 0; i < comparator->num_parameters(); ++i) { + replacements[comparator->parameter_instruction(i)] = + comparator->parameter_instruction(i ^ 1); + } + HloInstruction* cloned_root = nullptr; + for (HloInstruction* inst : instructions_postorder) { + if (inst->operand_count() == 0) { + continue; + } + std::vector new_operands; + new_operands.reserve(inst->operand_count()); + for (HloInstruction* operand : inst->operands()) { + new_operands.push_back(replace(operand)); + } + auto new_instruction = + inst->CloneWithNewOperands(inst->shape(), new_operands); + replacements[inst] = new_instruction.get(); + if (inst == old_root) { + cloned_root = new_instruction.get(); + } + comparator->AddInstruction(std::move(new_instruction)); + } + CHECK_NE(cloned_root, nullptr); + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + HloInstruction* same = + comparator->AddInstruction(HloInstruction::CreateCompare( + scalar_pred, old_root, cloned_root, ComparisonDirection::kEq)); + HloInstruction* tie_breaker = + comparator->AddInstruction(HloInstruction::CreateCompare( + scalar_pred, comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1), + ComparisonDirection::kLt)); + HloInstruction* new_root = + comparator->AddInstruction(HloInstruction::CreateTernary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, + old_root)); + comparator->set_root_instruction(new_root); + + return expanded_sort; +} + +bool StableSortExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSort && + Cast(instruction)->is_stable(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.h b/tensorflow/compiler/xla/service/stable_sort_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..31b6fd92d25370218017c58072f1aa5e64df00c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which expands Sort ops that have the is_stable field set to true +// into equivalent Sort ops which guarantee stable sorting without relying on +// the is_stable field. +class StableSortExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "stable-sort-expander"; } + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..61fb4392a32b73e912b6878b0aceed0f4e88a140 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -0,0 +1,359 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using StableSortExpanderTest = HloTestBase; + +// Checks whether 'a' and 'b' are roots of equivalent computations, except that +// parameters 2 * i and 2 * i + 1 are switched. +bool IsSameComputationExceptParams(const HloInstruction* a, + const HloInstruction* b) { + if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) { + return false; + } + if (a->opcode() == HloOpcode::kParameter) { + // Check that parameters were switched. + return a->parameter_number() == (b->parameter_number() ^ 1); + } + // If the operation has no operands, it should actually be the same. + if (a->operand_count() == 0) { + return a == b; + } + // Otherwise recursively compare all operands. + for (int64 i = 0; i < a->operand_count(); ++i) { + if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) { + return false; + } + } + return true; +} + +// Check that the comparison computation has been modified to add a tie breaker +// using 'iota_parameter'. +void CheckComputationHasTieBreaker(const HloInstruction* root, + int64 iota_parameter) { + // With the tie breaker, the root instruction should be + // Select(Eq(Comp(), CompReverse()), Lt(), Comp()) + // with Comp() being the original comparison function, and CompReverse() being + // the copied comparison function where the parameters are reversed. Lt() is + // the tie breaker comparison using the Iota operand. + ASSERT_EQ(root->opcode(), HloOpcode::kSelect); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare); + ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq); + + // Check that the tie breaker instruction is correct. + EXPECT_THAT(root->operand(1), + GmockMatch(m::Lt(m::Parameter(iota_parameter * 2), + m::Parameter(iota_parameter * 2 + 1)))); + EXPECT_EQ(root->operand(2), root->operand(0)->operand(0)); + + // Check that Comp() and CompReverse() are equivalent except that + // CompReverse() has reversed parameters. + EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0), + root->operand(0)->operand(1))); +} + +TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortReuseIotaOperandComplicatedComparison) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + max = u32[] constant(2147483647) + zero = s32[] constant(0) + lhs.signed = s32[] bitcast-convert(p.0.lhs) + lhs.unsigned = u32[] bitcast-convert(p.0.lhs) + lhs.flipped = u32[] subtract(max, lhs.unsigned) + lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) + lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT + lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) + rhs.signed = s32[] bitcast-convert(p.0.rhs) + rhs.unsigned = u32[] bitcast-convert(p.0.rhs) + rhs.flipped = u32[] subtract(max, rhs.unsigned) + rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) + rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT + rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) + ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Tuple( + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0), + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, HonorIsStableFlag) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=false + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie()); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortDontReuseIotaOperandWrongDimension) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=0 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] compare(lhs, rhs), direction=LT + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] compare(lhs, rhs), direction=LT + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + ROOT neg = s32[64,8732]{1,0} negate(sort) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/1); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 15ef623cc7b2dbc31e9cba5c4783c39b8805a5aa..29a9d26ce2ccdd9fa4352b81c2be14950725124e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,8 +42,11 @@ TransferManager::GetPlatformTransferManagers() { return r; } +TransferManager::TransferMetadata::~TransferMetadata() {} + StatusOr TransferManager::TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer) { + se::Stream* stream, const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) { StatusOr ret; se::Stream* substream = stream->GetOrCreateSubStream(); @@ -54,11 +57,13 @@ StatusOr TransferManager::TransferLiteralFromDevice( tensorflow::Notification n; Status s; Literal literal(device_buffer.on_host_shape()); - TransferLiteralFromDevice(substream, device_buffer, literal, - [&](Status status) { - s = status; - n.Notify(); - }); + TransferLiteralFromDevice( + substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }, + transfer_metadata); n.WaitForNotification(); if (!s.ok()) { return s; @@ -68,25 +73,29 @@ StatusOr TransferManager::TransferLiteralFromDevice( Status TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - const MutableBorrowingLiteral& literal) { + const MutableBorrowingLiteral& literal, + const TransferMetadata* transfer_metadata) { se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); Status ret; tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, literal, - [&](Status status) { - ret = status; - n.Notify(); - }); + TransferLiteralFromDevice( + substream, device_buffer, literal, + [&](Status status) { + ret = status; + n.Notify(); + }, + transfer_metadata); n.WaitForNotification(); return ret; } Status TransferManager::TransferLiteralToDevice( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) { + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) { // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -94,14 +103,14 @@ Status TransferManager::TransferLiteralToDevice( substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); - TF_RETURN_IF_ERROR( - TransferLiteralToDeviceAsync(substream, literal, device_buffer)); + TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync( + substream, literal, device_buffer, transfer_metadata)); return substream->BlockHostUntilDone(); } StatusOr TransferManager::TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source) { + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, + const TransferMetadata* transfer_metadata) { StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't @@ -113,11 +122,13 @@ StatusOr TransferManager::TransferArrayFromDevice( tensorflow::Notification n; Literal literal(shape); Status s; - TransferArrayFromDevice(substream, shape, source, literal, - [&](Status status) { - s = status; - n.Notify(); - }); + TransferArrayFromDevice( + substream, shape, source, literal, + [&](Status status) { + s = status; + n.Notify(); + }, + transfer_metadata); n.WaitForNotification(); if (!s.ok()) { return s; @@ -127,20 +138,23 @@ StatusOr TransferManager::TransferArrayFromDevice( Status TransferManager::TransferArrayToDevice( se::Stream* stream, const LiteralSlice& literal, - const se::DeviceMemoryBase& dest) { + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata) { // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); - TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest)); + TF_RETURN_IF_ERROR( + TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata)); return substream->BlockHostUntilDone(); } Status TransferManager::TransferArrayToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const se::DeviceMemoryBase& dest) { + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(on_device_shape.IsArray()) << "On-device representation of " @@ -156,12 +170,14 @@ Status TransferManager::TransferArrayToDeviceAsync( stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); - return TransferLiteralToDevice(stream, literal, shaped_buffer); + return TransferLiteralToDevice(stream, literal, shaped_buffer, + transfer_metadata); } void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - const MutableBorrowingLiteral& literal, std::function done) { + const MutableBorrowingLiteral& literal, std::function done, + const TransferMetadata* transfer_metadata) { if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", @@ -179,7 +195,7 @@ void TransferManager::TransferArrayFromDevice( stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); return TransferLiteralFromDevice(stream, shaped_buffer, literal, - std::move(done)); + std::move(done), transfer_metadata); } /* static */ void TransferManager::RegisterTransferManager( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 43a50487c636da75224547286a31625db3f91330..2a934563b0b0a0cc095ca6a964a8f00de39f161b 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -52,16 +52,38 @@ class TransferManager { return host_shape; } + // Base class for specifying platform specific transfer metadata that can be + // used to tell the underlying implementation to perform specific optimization + // to a transfer. Actual metadata passed to supported transfer methods should + // subclass this class. + class TransferMetadata { + public: + virtual ~TransferMetadata() = 0; + }; // Returns a literal containing the data held in the given ShapedBuffer // using the provided executor. This operation is performed synchronously // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. virtual StatusOr TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer); + se::Stream* stream, const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata); + StatusOr TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer) { + return TransferLiteralFromDevice(stream, device_buffer, nullptr); + } virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - const MutableBorrowingLiteral& literal); + const MutableBorrowingLiteral& literal, + const TransferMetadata* transfer_metadata); + Status TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr); + } // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -72,10 +94,20 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice(se::Stream* stream, - const ShapedBuffer& device_buffer, - MutableBorrowingLiteral literal, - std::function done) = 0; + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. + virtual void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* transfer_metadata) = 0; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) { + return TransferLiteralFromDevice(stream, device_buffer, literal, done, + nullptr); + } // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -85,9 +117,18 @@ class TransferManager { // This operation is performed synchronously without waiting for any other // operation on a stream to complete. This function should be avoided in favor // of the asynchronous version below. - virtual Status TransferLiteralToDevice(se::Stream* stream, - const LiteralSlice& literal, - const ShapedBuffer& device_buffer); + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. + virtual Status TransferLiteralToDevice( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata); + Status TransferLiteralToDevice(se::Stream* stream, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + return TransferLiteralToDevice(stream, literal, device_buffer, nullptr); + } // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -102,26 +143,44 @@ class TransferManager { // immediately after this function returns, however their constituent buffers // on both host and device must remain valid until the enqueued transfer has // completed on 'stream'. + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. virtual Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) = 0; + Status TransferLiteralToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + return TransferLiteralToDeviceAsync(stream, literal, device_buffer, + nullptr); + } // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, - const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - const MutableBorrowingLiteral& literal, - std::function done); - - Status TransferArrayToDeviceAsync(se::Stream* stream, - const LiteralSlice& literal, - const se::DeviceMemoryBase& dest); - StatusOr TransferArrayFromDevice(se::Stream* stream, - const Shape& shape, - const se::DeviceMemoryBase& source); + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. + Status TransferArrayToDevice( + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata = nullptr); + void TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, std::function done, + const TransferMetadata* transfer_metadata = nullptr); + + Status TransferArrayToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata = nullptr); + StatusOr TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const TransferMetadata* transfer_metadata = nullptr); // Transfers the given literal into the Infeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc similarity index 81% rename from tensorflow/compiler/xla/client/lib/triangular_solve.cc rename to tensorflow/compiler/xla/service/triangular_solve_expander.cc index c2f31742e9eff9f325fb71160b4ec3aea928d15e..790074ab834cd057a54b55c150a208e2b49d67b6 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include #include @@ -33,6 +33,8 @@ limitations under the License. namespace xla { +namespace { + // Get the diagonal blocks of the coefficient matrix XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { XlaBuilder* builder = a.builder(); @@ -140,9 +142,7 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); - auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); + auto ones = FullLike(diags, 1); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -347,9 +347,10 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, }); } -XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool transpose_a, bool conjugate_a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -402,12 +403,29 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, block_size); } + block_size = std::max( + int64{1}, std::min(block_size, ShapeUtil::GetDimension(a_shape, -1))); + if (ShapeUtil::IsZeroElementArray(b_shape)) { // The output has the same shape as 'b', and since the output has zero // elements, any such array will do. return b; } + // TODO(phawkins): consider pushing triangle masking into + // InvertDiagonalBlocks. + if (unit_diagonal) { + // Mask everything but the subdiagonal/superdiagonal elements. + a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a)) + : Select(TriangleMask(a, 0), ZerosLike(a), a); + int64 k = ShapeUtil::GetDimension(a_shape, -1); + a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k), + /*broadcast_dimensions=*/{ndims - 2, ndims - 1}); + } else { + // Mask off the ignored elements of the triangular matrix a. + a = Triangle(a, lower); + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); @@ -415,11 +433,6 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a, precision); - // Mask off the ignored elements of the triangular matrix a. - // TODO(phawkins): it would probably be preferable to perform this masking - // block by block inside SolveWithInvertedDiagonalBlocks. - a = Triangle(a, lower); - // We now find the solution using GEMMs auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, @@ -429,4 +442,66 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } +} // namespace + +bool TriangularSolveExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kTriangularSolve; +} + +StatusOr TriangularSolveExpander::ExpandInstruction( + HloInstruction* instruction) { + const TriangularSolveOptions& options = + instruction->triangular_solve_options(); + const string name = absl::StrFormat( + "xla.triangular_solve_%s_%s_%s_%s_%s_%s", + instruction->operand(0)->shape().ToString(), + instruction->operand(1)->shape().ToString(), + options.left_side() ? "left" : "right", + options.lower() ? "lower" : "upper", + TriangularSolveOptions_Transpose_Name(options.transpose_a()), + options.unit_diagonal() ? "unit" : "nonunit"); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // We do something unusual here: we build the computation using the + // XlaBuilder API, which is nominally an XLA client API. We do this because + // the external APIs for building complicated computations (XlaBuilder) + // are much more ergonomic than the internal ones. As it turns out, + // XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b"); + bool transpose_a = + options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE; + bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT; + + BuildTriangularSolve(a, b, options.left_side(), options.lower(), + transpose_a, conjugate_a, options.unit_diagonal(), + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST); + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..be2374ef8c86254d8db5ac1acac385aa0de7d3a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class TriangularSolveExpander : public OpExpanderPass { + public: + absl::string_view name() const override { + return "triangular_solve_expander"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 5e505aaf02f157d0cba9dff42b1a9b89a6691504..cc82e9bb0287b5a586fb21fee35d3124a6d6f121 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -699,6 +699,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. // (5) The 'user' of 'operand' is Sort, and it is the only user. +// (6) The 'user' of 'operand' is TriangularSolve, it is the second operand, +// and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -779,6 +781,14 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; } + if (user->opcode() == HloOpcode::kTriangularSolve) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 1; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index fd5759e44230db8223822d6ae0f511027f73d8f9..61b98673cbef0d1d9e588f2888a07373405feb9e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" @@ -933,8 +934,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { HloInstruction::CreateParameter(0, in_shape, "param0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + auto result = builder.AddInstruction(HloInstruction::CreateCompare( + out_shape, param0, param1, ComparisonDirection::kEq)); BuildModuleAndRunAnalysis(builder.Build()); @@ -1065,14 +1066,17 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto sort = - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, + &builder, module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); EXPECT_TRUE( points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); @@ -1080,6 +1084,7 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); Shape values_shape = ShapeUtil::MakeShape(F32, {8}); @@ -1087,11 +1092,14 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { HloInstruction::CreateParameter(0, keys_shape, "keys")); auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); - auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, - {values})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); // The buffer for the keys can be shared with the first tuple entry. EXPECT_TRUE( @@ -1177,8 +1185,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); return builder.Build(); }; diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index c93a9ba3176002a34fe84a29e62075de4d19168f..ffa89b6a797673240e139b90052351cd825bf09b 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -14,15 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "absl/base/casts.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" namespace xla { using absl::nullopt; using absl::optional; +namespace m = match; // Finds and returns the non-constant operand in instr. // @@ -48,41 +51,36 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, const HloInstruction* gte_operand) { VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " << gte_operand->ToString() << ")"; - optional tuple_idx; + + // Among the operands of `instr`, find one that is a get-tuple-element op. + auto gte_it = c_find_if(instr->operands(), [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kGetTupleElement; + }); + if (gte_it == instr->operands().end()) { + VLOG(2) << "instr does not have a gte operand."; + return nullopt; + } + + // All operands of `instr` must be either constants or of the form + // get-tuple-element(gte_operand, tuple_idx) + // for the same value tuple_idx. + int64 tuple_idx = (*gte_it)->tuple_index(); for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - // Look through copies. - // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU - // would run before copy insertion. - if (operand->opcode() == HloOpcode::kCopy) { - operand = operand->operand(0); - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); + if (!Match(operand, m::Constant()) && + !Match(operand, + m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) { + VLOG(2) + << "instr uses something other than a constant or gte(gte_operand, " + << tuple_idx << "): " << operand->ToString(); return nullopt; } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); } return tuple_idx; } // Tries to get the tuple index of the induction variable of a while loop. // -// Checks that the loop condition and root both plumb the induction variable +// Checks that the loop condition and body both plumb the induction variable // through the same tuple index, and that they both apply exactly one op to the // induction variable before deciding whether to do another loop iteration (in // the loop condition's case) or packing the induction variable into the result @@ -98,8 +96,7 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, // root = tuple(..., inc, ...) // inc is N'th operand of tuple(). // // If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { +optional GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); VLOG(2) << "Finding induction variable for loop " << while_op->ToShortString(); @@ -166,8 +163,171 @@ static optional GetLoopInductionVarTupleIdx( return indvar_tuple_idx; } +// Converts the given literal to a scalar int64, if possible. +// +// Fails if the literal is not an integral type or if the value it contains +// cannot be represented in an int64. +static optional LiteralAsScalarInt64(const Literal& l) { + if (!ShapeUtil::IsEffectiveScalar(l.shape())) { + VLOG(2) << "literal is not an effective scalar: " << l.ToString(); + return nullopt; + } + switch (l.shape().element_type()) { + case S8: + return l.GetFirstElement(); + case S16: + return l.GetFirstElement(); + case S32: + return l.GetFirstElement(); + case S64: + return l.GetFirstElement(); + case U8: + return l.GetFirstElement(); + case U16: + return l.GetFirstElement(); + case U32: + return l.GetFirstElement(); + case U64: { + uint64 v = l.GetFirstElement(); + if (v > static_cast(std::numeric_limits::max())) { + VLOG(2) << "uint64 literal is out of range for int64: " << v; + return nullopt; + } + return v; + } + default: + VLOG(2) << "literal is of non-integral type " << l.shape().ToString(); + return nullopt; + } +} + +// Computes a + b, returning nullopt if it overflows. +optional CheckedAdd(int64 a, int64 b) { + // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a + // different sign, see Hacker's Delignt 2nd Ed. pp 28. + uint64 aa = absl::bit_cast(a); + uint64 bb = absl::bit_cast(b); + int64 result = absl::bit_cast(aa + bb); + if (a >= 0 == b >= 0 && result >= 0 != a >= 0) { + return nullopt; + } + return result; +} + +// Computes a - b, returning nullopt if it overflows. +optional CheckedSubtract(int64 a, int64 b) { + uint64 aa = absl::bit_cast(a); + uint64 bb = absl::bit_cast(b); + int64 result = absl::bit_cast(aa - bb); + // Overflow occurred iff `a` and `b` have different signs and the sign of + // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29. + if (a >= 0 != b >= 0 && result >= 0 == b >= 0) { + return nullopt; + } + return result; +} + +// Check if +// - `i` is initialized to a scalar constant K (namely, `indvar_init`), +// - the while condition does `i < N` or `i <= N`, and +// - the while body does `i++`. +// If so, it's trivial to compute the loop bound. +static optional PatternMatchLoopTripCount(HloInstruction* while_op, + int64 indvar_tuple_idx, + const Literal& indvar_init) { + // First, find the scalar constant K that `i` is initialized to. + optional indvar_init_val = LiteralAsScalarInt64(indvar_init); + if (!indvar_init_val) { + VLOG(2) << "Pattern-match failed: induction variable init is not a " + "constant scalar representable as an int64: " + << indvar_init.ToString(); + return nullopt; + } + + // Check that `i` goes as `i++` in the while body. + // + // TODO(jlebar): We could also handle i-- and other idioms. + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + if (!Match(while_body_indvar_update, + m::AddAnyOrder(m::Op().Is(while_body_indvar), + m::ConstantEffectiveScalar(1)))) { + VLOG(2) << "Pattern-match failed: induction variable does not go as i++: " + << while_body_indvar_update->ToString(); + return nullopt; + } + + // Check that we do op(i, N) or op(N, i) as the while condition. Capture the + // value N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + HloInstruction* while_cond_bound = nullptr; + if (!Match(while_cond_root, + m::Op().WithBinaryOperandsAnyOrder( + m::Op().Is(while_cond_indvar), + m::ConstantEffectiveScalar(&while_cond_bound)))) { + VLOG(2) << "Pattern-match failed: while condition is not of the form " + "op(i, N) or op(N, i)."; + return nullopt; + } + // Note: If this succeeds, the constant `N` is representable as an int64 -- + // that is, if it's an XLA U64, it fits within an int64. + optional while_cond_bound_val = + LiteralAsScalarInt64(while_cond_bound->literal()); + if (!while_cond_bound_val) { + VLOG(2) << "Pattern-match failed: while condition induction variable is " + "not a constant scalar representable as an int64."; + return nullopt; + } + + // Handle `i = K; i < N; ++i`. + if (Match(while_cond_root, + m::Op() + .WithComparisonDirection(ComparisonDirection::kLt) + .WithOperand(0, m::Op().Is(while_cond_indvar)))) { + VLOG(2) << "Pattern-match succeeded: loop condition is i < N: " + << while_cond_root->ToString(); + optional trips = + CheckedSubtract(*while_cond_bound_val, *indvar_init_val); + if (trips) { + return std::max(int64{0}, *trips); + } else { + VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX."; + return nullopt; + } + } + + // Handle `i = K; i <= N; ++i`. + if (Match(while_cond_root, + m::Op() + .WithComparisonDirection(ComparisonDirection::kLe) + .WithOperand(0, m::Op().Is(while_cond_indvar)))) { + VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: " + << while_cond_root->ToString(); + optional trips = + CheckedSubtract(*while_cond_bound_val, *indvar_init_val); + if (!trips) { + VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX"; + return nullopt; + } + trips = CheckedAdd(*trips, 1); + if (!trips) { + VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX"; + return nullopt; + } + return std::max(0, *trips); + } + + VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: " + << while_cond_root->ToString(); + return nullopt; +} + optional ComputeWhileLoopTripCount(HloInstruction* while_op, - int64 max_value_returned) { + int64 max_brute_force_iters) { VLOG(2) << "Getting trip count for loop " << while_op->ToString(); // The loop's induction variable is found at @@ -188,23 +348,30 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); + VLOG(2) << "Couldn't evaluate induction variable init, " + << indvar_init_result.status() << ", " << indvar_init->ToString(); return nullopt; } + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); + + // First, try to pattern-match. + if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx, + indvar_iter_val)) { + return trip_count; + } + // If our pattern-match failed, try brute-forcing the loop trip count. auto* while_body = while_op->while_body(); auto* while_body_indvar_update = while_body->root_instruction()->operand(*indvar_tuple_idx); auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - // The initial value of the induction variable. - Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); - for (int64 trip_count = 0; trip_count != max_value_returned + 1; + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + + for (int64 trip_count = 0; trip_count != max_brute_force_iters + 1; ++trip_count) { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); StatusOr result = evaluator.EvaluateWithSubstitutions( while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index ac69a727bd6b403672a676400993fb7d8afc0a55..10b644599742a17b9d7ea08284abbd1c03b3a08d 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -22,16 +22,24 @@ limitations under the License. namespace xla { // Returns the precise trip count of the loop if it's statically known, -// nullopt otherwise. max_value_returned limits the number of steps that are -// evaluated while trying to brute force a loop trip count, trip counts larger -// than max_value_returned result in nullopt. -absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, - int64 max_value_returned = 128); +// nullopt otherwise. +// +// max_brute_force_iters limits the number of steps that are evaluated while +// trying to brute force a loop trip count. trip counts larger than +// max_brute_force_iters may be returned if we can pattern-match the loop +// condition. +absl::optional ComputeWhileLoopTripCount( + HloInstruction *while_op, int64 max_brute_force_iters = 128); // Returns an upper bound on the trip count of the loop if it's statically // known, nullopt otherwise. absl::optional ComputeWhileLoopTripCountUpperBound( HloInstruction *while_op); + +// Returns the tuple index of the loop induction variable if there is such an +// induction variable detected. Otherwise returns nullopt. +absl::optional GetLoopInductionVarTupleIdx( + const HloInstruction *while_op); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc index 1da0fbeac89a93eaaef893e5f25dd3b87cc1d5d5..5a5dc742c0304d28d7fca3937fcdef5e7aea7ce9 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -40,7 +40,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { @@ -71,7 +71,7 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { @@ -104,7 +104,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] less-than(gte, const) + ROOT result = pred[] compare(gte, const), direction=LT } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 3bcf5c38309a86e9e3cab3268f3f065005f7a923..8ab5e433e0f2893e93d7a83e4114352c8a3d82fd 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -260,7 +260,7 @@ condition { p_cond = (f32[],f32[]) parameter(0) p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 - ROOT result = pred[] less-than(p_cond.0, p_cond.1) + ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT } ENTRY entry { @@ -300,7 +300,7 @@ condition { p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 - ROOT result = pred[] less-than(p_c.0, p_c.1.1) + ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT } ENTRY entry { @@ -342,7 +342,7 @@ condition { p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - ROOT result = pred[] less-than(p_cond.0, p_cond.1) + ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT } ENTRY entry { @@ -389,10 +389,10 @@ condition { p_cond = (f32[],f32[],f32[]) parameter(0) p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - lt.0 = pred[] less-than(p_cond.0, p_cond.2) + lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT ROOT result = pred[] and(lt.0, lt.1) } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 3587c016b4420163a607422b1acc838646fab83a..f0bb646d9c0e3f563641f459b9e30a5133191305 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -556,7 +556,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=3 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 09d54095718029541a7a25aa62f9a2e9a177960d..386ffb995477ff1b4aef73080b6a6fd988dd1980 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -180,7 +180,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // replace the old instructions after we remove unused elements from the while // tuple. auto make_while_computation_replacements = [&](const HloComputation* comp) { - std::unordered_map> + absl::flat_hash_map> replacements; auto* param = comp->parameter_instruction(0); @@ -232,7 +232,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_cond->CloneWithReplacements( make_while_computation_replacements(while_cond)); - std::unordered_map> + absl::flat_hash_map> while_body_replacements = make_while_computation_replacements(while_body); std::vector new_while_body_root_elems; new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index ecca76b1e86d833c73fbb9bad6a341660a7d2669..65175fb6ab38a3b07ce81ec6b24a070529305f55 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -72,7 +72,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant({{LOOP_BOUND}}) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -107,7 +107,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 - ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) + ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT } ENTRY SimpleLoopWithIndirectLoopBound { constant.3 = s32[] constant(42) @@ -237,7 +237,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { NonTupleShapedLoop.condition { loop_var = s32[] parameter(0) constant = s32[] constant(100) - ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant) + ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT } ENTRY INonTupleShapedLoop { constant.2 = s32[] constant(42) @@ -387,7 +387,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { param0 = (s32[], s32[], s32[]) parameter(0) get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0), index=2 - ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element) + ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ } ENTRY RemoveUnusedOperands { x = s32[] parameter(0) @@ -471,7 +471,7 @@ TEST_F(WhileLoopSimplifierTest, loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(44) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -503,7 +503,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(47) - ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -679,7 +679,7 @@ const char* const kSimpleMergeInductionVariablesModule = R"( b = TYPE[] get-tuple-element(param), index=1 sum = TYPE[] power(a, b) ten = TYPE[] constant(10) - ROOT cond = pred[] less-than(sum, ten) + ROOT cond = pred[] compare(sum, ten), direction=LT } ENTRY Loop { a = TYPE[] constant(10) diff --git a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator.cc b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator.cc new file mode 100644 index 0000000000000000000000000000000000000000..03bb6792fe74e9eb90278cbd4152e609a7904c80 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator.cc @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +StatusOr WhileLoopTripCountAnnotator::Run(HloModule* module) { + bool changed = false; + for (const HloComputation* comp : module->computations()) { + for (HloInstruction* instr : comp->instructions()) { + if (instr->opcode() != HloOpcode::kWhile) { + continue; + } + if (auto trip_count = ComputeWhileLoopTripCount(instr)) { + WhileLoopBackendConfig config; + config.mutable_known_trip_count()->set_n(*trip_count); + TF_RETURN_IF_ERROR(instr->set_backend_config(config)); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h new file mode 100644 index 0000000000000000000000000000000000000000..7cda2f10cefba821bccc1b5d3b5a33cd7a68e004 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Pass that annotates `while` loops with known trip counts. +// +// The annotation is stored as a backend-config on the while loop node. +// +// This pass should run after all passes that might semantically modify a while +// loop, e.g. by unrolling it. Otherwise, a loop could end up with a +// backend-config that doesn't match its true trip-count. +// +// This pass does some pattern-matching on loop bodies and conditions, so it +// should run after most HLO simplifications and before fusion and layout +// assignment, which make pattern matching much more difficult by e.g. +// introducing `copy` nodes. +class WhileLoopTripCountAnnotator : public HloModulePass { + public: + ~WhileLoopTripCountAnnotator() override {} + absl::string_view name() const override { + return "while-loop-trip-count-annotator"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1e18bbdef6b84332e0187ae1bf73b67396eaf80 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc @@ -0,0 +1,207 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class TripCountAnnotatorTest : public HloTestBase {}; + +TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) { + const char* kModuleStr = R"( + HloModule test + Body { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + one = s32[] constant(1) + i_plus_one = s32[] add(i, one) + ROOT tuple = (s32[]) tuple(i_plus_one) + } + + Cond { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(i, trip_count), direction=LT + } + + ENTRY test { + i_start = s32[] constant(0) + initial_tuple = (s32[]) tuple(i_start) + ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + WhileLoopTripCountAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get())); + ASSERT_TRUE(changed); + + TF_ASSERT_OK_AND_ASSIGN(auto config, + m->entry_computation() + ->root_instruction() + ->backend_config()); + EXPECT_EQ(10, config.known_trip_count().n()); +} + +TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) { + const char* kModuleStr = R"( + HloModule test + Body { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + one = s32[] constant(1) + i_plus_one = s32[] add(i, one) + ROOT tuple = (s32[]) tuple(i_plus_one) + } + + Cond { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + trip_count = s32[] constant(1000000) + ROOT done = pred[] compare(i, trip_count), direction=LT + } + + ENTRY test { + i_start = s32[] constant(0) + initial_tuple = (s32[]) tuple(i_start) + ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + WhileLoopTripCountAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get())); + ASSERT_TRUE(changed); + + TF_ASSERT_OK_AND_ASSIGN(auto config, + m->entry_computation() + ->root_instruction() + ->backend_config()); + EXPECT_EQ(1000000, config.known_trip_count().n()); +} + +TEST_F(TripCountAnnotatorTest, NonzeroStart) { + const char* kModuleStr = R"( + HloModule test + Body { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + one = s32[] constant(1) + i_plus_one = s32[] add(i, one) + ROOT tuple = (s32[]) tuple(i_plus_one) + } + + Cond { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + trip_count = s32[] constant(1000000) + ROOT done = pred[] compare(i, trip_count), direction=LT + } + + ENTRY test { + i_start = s32[] constant(10) + initial_tuple = (s32[]) tuple(i_start) + ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + WhileLoopTripCountAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get())); + ASSERT_TRUE(changed); + + TF_ASSERT_OK_AND_ASSIGN(auto config, + m->entry_computation() + ->root_instruction() + ->backend_config()); + EXPECT_EQ(999990, config.known_trip_count().n()); +} + +TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) { + const char* kModuleStr = R"( + HloModule test + Body { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + one = s32[] constant(1) + i_plus_one = s32[] add(i, one) + ROOT tuple = (s32[]) tuple(i_plus_one) + } + + Cond { + param = (s32[]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + trip_count = s32[] constant(1000000) + ROOT done = pred[] compare(i, trip_count), direction=LE + } + + ENTRY test { + i_start = s32[] constant(10) + initial_tuple = (s32[]) tuple(i_start) + ROOT while = (s32[]) while(initial_tuple), condition=Cond, body=Body + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + WhileLoopTripCountAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get())); + ASSERT_TRUE(changed); + + TF_ASSERT_OK_AND_ASSIGN(auto config, + m->entry_computation() + ->root_instruction() + ->backend_config()); + EXPECT_EQ(999991, config.known_trip_count().n()); +} + +TEST_F(TripCountAnnotatorTest, Int64Overflow) { + // for(i = INT64_MIN; i < INT64_MAX; ++i) + // + // We store the trip count as an int64, so this loop is unanalyzable. + const char* kModuleStr = R"( + HloModule test + Body { + param = (s64[]) parameter(0) + i = s64[] get-tuple-element(param), index=0 + one = s64[] constant(1) + i_plus_one = s64[] add(i, one) + ROOT tuple = (s64[]) tuple(i_plus_one) + } + + Cond { + param = (s64[]) parameter(0) + i = s64[] get-tuple-element(param), index=0 + trip_count = s64[] constant(9223372036854775807) // 2^63-1 + ROOT done = pred[] compare(i, trip_count), direction=LE + } + + ENTRY test { + i_start = s64[] constant(-9223372036854775808) // -2^63 + initial_tuple = (s64[]) tuple(i_start) + ROOT while = (s64[]) while(initial_tuple), condition=Cond, body=Body + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + WhileLoopTripCountAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index d77386497a14b3e52be2ea7f655fa330f60e4a97..b6f65c763ea3c489f385ffe3e49b53a36a2877ba 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -166,7 +166,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, TF_ASSIGN_OR_RETURN( HloInstruction * compare, - MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); + MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant)); cond_computation->set_root_instruction(compare); return std::move(cond_computation); } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index a76fbf3f66adae0a5e5357178bc576bbc74701c7..661b7aa7d99ca549da6a509812760a1665d60919 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -37,9 +37,15 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { } if (comp->IsRemovable(instruction) && ShapeUtil::IsZeroElementArray(instruction->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal. + Shape shape = instruction->shape(); + if (!LayoutUtil::HasLayout(shape)) { + LayoutUtil::SetToDefaultLayout(&shape); + } TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( - instruction, HloInstruction::CreateConstant( - Literal::CreateFromShape(instruction->shape())))); + instruction, + HloInstruction::CreateConstant(Literal::CreateFromShape(shape)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index a546a6d39cc55d1f327b8449c7d26cd4c95dbf98..572a79609e7a912277af0fd2ba43f9a1e14a6f52 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -82,5 +82,18 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) { EXPECT_FALSE(changed); } +TEST_F(ZeroSizedHloEliminationTest, ZeroSizedInstructionWithoutLayoutFolded) { + Shape op_shape = ShapeUtil::MakeShape(F32, {4, 0}); + op_shape.clear_layout(); + HloInstruction* param1 = builder_.AddInstruction( + HloInstruction::CreateParameter(1, op_shape, "zero sized param 1")); + HloInstruction* param2 = builder_.AddInstruction( + HloInstruction::CreateParameter(2, op_shape, "zero sized param 2")); + builder_.AddInstruction( + HloInstruction::CreateBinary(op_shape, HloOpcode::kAdd, param1, param2)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 1a029efe8543b5433ef5fe7923e1e804019ba0c0..94854047e530babe2234381a615aeb805f0d5933 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -34,8 +34,12 @@ Shape::Shape(const ShapeProto& shape_proto) { // instead of a constructor. if (shape_proto.dimensions_size() != shape_proto.is_dynamic_dimension_size()) { - LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " - "fields does not match number of dimension fields"; + if (shape_proto.is_dynamic_dimension_size() != 0) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } else { + LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty"; + } } int64 num_dynamic_dimension_fields = std::min( shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); @@ -112,6 +116,68 @@ void Shape::DeleteDimension(int64 dim_to_delete) { } } +bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { + if (lhs.IsTuple()) { + return rhs.IsTuple() && + absl::c_equal( + lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { return (*this)(l, r); }); + } else if (!lhs.IsArray()) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return lhs.element_type() == rhs.element_type(); + } + + if (!rhs.IsArray()) { + return false; + } + + if (!ignore_element_type_) { + if ((ignore_fp_precision_ && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + } + + if (!ignore_layout_) { + if (lhs.layout().format() != rhs.layout().format()) { + VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; + return false; + } + if (LayoutUtil::IsDenseArray(lhs)) { + Layout::Equal equal; + if (ignore_tiles_in_layout_) { + equal.IgnoreTiles(); + } + if (ignore_element_size_in_layout_) { + equal.IgnoreElementSize(); + } + if (!equal(lhs.layout(), rhs.layout())) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + } + } + + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + + if (!ignore_dynamic_dimension_) { + for (int i = 0; i < lhs.rank(); ++i) { + if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { + VLOG(3) + << "CompareShapes: lhs and rhs have different dynamic dimensions."; + return false; + } + } + } + return true; +} + std::ostream& operator<<(std::ostream& out, const Shape& shape) { out << shape.ToString(/*print_layout=*/true); return out; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index dc4cdc31a74d43471b72a71d9d436408e0e62deb..78cea83c6d71e5965f10cd3a917ffccabd630462 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -72,6 +72,10 @@ class Shape { dynamic_dimensions_[dimension] = is_dynamic; } + const std::vector& dynamic_dimensions() const { + return dynamic_dimensions_; + } + // Add dimension_upper_bound(). // Removes the given dimension form the shape. Layout, if it exists, is @@ -138,6 +142,59 @@ class Shape { string ShortDebugString() const { return ToProto().ShortDebugString(); } string DebugString() const { return ToProto().DebugString(); } + // Equal is a configurable functor to check the equality of two shapes. + // + // Examples: + // + // - Comparing two shapes ignoring their layout difference: + // Equal().IgnoreLayout()(shape1, shape2); + // + // - Comparing two shapes ignoring their layout and element type difference: + // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); + class Equal { + public: + Equal() = default; + + bool operator()(const Shape& lhs, const Shape& rhs); + + Equal& IgnoreLayout() { + ignore_layout_ = true; + return *this; + } + Equal& IgnoreTilesInLayout() { + ignore_tiles_in_layout_ = true; + return *this; + } + Equal& IgnoreElementSizeInLayout() { + ignore_element_size_in_layout_ = true; + return *this; + } + Equal& IgnoreElementType() { + ignore_element_type_ = true; + return *this; + } + Equal& IgnoreFpPrecision() { + ignore_fp_precision_ = true; + return *this; + } + Equal& IgnoreDynamicDimension() { + ignore_dynamic_dimension_ = true; + return *this; + } + + private: + bool ignore_layout_ = false; + bool ignore_tiles_in_layout_ = false; + bool ignore_element_size_in_layout_ = false; + bool ignore_element_type_ = false; + bool ignore_fp_precision_ = false; + bool ignore_dynamic_dimension_ = false; + }; + + // Test that all fields of the shape are the same, equivalent to Equal(). + bool operator==(const Shape& other) const { return Equal()(*this, other); } + bool operator!=(const Shape& other) const { return !(*this == other); } + private: // The element type of this shape (tuple, array, etc). PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc index 55ce5fe884e98e474253be9ef694f1b8137b4b01..dbdafcc0a1f7348af8394598363d570118cdd87e 100644 --- a/tensorflow/compiler/xla/shape_test.cc +++ b/tensorflow/compiler/xla/shape_test.cc @@ -35,6 +35,8 @@ class ShapeTest : public ::testing::Test { const Shape opaque_ = ShapeUtil::MakeOpaqueShape(); const Shape token_ = ShapeUtil::MakeTokenShape(); const Shape scalar_ = ShapeUtil::MakeShape(F32, {}); + const Shape scalar_with_tile_ = + ShapeUtil::MakeShapeWithLayout(F32, {}, {}, {Tile({256})}); const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2}); const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); const Shape tuple_ = @@ -66,6 +68,8 @@ TEST_F(ShapeTest, ShapeToString) { EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true)); EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true)); + EXPECT_EQ("f32[]{:T(256)}", + scalar_with_tile_.ToString(/*print_layout=*/true)); EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true)); EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", @@ -85,6 +89,24 @@ TEST_F(ShapeTest, DynamicShapeToString) { EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); } +TEST_F(ShapeTest, EqualityTest) { + // Different layouts. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {0, 1})); + + // Different dims. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {44, 23}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); + + // Different elements. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(S32, {44, 23}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); + + // Equal shapes. + EXPECT_EQ(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); +} + TEST_F(ShapeTest, IsStatic) { EXPECT_TRUE(opaque_.is_static()); EXPECT_TRUE(token_.is_static()); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index f1ef202da0971926f608b4594099d8b4c3a6baae..acaa9cae7c2c2745a3ed413ca9f00b5bf0187a0c 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -85,82 +86,12 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { } namespace { - -// Recursive helper for comparing the equality of two shapes. Returns true if -// the shapes are the same. If compare_layouts is true, then layouts must also -// match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, - bool ignore_fp_precision) { - if ((ignore_fp_precision && - !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || - (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } - - if (lhs.IsTuple()) { - return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); - } else if (!lhs.IsArray()) { - // Non-tuple, non-array tupes such as opaque and token types are trivially - // the same. - return true; - } - - if (compare_layouts) { - if (lhs.layout().format() != rhs.layout().format()) { - VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; - return false; - } - if (LayoutUtil::IsDenseArray(lhs)) { - if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - - const auto& lhs_tiles = lhs.layout().tiles(); - const auto& rhs_tiles = rhs.layout().tiles(); - if (lhs_tiles.size() != rhs_tiles.size()) { - return false; - } - for (int64 i = 0; i < lhs_tiles.size(); i++) { - if (!absl::c_equal(lhs_tiles[i].dimensions(), - rhs_tiles[i].dimensions())) { - return false; - } - } - - if (lhs.layout().element_size_in_bits() != - rhs.layout().element_size_in_bits()) { - return false; - } - } - } - - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; - } - - for (int i = 0; i < lhs.rank(); ++i) { - if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { - VLOG(3) - << "CompareShapes: lhs and rhs have different dynamic dimensions."; - return false; - } - } - return true; -} - // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major) { + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); @@ -171,23 +102,19 @@ StatusOr MakeShapeWithLayoutInternal( } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->clear(); - for (int64 value : minor_to_major) { - min2maj->push_back(value); - } + *shape.mutable_layout() = + LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits); if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); return shape; } - } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/false); + bool equal = Shape::Equal()(lhs, rhs); + if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -198,8 +125,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/true); + bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -262,8 +188,10 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major) { - return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major, + tiles, element_size_in_bits) .ValueOrDie(); } @@ -534,10 +462,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i)); } result += "]"; - if (!IsScalar(shape) && shape.IsArray()) { - if (LayoutUtil::HasLayout(shape)) { - StrAppend(&result, LayoutUtil::HumanString(shape.layout())); + if (IsScalar(shape)) { + string layout_str = LayoutUtil::HumanString(shape.layout()); + // Don't print "{}" as layout for scalars. + if (layout_str != "{}") { + StrAppend(&result, layout_str); } + } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } return result; } @@ -563,37 +495,17 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false, - /*ignore_fp_precision=*/false); + return Shape::Equal().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.IsArray()) { - return rhs.IsArray() && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.IsArray()) { - return rhs.IsArray() && SameElementTypeIgnoringFpPrecision(lhs, rhs) && - CompatibleIgnoringElementType(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs); } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -1022,6 +934,10 @@ Status ForEachMutableSubshapeHelper( for (auto dim : Permute(permutation, shape.dimensions())) { new_shape.add_dimensions(dim); } + for (int64 i = 0; i < shape.rank(); i++) { + new_shape.set_dynamic_dimension(permutation[i], + shape.is_dynamic_dimension(i)); + } // If `shape` has a layout, by contract we choose a new layout such that the // transpose defined by this permutation is a bitcast. @@ -1345,6 +1261,43 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& input_shape, const Shape& output_shape) { CHECK(input_shape.IsArray()); CHECK(output_shape.IsArray()); + // Removing trivial dimensions from the shape simplifies the alignment + // algorithm since ones can go in any position. + if (HasDegenerateDimensions(input_shape) || + HasDegenerateDimensions(output_shape)) { + auto simple_output_shape = + AlignLayouts(DropDegenerateDimensions(input_shape), + DropDegenerateDimensions(output_shape)); + if (!simple_output_shape) { + return absl::nullopt; + } + + auto layout = simple_output_shape->layout().minor_to_major(); + // For each one sized dimension in the output, increment the dimension + // numbers in layout that are more minor than the one. + absl::InlinedVector dim_map; + dim_map.reserve(simple_output_shape->rank()); + for (int64 i = 0; i < output_shape.rank(); ++i) { + if (output_shape.dimensions(i) != 1) { + dim_map.push_back(i); + } + } + for (int64& d : layout) { + d = dim_map[d]; + } + + // Add the ones in descending order to the layout. Descending layouts tend + // to reduce the number of copies inserted in layout assignment. + for (int64 i = output_shape.rank() - 1; i >= 0; --i) { + if (output_shape.dimensions(i) == 1) { + layout.push_back(i); + } + } + Shape output_shape_with_layout = output_shape; + *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() = + layout; + return output_shape_with_layout; + } int64 input_rank = input_shape.rank(); int64 output_rank = output_shape.rank(); @@ -1393,10 +1346,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (input_dimension_product != output_dimension_product) { return absl::nullopt; } + // We also need to store an end element so that we know where the last // alignment part ends. alignment.push_back({input_rank, output_rank}); - // Now check if the physical layout can potentially be aligned to the output // shape by changing the physical layout of the output shape. We need to check // that all dimension numbers that belong to the same alignment part appear @@ -1408,40 +1361,23 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, for (int64 i = 0; i < input_rank;) { int64 current_dimension_number = input_dimension_numbers[i]; - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(current_dimension_number) == 1) { - ++i; - continue; - } - - // Calculate the number of non-trivial dimension bounds in the input shape - // belonging to the current alignment part. + // Trivial dimensions are stripped. + CHECK_NE(input_shape.dimensions(current_dimension_number), 1); const int64 current_alignment_index = dimension_to_alignment_index[current_dimension_number]; // Because of the special end element that we added, we can be sure that // 'current_alignment_index' is < alignment.size() - 1. CHECK_LT(current_alignment_index, alignment.size() - 1); - int64 num_non_trivial_dimensions_in_alignment_part = 0; - for (int64 j = alignment[current_alignment_index].first; - j < alignment[current_alignment_index + 1].first; ++j) { - if (input_shape.dimensions(j) != 1) { - ++num_non_trivial_dimensions_in_alignment_part; - } - } // Check that the following 'num_non_trivial_dimensions_in_alignment_part' // dimension numbers (ignoring dimension numbers with dimension bound 1) are // in descending order and belong to the current alignment part. - for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + for (int64 j = 0; j < alignment[current_alignment_index + 1].first - + alignment[current_alignment_index].first; ++i, ++j) { if (i == input_rank) { return absl::nullopt; } - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { - --j; - continue; - } // If the current dimension number belongs to a different alignment part, // or the dimension numbers are not in descending order, we can return // early. @@ -1452,22 +1388,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } current_dimension_number = input_dimension_numbers[i]; } - // The output dimension numbers that belong to the current alignment part - // need to appear in the same descending order as in the input. Again, we - // can skip dimensions with a bound of 1. + // need to appear in the same descending order as in the input. for (int64 j = alignment[current_alignment_index + 1].second - 1; j >= alignment[current_alignment_index].second; --j) { - if (output_shape.dimensions(j) != 1) { - output_layout.push_back(j); - } - } - } - // Now add all the dimensions with dimension bound 1 at the end of - // 'output_layout'. - for (int64 i = 0; i < output_rank; ++i) { - if (output_shape.dimensions(i) == 1) { - output_layout.push_back(i); + output_layout.push_back(j); } } CHECK_EQ(output_layout.size(), output_rank); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index e98c6e024bec1f6db5c40d3cd3215ca44eb13698..7f610a6085d6fbe3d3143d5027cdc43d4b07bcbf 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -398,7 +398,9 @@ class ShapeUtil { // Returns a value shape such that shape.has_layout(). static Shape MakeShapeWithLayout(PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major); + absl::Span minor_to_major, + absl::Span tiles = {}, + int64 element_size_in_bits = 0); static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, absl::Span dimensions, @@ -675,11 +677,9 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { - ForEachIndexWithStatus(shape, - [&](absl::Span indices) { - return StatusOr(visitor_function(indices)); - }) - .IgnoreError(); + ForEachIndexWithStatus(shape, [&](absl::Span indices) { + return StatusOr(visitor_function(indices)); + }).IgnoreError(); } // A parallel version of ForEachIndex(WithStatus). This can only be used if diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 61b4e73e060c18a3d0108e68d1117607d6c11c0f..020b062f6b1b032bab958772d3a6a1e35daee38b 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -710,6 +710,26 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { } while (std::next_permutation(layout.begin(), layout.end())); } +TEST(ShapeUtilTest, PermuteDynamicDimensions) { + Shape shape = + ShapeUtil::MakeShape(F32, {10, 100, 1000}, + /*dynamic_dimensions*/ {false, true, true}); + SCOPED_TRACE(absl::StrCat("shape=", shape.ToString())); + + std::vector permutation(3); + std::iota(permutation.begin(), permutation.end(), 0); + do { + SCOPED_TRACE(absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); + + auto permuted = ShapeUtil::PermuteDimensions(permutation, shape); + for (int i = 0; i < shape.rank(); i++) { + EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i)); + EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]), + shape.is_dynamic_dimension(i)); + } + } while (std::next_permutation(permutation.begin(), permutation.end())); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), @@ -741,8 +761,15 @@ TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { auto aligned_shape = ShapeUtil::AlignLayouts( input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); EXPECT_TRUE(aligned_shape); - EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), - ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithAllTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 1, 1, 1}, {0, 1, 3, 2}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 1, 1, 1, 1})); + EXPECT_TRUE(aligned_shape); EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); } diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index b88fe367d7416a26c1147fd5e10fb20772814fe5..aa7238f07d432aabb44d2cbed66786217e6a846c 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -25,6 +25,13 @@ limitations under the License. namespace xla { namespace status_macros { +ABSL_CONST_INIT const char kPossibleAutoJitAlternative[] = + "This error might be occurring with the use of xla.compile. If it is not " + "necessary that every Op be compiled with XLA, an alternative is to use " + "auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment " + "variable TF_XLA_FLAGS=\"tf_xla_auto_jit=2\" which will attempt to use xla " + "to compile as much of the graph as the compiler is able to."; + static Status MakeStatus(tensorflow::error::Code code, const string& message) { return Status(code, message); } diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index e51dd64e2a3dc7c359918cb08c6c94b2b4d9e91b..315136acc71670fa3ad48da4dc064e384ddadaa9 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -30,6 +30,10 @@ limitations under the License. namespace xla { namespace status_macros { +// This is a useful error message when encountering XLA Compiler errors that +// could be handled with the non-strict AutoJit mode. +extern const char kPossibleAutoJitAlternative[]; + // Stream object used to collect error messages in MAKE_ERROR macros // or append error messages with APPEND_ERROR. It accepts any // arguments with operator<< to build an error string, and then has an diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4e7480c9af46beb95f49f3db4ff764326e5e9882..ab875c1c4bc7464e4b397838a6ad580eb18cdcc6 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -317,7 +317,12 @@ xla_test( name = "conv_depthwise_backprop_filter_test", timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], - shard_count = 1, + # these backends do not natively handle batch group counts. + blacklisted_backends = [ + "gpu", + "cpu", + ], + shard_count = 6, deps = [ "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", @@ -547,6 +552,7 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], + shard_count = 2, deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -669,23 +675,21 @@ xla_test( ) xla_test( - name = "exhaustive_f32_elementwise_op_test", - size = "enormous", - srcs = ["exhaustive_f32_elementwise_op_test.cc"], - backends = [ - "cpu", - "gpu", - ], + name = "exhaustive_op_test", + srcs = ["exhaustive_op_test.cc"], + real_hardware_only = True, # Very slow on the interpreter. shard_count = 48, tags = [ - "broken", - "manual", - "notap", + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", ], deps = [ ":client_library_test_base", ":literal_test_util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/base", @@ -730,6 +734,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -788,6 +793,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -832,6 +838,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1141,7 +1148,7 @@ xla_test( xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], - shard_count = 40, + shard_count = 31, tags = [ "optonly", ], @@ -1389,8 +1396,8 @@ xla_test( ) xla_test( - name = "fmax_test", - srcs = ["fmax_test.cc"], + name = "fmax_fmin_test", + srcs = ["fmax_fmin_test.cc"], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1595,6 +1602,39 @@ xla_test( ], ) +xla_test( + name = "multi_device_all_reduce_test", + srcs = ["multi_device_all_reduce_test.cc"], + backends = ["gpu"], + tags = [ + "manual", + "multi_gpu", + "no_oss", + "notap", + ], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], @@ -1783,7 +1823,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", - "//tensorflow/compiler/xla/service/gpu:gpu_compiler", + "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor", @@ -2155,3 +2195,46 @@ xla_test( "//tensorflow/compiler/xla:test", ], ) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = [ + "enable_for_xla_interpreter", + "noasan", # sometimes times out, http://b/78650012 + ], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 7379fbcc22745f46f2a29732c4bda46f352d07e7..21458b40b10858599f15d1fcaf5acb617eafae0e 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -350,9 +349,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } -// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. -XLA_TEST_F(ArrayElementwiseOpTest, - DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { +XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) { XlaBuilder builder(TestName()); std::vector values(30, 0.0); auto a_literal = LiteralUtil::CreateR1(values); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 05d4d04034bf50c8bb840e59b28a590fce048c19..c14d279ac560db33066ae4fc68b6290f7499bb39 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -34,6 +34,7 @@ def xla_test( xla_test_library_deps = [], backends = [], blacklisted_backends = [], + real_hardware_only = False, args = [], tags = [], copts = [], @@ -108,6 +109,10 @@ def xla_test( use for that target. **kwargs: Additional keyword arguments to pass to native.cc_test. """ + + # All of the backends in all_backends are real hardware. + _ignore = [real_hardware_only] + test_names = [] if not backends: backends = all_backends diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc similarity index 54% rename from tensorflow/compiler/xla/client/lib/cholesky_test.cc rename to tensorflow/compiler/xla/tests/cholesky_test.cc index 095dd4fbf8b7c90047c4428b50c626c16e9c1e94..272d5784362dd347061e7178ff48f9fab4ffd822 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/cholesky.h" - +#include #include #include #include @@ -32,27 +31,27 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" +namespace xla { namespace { -using xla::int64; - -using CholeskyTest = xla::ClientLibraryTestBase; +using CholeskyTest = ClientLibraryTestBase; -XLA_TEST_F(CholeskyTest, Simple) { - xla::XlaBuilder builder(TestName()); +XLA_TEST_F(CholeskyTest, Lower) { + XlaBuilder builder(TestName()); - xla::Array2D a_vals({ - {4, 6, 8, 10}, - {6, 45, 54, 63}, - {8, 54, 146, 166}, + float nan = std::numeric_limits::quiet_NaN(); + Array2D a_vals({ + {4, nan, nan, nan}, + {6, 45, nan, nan}, + {8, 54, 146, nan}, {10, 63, 166, 310}, }); - xla::XlaOp a; + XlaOp a; auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - xla::Cholesky(a, /*block_size=*/2); + LowerTriangle(Cholesky(a, /*lower=*/true)); - xla::Array2D expected({ + Array2D expected({ {2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, @@ -60,34 +59,62 @@ XLA_TEST_F(CholeskyTest, Simple) { }); ComputeAndCompareR2(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Upper) { + XlaBuilder builder(TestName()); + + float nan = std::numeric_limits::quiet_NaN(); + Array2D a_vals({ + {4, 6, 8, 10}, + {nan, 45, 54, 63}, + {nan, nan, 146, 166}, + {nan, nan, nan, 310}, + }); + + XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + UpperTriangle(Cholesky(a, /*lower=*/false)); + + Array2D expected({ + {2, 3, 4, 5}, + {0, 6, 7, 8}, + {0, 0, 9, 10}, + {0, 0, 0, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + ErrorSpec(1e-4, 1e-4)); } XLA_TEST_F(CholeskyTest, Simple2) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::Array2D a_vals({ + Array2D a_vals({ {16, 24, 8, 12}, {24, 61, 82, 48}, {8, 82, 456, 106}, {12, 48, 106, 62}, }); - xla::XlaOp a; + XlaOp a; auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - xla::Cholesky(a); + LowerTriangle(Cholesky(a, /*lower=*/true)); - xla::Array2D expected( - {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + Array2D expected({{4, 0, 0, 0}, // + {6, 5, 0, 0}, // + {2, 14, 16, 0}, // + {3, 6, 1, 4}}); ComputeAndCompareR2(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); } XLA_TEST_F(CholeskyTest, SimpleBatched) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::Array3D a_vals({ + Array3D a_vals({ { {4, 6, 8, 10}, {6, 45, 54, 63}, @@ -102,65 +129,78 @@ XLA_TEST_F(CholeskyTest, SimpleBatched) { }, }); - xla::XlaOp a; + XlaOp a; auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); - xla::Cholesky(a); + LowerTriangle(Cholesky(a, /*lower=*/true)); - xla::Array3D expected({ + Array3D expected({ { {2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}, }, - {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + {{4, 0, 0, 0}, // + {6, 5, 0, 0}, // + {2, 14, 16, 0}, // + {3, 6, 1, 4}}, }); ComputeAndCompareR3(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); } -using CholeskyTestCase = std::tuple; +using CholeskyTestCase = std::tuple; class RandomCholeskyTest - : public xla::ClientLibraryTestBase, + : public ClientLibraryTestBase, public ::testing::WithParamInterface {}; XLA_TEST_P(RandomCholeskyTest, Random) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); auto test_params = GetParam(); std::vector dimensions = {std::get<0>(test_params), std::get<1>(test_params), std::get<1>(test_params)}; - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + bool lower = std::get<2>(test_params); + Shape shape = ShapeUtil::MakeShape(F32, dimensions); TF_ASSERT_OK_AND_ASSIGN( - auto literal, - xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + auto literal, LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); - auto input = xla::Parameter(&builder, 0, shape, "input"); + auto input = Parameter(&builder, 0, shape, "input"); // Form a random positive definite matrix. - auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), - xla::PrecisionConfig::HIGHEST); + auto matrix = + BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST); - auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + auto cholesky = Triangle(Cholesky(matrix, lower), lower); // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 - auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), - xla::PrecisionConfig::HIGHEST); + XlaOp verification; + if (lower) { + verification = BatchDot(cholesky, TransposeInMinorDims(cholesky), + PrecisionConfig::HIGHEST); + } else { + verification = BatchDot(TransposeInMinorDims(cholesky), cholesky, + PrecisionConfig::HIGHEST); + } auto delta = matrix - verification; - xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), - CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + Reduce(delta * delta, ConstantR0(&builder, 0.0), + CreateScalarAddComputation(F32, &builder), {0, 1, 2}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); } INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, - ::testing::Values(CholeskyTestCase{1, 1}, - CholeskyTestCase{1, 2}, - CholeskyTestCase{10, 5}, - CholeskyTestCase{2, 20})); + ::testing::Values(CholeskyTestCase{1, 1, true}, + CholeskyTestCase{1, 2, true}, + CholeskyTestCase{1, 50, true}, + CholeskyTestCase{1, 50, false}, + CholeskyTestCase{10, 5, true}, + CholeskyTestCase{5, 10, false}, + CholeskyTestCase{2, 20, true})); } // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index edb95c973b70e30702ed8490c15a48d4d5604170..0e99ede5d01fcfa88c54c9cbc5a6a85bf8f15ddf 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -41,8 +41,9 @@ constexpr char kInterpreter[] = "interpreter"; // Wrapper function that creates a nicer error message (than a bare // ValueOrDie()) if the platform we intend to test is not available. -Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { - StatusOr result = +LocalClient* GetOrCreateLocalClientOrDie( + const LocalClientOptions& client_options) { + StatusOr result = ClientLibrary::GetOrCreateLocalClient(client_options); TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 3f65ed7fce4ff4b5c3781ac2581935bfacc69ce1..d700437ed355c144639f76d683055e211975fde9 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -385,8 +385,8 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - Client* client_; - Client* ref_client_; // To compute reference result. + LocalClient* client_; + LocalClient* ref_client_; // To compute reference result. ExecutionOptions execution_options_; private: diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 32cac499c7439af80bafb88ac61b0b078f589599..f75c3fb01e2c854475537ca4b413f381cf74355c 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -169,6 +170,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.001}; }; +// Test fixture to run indexed conditional (switch/case) tests with varying +// number of branches. +class CaseOpTest : public ConditionalOpTest, + public ::testing::WithParamInterface {}; + // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); @@ -182,6 +188,36 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) { ComputeAndCompareR0(&builder, 56.0f, {pred_arg.get()}, error_spec_); } +// Test branch computations that do not take any parameters. +XLA_TEST_P(CaseOpTest, Parameters0) { + int num_branches = GetParam(); + for (int bi = -1; bi <= num_branches; ++bi) { + SCOPED_TRACE(bi); + XlaBuilder builder(TestName()); + XlaOp branch_index; + auto branch_index_arg = CreateR0Parameter(bi, 0, "branch_index_arg", + &builder, &branch_index); + auto operand = Tuple(&builder, {}); + + std::vector operands(num_branches, operand); + std::vector branches; + branches.reserve(num_branches); + std::vector branches_p(num_branches); + for (int i = 0; i < num_branches; ++i) { + branches.emplace_back( + CreateR0ConstantComputation(static_cast(i) * 10)); + branches_p[i] = &branches[i]; + } + Conditional(branch_index, branches_p, operands); + + float expected = 10 * static_cast((bi < 0 || bi >= num_branches) + ? num_branches - 1 + : bi); + ComputeAndCompareR0(&builder, expected, {branch_index_arg.get()}, + error_spec_); + } +} + // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); @@ -195,6 +231,45 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } +// Test branch computations that take in 1 parameter. +XLA_TEST_P(CaseOpTest, Parameters1) { + int num_branches = GetParam(); + for (int bi = -1; bi <= num_branches; ++bi) { + SCOPED_TRACE(bi); + XlaBuilder builder(TestName()); + XlaOp branch_index; + auto branch_index_arg = CreateR0Parameter(bi, 0, "branch_index_arg", + &builder, &branch_index); + + auto make_branch = [&builder, this](int i) { + auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); + Add(ConstantR0(sb.get(), static_cast(i)), + Parameter(sb.get(), 0, r0f32_, "p0")); + return sb->BuildAndNoteError(); + }; + std::vector branches; + branches.reserve(num_branches); + std::vector branches_p(num_branches); + std::vector operands; + operands.reserve(num_branches); + std::vector expecteds(num_branches); + for (int i = 0; i < num_branches; ++i) { + branches.emplace_back(make_branch(i)); + branches_p[i] = &branches[i]; + auto fi = static_cast(i); + operands.emplace_back(ConstantR0(&builder, 10 * fi + 7)); + expecteds[i] = 10 * fi + 7 + fi; + } + + Conditional(branch_index, branches_p, operands); + float expected = (bi < 0 || bi >= num_branches) + ? expecteds[num_branches - 1] + : expecteds[bi]; + ComputeAndCompareR0(&builder, expected, {branch_index_arg.get()}, + error_spec_); + } +} + // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { @@ -331,6 +406,46 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { error_spec_); } +// Test branch computations that take in 2 array parameters. +XLA_TEST_P(CaseOpTest, Parameters2Array) { + int num_branches = GetParam(); + for (int bi = -1; bi <= num_branches; ++bi) { + SCOPED_TRACE(bi); + XlaBuilder builder(TestName()); + XlaOp branch_index; + auto branch_index_arg = + CreateR0Parameter(bi, 0, "pred", &builder, &branch_index); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + auto make_branch = [&builder, this](int i) { + auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); + auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0"); + Add(Mul(ConstantR0(sb.get(), static_cast(i)), + GetTupleElement(p, 0)), + GetTupleElement(p, 1)); + return sb->BuildAndNoteError(); + }; + std::vector branches; + branches.reserve(num_branches); + std::vector branches_p(num_branches); + for (int i = 0; i < num_branches; ++i) { + branches.emplace_back(make_branch(i)); + branches_p[i] = &branches[i]; + } + Conditional(branch_index, branches_p, + std::vector(num_branches, operands)); + auto modified_bi = static_cast( + (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi); + ComputeAndCompareR1( + &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11}, + {branch_index_arg.get()}, error_spec_); + } +} + +INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest, + ::testing::Values(1, 2, 3, 4, 5)); + // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { @@ -582,8 +697,8 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { auto result = builder.Build(); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().error_message(), - ::testing::HasSubstr("true_operand must match the shape of the " - "only parameter of true_computation")); + ::testing::HasSubstr("operand 0 must match the shape of the " + "only parameter of branch computation 0")); } XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 9174f2651cb90b364f869364fe108cf208c11a84..6530007871ced1d0bbffe2b44ccc8cf9bddd79e1 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -180,6 +181,29 @@ TEST_F(ConstantsTest, Token) { TF_ASSERT_OK(Execute(&builder, {}).status()); } +TEST_F(ConstantsTest, FullLike) { + XlaBuilder b(TestName()); + auto val1 = Iota(&b, F32, 3); + auto val2 = FullLike(val1, 10); + val1 + val2; + ComputeAndCompareR1(&b, {10, 11, 12}, {}, error_spec_); +} + +TEST_F(ConstantsTest, IllegalFullLikeOnTuple) { + XlaBuilder b(TestName()); + auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)}); + FullLike(tuple, 10); // Illegal; can't do FullLike on a tuple. + EXPECT_FALSE(b.Build().ok()); +} + +TEST_F(ConstantsTest, FullLikeScalar) { + XlaBuilder b(TestName()); + auto scalar1 = ConstantR0WithType(&b, F32, 1); + auto scalar2 = FullLike(scalar1, 2); + scalar1 - scalar2; + ComputeAndCompareR0(&b, -1, {}, error_spec_); +} + class ConstantsHloTest : public HloTestBase {}; // TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. @@ -200,9 +224,7 @@ XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) { ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR0(1); auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal(param, result)); diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc index 90c197140359d0021d08931b73f221d659e71144..dfbf0478e62713635446d11557367cfac6ab0dce 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -32,25 +32,26 @@ string GetFloatDataType(bool use_bfloat16) { return use_bfloat16 ? "bf16" : "f32"; } -struct DepthwiseConvolution2DSpec { - int64 output_batch, window; +struct BatchGroupedConvolution2DSpec { + int64 output_batch, window, window_dilation; std::vector activation_dims; - std::vector activation_layout; std::vector kernel_dims; - std::vector kernel_layout; std::vector output_dims; + std::vector activation_and_kernel_layout; std::vector output_layout; }; -class DepthwiseConvolution2DTest +class BatchGroupedConvolution2DTest : public HloTestBase, public ::testing::WithParamInterface< - ::testing::tuple> {}; + ::testing::tuple> {}; -static std::vector GetConv2DTestCases() { - std::vector config_set; +static std::vector GetConv2DTestCases() { + std::vector config_set; std::vector> config_options = { - {16, 5, 5, 2}, {64, 4, 4, 16}, {2, 5, 5, 256}}; + {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, + {16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4}, + {256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}}; for (auto option : config_options) { int64 feature = option[3]; @@ -58,93 +59,120 @@ static std::vector GetConv2DTestCases() { int64 kernel_size = option[2]; int64 batch = option[0]; - std::vector kernel_layout = {3, 2, 1, 0}; - DepthwiseConvolution2DSpec config; + BatchGroupedConvolution2DSpec config; + config.window_dilation = 1; config.output_batch = feature; config.window = kernel_size; config.activation_dims = {batch, activation_size, activation_size, feature}; - config.activation_layout = {0, 3, 2, 1}; config.kernel_dims = {batch, kernel_size, kernel_size, feature}; - config.kernel_layout = {0, 2, 3, 1}; - config.output_dims = {3, 3, feature, 1}; + int64 output_space_size = 3 + activation_size - kernel_size; + config.output_dims = {output_space_size, output_space_size, feature, 1}; - // Try this layout for all kernel shapes. - config.output_layout = {3, 2, 0, 1}; + config.activation_and_kernel_layout = {0, 3, 1, 2}; + config.output_layout = {2, 3, 0, 1}; config_set.push_back(config); + + BatchGroupedConvolution2DSpec different_layout_config = config; + different_layout_config.activation_and_kernel_layout = {3, 0, 1, 2}; + config_set.push_back(different_layout_config); + + // Add configurations for window dilation cases. + if (activation_size % 2 == 0 && activation_size == kernel_size) { + BatchGroupedConvolution2DSpec config; + config.window_dilation = 2; + config.output_batch = feature; + config.window = kernel_size / 2; + config.activation_dims = {batch, activation_size, activation_size, + feature}; + config.kernel_dims = {batch, kernel_size / 2, kernel_size / 2, feature}; + config.activation_and_kernel_layout = {0, 3, 1, 2}; + config.output_layout = {2, 3, 0, 1}; + + int64 output_space_size = 5; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config_set.push_back(config); + + BatchGroupedConvolution2DSpec different_layout_config = config; + different_layout_config.activation_and_kernel_layout = {3, 0, 1, 2}; + config_set.push_back(different_layout_config); + } } return config_set; } -string DepthwiseConvolution2DTestDataToString( +string BatchGroupedConvolution2DTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& spec = ::testing::get<0>(data.param); const string data_type = GetFloatDataType(::testing::get<1>(data.param)); string str = absl::StrCat( "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), - "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), - "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", - absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", - absl::StrJoin(spec.output_dims, "x"), "_output_layout_", - absl::StrJoin(spec.output_layout, "_"), data_type); + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), + "_activation_layout_", + absl::StrJoin(spec.activation_and_kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), data_type, "_output_layout_", + absl::StrJoin(spec.output_layout, "_")); // Test names are not allowed to contain the '-' character. absl::c_replace(str, '-', 'n'); return str; } -string BuildHloTextDepthwiseConvolution2D( - const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { +string BuildHloTextBatchGroupedConvolution2D( + const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) { const string data_type = GetFloatDataType(use_bfloat16); return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv + HloModule TensorFlowDepthwiseConv, is_scheduled=true ENTRY main { activation = %s[%s]{%s} parameter(0) kernel = %s[%s]{%s} parameter(1) ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), - window={size=%dx%d pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, + window={size=%dx%d pad=1_%dx1_%d rhs_dilate=%dx%d}, dim_labels=f01b_i01o->01fb, batch_group_count=%d } )", data_type, absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, absl::StrJoin(spec.output_dims, ","), absl::StrJoin(spec.output_layout, ","), data_type, absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, - spec.output_batch); + absl::StrJoin(spec.activation_and_kernel_layout, ","), spec.window, + spec.window, spec.window_dilation, spec.window_dilation, + spec.window_dilation, spec.window_dilation, spec.output_batch); } -XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { - const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); +XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { + const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); bool use_bfloat16 = ::testing::get<1>(GetParam()); const string hlo_text = - BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, - [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompareNoHloPasses( + hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); } INSTANTIATE_TEST_CASE_P( - DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest, + BatchGroupedConvolution2DTestWithRandomIndices, + BatchGroupedConvolution2DTest, ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), ::testing::Bool()), - DepthwiseConvolution2DTestDataToString); + BatchGroupedConvolution2DTestDataToString); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 9db9f2563b636c4f929585eb13a9c7f809833eda..cfee9c0f8a4c908d5dbdd5345ed7f839dfa4dee2 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -1945,7 +1945,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { class ConvolutionHloTest : public HloTestBase {}; -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF64Forward) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1957,7 +1957,7 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF32ForwardReversed)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF32ForwardReversed) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1969,7 +1969,7 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardFilter) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1981,7 +1981,7 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) { +XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardInput) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index cad43d1b5547d74701760fa623e50466fc15c263..4687ed61a7de91bc1bce0efeadf1965ad7d52d55 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -172,8 +172,10 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { const Shape& r2f32_dim0_major = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); - b.AddInstruction(HloInstruction::CreateCustomCall( + auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall( r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + b.AddInstruction( + custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call})); module->AddEntryComputation(b.Build()); ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); @@ -182,7 +184,7 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); Literal result = ExecuteAndTransfer(std::move(module), {&argument}); - LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); + LiteralTestUtil::ExpectR2Equal({{3.f, 4.f}, {5.f, 6.f}}, result); } XLA_TEST_F(CustomCallTest, TupleOutput) { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 33d816292c5c391d2ca5a0ebaf4c80e9cc0dc88c..414d0b14a6b4f0307851fcc717c5e8a74a33782b 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -1157,6 +1158,53 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } +using EinsumParamType = + std::tuple, std::vector, string>; +class EinsumTest : public DotOperationTest, + public ::testing::WithParamInterface {}; +XLA_TEST_P(EinsumTest, SimpleEinsumTest) { + XlaBuilder builder(TestName()); + auto x = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) + .ValueOrDie(), + &builder); + auto y = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) + .ValueOrDie(), + &builder); + Einsum(x, y, std::get<2>(GetParam())); + ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); +} + +std::vector GetEinsumTestCases() { + using v = std::vector; + using p = EinsumParamType; + std::vector

test_cases = { + p{v{5, 6}, v{6, 7}, "mk,kn->mn"}, + p{v{5, 6}, v{6, 7}, "mk,kn->nm"}, + p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"}, + p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, + p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, + p{v{6}, v{6, 7}, "b,bc->c"}, + p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"}, + p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"}, + p{v{77}, v{77}, "a,a->a"}, + p{v{77}, v{77, 55}, "a,ab->ba"}, + p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, + p{v{55}, v{}, "a,->a"}, + p{v{11, 111}, v{11}, "ab,a->ab"}, + p{v{16, 34}, v{16, 34}, "ab,ab->ab"}, + p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"}, + p{v{5, 19}, v{}, "ab,->ab"}, + }; + return test_cases; +} + +INSTANTIATE_TEST_CASE_P(Einsum, EinsumTest, + ::testing::ValuesIn(GetEinsumTestCases())); + class DotOperationTextTest : public HloTestBase {}; XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { @@ -1189,5 +1237,129 @@ ENTRY %test { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); } +XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) { + absl::string_view hlo_string = + R"( +HloModule DotWithNoDnums + +ENTRY %test { + %lhs = f32[2,3]{1,0} parameter(0) + %rhs = f32[4,5]{1,0} parameter(1) + ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, Einsum) { + absl::string_view hlo_string = + R"( +HloModule Einsum + +ENTRY %test { + %lhs = f32[8,64,96]{2,1,0} parameter(0) + %rhs = f32[96,32,4]{2,1,0} parameter(1) + ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(2) + rhs_1 = f32[1,40] parameter(1) + + dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + ROOT result = f32[20,1] divide(dot_0, dot_1) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs_0 = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(1) + lhs_1 = f32[1,40] parameter(2) + rhs_1 = f32[20,40] parameter(3) + + dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + dot_0_reshaped = f32[20] reshape(dot_0) + dot_1_reshaped = f32[20] reshape(dot_1) + + ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuIntegerDotCodegen)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s32[1,2,2] parameter(0) + arg1 = s32[1,2,1] parameter(1) + ROOT dot = s32[1,2,1] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuTransposeOutput)) { + absl::string_view hlo_string = + R"( +HloModule TransposeOutput + +ENTRY TransposeOutput { + p0 = f32[32,32] parameter(0) + p1 = f32[32,64] parameter(1) + dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT tr = f32[64,32] transpose(dot), dimensions={1,0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, MatrixVectorComplex) { + absl::string_view hlo_string = + R"( +HloModule MatrixVectorComplex + +ENTRY MatrixVectorComplex { + p0 = c64[5,5] parameter(0) + p1 = c64[5,1] parameter(1) + p2 = c64[5,1] parameter(2) + dot = c64[5,1] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add = c64[5,1] add(dot, p2) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc deleted file mode 100644 index c84973e17b234c24c84f02a369ce0185f5772cca..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/base/casts.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" - -namespace xla { -namespace { -class ExhaustiveF32ElementwiseOpTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface> { - protected: - ErrorSpec error_spec_{0.0001, 0.0001, /*relaxed_nans=*/true}; - - template - void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op, - float (*evaluate_op)(float), - std::pair known_incorrect_range) { - int64 begin, end; - std::tie(begin, end) = GetParam(); - int64 input_size = end - begin; - LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; - - XlaBuilder builder(TestName()); - - Literal input_literal = - LiteralUtil::CreateFromDimensions(F32, {input_size}); - for (int64 i = begin; i < end; i++) { - if (i >= known_incorrect_range.first && - i < known_incorrect_range.second) { - // If the operation is known to be buggy on a specific input clamp that - // input to 0 under the assumption that the op is at least correct on 0. - input_literal.Set({i - begin}, 0.0f); - } else { - input_literal.Set({i - begin}, absl::bit_cast(i)); - } - } - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(input_literal)); - - auto input = Parameter(&builder, 0, input_literal.shape(), "input"); - enqueue_op(&builder, input); - - std::vector expected_result; - expected_result.reserve(input_size); - for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal.Get({i}))); - } - - ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, - error_spec_); - } -}; - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { -#ifdef XLA_TEST_BACKEND_CPU - // TODO(b/73141998): The vectorized Log implementation gives results outside - // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64). - std::pair known_incorrect_range = {1, 8388608}; -#else - std::pair known_incorrect_range = {0, 0}; -#endif - - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, - known_incorrect_range); -} - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { -#ifdef XLA_TEST_BACKEND_CPU - // TODO(b/73142289): The vectorized Exp implementation gives results outside - // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64): - std::pair known_incorrect_range = {1107296256 + 11583654, - 1107296256 + 11629080}; -#else - std::pair known_incorrect_range = {0, 0}; -#endif - - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Exp(input); }, std::exp, - known_incorrect_range); -} - -XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { - ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { Tanh(input); }, std::tanh, - /*known_incorrect_range=*/{0, 0}); -} - -std::vector> CreateExhaustiveParameters() { - // We break up the 2^32-element space into small'ish chunks to keep peak - // memory usage low. - std::vector> result; - const int64 step = 1 << 25; - for (int64 i = 0; i < (1l << 32); i += step) { - result.push_back({i, i + step}); - } - return result; -} - -INSTANTIATE_TEST_CASE_P(ExhaustiveF32ElementwiseOpTestInstance, - ExhaustiveF32ElementwiseOpTest, - ::testing::ValuesIn(CreateExhaustiveParameters())); -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..58bb9a217b805a142869149c19d7bcfc91a1aee1 --- /dev/null +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -0,0 +1,646 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "absl/base/casts.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using Eigen::half; + +template +T EvaluatePolynomial(T x, const std::array& coeffs) { + T result = 0; + for (T c : coeffs) { + result = result * x + c; + } + return result; +} + +// There's no std::erfinv, so we have to implement it ourselves. This follows +// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a +// different implementation from that in math.cc. +float HostErfInv(float x) { + std::array kPolyA = { + 8.8709406962545514830200e2, 1.1819493347062294404278e4, + 2.3782041382114385731252e4, 1.6235862515167575384252e4, + 4.8548868893843886794648e3, 6.9706266534389598238465e2, + 4.7072688112383978012285e1, 1.1975323115670912564578e0, + }; + std::array kPolyB = { + 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, + 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, + 4.2313330701600911252e1, 1.0000000000000000000e0, + }; + std::array kPolyC = { + 7.74545014278341407640e-4, 2.27238449892691845833e-2, + 2.41780725177450611770e-1, 1.27045825245236838258e0, + 3.64784832476320460504e0, 5.76949722146069140550e0, + 4.63033784615654529590e0, 1.42343711074968357734e0, + }; + std::array kPolyD = { + 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, + 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, + 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, + 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, + }; + std::array kPolyE = { + 2.01033439929228813265e-7, 2.71155556874348757815e-5, + 1.24266094738807843860e-3, 2.65321895265761230930e-2, + 2.96560571828504891230e-1, 1.78482653991729133580e0, + 5.46378491116411436990e0, 6.65790464350110377720e0, + }; + std::array kPolyF = { + 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, + 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, + 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, + 8.482908416595164588112026e-1, 1.414213562373095048801689e0, + }; + + if (std::abs(x) > 1 || std::isnan(x)) { + return std::numeric_limits::quiet_NaN(); + } + if (std::abs(x) == 1) { + return std::copysign(std::numeric_limits::infinity(), x); + } + + float unsigned_result = [&] { + float y = std::abs(x); + if (y <= 0.85) { + double r = 0.180625 - 0.25 * y * y; + return (y * EvaluatePolynomial(r, kPolyA)) / + EvaluatePolynomial(r, kPolyB); + } else { + double r = std::sqrt(std::log(2.0) - std::log1p(-y)); + if (r <= 5.0) { + r -= 1.6; + return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD); + } else { + r -= 5; + return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF); + } + } + }(); + return std::copysign(unsigned_result, x); +} + +// Digamma implementation using a polynomial from Cephes. Notably this is a +// different implementation from the one in math.cc. +float HostDigamma(float x) { + // Euler-Mascheroni constant + float kGamma = 0.57721566490153286061; + float kPi = M_PI; + + std::array kPoly = { + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + float reflection = 0; + if (x <= 0) { + float floor = std::floor(x); + if (x == floor) { + return std::numeric_limits::quiet_NaN(); + } + // Compute reflection term, pi * cot(pi * x). + reflection = x - floor; + if (reflection == 0.5) { + reflection = 0; + } else { + if (reflection > 0.5) { + reflection = x - (floor + 1.0f); + } + reflection = kPi / std::tan(kPi * reflection); + } + x = 1 - x; + } + + float result = 0; + if (x <= 10 && x == std::floor(x)) { + // Special case for integers <= 10. + for (int i = 1; i < x; ++i) { + result += 1.0f / i; + } + result -= kGamma; + } else { + float w = 0; + for (; x < 10; ++x) { + w += 1.0f / x; + } + if (x < 1e8) { + float z = 1.0f / (x * x); + result = z * EvaluatePolynomial(z, kPoly); + } + result = std::log(x) - 0.5f / x - result - w; + } + + // Compute the final, reflected value. + return result - reflection; +} + +// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be +// guaranteed that we're printing the full number. +// +// (The general formula is, given a floating-point number with S significand +// bits, the number of decimal digits needed to print it to full precision is +// +// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103). +// +// See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.) +string StringifyNum(float x) { + return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast(x)); +} + +string StringifyNum(half x) { + return absl::StrFormat("%0.5g (0x%04x)", static_cast(x), + absl::bit_cast(x)); +} + +string StringifyNum(bfloat16 x) { + return absl::StrFormat("%0.4g (0x%04x)", static_cast(x), + absl::bit_cast(x)); +} + +// Test parameter is a tuple containing +// - primitive type under test, +// - (begin, end) range under test, as zero-extended int64s bitcast to the +// primtive type under test. +class ExhaustiveOpTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple>> { + public: + ExhaustiveOpTest() + : ty_(std::get<0>(GetParam())), platform_(client_->platform()->Name()) {} + + void Run(std::function enqueue_op, + float (*evaluate_op)(float)) { + SetFastMathDisabled(true); + + // Run all HLO passes. In particular, constant folding is disabled by + // default for tests, but we need to run it in order to tickle some bugs. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + + PrimitiveType ty; + std::tie(ty, std::ignore) = GetParam(); + + switch (ty) { + case F32: + SetDefaultErrSpec(0.0001, 0.0001); + RunImpl(enqueue_op, evaluate_op); + break; + case F16: + SetDefaultErrSpec(0.001, 0.001); + RunImpl(enqueue_op, evaluate_op); + break; + case BF16: + SetDefaultErrSpec(0.001, 0.01); + RunImpl(enqueue_op, evaluate_op); + break; + default: + LOG(FATAL) << "Unhandled type."; + } + } + + void SetDefaultErrSpec(float abs_err, float rel_err) { + if (!abs_err_.has_value()) { + abs_err_ = abs_err; + } + if (!rel_err_.has_value()) { + rel_err_ = rel_err; + } + } + + template + void RunImpl(std::function enqueue_op, + float (*evaluate_op)(float)) { + static_assert( + sizeof(T) == sizeof(IntegralT), + "IntegralT must be an unsigned integer type of the same width as T."); + + PrimitiveType ty; + std::pair test_range; + std::tie(ty, test_range) = GetParam(); + int64 begin, end; + std::tie(begin, end) = test_range; + + if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) { + LOG(INFO) << absl::StreamFormat( + "Skipping this shard, as the range under test, [%d, %d), falls " + "entirely within the known-incorrect range [%d, %d).", + begin, end, known_incorrect_begin_, known_incorrect_end_); + return; + } + + LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; + + int64 input_size = end - begin; + Literal input_literal = LiteralUtil::CreateFromDimensions(ty, {input_size}); + absl::Span input_arr = input_literal.data(); + for (int64 i = 0; i < input_size; i++) { + IntegralT input_val = i + begin; + // If the operation is known to be buggy on a specific input clamp that + // input to 0 under the assumption that the op is at least correct on 0. + if (input_val >= known_incorrect_begin_ && + input_val < known_incorrect_end_) { + input_arr[i] = T{0}; + } else { + input_arr[i] = absl::bit_cast(input_val); + } + } + + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + BuildAndRunComputation(enqueue_op, input_literal)); + ExpectNear(input_literal, result_literal, evaluate_op); + } + + StatusOr BuildAndRunComputation( + const std::function& enqueue_op, + const Literal& input_literal) { + XlaBuilder builder(TestName()); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); + enqueue_op(input); + TF_ASSIGN_OR_RETURN(XlaComputation comp, builder.Build()); + + // Build and run the computation using the LocalClient API, rather than the + // plain Client API, which is used by ClientLibraryTestBase. This is + // because the plain Client API results does more memcpys to/from Literals, + // and that's slow given that we're touching a lot of data here. + // + // Copy debug options from ClientLibraryTestBase. In particular, we're + // interested in disabling constant folding. + ExecutableBuildOptions build_opts; + *build_opts.mutable_debug_options() = *mutable_debug_options(); + TF_ASSIGN_OR_RETURN( + auto executable, + client_->Compile(comp, {&input_literal.shape()}, build_opts)); + + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer input_data, + client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0)); + + ExecutableRunOptions run_opts; + run_opts.set_allocator(client_->backend().memory_allocator()); + run_opts.set_intra_op_thread_pool( + client_->backend().eigen_intra_op_thread_pool_device()); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + executable->Run({&input_data}, run_opts)); + + TF_ASSIGN_OR_RETURN(Literal result_literal, + client_->ShapedBufferToLiteral(result)); + return std::move(result_literal); + } + + template + bool IsClose(T expected, T actual) { + float expected_f32 = static_cast(expected); + float actual_f32 = static_cast(actual); + float abs_err = std::abs(expected_f32 - actual_f32); + float rel_err = abs_err / std::abs(expected_f32); + if (strict_signed_zeros_ && actual == T{0} && expected == T{0}) { + // Check sign of zero. + return std::signbit(actual_f32) == std::signbit(expected_f32); + } + return abs_err < *abs_err_ || rel_err < *rel_err_ || + (std::isnan(expected_f32) && std::isnan(actual_f32)) || + (std::isinf(expected_f32) && std::isinf(actual_f32) && + (expected_f32 > 0) == (actual_f32 > 0)); + } + + template + void ExpectNear(const Literal& input_literal, const Literal& result_literal, + float (*evaluate_op)(float)) { + // We essentially reimplement LiteralTestUtil::Near here because + // a) this streamlined implementation is much faster, and + // b) we can print out better error messages (namely, we can print out + // which floating-point value input failed, while LiteralTestUtil::Near + // can only print out the input index that failed). + // c) we need special handling of certain inputs. For example, we say that + // a denormal input has multiple correct outputs (namely, f(x) and f(0)) + // and just needs to be close to one of them. + absl::Span input_arr = input_literal.data(); + absl::Span result_arr = result_literal.data(); + ASSERT_EQ(result_arr.size(), input_arr.size()); + int64 mismatches = 0; + // Hoisting these out of the loop is a nice speedup on shards that have many + // denormals. + const T expected_at_pos_zero = static_cast(evaluate_op(0)); + const T expected_at_neg_zero = static_cast(evaluate_op(-0.0)); + for (int64 i = 0; i < input_arr.size(); ++i) { + T input = input_arr[i]; + float input_f32 = static_cast(input); + T actual = result_arr[i]; + T expected = static_cast(evaluate_op(input_f32)); + + if (IsClose(expected, actual)) { + continue; + } + + // Easy case: If `input` is not denormal and !IsClose(expected, actual), + // print an error. + // + // (This doesn't correctly detect f16 and bfloat16 denormals! This seems + // to be OK for now, but at some point we may need to implement fpclassify + // for half and bfloat.) + if (std::fpclassify(input_f32) != FP_SUBNORMAL) { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(actual)); + }); + continue; + } + + // Otherwise, `input` is denormal. For denormal inputs, we accept answers + // that are close to any of: + // + // - evaluate_op(input) + // - evaluate_op(+/-0), where the sign of 0 equal to the sign of + // `input`, + // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of + // 0 is the opposite of `input`. + T sign_preserving_ftz_expected = + std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero; + T sign_nonpreserving_ftz_expected = + std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero; + if (IsClose(sign_preserving_ftz_expected, actual) || + (relaxed_denormal_signs_ && + IsClose(sign_nonpreserving_ftz_expected, actual))) { + continue; + } + + if (relaxed_denormal_signs_) { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat( + "Mismatch on denormal value %s. Expected one of:\n" + " %10s (evaluated at full-precision value)\n" + " %10s (evaluated after flushing to sign-preserving zero)\n" + " %10s (evaluated after flushing to non-sign-preserving " + "zero)\n" + "but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(sign_preserving_ftz_expected), + StringifyNum(sign_nonpreserving_ftz_expected), + StringifyNum(actual)); + }); + } else { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat( + "Mismatch on denormal value %s. Expected one of:\n" + " %10s (evaluated at full-precision value)\n" + " %10s (evaluated after flushing to sign-preserving zero)\n" + "but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual)); + }); + } + } + EXPECT_EQ(mismatches, 0); + } + + template + void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) { + // We send a few mismatches to gunit so they show up nicely in test logs. + // Then we send more to LOG(ERROR). The remainder we squelch unless we're + // at vlog level 2. + constexpr int64 kMaxMismatchesLoggedToGunit = 10; + constexpr int64 kMaxMismatchesLoggedToErr = 1000; + + (*mismatches)++; + if (*mismatches < kMaxMismatchesLoggedToGunit) { + FAIL() << err_generator(); + } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) { + LOG(ERROR) << err_generator(); + } else if (*mismatches == kMaxMismatchesLoggedToErr) { + LOG(ERROR) << "Not printing any more mismatches; pass " + "--vmodule=exhaustive_f32__op_test=2 to see " + "all of them."; + } + } + + // The following members are set during construction so testcases can read + // these values and use them e.g. to influence the values given to the mutable + // members below. + + // The primitive type under test. + const PrimitiveType ty_; + + // The platform under test. + const string platform_; + + // Tests can set the following variables for control over execution. This is + // safe because each XLA_TEST_P instantiates a new instance of this class. + + // Testing will ignore the given range (encoded as bitwise representations of + // the type under test zero-extended to int64). + int64 known_incorrect_begin_ = 0; + int64 known_incorrect_end_ = 0; + + // If unset, reasonable defaults will be used depending on the type under + // test. + absl::optional abs_err_; + absl::optional rel_err_; + + // If true, will consider -0 not near to +0 and vice versa. Note that + // +epsilon may still be considered close to -0, depending on the error spec; + // this only covers the case when both `expected` and `actual` are equal to 0. + bool strict_signed_zeros_ = false; + + // If true, allows denormals to be flushed to non-sign-preserving 0. + // + // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of + // a negative number) or -inf (flush the denormal to sign-perserving zero, + // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). + // + // XLA:GPU preserves denormal signs, but other backends don't. + bool relaxed_denormal_signs_ = platform_ != "CUDA"; +}; + +XLA_TEST_P(ExhaustiveOpTest, Log) { + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + abs_err_ = 0.001; + rel_err_ = 0.001; + } + + Run(Log, std::log); +} + +XLA_TEST_P(ExhaustiveOpTest, Log1p) { + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + abs_err_ = 0.001; + rel_err_ = 0.001; + } + + Run(Log1p, std::log1p); +} + +XLA_TEST_P(ExhaustiveOpTest, Exp) { + if (platform_ == "Host" && ty_ == F32) { + // TODO(b/73142289): The vectorized Exp implementation gives results outside + // our error spec in this range. + known_incorrect_begin_ = 1107296256 + 11583654; + known_incorrect_end_ = 1107296256 + 11629080; + } else if (platform_ == "Host" && ty_ == BF16) { + // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? + // + // Mismatch on 88.5 (0x42b1). + // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). + known_incorrect_begin_ = 0x42b1; + known_incorrect_end_ = 0x42b2; + } + + Run(Exp, std::exp); +} + +XLA_TEST_P(ExhaustiveOpTest, Expm1) { + // Expm1 has the same erroneous behavior on CPU as Exp. + if (platform_ == "Host" && ty_ == F32) { + // TODO(b/73142289): The vectorized Exp implementation gives results outside + // our error spec in this range. + known_incorrect_begin_ = 1107296256 + 11583654; + known_incorrect_end_ = 1107296256 + 11629080; + } else if (platform_ == "Host" && ty_ == BF16) { + // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? + // + // Mismatch on 88.5 (0x42b1). + // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). + known_incorrect_begin_ = 0x42b1; + known_incorrect_end_ = 0x42b2; + } + + Run(Expm1, std::expm1); +} + +// It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but +// this *did* find a bug, namely that some backends were assuming sqrt(x) == +// pow(x, 0.5), but this is not true for x == -inf. +XLA_TEST_P(ExhaustiveOpTest, PowOneHalf) { + Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, + +[](float x) { return std::pow(x, 0.5f); }); +} + +XLA_TEST_P(ExhaustiveOpTest, Rsqrt) { + Run( + Rsqrt, +[](float x) { return 1 / std::sqrt(x); }); +} + +XLA_TEST_P(ExhaustiveOpTest, Sqrt) { + if (platform_ == "Host" || platform_ == "CUDA") { + strict_signed_zeros_ = true; + } + + Run(Sqrt, std::sqrt); +} + +// TODO(jlebar): Add remaining trig functions. Don't forget Atan2! +// TODO(jlebar): Test trig functions over complex inputs. +XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } + +XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } +XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); } +XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); } +XLA_TEST_P(ExhaustiveOpTest, Digamma) { + if (platform_ != "Host" && platform_ != "CUDA") { + // TODO(b/123956399): This is a fairly high error, significantly higher than + // we see on CPU/GPU. + rel_err_ = 0.01; + abs_err_ = 0.01; + } + + if (platform_ == "CUDA") { + // On GPU we get a wrong answer for the denormal inputs +/-2.93873588e-39 + // (0x00200000 and 0x80200000). These should return -/+inf (at least + // according to our reference implementation!) but XLA:GPU returns + // -/+3.40282326e+38 (0xff7ffffe and 0x7f7ffffe). + // + // I deem this an acceptable result, as XLA:GPU flushes denormals, and as + // the results we get here are very close to MAX_FLOAT. We just hardcode + // these results, as this is better than ignoring these inputs altogether. + auto host_digamma_with_gpu_ftz_errors = +[](float x) { + if (absl::bit_cast(x) == 0x00200000 || + absl::bit_cast(x) == 0x80200000) { + return std::copysign(std::numeric_limits::max(), -x); + } + return HostDigamma(x); + }; + Run(Digamma, host_digamma_with_gpu_ftz_errors); + } else { + Run(Digamma, HostDigamma); + } +} +XLA_TEST_P(ExhaustiveOpTest, Lgamma) { + // Our implementation gets within 0.0001 rel error except for ~20 denormal + // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma. + if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) { + rel_err_ = 0.001; + } + if (platform_ != "Host" && platform_ != "CUDA") { + // TODO(b/123956399): This is a fairly high error, significantly higher than + // we see on CPU/GPU. + rel_err_ = 0.01; + abs_err_ = 0.01; + + // Overflows for to inf for input 4.08500343e+36 (0x7c44af8e). + if (ty_ == F32) { + known_incorrect_begin_ = 0x7c44af8e; + known_incorrect_end_ = 0x7c44af8e + 1; + } + } + Run(Lgamma, std::lgamma); +} + +XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); } + +std::vector> CreateExhaustiveF32Ranges() { + // We break up the 2^32-element space into small'ish chunks to keep peak + // memory usage low. + std::vector> result; + const int64 step = 1 << 25; + for (int64 i = 0; i < (1l << 32); i += step) { + result.push_back({i, i + step}); + } + return result; +} + +INSTANTIATE_TEST_SUITE_P( + F32, ExhaustiveOpTest, + ::testing::Combine(::testing::Values(F32), + ::testing::ValuesIn(CreateExhaustiveF32Ranges()))); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P( + F16, ExhaustiveOpTest, + ::testing::Combine(::testing::Values(F16), + ::testing::Values(std::make_pair(0, 1 << 16)))); +#endif + +#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +INSTANTIATE_TEST_SUITE_P( + BF16, ExhaustiveOpTest, + ::testing::Combine(::testing::Values(BF16), + ::testing::Values(std::make_pair(0, 1 << 16)))); +#endif + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index 1b0bebe2d03a9a153cd0c80329ed0c49c91333a3..5d91326aad0671b484341db3aa6d618aa646e8c3 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -47,8 +47,9 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { } tensorflow::SubProcess file_check_process; - file_check_process.SetProgram(file_check_path, - {file_check_path, "-v", pattern_path}); + file_check_process.SetProgram( + file_check_path, + {file_check_path, "-v", "-dump-input=always", pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, diff --git a/tensorflow/compiler/xla/tests/fmax_fmin_test.cc b/tensorflow/compiler/xla/tests/fmax_fmin_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7423ac0bcdb0bc305ee384fb98bd17413404ecef --- /dev/null +++ b/tensorflow/compiler/xla/tests/fmax_fmin_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class FmaxSimpleTest : public ClientLibraryTestBase {}; + +TEST_F(FmaxSimpleTest, FmaxTenValues) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); + + std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(FmaxSimpleTest, FmaxEdgeCases) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + XlaOp param0, param1; + std::unique_ptr param0_data = CreateR1Parameter( + {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, + INFINITY, -INFINITY, NAN}, + /*parameter_number=*/0, /*name=*/"param0", + /*builder=*/&builder, /*data_handle=*/¶m0); + std::unique_ptr param1_data = CreateR1Parameter( + {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, + /*parameter_number=*/1, /*name=*/"param1", + /*builder=*/&builder, /*data_handle=*/¶m1); + + Max(param0, param1); + std::vector expected = {INFINITY, INFINITY, NAN, NAN, INFINITY, + -5, NAN, INFINITY, 8, NAN}; + ComputeAndCompareR1(&builder, expected, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001)); +} + +TEST_F(FmaxSimpleTest, FminEdgeCases) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + XlaOp param0, param1; + std::unique_ptr param0_data = CreateR1Parameter( + {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, + INFINITY, -INFINITY, NAN}, + /*parameter_number=*/0, /*name=*/"param0", + /*builder=*/&builder, /*data_handle=*/¶m0); + std::unique_ptr param1_data = CreateR1Parameter( + {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, + /*parameter_number=*/1, /*name=*/"param1", + /*builder=*/&builder, /*data_handle=*/¶m1); + + Min(param0, param1); + std::vector expected = {INFINITY, -INFINITY, NAN, NAN, -4, + -INFINITY, NAN, 7, -INFINITY, NAN}; + ComputeAndCompareR1(&builder, expected, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc deleted file mode 100644 index c5bbbe778df15d63a2586bd6291a7a33fc82aa52..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class FmaxSimpleTest : public ClientLibraryTestBase {}; - -TEST_F(FmaxSimpleTest, FmaxTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1( - &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = ConstantR1( - &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - Max(x, y); - - std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, 9.0}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2178c9b3f3d39ac034c59585c6836d2bc59162c1..c410f1f6d8fbcd3be228c23614d1d69888c781b8 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -63,7 +63,11 @@ const float test_float_vals[3][test_width][test_height] = { class FusionTest : public HloTestBase { protected: template - void TestElementwise2D(HloOpcode opcode) { + void TestElementwise2D( + HloOpcode opcode, + absl::optional direction = absl::nullopt) { + // Create a variable for comparisons since they require the direction. + bool is_compare = std::is_same::value; Array2D operand_data[Arity]; for (int i = 0; i < Arity; ++i) { new (&operand_data[i]) Array2D(test_width, test_height); @@ -76,12 +80,16 @@ class FusionTest : public HloTestBase { xs[k] = test_float_vals[k][i][j]; operand_data[k](i, j) = xs[k]; } - answer_data(i, j) = ComputeElementwiseAnswer(opcode, xs); + if (is_compare) { + answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs); + } else { + answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs); + } } } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -98,8 +106,13 @@ class FusionTest : public HloTestBase { root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); break; case 2: - root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], - hlos[2]); + if (is_compare) { + root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1], + hlos[2], *direction); + } else { + root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], + hlos[2]); + } break; case 3: root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], @@ -124,13 +137,19 @@ class FusionTest : public HloTestBase { } private: - template - T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); + float ComputeElementwiseAnswerFloat(HloOpcode opcode, + absl::Span xs); + bool ComputeElementwiseAnswerCompare(ComparisonDirection direction, + absl::Span xs); + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + return debug_options; + } }; -template <> -float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - absl::Span xs) { +float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode, + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -153,24 +172,21 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, } } -template <> -bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - absl::Span xs) { - switch (opcode) { - case HloOpcode::kEq: +bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction, + absl::Span xs) { + switch (direction) { + case ComparisonDirection::kEq: return xs[0] == xs[1]; - case HloOpcode::kNe: + case ComparisonDirection::kNe: return xs[0] != xs[1]; - case HloOpcode::kGt: + case ComparisonDirection::kGt: return xs[0] > xs[1]; - case HloOpcode::kLt: + case ComparisonDirection::kLt: return xs[0] < xs[1]; - case HloOpcode::kGe: + case ComparisonDirection::kGe: return xs[0] >= xs[1]; - case HloOpcode::kLe: + case ComparisonDirection::kLe: return xs[0] <= xs[1]; - default: - LOG(FATAL) << "No comparatory opcode: " << opcode; } } @@ -183,7 +199,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -231,7 +247,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -266,7 +282,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); // Build simple fusion computation: y = x^2 (elementwise). auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto two = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); @@ -290,7 +306,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -314,7 +330,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -329,7 +345,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -344,7 +360,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -359,7 +375,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -374,7 +390,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -389,7 +405,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -404,7 +420,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -419,7 +435,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -434,7 +450,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -449,7 +465,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -465,7 +481,7 @@ XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -483,7 +499,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -501,7 +517,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( @@ -519,7 +535,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( @@ -541,7 +557,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( @@ -559,7 +575,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -587,7 +603,7 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -606,8 +622,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = CreateNewUnverifiedModule(); +XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) { + auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -630,7 +646,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -682,7 +698,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // into a fusion, it should remain shared, rather than being duplicated // within the fusion. XLA_TEST_F(FusionTest, SharedConstant) { - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( @@ -740,64 +756,34 @@ XLA_TEST_F(FusionTest, Maximum2D) { TestElementwise2D(HloOpcode::kMaximum); } -XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D(HloOpcode::kEq); } +XLA_TEST_F(FusionTest, Equal2D) { + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kEq); +} XLA_TEST_F(FusionTest, Inequal2D) { - TestElementwise2D(HloOpcode::kNe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kNe); } XLA_TEST_F(FusionTest, Greater2D) { - TestElementwise2D(HloOpcode::kGt); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGt); } -XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D(HloOpcode::kLt); } +XLA_TEST_F(FusionTest, Lesser2D) { + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLt); +} XLA_TEST_F(FusionTest, GreaterOrEqual2D) { - TestElementwise2D(HloOpcode::kGe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGe); } XLA_TEST_F(FusionTest, LesserOrEqual2D) { - TestElementwise2D(HloOpcode::kLe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLe); } XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } -// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend -// should not generate layout changing elementwise operations. -#ifdef XLA_TEST_BACKEND_CPU -XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) { - const string hlo_text = R"( -HloModule Cluster - -fusion_c { - fusion.arg = f32[2,2]{1,0} parameter(0) - bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg) - tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0) - ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0) -} - -ENTRY main { - arg = f32[2,2]{1,0} parameter(0) - ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c -} -)"; - - Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN(Literal result, - test_runner_.Execute(std::move(module), {&operand}, - /*run_hlo_passes=*/false)); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - result)); -} -#endif - class FusionClientLibraryTest : public ClientLibraryTestBase {}; XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index d65b67a535d43553a3a94f76482ad4618f9b8aab..16a1371ec8da5377bb64a202988f118852aa12e2 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -598,6 +598,26 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, GatherFromScalar) { + const string hlo_text = R"( +HloModule GatherFromScalar + +ENTRY main { + operand = f32[] parameter(0) + indices = s32[0]{0} parameter(1) + ROOT gather = f32[] gather(operand, indices), + offset_dims={}, + collapsed_slice_dims={}, + start_index_map={}, + index_vector_dim=0, + slice_sizes={} +} +)"; + Literal operand = LiteralUtil::CreateR0(1); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; // Disabled on interpreter since ExectuteAsyncOnStream is not supported. diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 1115e50fe3120b7dbd891f07dedcacefa5ecf3ea..97b10083029c944d07422544277d8b52a2cea24b 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -73,7 +73,7 @@ half sign_imp(half value) { } half round_imp(half value) { - return half(round(static_cast(std::move(value)))); + return half(std::round(static_cast(std::move(value)))); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 66f72ba8d20b8ef1f436da4425b2bb6518ee9a94..a12fa04c01b2654c0970faff643d64c608dadf0d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -205,6 +205,32 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas, bool use_threads) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options, + use_threads); +} + +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas, DeviceAssignment* device_assignment, + bool run_hlo_passes, bool use_threads) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options, + device_assignment, use_threads); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { @@ -313,7 +339,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } -::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { +::testing::AssertionResult HloTestBase::Run(string_view hlo_string, + bool run_hlo_passes, + ExecutionProfile* profile, + string backend_config) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); if (!module_or_status.ok()) { @@ -321,19 +350,108 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( << "Error while parsing HLO text format: " << module_or_status.status().ToString(); } + + std::unique_ptr module = std::move(module_or_status.ValueOrDie()); const auto& fake_arguments = - MakeFakeArguments(module_or_status.ValueOrDie().get()) - .ConsumeValueOrDie(); + MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const Literal& literal) { return const_cast(&literal); }); - return test_runner_ - .Execute(std::move(module_or_status.ValueOrDie()), - fake_argument_ptrs, /*run_hlo_passes=*/true) - .ok() + + if (profile != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + module->set_config(config); + } + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + // return ::testing::AssertionSuccess(); + auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); + + return output.ok() ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure(); + : ::testing::AssertionFailure() << output.status().error_message(); +} + +::testing::AssertionResult HloTestBase::RunMultipleTimes( + string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, string backend_config) { + int n = profiles->size(); + std::vector> fake_argument_ptrs(n); + std::vector> fake_arguments(n); + std::vector> executables(n); + + for (int i = 0; i < n; ++i) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + std::unique_ptr module = + std::move(module_or_status.ValueOrDie()); + + fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie(); + absl::c_transform( + fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]), + [](const Literal& literal) { return const_cast(&literal); }); + + if (profiles != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + module->set_config(config); + } + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + auto executable = + test_runner_.CreateExecutable(std::move(module), run_hlo_passes); + if (!executable.ok()) { + return ::testing::AssertionFailure() + << executable.status().error_message(); + } + executables[i] = std::move(executable.ValueOrDie()); + } + + for (int i = 0; i < n; ++i) { + auto output = + test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i], + /*profile=*/&((*profiles)[i])); + if (!output.ok()) { + return ::testing::AssertionFailure() << output.status().error_message(); + } + } + + return ::testing::AssertionSuccess(); } ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 69a4f96288c7285010e9adbdc33f1b394f58d8d2..6c6fe34d1cc67b83efc8f9ea9b8e05e56d58fb8c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -173,6 +173,21 @@ class HloTestBase : public ::testing::Test { Literal ExecuteAndTransfer(std::unique_ptr module, absl::Span arguments); + // Executes the given module on multiple replicas. + // + // use_threads indicates whether this replicated computation will be executed + // with a thread-per-replica, vs using an implicitly async call such as + // Executable::ExecuteOnStreams. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas, bool use_threads); + + // Same as above, but uses specified device assignment. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas, DeviceAssignment* device_assignment, + bool run_hlo_passes, bool use_threads); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. @@ -221,8 +236,14 @@ class HloTestBase : public ::testing::Test { const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; - ::testing::AssertionResult Run(const absl::string_view hlo_string) - TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const absl::string_view hlo_string, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr, + string backend_config = "") TF_MUST_USE_RESULT; + ::testing::AssertionResult RunMultipleTimes( + const absl::string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, + string backend_config = "") TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 554eb24d44168caa7d7252015e3d99f2d567df9b..a2fd6070731943f15c773265f428b16f520d02ee 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error_spec, bool detailed_message) { + const ErrorSpec& error_spec, absl::optional detailed_message) { return StatusToAssertion(literal_comparison::Near( expected, actual, error_spec, detailed_message, &OnMiscompare)); } @@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( - expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); + expected, actual, *error, /*detailed_message=*/absl::nullopt, + &OnMiscompare)); } VLOG(1) << "Expects equal"; return StatusToAssertion(literal_comparison::Equal(expected, actual)); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 43cca91f64b2c0fbfde5054a361cf0f95302c23d..d7cf9bed98a3eb7479b6deb6838dc388a0869360 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -93,7 +93,7 @@ class LiteralTestUtil { static ::testing::AssertionResult Near( const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error_spec, - bool detailed_message = false) TF_MUST_USE_RESULT; + absl::optional detailed_message = absl::nullopt) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 60eb21aafd23a8d724d1f08d5c87098b7c3dcd6b..f1779c856bb4fdb1c4de453d9270a75dfcb8682b 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -69,12 +69,12 @@ int main(int argc, char** argv) { } else if (target_cpu == "arm") { triple_string = "aarch64-none-linux-gnu"; } else if (target_cpu == "local") { - triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple()); + triple_string = llvm::sys::getDefaultTargetTriple(); } else { LOG(FATAL) << "unsupported TARGET_CPU: " << target_cpu; } - llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); + llvm::Triple triple(triple_string); xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); xla::CompileOnlyClient::AotXlaComputationInstance instance{ diff --git a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1513d89ba9c95b3097229b268d22832dee3e98cd --- /dev/null +++ b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class MultiDeviceAllReduceTest : public HloTestBase {}; + +XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] all-reduce(p), to_apply=add + })"; + auto config = GetModuleConfigForTest(); + config.set_replica_count(2); + auto module = ParseHloString(module_str, config).ValueOrDie(); + auto literal = LiteralUtil::CreateR1({1, 2, 3}); + auto expected = LiteralUtil::CreateR1({2, 4, 6}); + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {&literal}, 2, + /*use_threads=*/true)); + EXPECT_EQ(expected, results[0]); + EXPECT_EQ(expected, results[1]); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 3f5135438fc59bea98527b1be30ee49339edd455..73c9d7ed4b09c2f78dc7226b5d755eb017ea664b 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), @@ -229,7 +227,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { fused_computation { p = f32[4] parameter(0) multiply = f32[4] multiply(p, p) - less-than = pred[4] less-than(p, multiply) + less-than = pred[4] compare(p, multiply), direction=LT ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) } @@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const = f32[4] constant({0, 0, 0, 0}) ROOT select = f32[4] select(gte0, gte1, const) })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); @@ -256,7 +252,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { fused_computation { p = f32[] parameter(0) multiply = f32[] multiply(p, p) - less-than = pred[] less-than(p, multiply) + less-than = pred[] compare(p, multiply), direction=LT ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) } @@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { p1 = f32[3] parameter(0) ROOT map = f32[3] map(p1), to_apply=map_computation })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); @@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); @@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); diff --git a/tensorflow/compiler/xla/tests/plugin.bzl b/tensorflow/compiler/xla/tests/plugin.bzl index 8a5d91363b619c6b214a96ad96e92742e3052541..107869fe59d43d0a9a3e2b14af2c09e4906d9f15 100644 --- a/tensorflow/compiler/xla/tests/plugin.bzl +++ b/tensorflow/compiler/xla/tests/plugin.bzl @@ -33,4 +33,3 @@ # } plugins = {} - diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index f80d29b9de440b11c36e8c9bc65d4a93353a6267..e2cf4c0be289b52d5cc581ea07752ed6e98da76f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 18c99490a387923aaf68e06041cd11ed3b954aa5..6d2c2fc79cec82a0e37c772657d94fb7e2a27866 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -455,7 +456,7 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { for (int64 colno = 0; colno < cols; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += log(input_data(rowno, colno)); + column_sum += std::log(input_data(rowno, colno)); } expected.push_back(column_sum); } @@ -486,7 +487,7 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { for (int64 colno = 0; colno < cols; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += log(input_data(rowno, colno)); + column_sum += std::log(input_data(rowno, colno)); } expected.push_back(column_sum); } @@ -533,7 +534,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { for (int64 colno = 0; colno < cols / 2; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += tanh(input_data(rowno, major, colno)); + column_sum += std::tanh(input_data(rowno, major, colno)); } expected.push_back(column_sum); } diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 16c67d94c76bcf8984a2b3e4cb092026a6924aeb..352b59f248b86234c46a9b89e17a68e347d7af6c 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -611,6 +611,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, // values. (Technically, the requirement is that the iota length is // relatively prime to all of the dimensions involved in the reduce-window.) input.FillRepeatedIota(0, 137); + // Floating point sum reduction requires higher localized precision. We need + // the following normalization in order to enable testing of kAdd on large + // windows. + input.Each([&](absl::Span /*indices*/, float* value) { + *value = *value / 10000000000.f; + }); Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -626,12 +632,6 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; - if (use_bfloat16()) { - // To avoid numerical issues, force the reducer to be kMax for bf16 - // inputs. - reducer = kMax; - } - auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); @@ -697,15 +697,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - // With non-1x1 window. - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kAdd}, - // With max instead of add. R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, /*window_bounds=*/{2, 3, 1, 1}, @@ -778,15 +769,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - // With second minor dimension == 9. - R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127}, - /*window_bounds=*/{1, 1, 1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kAdd}, - // With minor dimension == 129. R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129}, /*window_bounds=*/{1, 1, 1, 1}, @@ -814,7 +796,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3}, /*window_bounds=*/{1, 64, 64, 1}, /*strides=*/{1, 64, 64, 1}, /*pad_low=*/{0, 0, 0, 0}, @@ -828,6 +810,32 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kMax}, + + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + // With 0321 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 2, 3, 4}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 3, 2, 1}, + /*reducer=*/kAdd}, + + // With 0123 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17}, + /*window_bounds=*/{2, 3, 7, 9}, + /*strides=*/{1, 2, 5, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 1, 2, 3}, /*reducer=*/kAdd}, }; @@ -866,58 +874,60 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*pad_high=*/{0, 0, 2, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kMax}, -}; - -INSTANTIATE_TEST_CASE_P( - R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, - ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - ::testing::ValuesIn(use_bfloat16_params)), - R4ReduceWindowTestDataToString); -class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {}; + // Patterns generated by cumsum/cumprod. + R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, + /*window_bounds=*/{1021, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1020, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, -// TODO(b/72234705): Fix the test cases failed on CPU and GPU. -XLA_TEST_P(R4ReduceWindowAnyDimsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { - DoIt(); -} + R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, + /*window_bounds=*/{1, 1, 1021, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 1020, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, -const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = { - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 4, 5}, + R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021}, + /*window_bounds=*/{1, 1, 1, 1021}, /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, + /*pad_low=*/{0, 0, 0, 1020}, /*pad_high=*/{0, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 1, 1}, + + R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, + /*window_bounds=*/{1021, 1, 1, 1}, /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, + /*pad_low=*/{1021, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kMax}, - // With 0321 layout. - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 4, 5}, - /*strides=*/{1, 2, 3, 4}, - /*pad_low=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16}, + /*window_bounds=*/{1, 1, 1021, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 1021, 0}, /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{0, 3, 2, 1}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - // With 0123 layout. - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23}, - /*window_bounds=*/{2, 3, 7, 9}, - /*strides=*/{1, 2, 5, 8}, - /*pad_low=*/{0, 0, 0, 0}, + R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021}, + /*window_bounds=*/{1, 1, 1, 1021}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 1021}, /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{0, 1, 2, 3}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( - R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest, - ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues), + R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), ::testing::ValuesIn(use_bfloat16_params)), R4ReduceWindowTestDataToString); @@ -1113,6 +1123,11 @@ struct R2ReduceWindowTestData { {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, + // Regression test for b/72234705: bf16 lacks precision to store incremental + // results on very large windows. Using smaller window with minor dim 128. + {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, }; string R2ReduceWindowTestDataToString( @@ -1191,27 +1206,6 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(use_bfloat16_params)), R2ReduceWindowTestDataToString); -class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; - -// TODO(b/72234705): Fix the test cases failed on CPU and GPU. -XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, - DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { - DoIt(); -} - -const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { - {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, - /*layout=*/{1, 0}, - /*reducer=*/Reducer::kAdd}, -}; - -INSTANTIATE_TEST_CASE_P( - R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test, - ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test), - ::testing::ValuesIn(use_bfloat16_params)), - R2ReduceWindowTestDataToString); - struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; @@ -1321,9 +1315,9 @@ struct R1ReduceWindowTestData { /*reducer=*/Reducer::kMax}, // The pattern generated by exclusive scan (cumsum/cumprod). - {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + {/*base_bounds=*/{4095}, /*window_bounds=*/{4095}, /*strides=*/{1}, - /*pad_low=*/{4096}, + /*pad_low=*/{4095}, /*pad_high=*/{0}, /*reducer=*/Reducer::kMax}, }; @@ -1532,6 +1526,25 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowS64) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] { + %param0 = s64[] parameter(0) + ROOT %param1 = s64[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] { + %parameter.0 = s64[81,8]{1,0} parameter(0) + %parameter.1 = s64[] parameter(1) + ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); +} + XLA_TEST_F(HloTestBase, ReduceWindowF16) { const string hlo_string = R"( HloModule reduce-window diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 32de0fdf78f9c442e17c55e1b951e39122dac5ef..86d9999b4a4844b51db7fb56529ea4580797b809 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -718,5 +718,32 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, ScatterIntoScalar) { + const char* hlo_text = R"( +HloModule ScatterIntoScalar + +update_s32 { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + parameter.1 = s32[] parameter(0) + parameter.2 = s32[0]{0} parameter(1) + parameter.3 = s32[] parameter(2) + ROOT scatter = s32[] scatter(parameter.1, parameter.2, parameter.3), + update_window_dims={}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={}, + index_vector_dim=0, + to_apply=update_s32 +} +)"; + Literal operand = LiteralUtil::CreateR0(1); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR0(2); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index f737b5158b3622d677aea5bf64a421a56e2c42dd..0dcb1c42db1b0884d80adeaf88d55eeda79fb7d8 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -146,6 +146,12 @@ INSTANTIATE_TEST_CASE_P( Padding::kValid, {3, 3, 1, 1}, {3, 3, 1, 1}}, + // Uncovered by b/126212776. + SelectAndScatterTestParam{{15, 1, 1, 1}, + {2, 1, 1, 1}, + Padding::kValid, + {14, 1, 1, 1}, + {1, 1, 1, 1}}, SelectAndScatterTestParam{{7, 3, 4, 4}, {3, 1, 4, 4}, Padding::kValid, diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 95c89b0ba6f29c453abab88e29bca13ee006455a..4ac3dbd80cfaf2340d8f79cef3e9e02058cf919c 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -112,6 +112,31 @@ void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, } } +template +void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine, + bool no_duplicates) { + using InnerFloatT = typename ComplexT::value_type; + CHECK(engine != nullptr); + CHECK_EQ(result->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + Shape floating_point_shape = ShapeUtil::ChangeElementType( + result->shape(), primitive_util::NativeToPrimitiveType()); + Literal real_lit(floating_point_shape); + Literal imaginary_lit(floating_point_shape); + + PopulateWithFloatingPointData(&real_lit, engine, no_duplicates); + PopulateWithFloatingPointData(&imaginary_lit, engine, + no_duplicates); + + absl::Span real_data = real_lit.data(); + absl::Span imaginary_data = + imaginary_lit.data(); + absl::Span result_data = result->data(); + for (int i = 0; i < real_lit.data().size(); i++) { + result_data[i] = ComplexT(real_data[i], imaginary_data[i]); + } +} + template <> void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, @@ -220,6 +245,12 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, case U64: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; + case C64: + PopulateWithComplexData(&literal, engine, no_duplicates); + break; + case C128: + PopulateWithComplexData(&literal, engine, no_duplicates); + break; case PRED: { std::uniform_int_distribution generator(0, 1); TF_CHECK_OK( @@ -238,6 +269,79 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, return std::move(literal); } +template +void PopulateWithRandomIntegralDataWithBounds(Literal* literal, + std::minstd_rand0* engine, + IntT min, IntT max) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::uniform_int_distribution generator(min, max); + for (IntT& value : literal->data()) { + value = generator(*engine); + } +} + +// Same as MakeFakeLiteralInternal but generates random numbers in the given +// range [min, max]. Currently this works only for INT types. +StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, + std::minstd_rand0* engine, + int64 min, int64 max) { + if (shape.IsTuple()) { + std::vector elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + Literal element, + MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max)); + elements.push_back(std::move(element)); + } + return LiteralUtil::MakeTupleOwned(std::move(elements)); + } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + Literal literal(shape); + switch (shape.element_type()) { + case S8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + default: + return Unimplemented( + "Unsupported type for fake random literal generation with bounds: %s", + ShapeUtil::HumanString(shape)); + } + return std::move(literal); +} + enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. @@ -297,6 +401,10 @@ std::vector FindConstrainedUses( if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); + } else if ((opcode == HloOpcode::kGather || + opcode == HloOpcode::kScatter) && + op_num == 1) { + constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = instruction->fused_parameter(op_num); @@ -356,6 +464,22 @@ StatusOr CreateLiteralForConstrainedUses( } break; } + case HloOpcode::kGather: + case HloOpcode::kScatter: { + const Shape& operand_shape = use->operand(0)->shape(); + if (use->operand(1) == ¶m) { + auto index_map = + use->opcode() == HloOpcode::kGather + ? use->gather_dimension_numbers().start_index_map() + : use->scatter_dimension_numbers() + .scatter_dims_to_operand_dims(); + for (const auto dim_in_operand : index_map) { + index_bound = + std::min(index_bound, operand_shape.dimensions(dim_in_operand)); + } + } + break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = true; @@ -385,8 +509,8 @@ StatusOr CreateLiteralForConstrainedUses( return Unimplemented("Conflicting operand generation constraints."); } if (index_bound != INT64_MAX) { - return MakeRandomIndex(index_bound, engine) - .Reshape(param.shape().dimensions()); + return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1, + index_bound); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 591d6c19228a313f530cdae18f4be37e7b517601..4337aa4bf9a071ecad268e1c4233fa705c6ed491 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -92,12 +92,13 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - EXPECT_EQ(args[0].Get({}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(args[1].Get({}), 0); - EXPECT_LE(args[0].Get({}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(args[2].Get({}), 0); + EXPECT_GE(args[2].Get({}), -1); EXPECT_LE(args[2].Get({}), 3); } @@ -122,12 +123,13 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 7); - EXPECT_EQ(args[0].Get({}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(args[1].Get({}), 0); - EXPECT_LE(args[0].Get({}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(args[2].Get({}), 0); + EXPECT_GE(args[2].Get({}), -1); EXPECT_LE(args[2].Get({}), 3); } @@ -136,10 +138,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { auto module = ParseHloString(R"( HloModule sort.148.1589 +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { %parameter.0 = f32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) - ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} + ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare } )") .ValueOrDie(); @@ -159,10 +169,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { auto module = ParseHloString(R"( HloModule sort.148.1589 +compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { %parameter.0 = s32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) - ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} + ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare } )") .ValueOrDie(); @@ -182,10 +200,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { auto module = ParseHloString(R"( HloModule sort, is_scheduled=true +compare { + p.0.lhs = bf16[] parameter(0) + p.0.rhs = bf16[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { %parameter.0 = bf16[2,1452]{1,0} parameter(0) %parameter.1 = s32[2,1452]{1,0} parameter(1) - ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1} + ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1}, to_apply=compare } )") .ValueOrDie(); @@ -228,5 +254,77 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { << ShapeUtil::HumanString(args[1].shape()); } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { + auto module = ParseHloString(R"( + HloModule Test + +ENTRY %module(paramater.0: f32[200,100,300], parameter.1: s32[10,2]) -> + f32[10,300] { + %parameter.0 = f32[200,100,300] parameter(0) + %parameter.1 = s32[10,2] parameter(1) + ROOT gather = f32[10,300] gather(f32[200,100,300] %parameter.0, + s32[10,2] %parameter.1), + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, + index_vector_dim=1, + slice_sizes={1,1,300} +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) { + auto module = ParseHloString(R"( + HloModule Test + +scatter_update (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + ROOT rhs = f32[] parameter(1) +} + +ENTRY main { + operand = f32[200,100,300] parameter(0) + indices = s32[10,2] parameter(1) + updates = f32[10,300] parameter(2) + ROOT scatter = f32[200,100,300] scatter(operand, indices, updates), + to_apply=scatter_update, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 + } +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b77cf38ed8e29973985406015c0a3936916ad5e6..38a2a9b8fba280ed17f6c26688fba94b4ec6ff0c 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -129,7 +129,7 @@ HloModule TokenInWhileLoop %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %TokenInWhileLoop () -> s32[] { diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/tests/triangular_solve_test.cc similarity index 77% rename from tensorflow/compiler/xla/client/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/tests/triangular_solve_test.cc index 284a2e9d183a6a7923fb59ac134ce3b3a3a96e35..24ab12136ff396bd9ac37bb058311b0d2d6f2515 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/tests/triangular_solve_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" - #include #include #include @@ -54,6 +52,20 @@ Array2D AValsUpper() { {kNan, kNan, kNan, 11}}; } +Array2D AValsLowerUnitDiagonal() { + return {{kNan, kNan, kNan, kNan}, + {3, kNan, kNan, kNan}, + {4, 7, kNan, kNan}, + {5, 8, 10, kNan}}; +} + +Array2D AValsUpperUnitDiagonal() { + return {{kNan, 3, 4, 5}, + {kNan, kNan, 7, 8}, + {kNan, kNan, kNan, 10}, + {kNan, kNan, kNan, kNan}}; +} + Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } @@ -96,8 +108,8 @@ XLA_TEST_F(TriangularSolveTest, EmptyArrays) { CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); ComputeAndCompareR2(&builder, Array2D(0, 10), {a_data.get(), b_data.get()}); @@ -111,8 +123,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -132,8 +144,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -153,8 +165,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -174,8 +186,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -195,8 +207,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -217,8 +229,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {0.5, 1.0, 1.5}, @@ -231,6 +243,25 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsLowerUnitDiagonal(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/true, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected( + {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { XlaBuilder builder(TestName()); @@ -239,8 +270,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/3); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {0.5, 1.0, 1.5}, @@ -261,8 +292,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {0.5, 1.0, 1.5}, @@ -283,8 +314,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -297,6 +328,27 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsUpperUnitDiagonal(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/true, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({{-1402., -1538., -1674.}, + {575., 631., 687.}, + {-93., -102., -111.}, + {10., 11., 12.}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { XlaBuilder builder(TestName()); @@ -307,8 +359,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::ADJOINT); Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -333,8 +385,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {0.5, 1., 1.5}, @@ -368,11 +420,12 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { XlaOp a, b; auto a_data = CreateR3Parameter(avals, 0, "a", &builder, &a); auto b_data = CreateR3Parameter(bvals, 1, "b", &builder, &b); - BatchDot(ConstantR3FromArray3D(&builder, avals), - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2)); + BatchDot( + ConstantR3FromArray3D(&builder, avals), + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE)); ComputeAndCompareR3(&builder, bvals, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); @@ -382,7 +435,7 @@ struct TriangularSolveTestSpec { int m, n; // A is mxm, B is mxn bool left_side; bool lower; - bool transpose_a; + TriangularSolveOptions::Transpose transpose_a; }; class TriangularSolveParametricTest @@ -408,11 +461,11 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { XlaOp a, b; auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); - auto x = TriangularSolve(a, b, spec.left_side, spec.lower, spec.transpose_a, - /*conjugate_a=*/false, - /*block_size=*/3); + auto x = TriangularSolve(a, b, spec.left_side, spec.lower, + /*unit_diagonal=*/false, spec.transpose_a); auto a_tri = Triangle(a, spec.lower); - a_tri = MaybeTransposeInMinorDims(a_tri, spec.transpose_a); + a_tri = MaybeTransposeInMinorDims( + a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE); if (spec.left_side) { BatchDot(a_tri, x); } else { @@ -429,7 +482,9 @@ std::vector TriangularSolveTests() { for (int n : {5, 10}) { for (bool left_side : {false, true}) { for (bool lower : {false, true}) { - for (bool transpose_a : {false, true}) { + for (TriangularSolveOptions::Transpose transpose_a : + {TriangularSolveOptions::NO_TRANSPOSE, + TriangularSolveOptions::TRANSPOSE}) { specs.push_back({m, n, left_side, lower, transpose_a}); } } diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 0def0577e09609ecb82a635ba2e1681a1cc27b1c..cdf2c34fcc3cc005e84626c39c8ab301a9040529 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, BitcastAfterGTE) { ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -559,9 +557,7 @@ XLA_TEST_F(TupleHloTest, ROOT outfeed = token[] outfeed(tuple, token0) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param0 = LiteralUtil::CreateR1({1, 2}); auto param1 = LiteralUtil::CreateR1({2, 3}); auto param4 = LiteralUtil::CreateR0(false); diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 4fbd7f2fb174ac899c1e3b23801986cb52db96a2..c51f30f3b5db95962a719ec226dd03f41142a782 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -64,7 +64,9 @@ class UnaryOpTest : public ClientLibraryTestBase { &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); Sign(arg); - ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); + ComputeAndCompareR1( + &builder, + {-1, 1, static_cast(+0.0), static_cast(-0.0), -1, 1, -1}, {}); } template diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index c7337e8caae8f2ee25f4b25dc22439e08d2ecc25..7b7b8f5d02dc99607b30f898e18c5b448d421e07 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -40,8 +40,6 @@ limitations under the License. namespace xla { namespace { -namespace gtl = ::tensorflow::gtl; - class HloProfileTest : public ClientLibraryTestBase {}; struct ParsedProfileOutputLine { diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 7289ae7df65e56652eeeb67e536e4c721d97d999..fc7949d889dc8ed9fac425982cc555a6c42a7f1d 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 4412a6ec69ca3bb98f7b67e68802b262deaf9b8c..e9244ecf9f149ed439dc9beeb56a9442ebad6821 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -177,26 +177,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "dumped_computation_to_tf_graphdef", - srcs = ["dumped_computation_to_tf_graphdef.cc"], - deps = [ - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_binary( name = "hlo_proto_to_json", srcs = ["hlo_proto_to_json.cc"], @@ -251,7 +231,14 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) + +sh_test( + name = "interactive_graphviz_build_only_test", + srcs = ["interactive_graphviz_test.sh"], + data = [":interactive_graphviz"], +) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4375e7c138c9e8d193feaa7a39d63946c4ea3086..df2d3d18b9ff86c0dd2047c2415527aeb1c1f154 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 723569862c7550387e95003e3a673743464b67b8..35bb82ca22f46d2cdeaac3b9a87b253efe9a07d9 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc deleted file mode 100644 index f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* -// -// Dumps a tensorflow GraphDef in text format for a snapshot computation. The -// dumped graph is an HLO computation with HLO instructions as nodes and can be -// visualized on Tensorboard. Upload the dumped files on Tensorboard. -// -// some_binary_snapshot_proto is obtained by serializing the SessionModule from -// ServiceInterface::SnapshotComputation to disk. - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -using tensorflow::Env; - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - debug_options.set_xla_hlo_dump_as_graphdef(true); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc index c187222a11ee721b006194a68620c58749707193..4beb099b330cadf4540944979f38681bae07103c 100644 --- a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc +++ b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc @@ -36,9 +36,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); { auto extracted_module = @@ -75,9 +74,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); { auto extracted_module = @@ -120,9 +118,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); { auto extracted_module = diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc index 8460ae3e4991ee091af72d2553a8491f627c722e..88f3a8bdde244bb16d54d13f1022e9b4be1ef893 100644 --- a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -19,7 +19,9 @@ limitations under the License. // // Reads one serilized Hlo module, convert it into JSON format and dump into // some output directory. some_binaray_proto is obtained by serializing Hlo -// module to disk using --xla_dump_optimized_hlo_proto_to debug option. +// module to disk using the debug options +// +// --xla_dump_to=DIR --xla_dump_hlo_as_proto #include #include diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc index 6c90cde5a75a93837ee149fd9b5a60e6413c2ac4..5652d303f0256713a05331cfbc1a8c4f0009c3fe 100644 --- a/tensorflow/compiler/xla/tools/interactive_graphviz.cc +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,8 +29,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "absl/strings/string_view_utils.h" -#include "absl/strings/util.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -39,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/tools/hlo_extractor.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/subprocess.h" @@ -56,7 +57,8 @@ bool ReadLine(const char *prompt, string *line) { return util::ReadLine(prompt, line); #else std::cout << prompt; - return std::getline(std::cin, *line); + std::getline(std::cin, *line); + return std::cin.good(); #endif } @@ -139,9 +141,10 @@ HloComputation* FindComputation(const HloModule& module, // Print a help message describing the various available commands. void DoHelpCommand() { std::cout << R"(Commands: - [] - Renders a neighborhood of nodes around . If - is not provided, the default value is )" + [] [/ +] + Renders a neighborhood of nodes around , without going + beyond the optional boundary instructions. If is not provided, + the default value is )" << kDefaultWidth << R"(. allpaths [] Renders a subset of all paths from one instruction to the other. Either @@ -387,22 +390,18 @@ bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) { return false; } -void DisplayGraphHandle(const Options &opts, const string& handle) { - std::cout << handle << std::endl; +void OpenUrl(const Options& opts, absl::string_view url) { + std::cout << url << std::endl; // If it is a url, try to open it up in the user's browser too. - if (strings::StartsWithIgnoreCase(handle, "http://") || - strings::StartsWithIgnoreCase(handle, "https://") || - strings::StartsWithIgnoreCase(handle, "file://")) { + if (absl::StartsWithIgnoreCase(url, "http://") || + absl::StartsWithIgnoreCase(url, "https://") || + absl::StartsWithIgnoreCase(url, "file://")) { const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" : opts.browser.c_str(); tensorflow::SubProcess p; - p.SetProgram(browser_bin, {browser_bin, handle}); + p.SetProgram(browser_bin, {browser_bin, string(url)}); p.Start(); - } else if (handle.empty()) { - std::cerr << "Unable to render graph, perhaps due to graphviz server " - "timeout. Run with --logtostderr to see." - << std::endl; } else { std::cerr << "\nExpected a URL, but got strange graph result (dumped " "above). If this isn't what you expected, maybe file a bug?" @@ -410,6 +409,65 @@ void DisplayGraphHandle(const Options &opts, const string& handle) { } } +// Renders a graph by calling `renderer`, and then tries to open it. +// +// `renderer` is a callback so we can try various formats. In particular, the +// URL format doesn't work out of the box; it requires you to register a plugin. +void RenderAndDisplayGraph( + const Options& opts, + const std::function(RenderedGraphFormat)>& renderer) { + StatusOr url_result = renderer(RenderedGraphFormat::kUrl); + if (url_result.ok()) { + string url = url_result.ValueOrDie(); + OpenUrl(opts, url); + return; + } + + // Ignore UNAVAILABLE errors; these are expected when there's no URL renderer + // plugin registered. + if (url_result.status().code() != tensorflow::error::UNAVAILABLE) { + std::cerr << "Unable to render graph as URL: " << url_result.status() + << std::endl; + std::cerr << "Trying as HTML..." << std::endl; + } + + auto* env = tensorflow::Env::Default(); + StatusOr html_result = renderer(RenderedGraphFormat::kHtml); + if (!html_result.ok()) { + std::cerr << "Failed to render graph as HTML: " << html_result.status() + << std::endl; + return; + } + + std::vector temp_dirs; + env->GetLocalTempDirectories(&temp_dirs); + if (temp_dirs.empty()) { + std::cerr << "Can't render graph as HTML because we can't find a suitable " + "temp directory. Try setting $TMPDIR?" + << std::endl; + return; + } + + // Try to create a unique file inside of temp_dirs.front(). Notably, this + // file's name must end with ".html", otherwise web browsers will treat it as + // plain text, so we can't use Env::CreateUniqueFileName(). + string temp_file_path = tensorflow::io::JoinPath( + temp_dirs.front(), + absl::StrFormat("interactive_graphviz.%d.html", env->NowMicros())); + auto status = tensorflow::WriteStringToFile( + env, temp_file_path, std::move(html_result).ValueOrDie()); + if (status.ok()) { + OpenUrl(opts, absl::StrCat("file://", temp_file_path)); + return; + } + + std::cerr << "Failed to write rendered HTML graph to " << temp_file_path + << ": " << status; + + // We don't bother trying kDot, because kHTML should always work (or if it + // doesn't, we don't have any reason to believe kDot will work better). +} + void DoAllPathsCommand(const Options& opts, const HloModule& module, const std::vector& tokens) { if (tokens.size() > 4) { @@ -450,19 +508,15 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module, std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2]; return; } - DisplayGraphHandle(opts, hlo_graph_dumper::DumpAllPathsFromTo( - *from, *to, max_nodes, /*show_backend_config=*/show_backend_config)); + RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { + return RenderAllPathsFromTo(*from, *to, max_nodes, format, + /*show_backend_config=*/show_backend_config); + }); } // Plot a given instruction neighborhood or computation with graphviz. void DoPlotCommand(const Options& opts, const HloModule& module, const std::vector& tokens) { - if (tokens.size() > 2) { - std::cerr << R"(Illegal input. Enter e.g. "%fusion.1 42" or "%fusion.1".)" - << std::endl; - return; - } - string node_name = tokens[0]; // Find the node with the given name. @@ -475,28 +529,62 @@ void DoPlotCommand(const Options& opts, const HloModule& module, } uint64 graph_width = kDefaultWidth; - if (tokens.size() == 2) { + absl::flat_hash_set boundary; + if (tokens.size() >= 2) { if (comp) { std::cerr << "Can only use graph-size parameter with instructions, but " << node_name << " is a computation." << std::endl; return; } - if (!absl::SimpleAtoi(tokens[1], &graph_width)) { - std::cerr << "Can't parse '" << tokens[1] << "' as an integer." - << std::endl; - return; + + int bound_index = 1; + // Get the if present. + if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) { + bound_index++; + } else { + // not found, need to reset graph_width. + graph_width = kDefaultWidth; + } + // Get the '/'. + if (bound_index < tokens.size()) { + // This token must be a '/'. + if (tokens[bound_index] != "/") { + std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'." + << std::endl; + return; + } + bound_index++; + } + // Get the boundary nodes. + while (bound_index < tokens.size()) { + string bnode_name = tokens[bound_index]; + const HloInstruction* binstr = FindInstruction(module, bnode_name); + if (!binstr) { + std::cerr << "Couldn't find HloInstruction named " << bnode_name << "." + << std::endl; + return; + } + boundary.insert(binstr); + bound_index++; } } // Generate the graph and print the resulting string, which should be a // graphviz url. if (comp) { - DisplayGraphHandle(opts, hlo_graph_dumper::DumpGraph( - *comp, "", comp->parent()->config().debug_options(), nullptr, - /*show_backend_config=*/show_backend_config)); + RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { + return RenderGraph(*comp, /*label=*/"", + comp->parent()->config().debug_options(), format, + /*hlo_execution_profile=*/nullptr, + /*show_backend_config=*/show_backend_config); + }); } else { - DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( - *instr, graph_width, /*show_backend_config=*/show_backend_config)); + RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { + return RenderNeighborhoodAround( + *instr, graph_width, format, + /*show_backend_config=*/show_backend_config, + /*boundary=*/boundary); + }); } } @@ -515,7 +603,7 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { << std::endl; continue; } - std::vector tokens = strings::Split(line, ' '); + std::vector tokens = absl::StrSplit(line, ' ', absl::SkipEmpty()); if (tokens[0] == "quit" || tokens[0] == "exit") { break; } else if (tokens[0] == "help") { diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b3e43aa7da062547fb5f187b885e997fc44bbb65 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh @@ -0,0 +1,19 @@ +#! /bin/bash +# /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==============================================================================*/ + +# This is a placeholder for a compile-only test for intractive_graphviz tool. + +exit 0 diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index c01a47b510c0e4252e350960b995643b39b70d4a..d66561315b4ad7a5e3f1f7b1bc1e557b71da6705 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -90,8 +90,8 @@ struct Options { int num_runs = 1; }; -std::unique_ptr CompileExecutable(const HloSnapshot& module, - LocalClient* client) { +StatusOr> CompileExecutable( + const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); std::vector argument_layouts; argument_layouts.reserve( @@ -102,9 +102,9 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, argument_layouts.push_back(Shape(param)); argument_layout_ptrs.push_back(&argument_layouts.back()); } - return client - ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) - .ValueOrDie(); + ExecutableBuildOptions exec_build_options; + *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags(); + return client->Compile(computation, argument_layout_ptrs, exec_build_options); } absl::optional GetXfeedShape(bool is_infeed, @@ -329,7 +329,10 @@ StatusOr ParseInputFile(const string& filename, fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); string contents; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); - StatusOr> module = ParseHloString(contents); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsFromFlags()); + StatusOr> module = + ParseHloString(contents, config); if (module.ok()) { *snapshot.mutable_hlo()->mutable_hlo_module() = module.ValueOrDie()->ToProto(); @@ -357,7 +360,7 @@ int RealMain(absl::Span args, const Options& opts) { // Compile all the modules in parallel. LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; - std::vector> executables; + std::vector>> executables; { // ThreadPool CHECK-fails if we give it 0 threads. tensorflow::thread::ThreadPool thread_pool( @@ -374,7 +377,12 @@ int RealMain(absl::Span args, const Options& opts) { LOG(INFO) << "Done compiling; now running the modules."; for (int64 i = 0; i < executables.size(); ++i) { - LocalExecutable* executable = executables[i].get(); + if (!executables[i].ok()) { + LOG(ERROR) << "Compilation failed: " << executables[i].status(); + exit_status = EXIT_FAILURE; + continue; + } + LocalExecutable* executable = executables[i].ValueOrDie().get(); LOG(ERROR) << "Running iteration " << i; StatusOr result_status = ReplayComputation(snapshots[i], executable, client, opts); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index cdf306dfd1027cf6022c5d8ae844b4308f580e8d..b80d0db8d812380d8144713109d1c05168713c77 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 34b73b5206fa20d6dff7567afd78fd89897c8c33..bb8bbf57c4252b16836553334901a3c896a17f39 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -80,13 +81,9 @@ bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } - std::vector output(permutation.size(), -1); - for (auto index : permutation) { - CHECK_GE(index, 0); - CHECK_LT(index, rank); - output[index] = 0; - } - return !absl::c_linear_search(output, -1); + absl::InlinedVector trivial_permutation(rank); + absl::c_iota(trivial_permutation, 0); + return absl::c_is_permutation(permutation, trivial_permutation); } std::vector InversePermutation( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index f2fd17dc99455a921bf875aad2a3661b4d456823..1754ae0e44f3420bf7eb7cfb3b558dd476b31455 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -260,6 +260,16 @@ Status Unavailable(const absl::FormatSpec& format, return WithLogBacktrace( tensorflow::errors::Unavailable(absl::StrFormat(format, args...))); } +template +Status Unknown(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unknown(absl::StrFormat(format, args...))); +} +template +Status Internal(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Internal(absl::StrFormat(format, args...))); +} template Status InvalidArgumentStrCat(Args&&... concat) { diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index e001cc35f9fcea2783b3952e825838af6bbece72..f2e183110393c359c421031417117b79976bdab4 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -204,6 +204,14 @@ bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { window_dim.padding_low() == 0 && window_dim.padding_high() == 0; } +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} + int64 DilatedBound(int64 bound, int64 dilation) { CHECK_GE(bound, 0); CHECK_GE(dilation, 1); diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 099d7ecdd5c732ffc8c6ff6370288a2fc4144fa2..e7099285c340523c7d4e6240c7b039fd39443100 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -62,6 +62,10 @@ bool AllOrNoneReversed(const Window& window); // has window bound 1, no striding and no padding. bool IsInactiveWindowDimension(const Window& window, int64 logical_dim); +// Returns true if the provided window dimension is trivial (inactive and has no +// dilation) +bool IsTrivialWindowDimension(const WindowDimension& window_dimension); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 1439f1bcc5cec39203a7cb4b1f8604e7349382c6..cda2d7c7c6b2403868f6d01a485753fa29a8d95f 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -1,30 +1,47 @@ """Wrapper around cc_proto_library used inside the XLA codebase.""" -load("//tensorflow/core:platform/default/build_config.bzl", - "cc_proto_library") -load("//tensorflow/core:platform/default/build_config_root.bzl", - "if_static") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "cc_proto_library", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", +) +load("//tensorflow:tensorflow.bzl", "if_cuda_is_configured") # xla_proto_library() is a convenience wrapper around cc_proto_library. -def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs): - if kwargs.get('use_grpc_plugin'): - kwargs['use_grpc_namespace'] = True - cc_proto_library(name=name, - srcs=srcs, - deps=deps, - cc_libs = if_static( - ["@protobuf_archive//:protobuf"], - otherwise=["@protobuf_archive//:protobuf_headers"], - ), - protoc="@protobuf_archive//:protoc", - testonly=testonly, - visibility=visibility, - **kwargs) +def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs): + if kwargs.get("use_grpc_plugin"): + kwargs["use_grpc_namespace"] = True + cc_proto_library( + name = name, + srcs = srcs, + # Append well-known proto dep. As far as I know this is the only way + # for xla_proto_library to access google.protobuf.{Any,Duration,...}. + deps = deps + ["@protobuf_archive//:cc_wkt_protos"], + cc_libs = if_static( + ["@protobuf_archive//:protobuf"], + otherwise = ["@protobuf_archive//:protobuf_headers"], + ), + protoc = "@protobuf_archive//:protoc", + testonly = testonly, + visibility = visibility, + **kwargs + ) -def xla_py_grpc_library(**kwargs): - # Note: we don't currently define any special targets for Python GRPC in OSS. - _ignore = kwargs - pass +def xla_py_proto_library(**kwargs): + # Note: we don't currently define a proto library target for Python in OSS. + _ignore = kwargs + pass +def xla_py_grpc_library(**kwargs): + # Note: we don't currently define any special targets for Python GRPC in OSS. + _ignore = kwargs + pass ORC_JIT_MEMORY_MAPPER_TARGETS = [] + +# We link the GPU plugin into the XLA Python extension if CUDA is enabled. +def xla_python_default_plugins(): + return if_cuda_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"]) diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index e2d7b6ef4666c533951960fd3dcf6869ec2b52c5..6155f3698ebaa1f4ebdd4e5eeb062181169de71c 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -15,11 +15,11 @@ limitations under the License. syntax = "proto3"; -import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/hlo.proto"; - package xla; +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + // Options for the HLO insert-reduce-precision-operations pass. message HloReducePrecisionOptions { // Where and when the reduce-precision operations will be added. @@ -61,41 +61,12 @@ message HloReducePrecisionOptions { // Debugging options for XLA. These options may change at any time - there are // no guarantees about backward or forward compatibility for these fields. message DebugOptions { - // HLO modules matching this regex will be dumped to a .dot file throughout - // various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to - // dump *all* HLO modules. - string xla_generate_hlo_graph = 1; - // Show addresses of HLO ops in graph dump. bool xla_hlo_graph_addresses = 2; - // Path to dump HLO graphs to. - string xla_hlo_graph_path = 4; - - // Dump HLO graphs as TensorFlow GraphDefs. - bool xla_hlo_dump_as_graphdef = 5; - - // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to - // dump *all* HLO modules. - string xla_log_hlo_text = 6; - - // Dump all HLO modules as text into the provided directory path. - string xla_generate_hlo_text_to = 7; - - // Dump Hlo after all hlo passes are executed as proto binary into this - // directory. - string xla_dump_optimized_hlo_proto_to = 8; - // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; - // Dumps computations that XLA executes into the provided directory path. - string xla_dump_computations_to = 10; - - // Dumps parameters and results of computations that XLA executes into the - // provided directory path. - string xla_dump_executions_to = 11; - // List of HLO passes to disable. These names must exactly match the pass // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; @@ -115,9 +86,6 @@ message DebugOptions { // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; - // Dump the compiler IR into this directory as individual files. - string xla_dump_ir_to = 34; - // Eliminate implicit broadcasts when lowering user computations to HLO // instructions; use explicit broadcast instead. bool xla_eliminate_hlo_implicit_broadcast = 35; @@ -171,22 +139,12 @@ message DebugOptions { // HLO graph. bool xla_hlo_graph_sharding_color = 92; - // Prefix the name scopes of the TF graph exports with "devX" device - // assignments, if available. - bool xla_hlo_tfgraph_device_scopes = 93; + reserved 93; // Was xla_hlo_tfgraph_device_scopes // If true, the GPU backend is free to use cudnn for HLO batch normalization // ops. bool xla_gpu_use_cudnn_batchnorm = 94; - // Dump HLO before any hlo passes are executed as proto binary into this - // directory. - string xla_dump_unoptimized_hlo_proto_to = 95; - - // Dump HLO after each pass as an HloProto in binary file format into this - // directory. - string xla_dump_per_pass_hlo_proto_to = 96; - // Generate calls to MKL-DNN in the CPU backend. bool xla_cpu_use_mkl_dnn = 97; @@ -198,10 +156,21 @@ message DebugOptions { // // - Reducing the precision of operations (e.g. using an approximate sin // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf. + // - Assuming that operations never produce or consume NaN or +/- Inf (this + // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_nans = 120; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_infs = 121; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. @@ -224,9 +193,6 @@ message DebugOptions { // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). bool xla_gpu_disable_ptxas_optimizations = 103; - // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) - bool xla_hlo_dump_as_html = 105; - // Enable fast math with eigen in the HLO evaluator. bool xla_hlo_evaluator_use_fast_path = 106; @@ -234,11 +200,75 @@ message DebugOptions { // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. bool xla_allow_scalar_index_dynamic_ops = 107; - // Next id: 108 + enum StepMarkerLocation { + // Generate step mark at each iteration of top level while loop, which + // is assumed to be a training loop. This is the default. + STEP_MARK_AT_ENTRY = 0; + // Generate step mark at program entry. This handles the case where each + // step are done by one or multiple programs execution. Only the first + // program will be tagged for generating step mark at program entry. + STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; + // No step mark. + STEP_MARK_NONE = 2; + } + // Option to emit a target-specific marker to indicate the start of a training + // step. The location of the marker (if any) is determined by the option + // value. + StepMarkerLocation xla_step_marker_location = 108; + + // + // BEGIN flags controlling dumping HLO modules for debugging. + // + // When dumping is enabled, HLO modules dumped at the very beginning and end + // of compilation, and optionally also during the pass pipeline. + // + // In general, if you set one of these flags, we will try to infer reasonable + // defaults for the others. For example: + // + // * Setting --xla_dump_to=/tmp/foo without specifying a format + // with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text. + // + // * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will + // dump to stdout. + // + + // Directory to dump into. + string xla_dump_to = 109; + + // If specified, will only dump modules which match this regexp. + string xla_dump_hlo_module_re = 110; + + // If this flag is specified, will also HLO before and after passes that match + // this regular expression. Set to .* to dump before/after all passes. + string xla_dump_hlo_pass_re = 111; + + // Specifies the format that HLO is dumped in. Multiple of these may be + // specified. + bool xla_dump_hlo_as_text = 112; + bool xla_dump_hlo_as_proto = 113; + bool xla_dump_hlo_as_dot = 114; + bool xla_dump_hlo_as_url = 115; + + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_dump_hlo_as_html = 116; + + // If true, every time an HLO module is run, we will dump an HloSnapshot + // (essentially, a serialized module plus its inputs) to the --xla_dump_to + // directory. + bool xla_dump_hlo_snapshots = 118; + + // + // END flags controlling dumping HLO modules. + // + + // Next id: 121 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; + + reserved 117; // was xla_dump_to + reserved 5; // Was xla_hlo_dump_as_graphdef } // These settings control how XLA compiles and/or runs code. Not all settings @@ -265,6 +295,14 @@ message ExecutionOptions { // computation on. The computation will be partitioned across these devices. // If not provided, the default device will be chosen. repeated DeviceHandle device_handles = 5; + + // Number of replicas of the computation to run. If zero, uses the default + // number of replicas for the XLA service. + int32 num_replicas = 6; + + // This optional field specifies the device assignment if known at compile + // time. + DeviceAssignmentProto device_assignment = 7; } message GetDeviceHandlesRequest { @@ -302,8 +340,7 @@ message TransferToInfeedRequest { DeviceHandle device_handle = 3; } -message TransferToInfeedResponse { -} +message TransferToInfeedResponse {} message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this @@ -322,8 +359,7 @@ message ResetDeviceRequest { DeviceHandle device_handle = 1; } -message ResetDeviceResponse { -} +message ResetDeviceResponse {} message ComputationGraphStatsRequest { HloModuleProto computation = 1; @@ -346,8 +382,7 @@ message UnregisterRequest { repeated GlobalDataHandle data = 1; } -message UnregisterResponse { -} +message UnregisterResponse {} message CompileRequest { // The graph to be compiled. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index a64e2f5df5cacca05e83f31c941c57abd5ccf4de..6e5772a7396bae1674ec4e7393ba03506c9381e4 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; package xla; + option cc_enable_arenas = true; // Primitive types are the individual values that can be held in rectangular @@ -55,7 +56,7 @@ enum PrimitiveType { F64 = 12; // Complex values of fixed width. - C64 = 15; // Paired F32 (real, imag), as in std::complex. + C64 = 15; // Paired F32 (real, imag), as in std::complex. C128 = 18; // Paired F64 (real, imag), as in std::complex. // A tuple is a polymorphic sequence; e.g. a shape that holds different @@ -199,7 +200,7 @@ message ShapeProto { // in this field represents an upper bound on the size of the dimension. repeated int64 dimensions = 3; - // For tuples only, the shapes of constitutent shapes in the tuple sequence. + // For tuples only, the shapes of constituent shapes in the tuple sequence. repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. @@ -367,7 +368,7 @@ message LiteralProto { repeated uint64 u64s = 7; repeated float f32s = 8; repeated double f64s = 9; - repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated double c128s = 18; // Stored as interleaved real, imag doubles. repeated LiteralProto tuple_literals = 10; // The F16s, BF16s, U16s and S16s are encoded in little endian byte order @@ -510,7 +511,7 @@ message ConvolutionDimensionNumbers { repeated int64 output_spatial_dimensions = 12; // Next = 13 -}; +} enum FftType { FFT = 0; // Forward FFT; complex in, complex out. @@ -529,7 +530,7 @@ message DotDimensionNumbers { repeated int64 lhs_batch_dimensions = 3; // The dimension numbers that represent the 'rhs' batch dimensions. repeated int64 rhs_batch_dimensions = 4; -}; +} enum RandomDistribution { RNG_INVALID = 0; @@ -545,6 +546,32 @@ enum RandomDistribution { // Next: 4 } +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + }; + Transpose transpose_a = 4; +} + +message CholeskyOptions { + // If true, uses the lower triangle of `a`. If false, uses the upper triangle + // of `a`. + bool lower = 1; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -604,3 +631,32 @@ message PrecisionConfig { // Next: 2 } + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} + +// A backend-config for kWhile loops that stores the loop's trip count, if it is +// known. +// +// This is useful for backends that can implement a `for i in 0..N` loop more +// efficiently than a `while` loop. For example, on GPUs, we can implement a +// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, +// whereas implementing a `while` loop requires a host-device sync on each +// iteration. +message WhileLoopBackendConfig { + message KnownTripCount { + int64 n = 1; + } + // This indirection lets us distinguish between known-trip-count == 0 and + // unknown-trip-count. + KnownTripCount known_trip_count = 1; +} diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 2dae746d034a1bf52e84de74dfb0c6e23aaed4d1..b2718c5c283358d98da175a8d3b21bb1f2b01c75 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -11,9 +11,15 @@ package( load( "//tensorflow:tensorflow.bzl", + "tf_custom_op_py_library", "tf_gen_op_libs", + "tf_gen_op_wrapper_py", ) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( name = "xrt_proto", @@ -27,6 +33,12 @@ xla_proto_library( ], ) +tf_proto_library_py( + name = "xrt_proto", # bzl adds a _py suffix + srcs = ["xrt.proto"], + visibility = ["//visibility:public"], +) + cc_library( name = "xrt_utils", srcs = [ @@ -78,6 +90,25 @@ tf_gen_op_libs( ], ) +tf_gen_op_wrapper_py( + name = "xrt_ops_wrapper_py", + out = "xrt_ops.py", + deps = [ + ":xrt_compile_ops_op_lib", + ":xrt_execute_op_op_lib", + ":xrt_state_ops_op_lib", + ], +) + +tf_custom_op_py_library( + name = "xrt_ops", + kernels = ["//tensorflow/compiler/xrt/kernels:xrt_ops"], + visibility = ["//visibility:public"], + deps = [ + ":xrt_ops_wrapper_py", + ], +) + cc_library( name = "xrt_server", visibility = ["//visibility:public"], diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index dc02fd272fd8700c7f8fa64adf7ab57c88bab706..1e325191bba828e3d5e4599f87dcf4f4d0674945 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -51,7 +51,10 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xrt:xrt_compile_ops_op_lib", + "//tensorflow/compiler/xrt:xrt_execute_op_op_lib", "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/compiler/xrt:xrt_state_ops_op_lib", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 2ee1a6cd1aebcdbd65892b33e5044489070ab5c4..b791519c09758a4f4124c95add5351a9433ecb8f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -68,9 +68,11 @@ class XRTCompileOp : public OpKernel { Status CompilationCacheKey(const xrt::XLAComputation& computation, string* key) { - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); + const size_t size = computation.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK( + SerializeToBufferDeterministic(computation, serialized.get(), size)); + uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); *key = absl::StrCat(fingerprint); return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 116c193cab65410a5a7c3058f98cc2be2cbe9e67..42ef88168af4b6f391ffc2e69ab4c4000d1cbee1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 6a7f10652533920ba3fa48fba1d5161f7c4d4530..343f43b7159b55bad184eed2cada55c76085ffa0 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -122,6 +122,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .Device(DEVICE_XLA_GPU) .HostMemory("handle"), diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index e2c223b3dbb2311d0f42e1a36e316fd9d5f66040..6af73ecc85351a9b38ba526db076e9176d1cb2f1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -215,27 +217,29 @@ class XRTAllocateFromTensorOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); + std::vector minor_to_major; if (ctx->HasAttr("layouts")) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); } OP_REQUIRES( ctx, tf_shapes_.size() == dtypes_.size(), errors::InvalidArgument("shapes and dtypes must be the same length")); std::vector xla_shapes; + xla_shapes.reserve(tf_shapes_.size()); for (int i = 0; i < tf_shapes_.size(); i++) { xla::Shape xla_shape; OP_REQUIRES_OK( ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); - xla_shapes.push_back(xla_shape); + xla_shapes.push_back(std::move(xla_shape)); } if (xla_shapes.size() > 1 || make_tuple) { shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); } else { shape_.Swap(&xla_shapes.front()); } - if (!minor_to_major_.empty()) { + if (!minor_to_major.empty()) { xla::Shape shape_with_layouts; - OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_, + OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, /*layout_func=*/nullptr, &shape_with_layouts)); shape_.Swap(&shape_with_layouts); @@ -304,7 +308,6 @@ class XRTAllocateFromTensorOp : public OpKernel { private: std::vector tf_shapes_; DataTypeVector dtypes_; - std::vector minor_to_major_; xla::Shape shape_; }; @@ -487,7 +490,7 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - xla::Literal literal; + xla::Literal literal(allocation->on_host_shape()); OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); @@ -499,6 +502,96 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that reads a device-resident tuple to host memory and returns it as a +// literal. +template +class XRTReadToTensorOp : public OpKernel { + public: + explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + } + ~XRTReadToTensorOp() override = default; + XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; + XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReadToTensorOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not + // just scalars.) + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + + if (discard_) { + VLOG(2) << "Releasing handle " << allocation_handle; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( + rm, allocation_handle)); + } + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + + xla::Shape shape = allocation->on_host_shape(); + int output = 0; + Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( + &shape, + [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { + if (subshape->IsTuple()) return Status::OK(); + + xla::PrimitiveType xla_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( + ctx->expected_output_dtype(output), &xla_type)); + if (xla_type != subshape->element_type()) { + return errors::InvalidArgument( + "Type mismatch between buffer type (", subshape->ToString(), + ") and tensor type (", + DataTypeString(ctx->expected_output_dtype(output)), + ") for output tensor ", output); + } + + TensorShape output_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); + + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_output(output, output_shape, &output_tensor)); + + XRTTupleAllocation* sub; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + allocation, index, &sub, /*alias_parent_allocation=*/true)); + core::ScopedUnref sub_unref(sub); + + xla::MutableBorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, + &literal)); + TF_RETURN_IF_ERROR(sub->ToLiteral( + device_ref.backend(), device_ref.device_ordinal(), &literal)); + + ++output; + return Status::OK(); + }); + OP_REQUIRES_OK(ctx, status); + } + bool discard_; + DataTypeVector dtypes_; +}; + // Op that writes a new literal value into device-resident memory. template class XRTWriteLiteralOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 2e743fec4963a52ee1abf64525f26e3d89479670..87546fce4e4e7e38ef934d32ff95a60a4ad5492a 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -53,7 +53,7 @@ The shapes can differ from the corresponding input one, as long as the total number of elements matches. In other words, it is possible to feed an input tensor with shape {8} and have a corresponding shape {2,2,2}. layouts: A vector holding the requested layout in minor-to-major sequence. -If empty, the default layout wil be used. +If empty, the default layout will be used. For a tuple, the layouts vector holds a linearized minor-to-major numbers for all the tuple leaves, in the order they appear within the tuple. The elements within the layouts sequence corresponding to a given tuple @@ -151,6 +151,27 @@ releases the handle. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTReadToTensor") + .Input("handles: int64") + .Attr("release_handles: bool = False") + .Attr("dtypes: list(type)") + .Output("tensors: dtypes") + .SetShapeFn(tensorflow::shape_inference::UnknownShape) + .Doc( + R"( +Copies allocated values from device memory and returns them as zero or more +Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned. +In general, the tensors returned for a handle correspond to an in-order traversal +of a the tuple-tree value referenced by the handle. + +'handles' contains ids returned from Ops that produced on-device allocations. +At present, only a single (scalar) handle is supported. +'dtypes' are the expected types for each `Tensor` to be returned. If the +expected and actual tensor types do not match, an error is returned. +'release_handles': if True, `handles` are released. +'tensors' are the output Tensors. +)"); + REGISTER_OP("XRTReleaseAllocationHandle") .Input("handle: int64") .SetShapeFn(tensorflow::shape_inference::NoOutputs) diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 1e2a9584f88b73d7c92a929e93af60376a59170b..1b3bcbea4c1228944a6604fc923228024e74d700 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" @@ -221,7 +220,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal) { + xla::MutableLiteralBase* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); @@ -235,9 +234,8 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, " has been released"); } } - TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( - stream.get(), shaped_buffer)); - return Status::OK(); + return transfer_manager->TransferLiteralFromDevice(stream.get(), + shaped_buffer, *literal); } Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index ddf2656e6f51775024a6d1cd0d7a387605faae6f..6519da30d02e41da5a862cadd2133bd8dd8b42d7 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -147,7 +147,7 @@ class XRTTupleAllocation : public ResourceBase { // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal); + xla::MutableLiteralBase* literal); // Write a new literal value to the allocation. Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index 3ef8bedc7324696cd255c72a851f0f2410e03848..8b7749b1919710296bb5b5ec2f7cb43b189830d2 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -55,21 +55,14 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { return ref_options; } xla::DebugOptions options = xla::GetDebugOptionsFromFlags(); - options.set_xla_generate_hlo_text_to( - SafeDebugPath(ref_options.xla_generate_hlo_text_to())); - options.set_xla_dump_optimized_hlo_proto_to( - SafeDebugPath(ref_options.xla_dump_optimized_hlo_proto_to())); - options.set_xla_dump_computations_to( - SafeDebugPath(ref_options.xla_dump_computations_to())); - options.set_xla_dump_executions_to( - SafeDebugPath(ref_options.xla_dump_executions_to())); + options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to())); + options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto()); + options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text()); + options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots()); + options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re()); for (auto& pass : ref_options.xla_disable_hlo_passes()) { options.add_xla_disable_hlo_passes(pass); } - options.set_xla_dump_unoptimized_hlo_proto_to( - SafeDebugPath(ref_options.xla_dump_unoptimized_hlo_proto_to())); - options.set_xla_dump_per_pass_hlo_proto_to( - SafeDebugPath(ref_options.xla_dump_per_pass_hlo_proto_to())); return options; } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index a4c3d9623adfe3133af0c6ea055586b9544e659b..fb8dc070ad2437f7ece9dd5037089f972b988a38 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -59,6 +59,7 @@ py_library( "//tensorflow/contrib/labeled_tensor", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", + "//tensorflow/contrib/learn:head_test_lib", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", @@ -218,7 +219,6 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", - "//tensorflow/contrib/tpu:all_ops", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], @@ -238,7 +238,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_ops_op_lib", ], }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + "//tensorflow/compiler/tf2tensorrt:trt_op_libs", ]) + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index f0b1c92cf7e4b760381da38febd9682ce2a4f27c..5608e7ddafa25757484d8c845c8c84a5691e143c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -73,8 +73,7 @@ cc_binary( "-z defs", "-s", "-Wl,--gc-sections", - "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - "$(location {})".format(LINKER_SCRIPT), + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 648f3ebb05646a66144bcb118347cbc391909409..5174afe0a63d37e3ea3e19ac9bab644d1d83ecf1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -37,6 +37,7 @@ py_library( cc_library( name = "batch_ops_kernels", deps = [ + "//tensorflow/core:batch_ops_op_lib", "//tensorflow/core/kernels:batch_kernels", ], alwayslink = 1, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 6138d7912601344ef7422fd50fb35c8401fd2e63..c6e1bc22baa4fe26621a0d31c9a7df64dfbe62fe 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { - namespace { class BigtableClientOp : public OpKernel { @@ -277,9 +276,7 @@ class ToBigtableOp : public AsyncOpKernel { LOG(ERROR) << "Failure applying mutation on row (" << failure.original_index() << "): " << failure.mutation().row_key() - << " - error: " << failure.status().error_message() - << " (Details: " << failure.status().error_details() - << ")."; + << " - error: " << failure.status().message() << "."; } } OP_REQUIRES_ASYNC( @@ -341,8 +338,8 @@ class ToBigtableOp : public AsyncOpKernel { } template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { + Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { const Tensor* argument_t; TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); if (!TensorShapeUtils::IsScalar(argument_t->shape())) { @@ -360,5 +357,4 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), } // namespace } // namespace data - } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc index 98f906408c230a4382ffafe412ee9990d4384930..3a46e6e85d5dd12ed0adcec259e30d493add6232 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -27,11 +27,62 @@ Status GrpcStatusToTfStatus(const ::grpc::Status& status) { status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { grpc_code = ::grpc::StatusCode::INTERNAL; } - return Status(static_cast<::tensorflow::error::Code>(status.error_code()), + return Status(static_cast<::tensorflow::error::Code>(grpc_code), strings::StrCat("Error reading from Cloud Bigtable: ", status.error_message())); } +namespace { +::tensorflow::error::Code GcpErrorCodeToTfErrorCode( + ::google::cloud::StatusCode code) { + switch (code) { + case ::google::cloud::StatusCode::kOk: + return ::tensorflow::error::OK; + case ::google::cloud::StatusCode::kCancelled: + return ::tensorflow::error::CANCELLED; + case ::google::cloud::StatusCode::kUnknown: + return ::tensorflow::error::UNKNOWN; + case ::google::cloud::StatusCode::kInvalidArgument: + return ::tensorflow::error::INVALID_ARGUMENT; + case ::google::cloud::StatusCode::kDeadlineExceeded: + return ::tensorflow::error::DEADLINE_EXCEEDED; + case ::google::cloud::StatusCode::kNotFound: + return ::tensorflow::error::NOT_FOUND; + case ::google::cloud::StatusCode::kAlreadyExists: + return ::tensorflow::error::ALREADY_EXISTS; + case ::google::cloud::StatusCode::kPermissionDenied: + return ::tensorflow::error::PERMISSION_DENIED; + case ::google::cloud::StatusCode::kUnauthenticated: + return ::tensorflow::error::UNAUTHENTICATED; + case ::google::cloud::StatusCode::kResourceExhausted: + return ::tensorflow::error::RESOURCE_EXHAUSTED; + case ::google::cloud::StatusCode::kFailedPrecondition: + return ::tensorflow::error::FAILED_PRECONDITION; + case ::google::cloud::StatusCode::kAborted: + return ::tensorflow::error::ABORTED; + case ::google::cloud::StatusCode::kOutOfRange: + return ::tensorflow::error::OUT_OF_RANGE; + case ::google::cloud::StatusCode::kUnimplemented: + return ::tensorflow::error::UNIMPLEMENTED; + case ::google::cloud::StatusCode::kInternal: + return ::tensorflow::error::INTERNAL; + case ::google::cloud::StatusCode::kUnavailable: + return ::tensorflow::error::UNAVAILABLE; + case ::google::cloud::StatusCode::kDataLoss: + return ::tensorflow::error::DATA_LOSS; + } +} +} // namespace + +Status GcpStatusToTfStatus(const ::google::cloud::Status& status) { + if (status.ok()) { + return Status::OK(); + } + return Status( + GcpErrorCodeToTfErrorCode(status.code()), + strings::StrCat("Error reading from Cloud Bigtable: ", status.message())); +} + string RegexFromStringSet(const std::vector& strs) { CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; std::unordered_set uniq(strs.begin(), strs.end()); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index e3b4535bac4a01a1277290e0d1ea6d3c7613731c..f6aa67fb0b5dc95c510bbf799f4f47496d59b00e 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -26,6 +26,7 @@ limitations under the License. namespace tensorflow { Status GrpcStatusToTfStatus(const ::grpc::Status& status); +Status GcpStatusToTfStatus(const ::google::cloud::Status& status); string RegexFromStringSet(const std::vector& strs); @@ -89,22 +90,21 @@ class BigtableReaderDatasetIterator : public DatasetIterator { public: explicit BigtableReaderDatasetIterator( const typename DatasetIterator::Params& params) - : DatasetIterator(params), iterator_(nullptr, false) {} + : DatasetIterator(params) {} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); if (iterator_ == reader_->end()) { - grpc::Status status = reader_->Finish(); - if (status.ok()) { - *end_of_sequence = true; - return Status::OK(); - } - return GrpcStatusToTfStatus(status); + *end_of_sequence = true; + return Status::OK(); + } + if (!*iterator_) { + return GcpStatusToTfStatus(iterator_->status()); } *end_of_sequence = false; - google::cloud::bigtable::Row& row = *iterator_; + google::cloud::bigtable::Row& row = **iterator_; Status s = ParseRow(ctx, row, out_tensors); // Ensure we always advance. ++iterator_; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 2c6317157d25908c1ff66fc10bd188d93f040521..22b711a73d671a6609c45a55e9f6b13e2894b49d 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -152,11 +152,11 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { } if (input_tensors[0].NumElements() == 1) { // Single key lookup. - ::grpc::Status status; + ::google::cloud::Status status; auto pair = dataset()->table_->table().ReadRow( input_tensors[0].scalar()(), dataset()->filter_, status); if (!status.ok()) { - return GrpcStatusToTfStatus(status); + return GcpStatusToTfStatus(status); } if (!pair.first) { return errors::DataLoss("Row key '", diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e6fda9e61757f1441b3691c2a3d57c6f1a5a0d42..d9fce6e09f47ab05074f0b4c03dd8e672ed3d2ce 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -335,6 +335,17 @@ grpc::Status BigtableTestClient::ReadModifyWriteRow( return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "ReadModifyWriteRow not implemented."); } +std::unique_ptr> +BigtableTestClient::AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to AsyncReadModifyWriteRow:" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::unique_ptr< grpc::ClientReaderInterface> BigtableTestClient::ReadRows( diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index 8e1326f2ce841368ea81fc7194a0588e5d6cd637..63d59b32dd17a2f58d3413932b69f4d704c84e48 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -46,6 +46,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { google::bigtable::v2::ReadModifyWriteRowRequest const& request, google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + std::unique_ptr> + AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) override; + std::unique_ptr< grpc::ClientReaderInterface> ReadRows(grpc::ClientContext* context, diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc index 32611e2590d9a81f46d0b9dfc09fe7e0068e9671..cf6e619bfaf25101b7fea7ce59a31f7a688c0452 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -39,7 +39,6 @@ TEST(BigtableTestClientTest, EmptyRowRead) { ::google::cloud::bigtable::Filter::Latest(1)); auto rows = table.ReadRows(std::move(rowset), filter); EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; - EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; } TEST(BigtableTestClientTest, SingleRowWriteAndRead) { @@ -55,15 +54,15 @@ TEST(BigtableTestClientTest, SingleRowWriteAndRead) { auto rows = table.ReadRows(std::move(rowset), filter); auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_EQ(itr, rows.end()); - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { @@ -82,15 +81,15 @@ TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, MultiRowWriteAndRead) { @@ -109,33 +108,35 @@ TEST(BigtableTestClientTest, MultiRowWriteAndRead) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v2"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v2"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v3"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v3"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { @@ -154,33 +155,35 @@ TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v2"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v2"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v3"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v3"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, ColumnFiltering) { @@ -206,33 +209,35 @@ TEST(BigtableTestClientTest, ColumnFiltering) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v2"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v2"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v3"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v3"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, RowKeys) { @@ -257,33 +262,35 @@ TEST(BigtableTestClientTest, RowKeys) { table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), ""); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), ""); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), ""); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), ""); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), ""); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), ""); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, SampleKeys) { diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc index 416b719e30aa5f2504449d151a48e95c9105c68b..39c2a2e775d5d5287b137bf33eef66251738e6d3 100644 --- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -59,7 +59,7 @@ REGISTER_OP("BigtablePrefixKeyDataset") .Input("table: resource") .Input("prefix: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -68,14 +68,14 @@ REGISTER_OP("BigtableRangeKeyDataset") .Input("start_key: string") .Input("end_key: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("BigtableSampleKeysDataset") .Input("table: resource") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -85,7 +85,7 @@ REGISTER_OP("BigtableSampleKeyPairsDataset") .Input("start_key: string") .Input("end_key: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -100,7 +100,7 @@ REGISTER_OP("BigtableScanDataset") .Input("columns: string") .Input("probability: float") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index fa64055dfd65a134afdf46cebccb7f7d96106502..736cf3da49e934d49d0587d729cff6eaaed8f254 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -475,15 +475,17 @@ class BigtableTable(object): """ if timestamp is None: timestamp = -1 # Bigtable server provided timestamp. - for tensor_type in nest.flatten(dataset.output_types): + for tensor_type in nest.flatten( + dataset_ops.get_legacy_output_types(dataset)): if tensor_type != dtypes.string: raise ValueError("Not all elements of the dataset were `tf.string`") - for shape in nest.flatten(dataset.output_shapes): + for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)): if not shape.is_compatible_with(tensor_shape.scalar()): raise ValueError("Not all elements of the dataset were scalars") if len(column_families) != len(columns): raise ValueError("len(column_families) != len(columns)") - if len(nest.flatten(dataset.output_types)) != len(columns) + 1: + if len(nest.flatten( + dataset_ops.get_legacy_output_types(dataset))) != len(columns) + 1: raise ValueError("A column name must be specified for every component of " "the dataset elements. (e.g.: len(columns) != " "len(dataset.output_types))") diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index d3b23d949ee2c7674c3918d39e8b71d76eefcfec..64e4c4560ba3a1b177db12a09997ff7afe8775a3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -193,8 +193,9 @@ py_test( py_test( name = "estimator_test", - size = "large", + size = "medium", srcs = ["estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "no_gpu", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index a178820841c4c8bcb7f5742babdb6d0f4825de31..5ffbb9067081d7440ab5e11290697b822051bee5 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -84,12 +84,10 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -179,8 +177,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): `[batch_size, label_dimension]`). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. - label_name: String, name of the key in label dict. Can be null if label - is a tensor (single headed models). + label_name: String, name of the key in label dict. Can be null if label is + a tensor (single headed models). weight_column_name: Name of the column for weights, or None if not weighted. model_dir: Directory for model exports, etc. @@ -195,11 +193,11 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -286,11 +284,11 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -353,10 +351,9 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. head: `Head` instance. - ranking_model_pair_keys: Keys to distinguish between features - for left and right part of the training pairs for ranking. For example, - for an Example with features "a.f1" and "b.f1", the keys would be - ("a", "b"). + ranking_model_pair_keys: Keys to distinguish between features for left and + right part of the training pairs for ranking. For example, for an + Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. weight_column_name: Name of the column for weights, or None if not @@ -376,12 +373,10 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -417,12 +412,12 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) + # When using this estimator, make sure to regularize the hessian (at least l2, # min_node_weight)! # TODO(nponomareva): extend to take multiple quantiles in one go. class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): - """An estimator that does quantile regression and returns quantile estimates. - """ + """An estimator that does quantile regression and returns quantile estimates.""" def __init__(self, learner_config, @@ -449,8 +444,8 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. quantiles: a list of quantiles for the loss, each between 0 and 1. - label_dimension: Dimension of regression label. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape + label_dimension: Dimension of regression label. This is the size of the + last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). When label_dimension>1, it is recommended to use multiclass strategy diagonal hessian or full hessian. num_trees: An int, number of trees to build. @@ -469,11 +464,11 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -519,6 +514,7 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. @@ -534,10 +530,8 @@ def core_multiclass_head( def loss_fn(labels, logits): result = losses.per_example_maxent_loss( - labels=labels, - logits=logits, - weights=weight_column, - num_classes=n_classes) + # Don't pass the weights: head already multiplies by them. + labels=labels, logits=logits, weights=None, num_classes=n_classes) return result[0] # pylint:disable=protected-access @@ -564,7 +558,8 @@ def core_quantile_regression_head( result = losses.per_example_quantile_regression_loss( labels=labels, predictions=logits, - weights=weight_column, + # Don't pass the weights: head already multiplies by them. + weights=None, quantile=quantiles) return result[0] @@ -623,11 +618,11 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): the bias. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree num_quantiles: Number of quantiles to build for numeric feature values. """ @@ -685,10 +680,9 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. head: `Head` instance. - ranking_model_pair_keys: Keys to distinguish between features - for left and right part of the training pairs for ranking. For example, - for an Example with features "a.f1" and "b.f1", the keys would be - ("a", "b"). + ranking_model_pair_keys: Keys to distinguish between features for left and + right part of the training pairs for ranking. For example, for an + Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. weight_column_name: Name of the column for weights, or None if not @@ -703,12 +697,10 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree num_quantiles: Number of quantiles to build for numeric feature values. Raises: @@ -748,8 +740,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): # TODO(nponomareva): extend to take multiple quantiles in one go. class CoreGradientBoostedDecisionTreeQuantileRegressor( core_estimator.Estimator): - """An estimator that does quantile regression and returns quantile estimates. - """ + """An estimator that does quantile regression and returns quantile estimates.""" def __init__(self, learner_config, @@ -775,8 +766,8 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. quantiles: a list of quantiles for the loss, each between 0 and 1. - label_dimension: Dimension of regression label. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape + label_dimension: Dimension of regression label. This is the size of the + last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). When label_dimension>1, it is recommended to use multiclass strategy diagonal hessian or full hessian. num_trees: An int, number of trees to build. @@ -795,11 +786,11 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( the bias. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree num_quantiles: Number of quantiles to build for numeric feature values. """ if len(quantiles) > 1: @@ -814,7 +805,9 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( params={ 'head': core_quantile_regression_head( - quantiles[0], label_dimension=label_dimension), + quantiles[0], + label_dimension=label_dimension, + weight_column=weight_column_name), 'feature_columns': feature_columns, 'learner_config': diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 47d910d42a27db4b857eeb12209dfbb429dd1be2..5a8b2ba9caf0a9813cb5b3409b8a0dc3de0a45d7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -399,8 +399,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def testQuantileRegression(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -413,7 +413,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, + num_trees=12, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -428,31 +428,12 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper >= 0.92) self.assertTrue(frac_below_upper <= 0.98) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - # Multi-dimensional quantile regression. def testQuantileRegressionMultiDimLabel(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -467,7 +448,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): quantiles=[0.95], learner_config=learner_config, label_dimension=2, - num_trees=100, + num_trees=18, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -490,35 +471,6 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_both_below_upper >= 0.91) self.assertTrue(frac_both_below_upper <= 0.99) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - label_dimension=2, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) - class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -712,11 +664,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) - # One dimensional quantile regression. - def testQuantileRegression(self): + # Quantile regression in core is the same as in non core estimator, so we + # just check that it does not fail. + def testQuantileRegressionDoesNotThroughException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 + learner_config.constraints.max_tree_depth = 1 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -731,112 +684,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_upper.train(input_fn=train_input_fn, steps=1000) - result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper >= 0.92) - self.assertTrue(frac_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - - # Multi-dimensional quantile regression. - def testQuantileRegressionMultiDimLabel(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) - - train_input_fn, test_input_fn, y = _quantile_regression_input_fns( - two_dimension=True) - y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) - - # 95% percentile. - model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.95], - learner_config=learner_config, - num_trees=100, - label_dimension=2, + num_trees=1, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) model_upper.train(input_fn=train_input_fn, steps=1000) result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - count_below_upper = np.count_nonzero(upper > y, axis=0) - count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) - frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) - frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) - frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper_0 >= 0.92) - self.assertTrue(frac_below_upper_0 <= 0.98) - self.assertTrue(frac_below_upper_1 >= 0.92) - self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.91) - self.assertTrue(frac_both_below_upper <= 0.99) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - label_dimension=2, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index a6e422847d3914188bca9e6dff797ba1ffb06749..eecf3c5aeb6c6785cae3fd5808954a73db6190d6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util @@ -88,6 +89,12 @@ def model_builder(features, if config is None: raise ValueError("Missing estimator RunConfig.") + if config.session_config is not None: + session_config = config.session_config + session_config.allow_soft_placement = True + else: + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + config = config.replace(session_config=session_config) center_bias = params["center_bias"] diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 6d78e27e8f69ea289b686af8402bd91967f997f4..65276242abaf96de8b1936365278b18f8bba93a9 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -538,7 +538,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { partition_boundaries[non_empty_partitions[root_idx]]; float best_gain = std::numeric_limits::lowest(); - int32 best_dimension_idx = 0; bool default_right = false; int32 best_element_idx = 0; @@ -571,7 +570,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { const DimensionBoundary& dimension_and_start = dimension_boundaries[j]; - const int32 dimension_id = dimension_and_start.dimension_id; int start_index = dimension_and_start.start_index; // Even for the last dimension, we always have additional dummy @@ -630,7 +628,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { best_right_node_stats = right_stats_default_left; best_element_idx = element_idx; default_right = false; - best_dimension_idx = dimension_id; } } // Consider calculating the default direction only when there were @@ -648,7 +645,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { best_right_node_stats = right_stats_default_right; best_element_idx = element_idx; default_right = true; - best_dimension_idx = dimension_id; } } } diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index d26af58419752170bbc58bba757ac43349fc2cff..22ad181fc3fb6b0d1b36dd4bb916e63dfb8753ce 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -193,7 +193,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): num_minibatches = control_flow_ops.cond( ops.convert_to_tensor(self._loss_uses_sum_reduction), - lambda: math_ops.to_int64(1), lambda: num_minibatches) + lambda: math_ops.cast(1, dtypes.int64), + lambda: num_minibatches) partition_ids, gains, split_infos = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=num_minibatches, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 0476bed2cd3f3ea5b47b10c51a819f17d6e37c74..0e6a9f8f3a0126ca9f14c9621c9f91bdbf66b338 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -312,9 +312,10 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) # For sum_reduction, we don't need to divide by number of minibatches. - num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, - lambda: math_ops.to_int64(1), - lambda: num_minibatches) + num_minibatches = control_flow_ops.cond( + loss_uses_sum_reduction, + lambda: math_ops.cast(1, dtypes.int64), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) @@ -488,9 +489,10 @@ def _make_sparse_split( num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) - num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, - lambda: math_ops.to_int64(1), - lambda: num_minibatches) + num_minibatches = control_flow_ops.cond( + loss_uses_sum_reduction, + lambda: math_ops.cast(1, dtypes.int64), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 42d69645acaae063fcd46bd1f6c819ccb68f48bd..aa3f24f08a0f762507df83def72e7d595265221f 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -227,7 +227,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="restore_tree") resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() + variables.global_variables_initializer().run() my_saver = saver.Saver() # Add the second tree and replace the ensemble of the handle. diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index c3685b54e201f73039f6623443c67ba2b217a51e..f9945959812f030f76cb481cfcf91cba1f352fc1 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -33,7 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -96,18 +96,18 @@ class TreeEnsembleVariable(tracking.TrackableResource): self._init_op = None super(TreeEnsembleVariable, self).__init__() - def create_resource(self): + def _create_resource(self): return gen_model_ops.decision_tree_ensemble_resource_handle_op( self._container, shared_name=self._name, name=self._name) - def initialize(self): + def _initialize(self): return gen_model_ops.create_tree_ensemble_variable( self.resource_handle, self._stamp_token, self._tree_ensemble_config) @property def initializer(self): if self._init_op is None: - self._init_op = self.initialize() + self._init_op = self._initialize() return self._init_op def is_initialized(self): diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 0c319cc9bd1f720eb404a9da05227c5807ec874f..82f9b17b3308d6a521c79ee7a6f48f6c3813a769 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -120,8 +120,8 @@ class QuantileAccumulator(tracking.TrackableResource): name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: self._name = name - self._resource_handle = self.create_resource() - self._init_op = self.initialize() + self._resource_handle = self._create_resource() + self._init_op = self._initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self._init_op, is_initialized_op) @@ -129,11 +129,11 @@ class QuantileAccumulator(tracking.TrackableResource): self._init_op, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - def create_resource(self): + def _create_resource(self): return gen_quantile_ops.quantile_stream_resource_handle_op( container=self._container, shared_name=self._name, name=self._name) - def initialize(self): + def _initialize(self): return gen_quantile_ops.create_quantile_accumulator( self.resource_handle, self._init_stamp_token, @@ -145,7 +145,7 @@ class QuantileAccumulator(tracking.TrackableResource): @property def initializer(self): if self._init_op is None: - self._init_op = self.initialize() + self._init_op = self._initialize() return self._init_op def is_initialized(self): diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index ad1191d41236e71008bff8c8a7fbd42c16e3f9c5..1f6bbbf5740ec3c47697ea600eef030aa257707f 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -144,8 +144,8 @@ class StatsAccumulator(tracking.TrackableResource): name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: self._name = name - self._resource_handle = self.create_resource() - self._init_op = self.initialize() + self._resource_handle = self._create_resource() + self._init_op = self._initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self.initializer, is_initialized_op) @@ -153,7 +153,7 @@ class StatsAccumulator(tracking.TrackableResource): self.resource_handle, self.initializer, self._is_scalar, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - def create_resource(self): + def _create_resource(self): if self._is_scalar: return ( gen_stats_accumulator_ops.stats_accumulator_scalar_resource_handle_op( @@ -163,7 +163,7 @@ class StatsAccumulator(tracking.TrackableResource): gen_stats_accumulator_ops.stats_accumulator_tensor_resource_handle_op( self._container, self._name, name=self._name)) - def initialize(self): + def _initialize(self): if self._is_scalar: return gen_stats_accumulator_ops.create_stats_accumulator_scalar( self.resource_handle, self._stamp_token) @@ -175,7 +175,7 @@ class StatsAccumulator(tracking.TrackableResource): @property def initializer(self): if self._init_op is None: - self._init_op = self.initialize() + self._init_op = self._initialize() return self._init_op def is_initialized(self): diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e78ec476ab3b43e5eb56a2502008bb8020ae97e0..bca850514be943c0fad4a980092dd9ffe313d746 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -228,7 +228,7 @@ def extract_features(features, feature_columns, use_core_columns): indices = array_ops.concat([ array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]), array_ops.expand_dims( - math_ops.to_int64(categorical_tensor.values), -1) + math_ops.cast(categorical_tensor.values, dtypes.int64), -1) ], 1) tensor = sparse_tensor.SparseTensor( indices=indices, values=weight_tensor.values, dense_shape=shape) @@ -611,8 +611,9 @@ class GradientBoostedDecisionTreeModel(object): learner_pb2.LearnerConfig.TREE_PER_CLASS and self._logits_dimension != 1): # Choose the class for which the tree is built (one vs rest). - return math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) + return math_ops.cast( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension, + dtypes.int32) return constant_op.constant(-1, dtype=dtypes.int32) def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 92068e88a76cb8bfdd394c1093347a8fb8a63449..61441b22e908110da60765d53a6968afdeecc502 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -43,7 +43,7 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keepdims=True) + math_ops.squared_difference(predictions, label), 1, keepdims=True) return loss @@ -1149,9 +1149,9 @@ class GbdtTest(test_util.TensorFlowTestCase): expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508] expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264] self.assertArrayNear(expected_leaf_1, - output.trees[0].nodes[1].leaf.vector.value, 1e-3) + output.trees[0].nodes[1].leaf.vector.value, 3e-3) self.assertArrayNear(expected_leaf_2, - output.trees[0].nodes[2].leaf.vector.value, 1e-3) + output.trees[0].nodes[2].leaf.vector.value, 3e-3) def testTrainFnMulticlassDiagonalHessian(self): """Tests the GBDT train for multiclass diagonal hessian.""" diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 220e981618b7c0bfb1e4e98c087d83b451b9b3cf..40fdfcf45ac79ffcbab6ba4fbf8f9077a179b16f 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -43,7 +44,7 @@ def per_example_logistic_loss(labels, weights, predictions): loss: A Rank 2 (N, 1) tensor of per-example logistic loss. update_op: An update operation to update the loss's internal state. """ - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() @@ -74,7 +75,7 @@ def per_example_quantile_regression_loss(labels, weights, predictions, loss: A Rank 2 (N, 1) tensor of per-example quantile loss. update_op: An update operation to update the loss's internal state. """ - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) error = labels - predictions square_loss_right = array_ops.where(error * quantile < 1.0, math_ops.square(quantile * error), @@ -112,7 +113,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): loss: A Rank 2 (N, 1) tensor of per-example maxent loss update_op: An update operation to update the loss's internal state. """ - labels = math_ops.to_int64(labels) + labels = math_ops.cast(labels, dtypes.int64) # If labels are of rank 1, make them rank 2. labels_shape = labels.get_shape() if len(labels_shape) != 2: @@ -120,7 +121,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): # Labels are indices of classes, convert them to one hot encodings. target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes) labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1]) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) # Calculate softmax probabilities for each class. unnormalized_probs = math_ops.exp(logits) @@ -166,7 +167,7 @@ def per_example_squared_loss(labels, weights, predictions): update_op: An update operation to update the loss's internal state. """ unweighted_loss = math_ops.reduce_sum( - math_ops.square(predictions - labels), 1, keepdims=True) + math_ops.squared_difference(predictions, labels), 1, keepdims=True) return unweighted_loss * weights, control_flow_ops.no_op() @@ -253,7 +254,7 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1): preds_converted = min_res return math_ops.exp(-preds_converted * labels_converted) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) unweighted_loss = exp_with_logits( name=name, eps=eps, labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() @@ -312,7 +313,7 @@ def per_example_full_exp_loss(labels, weights, predictions, name=None): return math_ops.exp(-1.0 * logits * labels_converted) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) unweighted_loss = full_exp_with_logits( name=name, labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 99ed4959fad9699f265183d71a1f3b609d7e6d30..a416588691f580143aa4e5ee53ca1e5cab9c42e0 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -27,7 +27,7 @@ Managing dependencies: @@NoDependency @@split_dependency -Checkpointable data structures: +Trackable data structures: @@List @@Mapping @@UniqueNameTracker @@ -46,20 +46,20 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.python_state import NumpyState -from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint -from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph from tensorflow.python.training.checkpoint_management import CheckpointManager -from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase -from tensorflow.python.training.checkpointable.data_structures import List -from tensorflow.python.training.checkpointable.data_structures import Mapping -from tensorflow.python.training.checkpointable.data_structures import NoDependency -from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable -from tensorflow.python.training.checkpointable.util import capture_dependencies -from tensorflow.python.training.checkpointable.util import list_objects -from tensorflow.python.training.checkpointable.util import object_metadata - +from tensorflow.python.training.tracking.base import Trackable as CheckpointableBase +from tensorflow.python.training.tracking.data_structures import List +from tensorflow.python.training.tracking.data_structures import Mapping +from tensorflow.python.training.tracking.data_structures import NoDependency +from tensorflow.python.training.tracking.python_state import PythonState as PythonStateWrapper +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import capture_dependencies +from tensorflow.python.training.tracking.util import list_objects +from tensorflow.python.training.tracking.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) + diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 4e529322c7c76797938468b405cd175609dc0a73..caedf5b2d1d93dcbc40b0d07607c59597a38131a 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -12,7 +12,7 @@ py_library( ":python_state", ":split_dependency", ":visualize", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -22,8 +22,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -36,8 +36,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -47,7 +47,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", "//third_party/py/numpy", "@six_archive//:six", ], @@ -64,7 +64,7 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) @@ -76,7 +76,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) @@ -89,8 +89,8 @@ tf_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -101,8 +101,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -118,6 +118,6 @@ tf_py_test( "//tensorflow/python/eager:test", "//tensorflow/python/keras:engine", "//tensorflow/python/keras:layers", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 97936d9e9dfd5d6e62fdf8312707a276b63e1267..a25d51980ea760dfb7f323497a397fbd94fd5f23 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -1,4 +1,4 @@ -"""Checkpointable data structures.""" +"""Trackable data structures.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training.checkpointable import base as checkpointable_lib -from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.tracking import base as trackable_lib +from tensorflow.python.training.tracking import data_structures -class UniqueNameTracker(data_structures.CheckpointableDataStructure): - """Adds dependencies on checkpointable objects with name hints. +class UniqueNameTracker(data_structures.TrackableDataStructure): + """Adds dependencies on trackable objects with name hints. Useful for creating dependencies with locally unique names. @@ -43,30 +43,30 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): def __init__(self): super(UniqueNameTracker, self).__init__() - self._maybe_initialize_checkpointable() + self._maybe_initialize_trackable() self._name_counts = {} @property def _values(self): return [dep.ref for dep in self._checkpoint_dependencies] - def track(self, checkpointable, base_name): - """Add a dependency on `checkpointable`. + def track(self, trackable, base_name): + """Add a dependency on `trackable`. Args: - checkpointable: An object to add a checkpoint dependency on. + trackable: An object to add a checkpoint dependency on. base_name: A name hint, which is uniquified to determine the dependency name. Returns: - `checkpointable`, for chaining. + `trackable`, for chaining. Raises: - ValueError: If `checkpointable` is not a checkpointable object. + ValueError: If `trackable` is not a trackable object. """ - if not isinstance(checkpointable, checkpointable_lib.Checkpointable): + if not isinstance(trackable, trackable_lib.Trackable): raise ValueError( - ("Expected a checkpointable value, got %s which does not inherit " - "from CheckpointableBase.") % (checkpointable,)) + ("Expected a trackable value, got %s which does not inherit " + "from tf.track.Trackable.") % (trackable,)) def _format_name(prefix, number): if number > 0: @@ -80,5 +80,5 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - self._track_value(checkpointable, name=candidate) - return checkpointable + self._track_value(trackable, name=candidate) + return trackable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index a2d453ec6eb3dcf9aba4c52fe866756a92673c63..bace21939602666aa48a05d2abfe05ae6aae41e2 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,9 +26,9 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import data_structures -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import data_structures +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util class UniqueNameTrackerTests(test.TestCase): @@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase): save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = tracking.AutoCheckpointable() + restore_slots = tracking.AutoTrackable() restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) @@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(tracking.AutoCheckpointable): + class SlotManager(tracking.AutoTrackable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 969c90c78871ebff02b360f8f09623df56c9c077..1ada05227ba566cd3dfbff406e8fed80dccde684 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -17,13 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc -import functools -import six - import numpy -from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.tracking import base +from tensorflow.python.training.tracking import python_state as core_python_state # pylint: disable=g-import-not-at-top try: @@ -34,8 +31,8 @@ except ImportError: # pylint: enable=g-import-not-at-top -class NumpyState(base.Checkpointable): - """A checkpointable object whose NumPy array attributes are saved/restored. +class NumpyState(base.Trackable): + """A trackable object whose NumPy array attributes are saved/restored. Example usage: @@ -72,7 +69,7 @@ class NumpyState(base.Checkpointable): """Create placeholder NumPy arrays for to-be-restored attributes. Typically `_lookup_dependency` is used to check by name whether a dependency - exists. We cheat slightly by creating a checkpointable object for `name` if + exists. We cheat slightly by creating a trackable object for `name` if we don't already have one, giving us attribute re-creation behavior when loading a checkpoint. @@ -85,7 +82,7 @@ class NumpyState(base.Checkpointable): value = super(NumpyState, self)._lookup_dependency(name) if value is None: value = _NumpyWrapper(numpy.array([])) - new_reference = base.CheckpointableReference(name=name, ref=value) + new_reference = base.TrackableReference(name=name, ref=value) self._unconditional_checkpoint_dependencies.append(new_reference) self._unconditional_dependency_names[name] = value super(NumpyState, self).__setattr__(name, value) @@ -101,7 +98,7 @@ class NumpyState(base.Checkpointable): def __setattr__(self, name, value): """Automatically wrap NumPy arrays assigned to attributes.""" # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making - # ndarrays checkpointable natively and using standard checkpointable list + # ndarrays trackable natively and using standard trackable list # tracking. if isinstance(value, (numpy.ndarray, numpy.generic)): try: @@ -110,48 +107,26 @@ class NumpyState(base.Checkpointable): return except AttributeError: value = _NumpyWrapper(value) - self._track_checkpointable(value, name=name, overwrite=True) + self._track_trackable(value, name=name, overwrite=True) elif (name not in ("_setattr_tracking", "_update_uid") and getattr(self, "_setattr_tracking", True)): - # Mixing restore()-created attributes with user-added checkpointable + # Mixing restore()-created attributes with user-added trackable # objects is tricky, since we can't use the `_lookup_dependency` trick to # re-create attributes (we might accidentally steal the restoration for - # another checkpointable object). For now `NumpyState` objects must be + # another trackable object). For now `NumpyState` objects must be # leaf nodes. Theoretically we could add some extra arguments to # `_lookup_dependency` to figure out whether we should create a NumPy # array for the attribute or not. raise NotImplementedError( ("Assigned %s to the %s property of %s, which is not a NumPy array. " - "Currently mixing NumPy arrays and other checkpointable objects is " + "Currently mixing NumPy arrays and other trackable objects is " "not supported. File a feature request if this limitation bothers " "you.") % (value, name, self)) super(NumpyState, self).__setattr__(name, value) -@six.add_metaclass(abc.ABCMeta) -class PythonStateWrapper(base.Checkpointable): - """Wraps a Python object for storage in an object-based checkpoint.""" - - @abc.abstractmethod - def _serialize(self): - """Callback for `PythonStringStateSaveable` to serialize the object.""" - - @abc.abstractmethod - def _deserialize(self, string_value): - """Callback for `PythonStringStateSaveable` to deserialize the object.""" - - def _gather_saveables_for_checkpoint(self): - """Specify callbacks for saving and restoring `array`.""" - return { - "py_state": functools.partial( - base.PythonStringStateSaveable, - state_callback=self._serialize, - restore_callback=self._deserialize) - } - - -class _NumpyWrapper(PythonStateWrapper): +class _NumpyWrapper(core_python_state.PythonState): """Wraps a NumPy array for storage in an object-based checkpoint.""" def __init__(self, array): @@ -162,7 +137,7 @@ class _NumpyWrapper(PythonStateWrapper): """ self.array = array - def _serialize(self): + def serialize(self): """Callback to serialize the array.""" string_file = BytesIO() try: @@ -172,7 +147,7 @@ class _NumpyWrapper(PythonStateWrapper): string_file.close() return serialized - def _deserialize(self, string_value): + def deserialize(self, string_value): """Callback to deserialize the array.""" string_file = BytesIO(string_value) try: diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 45494351ff4e6c8c75634d8563c3fb63c6089036..40d8fe836402c8b6c8240ef9f665b753c54ede0d 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -26,7 +26,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variables -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import util class NumpyStateTests(test.TestCase): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 3e9700ad74618e24843181d169f3fb39ac96bff6..d7b02b538909305b14e638761bd8ba67a948d2b4 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -21,7 +21,7 @@ import functools from tensorflow.python.ops import control_flow_ops from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(checkpointable.Checkpointable): +class _SplitDependency(trackable.Trackable): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, @@ -81,9 +81,9 @@ class _SplitDependency(checkpointable.Checkpointable): return control_flow_ops.no_op() def _gather_saveables_for_checkpoint(self): - """Looks to Checkpointable like a regular variable.""" + """Looks to Trackable like a regular variable.""" return { - checkpointable.VARIABLE_VALUE_KEY: + trackable.VARIABLE_VALUE_KEY: functools.partial(_CallbackSaveable, dtype=self._dtype, save_callback=self._save, @@ -117,7 +117,7 @@ def split_dependency(component_names, component_dtypes, may return `None`). Returns: - A dictionary mapping from names to Checkpointable objects. If one is + A dictionary mapping from names to Trackable objects. If one is reachable from an object as a dependency, the others should be too; adding dependencies on some but not all of the objects will result in errors. """ diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 664a4e76ab31bf31c7a57924e4af866f2d746804..9bc01059481ff69064e3f9c682a764146b79a250 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,9 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import base +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util def _split_variable_closure(variable): @@ -44,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(base.Checkpointable): +class SaveTensorSlicesAsDeps(base.Trackable): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -56,17 +56,17 @@ class SaveTensorSlicesAsDeps(base.Checkpointable): consume_restore_buffer_fn=_combine_variable_closure( self.combined)) for name, dep in split_dependencies.items(): - self._track_checkpointable(dep, name=name) + self._track_trackable(dep, name=name) -class HasRegularDeps(tracking.AutoCheckpointable): +class HasRegularDeps(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(tracking.AutoCheckpointable): +class OnlyOneDep(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index bac071c4cff383f60b707b6e42c13faf5e0ac948..faf90f018476b3c70a7bfa1346a5b590edbbddcd 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.training.tracking import util as trackable_utils def dot_graph_from_checkpoint(save_path): @@ -51,7 +51,7 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - object_graph = checkpointable_utils.object_metadata(save_path) + object_graph = trackable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() graph = 'digraph {\n' @@ -63,7 +63,7 @@ def dot_graph_from_checkpoint(save_path): slot_ids.add(slot_reference.slot_variable_node_id) for node_id, node in enumerate(object_graph.nodes): if (len(node.attributes) == 1 - and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY): + and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY): if node_id in slot_ids: color = 'orange' tooltip_prefix = 'Slot variable' diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 583e3bc442893d825c337d73fb999d1e586738a1..98a22d573fdb6172cde100df461d9ae520c2c483 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils try: import pydot # pylint: disable=g-import-not-at-top @@ -57,7 +57,7 @@ class DotGraphTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = resource_variable_ops.ResourceVariable(12) - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) optimizer.minimize(functools.partial(model, input_value)) checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index 390b3e7550b3d991269bb84707c3500f2fa33290..a4dea85efd98893c881abbd3f7ebda78755b8189 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -23,7 +23,7 @@ from __future__ import print_function from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver @@ -36,7 +36,7 @@ _allowed_symbols = [ 'ClusterResolver', 'SimpleClusterResolver', 'UnionClusterResolver', - 'GceClusterResolver', + 'GCEClusterResolver', 'KubernetesClusterResolver', 'TFConfigClusterResolver', 'TPUClusterResolver', diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 10d93549ebbd4f7e900796d0516b0af1744224af..ef1e9f11a07a5be6c0b181f5e0b80e0e2214f972 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -25,7 +25,7 @@ from __future__ import print_function from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver @@ -43,7 +43,7 @@ _allowed_symbols = [ 'ClusterResolver', 'SimpleClusterResolver', 'UnionClusterResolver', - 'GceClusterResolver', + 'GCEClusterResolver', 'KubernetesClusterResolver', 'TFConfigClusterResolver', 'TPUClusterResolver', diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 55e61155c683c928efab9bb018868faec3e3df8c..5b49116ff6a4e17a774ea79b33ae1b948ba9f187 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Stub file for GceClusterResolver to maintain backwards compatibility.""" +"""Stub file for GCEClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division @@ -23,13 +23,14 @@ from __future__ import print_function # existing OSS code will not be broken. # pylint: disable=unused-import -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented + _allowed_symbols = [ - 'GceClusterResolver', + 'GCEClusterResolver', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index b4f4b028f6f5c363e1c791985ad4ba3bcb62e0c6..9e9d85def83850fe61ed9dc36a60d828d46bd10a 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -154,12 +154,12 @@ suitable interface for project configuration and dependency setting. 5. Click on `Configure`, a new window will be prompted out, specify the generator mode for the project generation. For Windows, choose `Visual Studio Win64`, for Linux, choose `Unix Makefiles`, then - press `Finish`. Wait for a moment, the default project dependecy would + press `Finish`. Wait for a moment, the default project dependency would automatically generate. 6. There are a few options that you can customize your own build. **The setting here is crucial for a successful build, please check all items carefully.** - * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_ALL_KERNELS` should always be `on` * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you to test build (optional) * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't @@ -278,7 +278,7 @@ suitable interface for project configuration and dependency setting. `make -sj install` Where `` is the threads used for the compilation, change - to any integer less or equal to your computer's maxiumum thread number. + to any integer less or equal to your computer's maximum thread number. Headers are discretely located in the build folders. Tensorflow library can be found at ``, namely `tensorflow.so` (Linux) or @@ -294,11 +294,12 @@ Here we assume that you have basic knowledge on gathering dependency with `CMakeLists.txt` and the c++ file `main.cxx` 2. Fill in the `main.cxx` with the code provided in [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). -3. Fill in the `CMakeLists.txt` with following code: ``` cmake +3. Fill in the `CMakeLists.txt` with following code: + + ```cmake cmake_minimum_required (VERSION 2.6) project (tf_hello) # Tensorflow - find_package(Tensorflow REQUIRED) include_directories(${TENSORFLOW_INCLUDE_DIRS}) @@ -314,7 +315,8 @@ Here we assume that you have basic knowledge on gathering dependency with this CMakeList.txt, under development") endif() add_executable(tf_hello main.cxx) target_link_libraries(tf_hello - ${TENSORFLOW_LIBRARIES}) ``` + ${TENSORFLOW_LIBRARIES}) + ``` 4. Configure the folder with cmake-gui, an error should be prompted out, requesting you to locate the folder containing `TensorflowConfig.cmake`. diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index e570c09ecb5e64130ed6f3375a51d74850cc3989..30b4e2dbdee1117df12ae7ab8ce902e667234fb0 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) +set(GRPC_TAG 62688b6a05cc85b47fb77dd408611734253e47e2) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 21ae9a08a6bb8f71e5935ddde2d7bb3ed0cd8bbc..fd205a4b9b065a4756fbe3985694bb64b93b85e6 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -13,6 +13,7 @@ tensorflow/core/lib tensorflow/core/lib/core tensorflow/core/profiler tensorflow/core/protobuf +tensorflow/core/protobuf/tpu tensorflow/core/util tensorflow/examples tensorflow/examples/tutorials @@ -70,8 +71,9 @@ tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/tools/api tensorflow/python/tools/api/generator +tensorflow/python/tpu tensorflow/python/training -tensorflow/python/training/checkpointable +tensorflow/python/training/tracking tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -437,7 +439,6 @@ tensorflow/contrib/timeseries/python/timeseries/state_space_models tensorflow/contrib/tpu tensorflow/contrib/tpu/ops tensorflow/contrib/tpu/profiler -tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/python tensorflow/contrib/tpu/python/ops tensorflow/contrib/tpu/python/profiler diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 013180c89083748b240ad061b342300e886d3568..b4603206da419f44af0857b9b933eb7df1b255ff 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -1,6 +1,7 @@ tensorflow/core tensorflow/core/kernels/boosted_trees tensorflow/core/profiler +tensorflow/core/protobuf/tpu tensorflow/python tensorflow/contrib/boosted_trees/proto tensorflow/contrib/cloud/kernels @@ -12,7 +13,6 @@ tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto tensorflow/contrib/tensorboard/plugins/projector -tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/profiler tensorflow/contrib/training/python/training tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index d8d1cc3aa2ca4fff3c950654b7cbd7085c76010c..cc263d7995c01100f1c51436bcb584b600c8c161 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -125,9 +125,9 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" + "${tensorflow_source_dir}/tensorflow/core/protobuf/tpu/*.proto" "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index e32097ceddfec95b8677fc762d641d09078e5343..839682afdc6284b9fea53405f094106c25485e79 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -23,6 +23,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":xla", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", ], @@ -63,9 +64,9 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", - "//tensorflow/python:summary_op_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/distribute:summary_op_util", "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/contrib/compiler/__init__.py b/tensorflow/contrib/compiler/__init__.py index c4937dadfb8be3211377f0ae7017b95e7642dab0..797e5e8164e231e8b3806d40b32774711879b050 100644 --- a/tensorflow/contrib/compiler/__init__.py +++ b/tensorflow/contrib/compiler/__init__.py @@ -19,3 +19,4 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.compiler import jit +from tensorflow.contrib.compiler import xla diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 0f1be500f499ebba7e1907de663f8bbfa889bb17..2ccb27da12fe5692dd4360d5d52eb9950159c484 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -25,11 +25,11 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.jit.ops import xla_ops from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.distribute import summary_op_util from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import summary_op_util from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat @@ -144,6 +144,30 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): logging.warning('... and %d more', len(self._unsupported_ops) - _MAX_WARNING_LINES) + def _RemoveExternalControlEdges(self, op): + """Remove any external control dependency on this op.""" + internal_control_inputs = [] + external_control_inputs = [] + for x in op.control_inputs: + # pylint: disable=protected-access + is_internal_op = False + ctxt = x._get_control_flow_context() + while ctxt is not None: + if ctxt == self: + is_internal_op = True + break + ctxt = ctxt._outer_context + if is_internal_op: + internal_control_inputs.append(x) + else: + external_control_inputs.append(x) + # pylint: enable=protected-access + # pylint: disable=protected-access + op._remove_all_control_inputs() + op._add_control_inputs(internal_control_inputs) + # pylint: enable=protected-access + return internal_control_inputs, external_control_inputs + def AddOp(self, op): """Create op in XLACompileContext and notifies outer context recursively.""" # pylint: disable=protected-access @@ -193,11 +217,14 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD index eee4329acbeb38c9f37f79227aeb3acd46dce5e7..619153df67c90cea5a5082a411972948bac5fe90 100644 --- a/tensorflow/contrib/constrained_optimization/BUILD +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -42,11 +42,6 @@ py_test( name = "candidates_test", srcs = ["python/candidates_test.py"], srcs_version = "PY2AND3", - tags = [ - # TODO(b/121223093): Re-enable this test after fixing "Distribution - # should match known solution" errors. - "no_mac", - ], deps = [ ":constrained_optimization", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d..7ffb6894d37444fd78015b6c124c46f2855c1cde 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -1,5 +1,10 @@ +**NOTE: As tensorflow.contrib is being +[deprecated](https://github.com/tensorflow/community/pull/18), TFCO is moving to +its own repository on +[github](https://github.com/google-research/tensorflow_constrained_optimization).** + # ConstrainedOptimization (TFCO) TFCO is a library for optimizing inequality-constrained problems in TensorFlow. diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py index a4c49d48bc5c763489215261a909573af0f19055..280e9acd88638a9385bfd9128ba6d3739879aab2 100644 --- a/tensorflow/contrib/constrained_optimization/python/candidates_test.py +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -52,12 +52,12 @@ class CandidatesTest(test.TestCase): distribution = candidates.find_best_candidate_distribution( objective_vector, constraints_matrix) # Verify that the solution is a probability distribution. - self.assertTrue(np.all(distribution >= 0)) + self.assertTrue(np.all(distribution >= -1e-6)) self.assertAlmostEqual(np.sum(distribution), 1.0) # Verify that the solution satisfies the constraints. maximum_constraint_violation = np.amax( np.dot(constraints_matrix, distribution)) - self.assertLessEqual(maximum_constraint_violation, 0) + self.assertLessEqual(maximum_constraint_violation, 1e-6) # Verify that the solution matches that which we expect. expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 40e159b8fcbd1864284e208cb15d9ed96119f840..5c4c52766604b19864b0849fa23e7a7cafcc56e0 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -283,7 +283,7 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) # Use int32 or int64 based on tag_indices' dtype. if tag_indices.dtype == dtypes.int64: - offsets = math_ops.to_int64(offsets) + offsets = math_ops.cast(offsets, dtypes.int64) flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) unary_scores = array_ops.reshape( diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index f5219eb134d07c09b16a544f71d4c18986c19681..5ed80953a4930f0e9de75c2a0c96ce9fb4e2ab57 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -69,6 +69,8 @@ def RunLSTM(sess, time, num_layers=1, variable_seq_lengths=False, + time_major=True, + dynamic_shape_input=False, is_training=True, dropout=0., num_dirs=True, @@ -84,11 +86,14 @@ def RunLSTM(sess, random_seed.set_random_seed(0) np.random.seed(0) - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + shape = ([time, batch_size, input_size] + if time_major else [batch_size, time, input_size]) + inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + inputs_static = variable_scope.get_variable( + "inputs", initializer=inputs_np, dtype=dtype) + inputs_dynamic = array_ops.placeholder( + dtype, shape=[None, None, None], name="inputs") + inputs = inputs_dynamic if dynamic_shape_input else inputs_static initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -122,12 +127,12 @@ def RunLSTM(sess, cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, - inputs, + inputs_static, sequence_length=lengths, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), dtype=dtype, - time_major=True, + time_major=time_major, scope=None) # Convert to cudnn opaque param. @@ -135,35 +140,38 @@ def RunLSTM(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque([w, b]) - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) - cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + cu_initial_h_op = array_ops.expand_dims( + initial_h_op, axis=(0 if time_major else 1)) + cu_initial_c_op = array_ops.expand_dims( + initial_c_op, axis=(0 if time_major else 1)) cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, cu_initial_c_op, opaque_params, sequence_lengths=lengths, + time_major=time_major, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) # Remove the trivial 1st dimension. cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( - c=array_ops.squeeze(cu_c_op, axis=0), - h=array_ops.squeeze(cu_h_op, axis=0)) + c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1), + h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1)) if is_training: (inp_grad_op, hgrad_op, cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( - outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) (cu_inp_grad_op, cu_hgrad_op, cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) # Remove the trivial 1st dimension - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1) # Remove the trivial 1st dimension - cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -183,10 +191,12 @@ def RunLSTM(sess, (hgrad_op, cgrad_op), wgrad_op, bgrad_op ]) (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, - cu_bgrad) = sess.run([ - cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ]) + cu_bgrad) = sess.run( + [ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -205,7 +215,10 @@ def RunLSTM(sess, cu_bgrad) else: outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) - cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], + feed_dict=({ + inputs: inputs_np + } if dynamic_shape_input else None)) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -336,6 +349,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtype, variable_seq_lengths, + time_major, + dynamic_shape_input=False, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -347,7 +362,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -359,13 +376,16 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -375,16 +395,22 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtypes.float32, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -396,16 +422,21 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): dtypes.float16, rtol=5e-3, atol=5e-4, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -417,7 +448,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs) # h @@ -426,13 +459,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(state_tuple.c, cu_state_tuple.c) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -445,7 +482,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -457,13 +496,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") @@ -480,7 +523,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=0., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -493,7 +538,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=1., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -510,6 +557,8 @@ def RunGRU(sess, num_layers=1, is_training=True, variable_seq_lengths=False, + time_major=True, + dynamic_shape_input=False, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -524,11 +573,14 @@ def RunGRU(sess, random_seed.set_random_seed(0) np.random.seed(0) - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + shape = ([time, batch_size, input_size] + if time_major else [batch_size, time, input_size]) + inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + inputs_static = variable_scope.get_variable( + "inputs", initializer=inputs_np, dtype=dtype) + inputs_dynamic = array_ops.placeholder( + dtype, shape=[None, None, None], name="inputs") + inputs = inputs_dynamic if dynamic_shape_input else inputs_static initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -573,11 +625,11 @@ def RunGRU(sess, cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) outputs_op, h_op = rnn.dynamic_rnn( cell, - inputs, + inputs_static, sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, - time_major=True, + time_major=time_major, scope=None) ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] @@ -588,13 +640,15 @@ def RunGRU(sess, opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_h_op = array_ops.expand_dims( + initial_h_op, axis=(0 if time_major else 1)) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, array_ops.zeros_like(cu_initial_h_op), # not used opaque_params, sequence_lengths=lengths, + time_major=time_major, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_GRU) @@ -602,12 +656,12 @@ def RunGRU(sess, if is_training: (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, cib_grad_op, chb_grad_op) = gradients_impl.gradients( - outputs_op, [inputs, initial_h_op] + ws + bs) + outputs_op, [inputs_static, initial_h_op] + ws + bs) (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) # Remove the trivial 1st dimension - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -627,13 +681,15 @@ def RunGRU(sess, (gk_grad_op, cik_grad_op, chk_grad_op), (gb_grad_op, cib_grad_op, chb_grad_op) ]) - (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([ - cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, - (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), - (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) - ]) + (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run( + [ + cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) # Remove the trivial 1st dimension - cu_h = np.squeeze(cu_h, axis=0) + cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -651,9 +707,12 @@ def RunGRU(sess, cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) else: outputs, h = sess.run([outputs_op, h_op]) - cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op], + feed_dict=({ + inputs: inputs_np + } if dynamic_shape_input else None)) # Remove the trivial 1st dimension. - cu_h = np.squeeze(cu_h, axis=0) + cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -672,6 +731,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtype, variable_seq_lengths, + time_major, + dynamic_shape_input=False, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -683,7 +744,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -695,13 +758,16 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -711,16 +777,22 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtypes.float32, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -732,16 +804,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): dtypes.float16, rtol=5e-3, atol=5e-4, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -753,18 +830,24 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -777,20 +860,26 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): @@ -807,7 +896,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=0., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -820,7 +911,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=1., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index ca92c31236a7a3882415834eb32a994a120b6d2d..403f30909520dc5cd5f5919af843291fe1400b91 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -58,7 +58,7 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -709,7 +709,7 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): self._TestSaveRestoreHelper(CUDNN_RNN_RELU) -class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): +class CudnnRNNTestSaveRestoreTrackable(test_util.TensorFlowTestCase): def _VerifyCheckpoint( self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn, @@ -718,7 +718,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with ops.device("gpu:0"): cudnn_layer = cudnn_cell_fn() - cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer) + cudnn_checkpoint = trackable_utils.Checkpoint(cell=cudnn_layer) status = cudnn_checkpoint.restore(checkpoint_path) inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], dtype=dtypes.float32) @@ -726,7 +726,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): status.run_restore_ops() second_save_path = cudnn_checkpoint.save(checkpoint_prefix) restore_layer = compatible_cell_fn() - restore_layer_checkpoint = checkpointable_utils.Checkpoint( + restore_layer_checkpoint = trackable_utils.Checkpoint( cell=restore_layer) status = restore_layer_checkpoint.restore(second_save_path) current_state = restore_layer.zero_state(1, dtypes.float32) @@ -742,7 +742,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): self.assertAllClose(self.evaluate(restore_layer_output), self.evaluate(cudnn_output)[-1, -1:, ...]) - def _CheckpointableSingleCellUnidirectionalTestTemplate( + def _TrackableSingleCellUnidirectionalTestTemplate( self, single_cell_fn, cudnn_cell_fn): # Single-layer cuDNN cells with object-based checkpointing should be # checkpoint compatible with either single CudnnCompatible cells or @@ -759,7 +759,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_cell_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -775,10 +775,10 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testLSTMCheckpointableSingleLayer(self): + def testLSTMTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( @@ -788,19 +788,19 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testGRUCheckpointableSingleLayer(self): + def testGRUTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION with self.assertRaises(NotImplementedError): # TODO(allenl): Implement object-based saving for GRUs and other cells. - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units), cudnn_cell_fn=functools.partial( cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units, direction=direction, name="awesome_gru")) - def _CheckpointableMultiLayerTestTemplate( + def _TrackableMultiLayerTestTemplate( self, single_cell_fn, cudnn_cell_fn, num_layers): def _MultiCellFn(): @@ -819,7 +819,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -837,7 +837,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): num_units = 2 num_layers = 3 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableMultiLayerTestTemplate( + self._TrackableMultiLayerTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 86ad8ae8073714657c78badb1e0b4a6d8c8ed5f0..c6ab003d6cca6a88a5e9e7971ed9211cab59cee0 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -104,7 +104,7 @@ class _CudnnRNN(base_layer.Layer): # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile. with tf.Graph().as_default(): - single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) # NOTE: Even if there's only one layer, the cell needs to be wrapped in # MultiRNNCell. @@ -124,7 +124,7 @@ class _CudnnRNN(base_layer.Layer): # Inference subgraph for bidirectional RNN with tf.Graph().as_default(): - single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) cells_fw = [single_cell() for _ in range(num_layers)] cells_bw = [single_cell() for _ in range(num_layers)] @@ -378,20 +378,33 @@ class _CudnnRNN(base_layer.Layer): inputs, initial_state=None, sequence_lengths=None, + time_major=True, training=True): """Runs the forward step for the RNN model. Args: - inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. + inputs: `3-D` tensor. If `time_major` is True (default), the Tensor shape + is [time_len, batch_size, input_size]. If `time_major` is False, the + shape is [batch_size, time_len, input_size]. initial_state: a tuple of tensor(s) of shape - `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use + `[num_layers * num_dirs, batch_size, num_units]` if + `time_major` is True (default) or `[batch_size, num_layers * num_dirs, + num_units]` if `time_major` is False. If not provided, use zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. training: whether this operation will be used in training or inference. Returns: - output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]` + if `time_major` is True (default) or `[batch_size, time_len, + num_dirs * num_units]` if `time_major` is False. It is a `concat([fwd_output, bak_output], axis=2)`. output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. @@ -417,8 +430,8 @@ class _CudnnRNN(base_layer.Layer): else: # For model that doesn't take input_c, replace with a dummy tensor. c = array_ops.constant([], dtype=dtype) - outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, - sequence_lengths, training) + outputs, (output_h, output_c) = self._forward( + inputs, h, c, self.kernel, sequence_lengths, time_major, training) if self._rnn_mode == CUDNN_LSTM: return outputs, (output_h, output_c) else: @@ -482,7 +495,8 @@ class _CudnnRNN(base_layer.Layer): dropout=self._dropout, direction=self._direction) - def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training): + def _forward(self, inputs, h, c, opaque_params, sequence_lengths, time_major, + training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access inputs, h, @@ -491,6 +505,7 @@ class _CudnnRNN(base_layer.Layer): training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -518,8 +533,8 @@ class _CudnnRNN(base_layer.Layer): direction=self.direction, scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) - self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access - checkpointable=self, dtype=self._plain_dtype) + self._saveable._add_trackable_dependencies( # pylint: disable=protected-access + trackable=self, dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index f36e8d5022bc7e3f8268a161089153e5510dffc6..77afbeec0c57280fdd09ffe9019c6fe17d7b0de8 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking as checkpointable_lib +from tensorflow.python.training.tracking import tracking as trackable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -737,13 +737,13 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return state_ops.assign( self._variables, opaque_params, validate_shape=False) - def _checkpointable_save(self, save_buffer): + def _trackable_save(self, save_buffer): weights, biases = self.format_converter.opaque_to_tf_canonical( self._variables) for name, tensor in zip(self._param_names, weights + biases): save_buffer[name] = array_ops.identity(tensor) - def _checkpointable_restore(self, restore_buffer): + def _trackable_restore(self, restore_buffer): tensors = [ array_ops.identity(restore_buffer[name]) for name in self._param_names ] @@ -752,26 +752,26 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): restored_shapes=None # Unused ) - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. + def _add_trackable_dependencies(self, trackable, dtype): + """Add canonical weight dependencies to `trackable`. When saving or restoring, converts to or from the opaque buffer format. Weights are saved and loaded in the configuration expected by cuDNN-compatible cells. Args: - checkpointable: An object inheriting from `CheckpointableBase` to add + trackable: An object inheriting from `Trackable` to add dependencies too (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ split_dependencies = split_dependency.split_dependency( component_names=self._param_names, component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) + fill_save_buffer_fn=self._trackable_save, + consume_restore_buffer_fn=self._trackable_restore) + self._trackable_track_params(trackable, split_dependencies) - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Tracks parameters in a canonical configuration.""" return # NotImplementedError raised by the Layer. @@ -819,7 +819,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): tf_weights_names.append(prefix + "/kernel") tf_bias_names.append(prefix + "/bias") - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" biases = [] weights = [] @@ -833,12 +833,12 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # wrapping. kernel, = weights # pylint: disable=unbalanced-tuple-unpacking bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + trackable._track_trackable(kernel, name="kernel") # pylint: disable=protected-access + trackable._track_trackable(bias, name="bias") # pylint: disable=protected-access assert len(biases) == len(weights) for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.AutoCheckpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell = trackable_lib.AutoTrackable() + trackable._track_trackable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access cell.bias = bias cell.kernel = kernel @@ -956,6 +956,7 @@ def _cudnn_rnn(inputs, is_training, rnn_mode, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -964,10 +965,12 @@ def _cudnn_rnn(inputs, """Cudnn RNN. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -977,6 +980,11 @@ def _cudnn_rnn(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1017,6 +1025,14 @@ def _cudnn_rnn(inputs, } if sequence_lengths is not None: args["sequence_lengths"] = sequence_lengths + args["time_major"] = time_major + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) + elif time_major is False: + batch_size = array_ops.shape(inputs)[0] + max_time = array_ops.shape(inputs)[1] + sequence_lengths = array_ops.fill([batch_size], max_time) + args["sequence_lengths"] = sequence_lengths + args["time_major"] = time_major outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) @@ -1031,6 +1047,7 @@ def cudnn_lstm(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1039,15 +1056,26 @@ def cudnn_lstm(inputs, """Cudnn LSTM. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference - input_mode: indicate whether there is a linear projection between the + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. + input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) always applies a linear projection of input @@ -1060,17 +1088,13 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. - sequence_lengths: an int32 array representing the variable sequence lengths - in a batch. The size of the array has to equal the batch_size. Default to - None, in which case sequences in the batch are assumed to have the same - length, which is inferred from inputs. name: name of the operation. Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, - sequence_lengths, input_mode, direction, dropout, seed, - name) + sequence_lengths, time_major, input_mode, direction, + dropout, seed, name) def _cudnn_rnn_no_input_c(inputs, @@ -1079,6 +1103,7 @@ def _cudnn_rnn_no_input_c(inputs, is_training, rnn_mode, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1087,10 +1112,12 @@ def _cudnn_rnn_no_input_c(inputs, """Cudnn RNN w/o input_c. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -1098,6 +1125,11 @@ def _cudnn_rnn_no_input_c(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1116,9 +1148,9 @@ def _cudnn_rnn_no_input_c(inputs, outputs, output_h """ input_c = array_ops.constant([], dtype=input_h.dtype) - outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, - is_training, rnn_mode, sequence_lengths, - input_mode, direction, dropout, seed, name) + outputs, output_h, _ = _cudnn_rnn( + inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths, + time_major, input_mode, direction, dropout, seed, name) return outputs, output_h @@ -1127,6 +1159,7 @@ def cudnn_gru(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1135,10 +1168,12 @@ def cudnn_gru(inputs, """Cudnn GRU. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1153,6 +1188,11 @@ def cudnn_gru(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1163,8 +1203,8 @@ def cudnn_gru(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, - sequence_lengths, input_mode, direction, dropout, - seed, name) + sequence_lengths, time_major, input_mode, + direction, dropout, seed, name) def cudnn_rnn_relu(inputs, @@ -1176,14 +1216,17 @@ def cudnn_rnn_relu(inputs, dropout=0., seed=0, sequence_lengths=None, + time_major=True, name=None): """Cudnn RNN Relu. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1201,14 +1244,19 @@ def cudnn_rnn_relu(inputs, sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. name: name of the operation. Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_RELU, sequence_lengths, input_mode, - direction, dropout, seed, name) + CUDNN_RNN_RELU, sequence_lengths, time_major, + input_mode, direction, dropout, seed, name) def cudnn_rnn_tanh(inputs, @@ -1216,6 +1264,7 @@ def cudnn_rnn_tanh(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1224,10 +1273,12 @@ def cudnn_rnn_tanh(inputs, """Cudnn RNN Tanh. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1242,6 +1293,11 @@ def cudnn_rnn_tanh(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1252,8 +1308,8 @@ def cudnn_rnn_tanh(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_TANH, sequence_lengths, input_mode, - direction, dropout, seed, name) + CUDNN_RNN_TANH, sequence_lengths, time_major, + input_mode, direction, dropout, seed, name) def cudnn_rnn_opaque_params_to_canonical(rnn_mode, @@ -1537,22 +1593,32 @@ class _CudnnRNN(object): input_c, params, is_training=True, - sequence_lengths=None): + sequence_lengths=None, + time_major=True): """Runs the forward step for the RNN model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. - input_c: the initial hidden state for c. This is only relevant for LSTM. - A Tensor of the same shape as input_h. + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. A + Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' is + used. + Returns: output: the output sequence. output_h: the final state for h. @@ -1566,6 +1632,7 @@ class _CudnnRNN(object): is_training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -1666,14 +1733,17 @@ class CudnnLSTM(_CudnnRNN): input_c, params, sequence_lengths=None, + time_major=True, is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1681,6 +1751,12 @@ class CudnnLSTM(_CudnnRNN): lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1693,6 +1769,7 @@ class CudnnLSTM(_CudnnRNN): input_c, params, sequence_lengths=sequence_lengths, + time_major=time_major, is_training=is_training) return (output, output_h, output_c) @@ -1752,19 +1829,28 @@ class _CudnnRNNNoInputC(_CudnnRNN): input_h, params, sequence_lengths=None, + time_major=True, is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1777,6 +1863,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): is_training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 6c5f8c6b00975b3fba041271309a93cecd9f5057..077571fcd2091b3b7216c57627a11989f3db1fdf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -25,11 +25,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class AssertElementShapeTest(test_base.DatasetTestBase): def test_assert_element_shape(self): @@ -41,10 +43,12 @@ class AssertElementShapeTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.range(5).map(create_dataset) expected_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) - self.assertEqual(expected_shapes, dataset.output_shapes) + self.assertEqual(expected_shapes, + dataset_ops.get_legacy_output_shapes(dataset)) result = dataset.apply(batching.assert_element_shape(expected_shapes)) - self.assertEqual(expected_shapes, result.output_shapes) + self.assertEqual(expected_shapes, + dataset_ops.get_legacy_output_shapes(result)) iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer @@ -81,12 +85,14 @@ class AssertElementShapeTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) + self.assertEqual(unknown_shapes, + dataset_ops.get_legacy_output_shapes(dataset)) expected_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) result = dataset.apply(batching.assert_element_shape(expected_shapes)) - self.assertEqual(expected_shapes, result.output_shapes) + self.assertEqual(expected_shapes, + dataset_ops.get_legacy_output_shapes(result)) iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer @@ -111,7 +117,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) + self.assertEqual(unknown_shapes, + dataset_ops.get_legacy_output_shapes(dataset)) wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) @@ -139,7 +146,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): # Partial shapes are merged with actual shapes: actual_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) - self.assertEqual(actual_shapes, result.output_shapes) + self.assertEqual(actual_shapes, + dataset_ops.get_legacy_output_shapes(result)) iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer @@ -176,12 +184,14 @@ class AssertElementShapeTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) + self.assertEqual(unknown_shapes, + dataset_ops.get_legacy_output_shapes(dataset)) expected_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 4))) result = dataset.apply(batching.assert_element_shape(expected_shapes)) - self.assertEqual(expected_shapes, result.output_shapes) + self.assertEqual(expected_shapes, + dataset_ops.get_legacy_output_shapes(result)) iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer @@ -206,7 +216,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) + self.assertEqual(unknown_shapes, + dataset_ops.get_legacy_output_shapes(dataset)) wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index b9840b1ff1a3df5a05db0e64f436637220f49f80..220f9934b67d1d2a97f6c0fd4ba7779f011e1b09 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -27,12 +27,14 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.util import compat prefix_path = "tensorflow/core/lib" +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py index e7281d531870c75c638b5c48fa3fc6dc606a3623..78019fcc7d810da444f1407f3885d54e76a741c6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py @@ -25,10 +25,12 @@ from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 2527706709fae8e459aca3489324d4db3c784be6..95cf659a84ba560fb9e06ad47477dd69b17cae3a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -26,11 +26,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( @@ -230,7 +232,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=array_ops.expand_dims( math_ops.range(i, dtype=dtypes.int64), 1), - values=array_ops.fill([math_ops.to_int32(i)], i), + values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i), dense_shape=[i]) iterator = dataset_ops.make_initializable_iterator( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 8c60459ca81cd7a7e08d90339011c54275ea9c0b..f8bb942c0a54d0892f382b1779ff830ab04b8258 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.framework import with_shape from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.util import deprecation @@ -215,14 +216,14 @@ def assert_element_shape(expected_shapes): return nest.pack_sequence_as(elements, checked_tensors) def _apply_fn(dataset): - output_shapes = _merge_output_shapes(dataset.output_shapes, - expected_shapes) + output_shapes = _merge_output_shapes( + dataset_ops.get_legacy_output_shapes(dataset), expected_shapes) # pylint: disable=protected-access return batching._RestructuredDataset( dataset.map(_check_shape), - dataset.output_types, + dataset_ops.get_legacy_output_types(dataset), output_shapes=output_shapes, - output_classes=dataset.output_classes) + output_classes=dataset_ops.get_legacy_output_classes(dataset)) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 6708e01d08135a132b797e317cd2a241c3428f40..b3c2c984a9d9920cd1501bd1612757b23e92b7de 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops @@ -38,9 +37,7 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") - input_structure = structure.convert_legacy_structure( - input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) + input_structure = dataset_ops.get_structure(input_dataset) self._structure = input_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_sliding_window_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access @@ -59,7 +56,7 @@ class _SlideDataset(dataset_ops.UnaryDataset): None, "stride is deprecated, use window_shift instead", "stride") @deprecation.deprecated( None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, " - "stride=window_stride).flat_map(lambda x: x.batch(window.size))` " + "stride=window_stride).flat_map(lambda x: x.batch(window_size))` " "instead.") def sliding_window_batch(window_size, stride=None, diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 6ae3ec7fb0892db5434d2364064ade574dc21e38..4260cfbc40d416ea654dfcd661b358db2b81db32 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -15,16 +15,26 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -# TODO(priyag): Figure out testonly issues that are preventing us from -# including our tests in pip for now. +py_library( + name = "distribute_test_lib_pip", + visibility = ["//tensorflow:internal"], + deps = [ + ":combinations", + ":keras_correctness_test_lib", + ":keras_test_lib", + ":multi_worker_test_base", + ":single_loss_example", + ":strategy_test_lib", + ], +) -cuda_py_test( +distribute_py_test( name = "values_test", srcs = ["values_test.py"], - additional_deps = [ + main = "values_test.py", + deps = [ ":combinations", ":mirrored_strategy", - "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -37,9 +47,7 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", - ], - tags = [ - "no_pip", + "@absl_py//absl/testing:parameterized", ], ) @@ -51,18 +59,13 @@ cuda_py_test( ":mirrored_strategy", ":multi_worker_test_base", "@absl_py//absl/testing:parameterized", - "//tensorflow/core:protos_all_py", "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - tags = [ - "no_pip", - ], ) py_library( @@ -117,7 +120,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -126,17 +128,17 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:input_lib", - "//tensorflow/python/distribute:numpy_dataset", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", - "@six_archive//:six", + "//tensorflow/python/distribute:one_device_strategy", + ], +) + +cuda_py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + additional_deps = [ + ":strategy_test_lib", + ":combinations", + "//tensorflow/python/eager:test", ], ) @@ -145,39 +147,34 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:cross_device_utils", - "//tensorflow/python/distribute:input_lib", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:numpy_dataset", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute:collective_all_reduce_strategy", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) py_library( name = "strategy_test_lib", - testonly = 1, srcs = ["strategy_test_lib.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients_impl", + "//tensorflow/python:init_ops", "//tensorflow/python:layers", "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -187,15 +184,12 @@ py_library( py_library( name = "combinations", - testonly = 1, srcs = ["combinations.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ ":mirrored_strategy", ":one_device_strategy", + ":parameter_server_strategy", ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", @@ -204,6 +198,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", + "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", ], ) @@ -211,30 +206,12 @@ py_library( py_test( name = "combinations_test", srcs = ["combinations_test.py"], - tags = [ - "no_pip", - ], deps = [ ":combinations", "//tensorflow/python/eager:test", ], ) -py_test( - name = "one_device_strategy_test", - srcs = ["one_device_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":one_device_strategy", - ":strategy_test_lib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/eager:test", - ], -) - # TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", @@ -260,18 +237,13 @@ cuda_py_test( tags = [ "guitar", "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "multi_worker_test_base", - testonly = 1, srcs = ["multi_worker_test_base.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -283,6 +255,33 @@ py_library( ], ) +cuda_py_test( + name = "keras_multi_worker_test", + srcs = ["keras_multi_worker_test.py"], + additional_deps = [ + "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", + "//tensorflow/contrib/distribute/python:combinations", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:multi_worker_test_base", + "//tensorflow/contrib/distribute/python:parameter_server_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:distribute_config", + "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras", + "//tensorflow/python/keras:engine", + ], + shard_count = 3, + tags = [ + # TODO(b/124344198): Add "multi_and_single_gpu", + ], +) + py_library( name = "step_fn", srcs = ["step_fn.py"], @@ -298,18 +297,8 @@ py_library( srcs = ["tpu_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":one_device_strategy", "//tensorflow/contrib/tpu:tpu_lib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//tensorflow/python/distribute:input_lib", - "//tensorflow/python/distribute:numpy_dataset", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", + "//tensorflow/python/distribute:tpu_strategy", ], ) @@ -340,7 +329,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -350,7 +338,6 @@ distribute_py_test( main = "minimize_loss_test.py", tags = [ "multi_and_single_gpu", - "no_pip", ], deps = [ ":combinations", @@ -384,9 +371,6 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - tags = [ - "no_pip", - ], ) cuda_py_test( @@ -404,7 +388,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -427,7 +410,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -438,10 +420,10 @@ cuda_py_test( additional_deps = [ ":keras_test_lib", ], + shard_count = 4, tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -471,7 +453,6 @@ cuda_py_test( shard_count = 48, tags = [ "multi_and_single_gpu", - "no_pip", # TODO(b/118768923): Re-enable {a,m,t}san test. "noasan", "nomsan", @@ -499,7 +480,6 @@ distribute_py_test( main = "step_fn_test.py", tags = [ "multi_and_single_gpu", - "no_pip", ], deps = [ ":combinations", @@ -529,10 +509,10 @@ cuda_py_test( additional_deps = [ ":combinations", ":monitor", - ":one_device_strategy", ":single_loss_example", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", + "//tensorflow/python/distribute:one_device_strategy", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python:framework_ops", @@ -540,7 +520,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -557,15 +536,13 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - tags = [ - "no_pip", - ], ) cuda_py_test( name = "cross_device_ops_test", srcs = ["cross_device_ops_test.py"], additional_deps = [ + ":collective_all_reduce_strategy", ":combinations", ":multi_worker_test_base", ":mirrored_strategy", @@ -581,16 +558,15 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "keras_test_lib", - testonly = 1, srcs = [ "keras_backward_compat_test.py", "keras_test.py", + "keras_utils_test.py", ], deps = [ ":combinations", @@ -609,16 +585,33 @@ py_library( distribute_py_test( name = "keras_test", srcs = ["keras_test.py"], + full_precision = True, main = "keras_test.py", - shard_count = 16, + shard_count = 32, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883) + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_test_lib", + ], +) + +distribute_py_test( + name = "keras_utils_test", + srcs = ["keras_utils_test.py"], + full_precision = True, + main = "keras_utils_test.py", + shard_count = 32, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], deps = [ + ":keras_test", ":keras_test_lib", ], ) @@ -629,11 +622,9 @@ distribute_py_test( srcs = ["keras_backward_compat_test.py"], full_precision = True, main = "keras_backward_compat_test.py", - shard_count = 16, + shard_count = 31, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], @@ -644,13 +635,13 @@ distribute_py_test( py_library( name = "keras_correctness_test_lib", - testonly = 1, srcs = [ "keras_correctness_test_base.py", "keras_dnn_correctness_test.py", "keras_embedding_model_correctness_test.py", "keras_image_model_correctness_test.py", "keras_lstm_model_correctness_test.py", + "keras_stateful_lstm_model_correctness_test.py", ], deps = [ ":combinations", @@ -677,8 +668,6 @@ distribute_py_test( shard_count = 19, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], @@ -698,8 +687,6 @@ distribute_py_test( shard_count = 31, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], @@ -719,8 +706,6 @@ distribute_py_test( shard_count = 31, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], @@ -740,7 +725,25 @@ distribute_py_test( shard_count = 31, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_stateful_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_stateful_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_stateful_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", "no_pip", "no_windows_gpu", "notsan", @@ -756,7 +759,6 @@ distribute_py_test( main = "metrics_v1_test.py", tags = [ "multi_and_single_gpu", - "no_pip", ], deps = [ ":combinations", @@ -783,7 +785,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -794,7 +795,6 @@ cuda_py_test( additional_deps = [ ":combinations", "//tensorflow/python:client_testlib", - "//tensorflow/python:checkpoint_utils_test", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -802,7 +802,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -817,12 +816,11 @@ tf_xla_py_test( ], tags = [ "no_oss", - "no_pip", ], deps = [ ":tpu_strategy", "//tensorflow/compiler/tests:xla_test", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index 3ef8b9574a36730dcc1d8fd42b6c7f364d84bbed..79369fc6b93b4491c9744653d8d64c5c8a4de30d 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -25,6 +25,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations @@ -33,12 +34,34 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import checkpoint_utils_test +from tensorflow.python.training import saver as saver_lib + + +def _create_checkpoints(sess, checkpoint_dir): + checkpoint_prefix = os.path.join(checkpoint_dir, "model") + checkpoint_state_name = "checkpoint" + v1 = variable_scope.get_variable("var1", [1, 10]) + v2 = variable_scope.get_variable("var2", [10, 10]) + sess.run(variables.global_variables_initializer()) + v1_value, v2_value = sess.run([v1, v2]) + saver = saver_lib.Saver() + saver.save( + sess, + checkpoint_prefix, + global_step=0, + latest_filename=checkpoint_state_name) + return v1_value, v2_value class CheckpointUtilsWithDistributionStrategyTest( test.TestCase, parameterized.TestCase): + def _get_test_object(self): + checkpoint_dir = self.get_temp_dir() + with self.cached_session() as session: + v1, v2 = _create_checkpoints(session, checkpoint_dir) + return checkpoint_dir, v1, v2 + @combinations.generate(combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, @@ -49,10 +72,7 @@ class CheckpointUtilsWithDistributionStrategyTest( in_replica_mode=[True, False], mode=["graph"])) def testInitFromCheckpoint(self, distribution, in_replica_mode): - checkpoint_dir = self.get_temp_dir() - with self.cached_session() as session: - v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( - session, checkpoint_dir) + checkpoint_dir, v1_value, v2_value = self._get_test_object() def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) @@ -75,6 +95,39 @@ class CheckpointUtilsWithDistributionStrategyTest( else: init_and_verify(g) + @combinations.generate( + combinations.combine( + distribution=[ + combinations.default_strategy, combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus + ], + in_replica_mode=[True, False], + mode=["graph"])) + def testInitFromDifferentNameObject(self, distribution, in_replica_mode): + checkpoint_dir, v1_value, _ = self._get_test_object() + + def init_and_verify(g): + v1 = variable_scope.get_variable("new_var1", [1, 10]) + # Use string add to create new object in each replica + prefix = "new_" + suffix = "var1" + new_var1 = prefix + suffix + checkpoint_utils.init_from_checkpoint(checkpoint_dir, { + "var1": new_var1, + }) + with self.test_session(graph=g) as session: + session.run(variables.global_variables_initializer()) + self.assertAllEqual(v1_value, self.evaluate(v1)) + + with ops.Graph().as_default() as g, distribution.scope(): + if in_replica_mode: + distribution.extended.call_for_each_replica(init_and_verify, [g]) + else: + init_and_verify(g) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/checkpointing_test.py b/tensorflow/contrib/distribute/python/checkpointing_test.py index aa5b9f57b8a5bc12ee94399ec1fc5a55177a5b5d..eadf7233f2ae5ee50b71836ebfcc895163124ac2 100644 --- a/tensorflow/contrib/distribute/python/checkpointing_test.py +++ b/tensorflow/contrib/distribute/python/checkpointing_test.py @@ -30,15 +30,15 @@ from tensorflow.python.platform import test from tensorflow.python.training import adam as adam_v1 from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util as trackable_utils -class NonLayerCheckpointable(tracking.AutoCheckpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + super(NonLayerTrackable, self).__init__() + self.a_variable = trackable_utils.add_variable( self, name="a_variable", shape=[]) @@ -49,8 +49,8 @@ class Subclassed(training.Model): super(Subclassed, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -76,7 +76,7 @@ class TrainingCheckpointTests(xla_test.XLATestCase): with strategy.scope(): model = Subclassed() optimizer = adam_v1.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(checkpoint_management.latest_checkpoint( diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index eee07543251321ae0c9eef57851431cf97c65643..d4f76e3e7b937798c978f740e080f44a4a1cb418 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,30 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import cross_device_utils -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import input_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import numpy_dataset -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import collective_ops -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver # TODO(yuefengz): support in-graph replication. class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. + *** contrib version *** + It is similar to the MirroredStrategy but it uses collective ops for reduction. @@ -53,331 +42,45 @@ class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): distributed environment. """ - def __init__(self, num_gpus_per_worker=0): + def __init__(self, + num_gpus_per_worker=0, + communication=cross_device_ops_lib.CollectiveCommunication.AUTO): """Initializes the object. Args: num_gpus_per_worker: number of local GPUs or GPUs per worker, the default is 0 meaning CPU only. + communication: optional Enum of type + `distribute.experimental.CollectiveCommunication`. This provides a way + for the user to override the choice of collective op communication. + Possible values include `AUTO`, `RING`, and `NCCL`. """ super(CollectiveAllReduceStrategy, self).__init__( - CollectiveAllReduceExtended(self, num_gpus_per_worker)) + CollectiveAllReduceExtended( + self, + num_gpus_per_worker=num_gpus_per_worker, + communication=communication)) -class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): +class CollectiveAllReduceExtended( + collective_all_reduce_strategy.CollectiveAllReduceExtended): """Implementation of CollectiveAllReduceStrategy.""" - def __init__(self, container_strategy, num_gpus_per_worker): - distribute_lib.DistributionStrategyExtended.__init__( - self, container_strategy) - self._cross_device_ops = None - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local_worker(num_gpus_per_worker) - assert isinstance(self._get_cross_device_ops(), - cross_device_ops_lib.CollectiveAllReduce) - - def _initialize_local_worker(self, num_gpus_per_worker): - """Initializes the object for local training.""" - self._is_chief = True - self._num_workers = 1 - - if num_gpus_per_worker: - local_devices = tuple( - "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ) - else: - local_devices = ("/device:CPU:0",) - self._worker_device = device_util.canonicalize("/device:CPU:0") - self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) - - self._collective_keys = cross_device_utils.CollectiveKeys() - self._initialize_local(local_devices) - # TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce. - self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys) - - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info("CollectiveAllReduceStrategy with local_devices = %r", - local_devices) - - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): - """Initializes the object for multi-worker training.""" - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - if task_type not in ("chief", "worker"): - raise ValueError( - "Unrecognized task_type: %r, valid task types are: \"chief\", " - "\"worker\"." % task_type) - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) - if not self._num_workers: - raise ValueError("No `worker` or `chief` tasks can be found in " - "`cluster_spec`.") - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - - self._worker_device = "/job:%s/task:%d" % (task_type, task_id) - self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) - if num_gpus_per_worker: - local_devices = tuple( - "%s/device:GPU:%d" % (self._worker_device, i) - for i in range(num_gpus_per_worker) - ) - else: - local_devices = (self._worker_device,) - - self._collective_keys = cross_device_utils.CollectiveKeys() - self._initialize_local(local_devices) - self._input_workers = input_lib.InputWorkers( - self._device_map, [(self._worker_device, self.worker_devices)]) - self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = "/job:%s/task:%d" % (task_type, task_id) - - self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker CollectiveAllReduceStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_workers = %r, local_devices = %r", cluster_spec.as_dict(), - task_type, task_id, self._num_workers, local_devices) - - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - if colocate_with is None: - device_map = self._device_map - logical_device = 0 # TODO(josh11b): Get logical device from scope here. - elif isinstance(colocate_with, numpy_dataset.SingleDevice): - with ops.device(colocate_with.device): - return next_creator(*args, **kwargs) - else: - device_map = colocate_with.device_map - logical_device = colocate_with.logical_device - - def _real_mirrored_creator(devices, *args, **kwargs): - """Creates one MirroredVariable on the current worker.""" - unique_var_name = ops.get_default_graph().unique_name( - kwargs["name"], mark_as_used=False).rstrip("/") - # pylint: disable=protected-access - collective_instance_key = self._collective_keys.get_instance_key( - key_id=unique_var_name) - # Only the first device participles in the broadcast of initial values. - group_key = self._collective_keys.get_group_key([devices[0]]) - group_size = self._num_workers - if "initial_value" not in kwargs: - raise ValueError("Initial value must be specified.") - initial_value = kwargs["initial_value"] - if callable(initial_value): - initial_value_fn = initial_value - else: - initial_value_fn = lambda: initial_value - - value_list = [] - for i, d in enumerate(devices): - with ops.init_scope(), ops.device(d): - if i == 0: - # The initial value fn makes sure variables all initialized to - # same values. The first device of the chief worker will send their - # variable values to other workers. - def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring - with ops.device(device): - initial_value = initial_value_fn() - assert not callable(initial_value) - initial_value = ops.convert_to_tensor(initial_value) - - assert index == 0, index - if self._num_workers > 1: - if self._is_chief: - bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) - with ops.control_dependencies([bcast_send]): - return array_ops.identity(initial_value) - else: - return collective_ops.broadcast_recv( - initial_value.shape, initial_value.dtype, group_size, - group_key, collective_instance_key) - return initial_value - else: - # Give replicas meaningful distinct names: - var0name = value_list[0].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - - # Variables on non-first replica get initial values from the - # variables created on the first device of each worker. - def _overridden_initial_value_fn(device=d, index=i): - assert index > 0 - with ops.device(device): - if context.executing_eagerly(): - return array_ops.identity(value_list[0].value()) - else: - return array_ops.identity(value_list[0].initial_value) - - kwargs["initial_value"] = _overridden_initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - # Don't record operations (e.g. other variable reads) during - # variable creation. - with tape.stop_recording(): - v = next_creator(*args, **kwargs) - - if i == 0: - actual_var_name = v.name.split(":")[0] - assert unique_var_name == actual_var_name, "%r vs %r" % ( - unique_var_name, actual_var_name) - assert not isinstance(v, values.DistributedVariable) - value_list.append(v) - return value_list - - # pylint: disable=protected-access - return mirrored_strategy._create_mirrored_variable( - self._container_strategy(), device_map, logical_device, - _real_mirrored_creator, *args, **kwargs) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - # TODO(yuefengz): shard the dataset. - worker_index = 0 - return input_lib.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, worker_index, - prefetch_on_device=True) - - def _make_dataset_iterator(self, dataset): - return input_lib.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec is None: - input_pipeline_id = 0 - else: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - input_context = distribute_lib.InputContext( - num_input_pipelines=self._num_workers, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - - return input_lib.InputFunctionIterator( - input_fn, self._input_workers, [input_context]) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the object. - - Args: - session_config: a `tf.ConfigProto` - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type, such as "worker". - task_id: the current task id. - - Raises: - ValueError: if `task_type` is not in the `cluster_spec`. - """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, - task_type, task_id) - assert isinstance(self._get_cross_device_ops(), - cross_device_ops_lib.CollectiveAllReduce) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - # Enable the scoped allocator optimization for CollectiveOps. This - # optimization converts many small all-reduces into fewer larger - # all-reduces. - rewrite_options = updated_config.graph_options.rewrite_options - rewrite_options.scoped_allocator_optimization = ( - rewriter_config_pb2.RewriterConfig.ON) - # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = - # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we - # clear and then append. - del rewrite_options.scoped_allocator_opts.enable_op[:] - rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") - - if not self._cluster_spec: - return updated_config - - assert self._task_type - assert self._task_id is not None - - # Collective group leader is needed for collective ops to coordinate - # workers. - if "chief" in self._cluster_spec.jobs: - updated_config.experimental.collective_group_leader = ( - "/job:chief/replica:0/task:0") - else: - if "worker" not in self._cluster_spec.jobs: - raise ValueError( - "You must have `chief` or `worker` jobs in the `cluster_spec`.") - updated_config.experimental.collective_group_leader = ( - "/job:worker/replica:0/task:0") - - # The device filters prevent communication between workers. - del updated_config.device_filters[:] - updated_config.device_filters.append( - "/job:%s/task:%d" % (self._task_type, self._task_id)) - - return updated_config - - @property - def experimental_between_graph(self): - return True - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return self._is_chief - - @property - def should_save_summary(self): - return self._is_chief - - @property - def _num_replicas_in_sync(self): - return len(self.worker_devices) * self._num_workers - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. - - `distribute_dataset` and `make_input_fn_iterator` assume per-replica - batching. - - Returns: - Boolean. - """ - return True + def __init__(self, + container_strategy, + num_gpus_per_worker, + communication): + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_id=tfconfig.task_id, + num_accelerators={"GPU": num_gpus_per_worker}, + rpc_layer=tfconfig.rpc_layer) + super(CollectiveAllReduceExtended, self).__init__( + container_strategy, + communication=communication, + cluster_resolver=cluster_resolver) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 9b6236fd9f89ec30c1234c846930a05d9c32e99d..2dab18791ca3e67c7b2494eaf90295b9f85c0cbb 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -29,9 +29,14 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_all_reduce_strategy as core_collective_all_reduce_strategy +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -49,6 +54,57 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.training import adam from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec + + +class MockCollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCollectiveAllReduceStrategy, self).__init__( + core_collective_all_reduce_strategy.CollectiveAllReduceExtended( + self, + communication=cross_device_ops_lib.CollectiveCommunication.AUTO, + cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + use_core_strategy=False): + sess_config = config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_id=task_id, + num_accelerators={'GPU': num_gpus}) + target = 'grpc://' + cluster_spec[task_type][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators={'GPU': num_gpus}) + target = '' + + strategy = MockCollectiveAllReduceStrategy(cluster_resolver) + sess_config = strategy.update_config_proto(sess_config) + else: + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + if task_type and task_id is not None: + strategy.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[task_type][task_id] + else: + target = '' + + return strategy, target, sess_config class CollectiveAllReduceStrategyTestBase( @@ -64,16 +120,18 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() - def _get_test_object(self, task_type, task_id, num_gpus=0): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus) - session_config = config_pb2.ConfigProto() - if task_type and task_id is not None: - distribution.configure( - session_config=session_config, - cluster_spec=self._cluster_spec, - task_type=task_type, - task_id=task_id) + def _get_test_object(self, + task_type, + task_id, + num_gpus=0, + use_core_strategy=False): + strategy, target, session_config = create_test_objects( + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id, + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, @@ -81,16 +139,16 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution.extended._collective_keys = collective_keys - distribution.extended._cross_device_ops._collective_keys = ( - collective_keys) - if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][ - task_id], session_config - else: - return distribution, '', session_config + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = (collective_keys) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + return strategy, target, session_config + + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -118,7 +176,7 @@ class CollectiveAllReduceStrategyTestBase( def update(v, g): return v.assign_sub(0.05 * g, use_locking=True) - one = d.broadcast(constant_op.constant([[1.]])) + one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" @@ -158,7 +216,11 @@ class CollectiveAllReduceStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_complex_model(self, task_type, task_id, num_gpus): + def _test_complex_model(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) @@ -204,13 +266,17 @@ class CollectiveAllReduceStrategyTestBase( target=master_target) as sess: with d.scope(): train_op = d.extended.call_for_each_replica(model_fn) - train_op = d.group(d.unwrap(train_op)) + train_op = d.group(d.experimental_local_results(train_op)) sess.run(variables.global_variables_initializer()) sess.run(train_op) return True - def _test_variable_initialization(self, task_type, task_id, num_gpus): + def _test_variable_initialization(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -228,7 +294,7 @@ class CollectiveAllReduceStrategyTestBase( x = distribution.extended.call_for_each_replica(model_fn) reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) - x = distribution.unwrap(x)[0] + x = distribution.experimental_local_results(x)[0] sess.run(variables.global_variables_initializer()) @@ -239,8 +305,14 @@ class CollectiveAllReduceStrategyTestBase( reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + test_reinitialize=True, + use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) devices = distribution.extended.worker_devices @@ -263,13 +335,14 @@ class CollectiveAllReduceStrategyTestBase( for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - sess.run(iterator.initialize()) + if test_reinitialize: + sess.run(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) class DistributedCollectiveAllReduceStrategyTest( @@ -283,71 +356,116 @@ class DistributedCollectiveAllReduceStrategyTest( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) - def test_num_replicas_in_sync(self): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - distribution.configure(cluster_spec=self._cluster_spec, task_type='worker', - task_id=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + distribution, _, _ = create_test_objects( + cluster_spec=self._cluster_spec, + task_type='worker', + task_id=0, + num_gpus=2, + use_core_strategy=use_core_strategy) num_workers = len(self._cluster_spec.get('chief', []) + self._cluster_spec.get('worker', [])) self.assertEqual(2 * num_workers, distribution.num_replicas_in_sync) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testVariableInitialization(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testVariableInitialization(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, - num_gpus=num_gpus) + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testComplexModel(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testComplexModel(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( - self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + self._test_complex_model, + self._cluster_spec, + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testMakeInputFnIterator(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_dataset=[True, False], + use_core_strategy=[True, False])) + def DISABLED_testMakeInputFnIterator(self, num_gpus, use_dataset, + use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next # We use CPU as the device when num_gpus = 0 devices_per_worker = max(1, num_gpus) expected_values = [[i+j for j in range(devices_per_worker)] for i in range(0, 100, devices_per_worker)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=3*devices_per_worker, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) - def testUpdateConfigProto(self): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - distribution.configure( - cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProto(self, use_core_strategy): + strategy, _, _ = self._get_test_object( + task_type='worker', + task_id=1, + num_gpus=2, + use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) rewrite_options = config_proto.graph_options.rewrite_options rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify group leader self.assertEqual('/job:worker/replica:0/task:0', @@ -363,6 +481,24 @@ class DistributedCollectiveAllReduceStrategyTest( self.assertEqual(['CollectiveReduce'], new_rewrite_options.scoped_allocator_opts.enable_op) + @combinations.generate(combinations.combine(mode=['eager'])) + def testEnableCollectiveOps(self): + mock_called = [False] + + # pylint: disable=dangerous-default-value + def mock_enable_collective_ops(server_def, mock_called=mock_called): + self.assertEqual('worker', server_def.job_name) + self.assertEqual(1, server_def.task_index) + self.assertEqual('grpc', server_def.protocol) + mock_called[0] = True + + with test.mock.patch.object(context.context(), 'enable_collective_ops', + mock_enable_collective_ops): + strategy, _, _ = self._get_test_object( + task_type='worker', task_id=1, num_gpus=2, use_core_strategy=True) + self.assertTrue(strategy.extended._std_server_started) + self.assertTrue(mock_called[0]) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -406,80 +542,127 @@ class LocalCollectiveAllReduceStrategy( @combinations.generate( combinations.combine( - mode=['graph', 'eager'], num_gpus=[2, 4], required_gpus=2)) - def testMinimizeLoss(self, num_gpus): + mode=['graph', 'eager'], + num_gpus=[2, 4], + required_gpus=2, + use_core_strategy=[True, False])) + def testMinimizeLoss(self, num_gpus, use_core_strategy): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if context.executing_eagerly(): - strategy, _, _ = self._get_test_object(None, None, num_gpus) + strategy, _, _ = self._get_test_object( + None, None, num_gpus, use_core_strategy=use_core_strategy) self._test_minimize_loss_eager(strategy) else: - self._test_minimize_loss_graph(None, None, num_gpus) + self._test_minimize_loss_graph( + None, None, num_gpus, use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[2, 4], required_gpus=2)) - def testComplexModel(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[2, 4], + required_gpus=2, + use_core_strategy=[True, False])) + def testComplexModel(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_complex_model(None, None, num_gpus) + self._test_complex_model( + None, None, num_gpus, use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph', 'eager'], required_gpus=2)) - def testMakeInputFnIterator(self): + combinations.combine( + mode=['graph', 'eager'], + required_gpus=2, + use_dataset=[True, False], + use_core_strategy=[True, False])) + def DISABLED_testMakeInputFnIterator(self, use_dataset, use_core_strategy): num_gpus = 2 - dataset_fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) + else: + def fn(): + dataset = dataset_ops.Dataset.range(5 * num_gpus) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [range(i, i + num_gpus) for i in range(0, 10, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) - def testAllReduceSum(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSum(self, use_core_strategy): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object(None, None, num_gpus=2) + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) with self.cached_session(config=config, target=target): self._test_all_reduce_sum(distribution) - def testAllReduceSumGradients(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSumGradients(self, use_core_strategy): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object(None, None, num_gpus=2) + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) with self.cached_session(config=config, target=target): self._test_all_reduce_sum_gradients(distribution) - def testAllReduceSumGradientTape(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSumGradientTape(self, use_core_strategy): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object(None, None, num_gpus=2) + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) with self.cached_session(config=config, target=target): self._test_all_reduce_sum_gradient_tape(distribution) - def testAllReduceMean(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMean(self, use_core_strategy): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object(None, None, num_gpus=2) + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) with self.cached_session(config=config, target=target): self._test_all_reduce_mean(distribution) - def testAllReduceMeanGradients(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMeanGradients(self, use_core_strategy): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object(None, None, num_gpus=2) + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) with self.cached_session(config=config, target=target): self._test_all_reduce_mean_gradients(distribution) - def testAllReduceMeanGradientTape(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMeanGradientTape(self, use_core_strategy): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object(None, None, num_gpus=2) + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) with self.cached_session(config=config, target=target): self._test_all_reduce_mean_gradient_tape(distribution) - def testNumpyIterator(self): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testNumpyIterator(self, use_core_strategy): num_gpus = 2 if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - strategy, _, _ = self._get_test_object(None, None, num_gpus) + strategy, _, _ = self._get_test_object( + None, None, num_gpus=num_gpus, use_core_strategy=use_core_strategy) self._test_numpy_iterator(strategy) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 304762e69e383ae51b914eb78e84cb3571cb12ed..250339fad7afb456869568f454f21ff470f7bc9d 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -45,17 +45,23 @@ import types import unittest from absl.testing import parameterized import six - -from tensorflow.contrib import cluster_resolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib -from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib -from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib +from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import one_device_strategy as one_device_lib +from tensorflow.python.distribute import tpu_strategy as tpu_lib +from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 +from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2 +from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2 +from tensorflow.python.tpu import device_assignment as device_assignment_lib +from tensorflow.python.tpu import tpu_strategy_util from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -226,7 +232,7 @@ def combine(**kwargs): if not kwargs: return [OrderedDict()] - sort_by_key = lambda k: k[0][0] + sort_by_key = lambda k: k[0] kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) first = list(kwargs.items())[0] @@ -321,11 +327,18 @@ class NamedDistribution(object): return self._required_tpu -def _get_tpu_strategy_creator(steps_per_run, **kwargs): +def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs): def _create_tpu_strategy(): - resolver = cluster_resolver.TPUClusterResolver("") - tpu_lib.initialize_tpu_system(resolver) + resolver = tpu_cluster_resolver.TPUClusterResolver("") + topology = tpu_strategy_util.initialize_tpu_system(resolver) + device_assignment = None + if use_single_core: + device_assignment = device_assignment_lib.DeviceAssignment( + topology, core_assignment=device_assignment_lib. + SINGLE_CORE_ASSIGNMENT) + strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run, + device_assignment=device_assignment, **kwargs) return strategy return _create_tpu_strategy @@ -339,20 +352,22 @@ default_strategy = NamedDistribution( one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) +one_device_strategy_gpu = NamedDistribution( + "OneDeviceGPU", lambda: one_device_lib.OneDeviceStrategy("/gpu:0"), + required_gpus=1) tpu_strategy = NamedDistribution( "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True) -# TODO(b/122327153): Remove below two NamedDistributions. -tpu_strategy_loop_on_device = NamedDistribution( - "TPULoopOnDevice", _get_tpu_strategy_creator( - steps_per_run=2, _disable_training_loop_on_host=True), +tpu_strategy_one_core = NamedDistribution( + "TPUOneCore", _get_tpu_strategy_creator( + steps_per_run=2, use_single_core=True), required_tpu=True) -tpu_strategy_one_step_loop_on_device = NamedDistribution( - "TPUOneStepLoopOnDevice", _get_tpu_strategy_creator( - steps_per_run=1, _disable_training_loop_on_host=True), +tpu_strategy_one_step_one_core = NamedDistribution( + "TPUOneStepOneCore", _get_tpu_strategy_creator( + steps_per_run=1, use_single_core=True), required_tpu=True) mirrored_strategy_with_one_cpu = NamedDistribution( @@ -385,6 +400,11 @@ core_mirrored_strategy_with_two_gpus = NamedDistribution( "CoreMirrored2GPUs", lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) +parameter_server_strategy_with_two_gpus = NamedDistribution( + "ParameterServer2GPUs", + lambda: parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2), + required_gpus=2) gradient_descent_optimizer_v1_fn = NamedObject( @@ -404,10 +424,20 @@ gradient_descent_optimizer_v2_fn = NamedObject( adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1.0)) optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] +gradient_descent_optimizer_keras_v2_fn = NamedObject( + "GradientDescentKerasV2", + lambda: gradient_descent_keras_v2.SGD(0.2)) +adagrad_optimizer_keras_v2_fn = NamedObject( + "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001)) +adam_optimizer_keras_v2_fn = NamedObject( + "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0)) +rmsprop_optimizer_keras_v2_fn = NamedObject( + "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001)) + graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 86aa48cea889c6c2ce169b18bcabb6d08890fbed..9f3deadbec98c4f66061ca29b4d29a74b8de40b1 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -42,6 +42,14 @@ class TestingCombinationsTest(test.TestCase): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_arguments_sorted(self): + self.assertEqual([ + OrderedDict([("aa", 1), ("ab", 2)]), + OrderedDict([("aa", 1), ("ab", 3)]), + OrderedDict([("aa", 2), ("ab", 2)]), + OrderedDict([("aa", 2), ("ab", 3)]) + ], combinations.combine(ab=[2, 3], aa=[1, 2])) + def test_combine_single_parameter(self): self.assertEqual([{ "a": 1, diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py index 54cce2988383fcf5e063726948fbbf62c7094ce5..2b8e0197961ae37b67dc8958054a03e164242dcd 100644 --- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -23,6 +23,7 @@ import itertools from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base @@ -204,15 +205,15 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): reduction_to_one_combinations = combinations.combine( cross_device_ops=[ combinations.NamedObject( - "DefaultReductionToOneDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + "DefaultReductionToOneDevice", + cross_device_ops_lib.ReductionToOneDevice()), combinations.NamedObject( "ReductionToCPUDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDevice( reduce_to_device=_cpu_device)), combinations.NamedObject( "AccumulateNCrossDeviceOp", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDevice( accumulation_fn=math_ops.accumulate_n)), ], distribution=[ @@ -228,20 +229,23 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): combinations.NamedObject( "AllReduce", cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_device_ops_lib.AllReduceCrossDeviceOps( - "hierarchical_copy", 8, 0, 0)), combinations.NamedObject( "AllReduceNoGradientRepacking", cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + combinations.NamedObject("NcclAllReduce", + cross_device_ops_lib.NcclAllReduce()), + combinations.NamedObject( + "HierarchicalCopy", + cross_device_ops_lib.HierarchicalCopyAllReduce(8)), combinations.NamedObject( "HierarchicalCopyAggregateSmallTensors", cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 0, 100, 10)) ], - distribution=[combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus + ], mode=["graph", "eager"]) @combinations.generate(reduction_to_one_combinations + allreduce_combinations) @@ -306,8 +310,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): combinations.combine( cross_device_ops_instance=[ combinations.NamedObject( - "ReductionToOneDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + "ReductionToOneDevice", + cross_device_ops_lib.ReductionToOneDevice()), combinations.NamedObject( "AllReduceCrossDeviceOps", cross_device_ops_lib.AllReduceCrossDeviceOps()) @@ -426,6 +430,9 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_device_ops, distribution) +NUM_WORKERS = 3 + + class MultiWorkerCollectiveAllReduceTest( multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): @@ -433,9 +440,9 @@ class MultiWorkerCollectiveAllReduceTest( @classmethod def setUpClass(cls): - """Create a local cluster with 2 workers.""" + """Create a local cluster with 3 workers.""" cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=0) + num_workers=NUM_WORKERS, num_ps=0) def setUp(self): super(MultiWorkerCollectiveAllReduceTest, self).setUp() @@ -443,7 +450,12 @@ class MultiWorkerCollectiveAllReduceTest( # collective key base for different tests. MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 - def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + def _get_test_objects(self, + task_type, + task_id, + num_gpus=0, + use_strategy_object=False, + local_mode=False): collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + MultiWorkerCollectiveAllReduceTest.collective_key_base, @@ -452,16 +464,24 @@ class MultiWorkerCollectiveAllReduceTest( instance_key_with_id_start=num_gpus * 10000 + MultiWorkerCollectiveAllReduceTest.collective_key_base) if local_mode: - collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( - 1, num_gpus, collective_keys=collective_keys) if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] else: devices = ["/device:CPU:0"] - return collective_all_reduce_ops, devices, "" + + if use_strategy_object: + # Still using contrib CollectiveAllReduceStrategy because we can specify + # num_gpus in its constructor. + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = collective_keys + return strategy, devices, "" + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + return collective_all_reduce_ops, devices, "" else: - collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( - 3, num_gpus, collective_keys=collective_keys) if num_gpus: devices = [ "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) @@ -469,8 +489,23 @@ class MultiWorkerCollectiveAllReduceTest( ] else: devices = ["/job:%s/task:%d" % (task_type, task_id)] - return (collective_all_reduce_ops, devices, - "grpc://" + self._cluster_spec[task_type][task_id]) + + if use_strategy_object: + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + strategy.configure( + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = collective_keys + return (strategy, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + NUM_WORKERS, num_gpus, collective_keys=collective_keys) + return (collective_all_reduce_ops, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) def _assert_values_equal(self, left, right, sess): if isinstance(left, list): @@ -490,9 +525,18 @@ class MultiWorkerCollectiveAllReduceTest( for l, r in zip(left_values, right_values): self.assertEqual(l, r) - def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + def _test_reduction(self, + task_type, + task_id, + num_gpus, + use_strategy_object=False, + local_mode=False): collective_all_reduce, devices, master_target = self._get_test_objects( - task_type, task_id, num_gpus, local_mode=local_mode) + task_type, + task_id, + num_gpus, + use_strategy_object=use_strategy_object, + local_mode=local_mode) if local_mode: num_workers = 1 worker_device = None @@ -500,6 +544,27 @@ class MultiWorkerCollectiveAllReduceTest( num_workers = len(self._cluster_spec.get("chief", [])) + len( self._cluster_spec.get("worker", [])) worker_device = "/job:%s/task:%d" % (task_type, task_id) + + def _reduce(test_object, reduce_op, per_replica, destinations): + if use_strategy_object: + with test_object.scope(): + # Mimic the behavior that distribution strategy usually strips the + # wrapper if there is only one value. + if len(per_replica.values) == 1: + per_replica = per_replica.values[0] + return test_object.extended.reduce_to(reduce_op, per_replica, + destinations) + else: + return test_object.reduce(reduce_op, per_replica, destinations) + + def _batch_reduce(test_object, reduce_op, value_destination_pairs): + if use_strategy_object: + with test_object.scope(): + return test_object.extended.batch_reduce_to(reduce_op, + value_destination_pairs) + else: + return test_object.batch_reduce(reduce_op, value_destination_pairs) + with ops.Graph().as_default(), \ ops.device(worker_device), \ self.cached_session(target=master_target) as sess: @@ -524,26 +589,30 @@ class MultiWorkerCollectiveAllReduceTest( # test reduce() for destinations in all_destinations: self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.MEAN, per_replica, - destinations=destinations), - _fake_mirrored(mean, destinations), sess) + destinations=destinations), _fake_mirrored(mean, destinations), + sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.MEAN, per_replica_2, - destinations=destinations), - _fake_mirrored(mean_2, destinations), sess) + destinations=destinations), _fake_mirrored( + mean_2, destinations), sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), @@ -553,17 +622,13 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, - [(per_replica, d1), - (per_replica_2, d2)]), - [ - _fake_mirrored(mean, d1), - _fake_mirrored(mean_2, d2) - ], sess) + _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), + [_fake_mirrored(mean, d1), + _fake_mirrored(mean_2, d2)], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, - [(per_replica, d1), - (per_replica_2, d2)]), + _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.SUM, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices) * num_workers, d1), _fake_mirrored(mean_2 * len(devices) * num_workers, d2) @@ -572,18 +637,36 @@ class MultiWorkerCollectiveAllReduceTest( return True @combinations.generate( - combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) - def testReductionDistributed(self, num_gpus): + combinations.combine( + mode=["graph"], + num_gpus=[0, 1, 2], + required_gpus=1, + use_strategy_object=[True, False])) + def testReductionDistributed(self, num_gpus, use_strategy_object): if context.num_gpus() < num_gpus: return - self._run_between_graph_clients(self._test_reduction, self._cluster_spec, - num_gpus) + self._run_between_graph_clients( + self._test_reduction, + self._cluster_spec, + num_gpus, + use_strategy_object=use_strategy_object) # Collective ops doesn't support strategy with one device. - def testReductionLocal(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=["graph"], + num_gpus=[2], + required_gpus=2, + use_strategy_object=[True, False])) + def testReductionLocal(self, num_gpus, use_strategy_object): if context.num_gpus() < num_gpus: return - self._test_reduction(None, None, num_gpus, local_mode=True) + self._test_reduction( + None, + None, + num_gpus, + use_strategy_object=use_strategy_object, + local_mode=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index e17085628ba6d1dfc79839fd824801723f07a518..1ff1e7c1d255492e0535175dae7594d2ceb4010b 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -22,7 +22,6 @@ import shutil import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.optimizer_v2 import adagrad @@ -117,7 +116,7 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) predictions = np.array([ x[prediction_keys.PredictionKeys.PREDICTIONS] diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 3f55a8a1c8b88d1b8e4031547fa3fbe519983630..e8513943e8d3bb0afa7b468528bfc524fc6a5504 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -34,6 +34,7 @@ from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import estimator_training as dc_training from tensorflow.python.distribute.distribute_config import DistributeConfig @@ -287,6 +288,34 @@ class DistributeCoordinatorIntegrationTest( cluster_spec) self._inspect_train_and_eval_events(estimator) + @combinations.generate( + combinations.combine( + mode=["graph"], + eval_distribute_class=[ + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + parameter_server_strategy.ParameterServerStrategy, + ], + required_gpus=[0, 1])) + def test_complete_flow_standalone_client_collective_nccl( + self, eval_distribute_class): + train_distribute = ( + collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=context.num_gpus(), + communication=cross_device_ops_lib.CollectiveCommunication.NCCL)) + + if eval_distribute_class: + eval_distribute = self._get_strategy_object(eval_distribute_class) + else: + eval_distribute = None + + cluster_spec = copy.deepcopy(self._cluster_spec) + cluster_spec.pop("ps", None) + estimator = self._complete_flow(train_distribute, eval_distribute, + cluster_spec) + self._inspect_train_and_eval_events(estimator) + @combinations.generate( combinations.combine( mode=["graph"], @@ -347,7 +376,7 @@ class DistributeCoordinatorIntegrationTest( parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) - def test_complete_flow_indepedent_worker_between_graph( + def test_complete_flow_independent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): @@ -399,8 +428,8 @@ class DistributeCoordinatorIntegrationTest( mirrored_strategy.CoreMirroredStrategy ], required_gpus=[0, 1])) - def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, - eval_distribute_cls): + def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls, + eval_distribute_cls): train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 84b106545e1326fddd3ed299462534af982dc102..58bede801ff13bb60ed4ada4810eb8ce2dbcb0a3 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -31,6 +31,12 @@ py_binary( py_binary( name = "keras_mnist", + srcs = ["keras_mnist.py"], + deps = [":keras_mnist_lib"], +) + +py_library( + name = "keras_mnist_lib", srcs = [ "keras_mnist.py", ], @@ -39,3 +45,25 @@ py_binary( "//third_party/py/numpy", ], ) + +py_binary( + name = "mnist_eager_multigpu", + srcs = [ + "mnist_eager_multigpu.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "mnist_tf1_tpu", + srcs = [ + "mnist_tf1_tpu.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py new file mode 100644 index 0000000000000000000000000000000000000000..c045a5586b9dad371d8c505f9cac4b792dd157fd --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py @@ -0,0 +1,169 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run MNIST on multiple GPUs on using MirroredStrategy with eager execution. + +By default, runs on all available GPUs, or CPU if no GPUs are available. + +NOTE: Currently, this takes more time than when running MNIST in eager without +MirroredStrategy because of a number overheads. Therefore, this is just a +proof of concept right now and cannot be used to actually scale up training. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +import numpy as np +import tensorflow.compat.v2 as tf + +flags.DEFINE_integer("num_gpus", None, "How many GPUs should we run on?" + "Defaults to all available GPUs, otherwise CPU.") +flags.DEFINE_integer("batch_size", 64, + "What should be the size of each batch?") +flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?") +flags.DEFINE_float("learning_rate", 0.01, "Learning Rate") +flags.DEFINE_float("momentum", 0.5, "SGD momentum") +flags.DEFINE_boolean("use_function", False, + "Should we wrap the step in a tf.function.") + +FLAGS = flags.FLAGS +NUM_TRAIN_IMAGES = 60000 + + +def create_model(): + max_pool = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding="same") + # The model consists of a sequential chain of layers, so tf.keras.Sequential + # (a subclass of tf.keras.Model) makes for a compact description. + return tf.keras.Sequential([ + tf.keras.layers.Reshape( + target_shape=[28, 28, 1], + input_shape=(28, 28,)), + tf.keras.layers.Conv2D(2, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Conv2D(4, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.4), + tf.keras.layers.Dense(10)]) + + +def compute_loss(logits, labels): + loss = tf.reduce_sum( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels)) + # Scale loss by global batch size. + return loss * (1. / FLAGS.batch_size) + + +def mnist_datasets(): + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32. + x_train, x_test = x_train / np.float32(255), x_test / np.float32(255) + y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64) + # TODO(priyag): `strategy.make_numpy_iterator` can be used directly instead of + # converting to datasets. + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_dataset, test_dataset + + +def main(unused_argv): + """Run a CNN model on MNIST data to demonstrate DistributedStrategies.""" + + tf.enable_v2_behavior() + + num_gpus = FLAGS.num_gpus + if num_gpus is None: + devices = None + elif num_gpus == 0: + devices = ["/device:CPU:0"] + else: + devices = ["/device:GPU:{}".format(i) for i in range(num_gpus)] + strategy = tf.distribute.MirroredStrategy(devices) + + with strategy.scope(): + train_ds, test_ds = mnist_datasets() + train_ds = train_ds.shuffle(NUM_TRAIN_IMAGES).batch(FLAGS.batch_size) + test_ds = test_ds.batch(FLAGS.batch_size) + + model = create_model() + optimizer = tf.keras.optimizers.SGD(FLAGS.learning_rate, FLAGS.momentum) + training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) + training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "training_accuracy", dtype=tf.float32) + test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32) + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32) + + def train_step(inputs): + images, labels = inputs + with tf.GradientTape() as tape: + logits = model(images, training=True) + loss = compute_loss(logits, labels) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + training_loss.update_state(loss) + training_accuracy.update_state(labels, logits) + + def test_step(inputs): + images, labels = inputs + logits = model(images, training=False) + loss = compute_loss(logits, labels) + test_loss.update_state(loss) + test_accuracy.update_state(labels, logits) + + train_iterator = strategy.make_dataset_iterator(train_ds) + test_iterator = strategy.make_dataset_iterator(test_ds) + + for epoch in range(0, FLAGS.num_epochs): + # TODO(b/123315763): Create the tf.function outside this loop once we are + # able to initialize iterator in eager mode. + dist_train = lambda it: strategy.experimental_run(train_step, it) + dist_test = lambda it: strategy.experimental_run(test_step, it) + if FLAGS.use_function: + dist_train = tf.function(dist_train) + dist_test = tf.function(dist_test) + + # Train + print("Starting epoch {}".format(epoch)) + train_iterator.initialize() + while True: + try: + dist_train(train_iterator) + except tf.errors.OutOfRangeError: + break + print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format( + training_loss.result(), training_accuracy.result() * 100)) + training_loss.reset_states() + training_accuracy.reset_states() + + # Test + test_iterator.initialize() + while True: + try: + dist_test(test_iterator) + except tf.errors.OutOfRangeError: + break + print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format( + test_loss.result(), test_accuracy.result() * 100)) + test_loss.reset_states() + test_accuracy.reset_states() + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow/contrib/distribute/python/examples/mnist_tf1_tpu.py b/tensorflow/contrib/distribute/python/examples/mnist_tf1_tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8194c576e67f7ba864f63885c9b028e4136e61 --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/mnist_tf1_tpu.py @@ -0,0 +1,188 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run MNIST on multiple GPUs on using MirroredStrategy with eager execution. + +By default, runs on all available GPUs, or CPU if no GPUs are available. + +NOTE: Currently, this takes more time than when running MNIST in eager without +MirroredStrategy because of a number overheads. Therefore, this is just a +proof of concept right now and cannot be used to actually scale up training. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +import numpy as np +import tensorflow as tf + + +flags.DEFINE_string("tpu", None, "Name of the TPU to use.") +flags.DEFINE_integer("batch_size", 64, + "What should be the size of each batch?") +flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?") +flags.DEFINE_float("learning_rate", 0.01, "Learning Rate") +flags.DEFINE_float("momentum", 0.5, "SGD momentum") + +FLAGS = flags.FLAGS +NUM_TRAIN_IMAGES = 60000 + + +def create_model(): + max_pool = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding="same") + # The model consists of a sequential chain of layers, so tf.keras.Sequential + # (a subclass of tf.keras.Model) makes for a compact description. + return tf.keras.Sequential([ + tf.keras.layers.Reshape( + target_shape=[28, 28, 1], + input_shape=(28, 28,)), + tf.keras.layers.Conv2D(2, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Conv2D(4, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.4), + tf.keras.layers.Dense(10)]) + + +def compute_loss(logits, labels): + loss = tf.reduce_sum( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels)) + # Scale loss by global batch size. + return loss * (1. / FLAGS.batch_size) + + +def mnist_datasets(): + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32. + x_train, x_test = x_train / np.float32(255), x_test / np.float32(255) + y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64) + # TODO(priyag): `strategy.make_numpy_iterator` can be used directly instead of + # converting to datasets. + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_dataset, test_dataset + + +def main(argv): + """Run a CNN model on MNIST data to demonstrate DistributedStrategies.""" + del argv # Unused. + tf.disable_v2_behavior() + + cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + tpu=FLAGS.tpu) + strategy = tf.contrib.distribute.TPUStrategy(cluster_resolver) + + with strategy.scope(): + train_ds, test_ds = mnist_datasets() + train_ds = train_ds.shuffle(NUM_TRAIN_IMAGES).batch(FLAGS.batch_size) + test_ds = test_ds.batch(FLAGS.batch_size) + + model = create_model() + optimizer = tf.keras.optimizers.SGD(FLAGS.learning_rate, FLAGS.momentum) + training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) + training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "training_accuracy", dtype=tf.float32) + test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32) + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32) + + def train_step(inputs): # pylint: disable=missing-docstring + images, labels = inputs + with tf.GradientTape() as tape: + logits = model(images, training=True) + loss = compute_loss(logits, labels) + grads = tape.gradient(loss, model.variables) + update_vars = optimizer.apply_gradients(zip(grads, model.variables)) + update_loss = training_loss.update_state(loss) + update_accuracy = training_accuracy.update_state(labels, logits) + + with tf.control_dependencies([update_vars, update_loss, update_accuracy]): + return tf.identity(loss) + + def test_step(inputs): + images, labels = inputs + logits = model(images, training=False) + loss = compute_loss(logits, labels) + update_loss = test_loss.update_state(loss) + update_accuracy = test_accuracy.update_state(labels, logits) + + with tf.control_dependencies([update_loss, update_accuracy]): + return tf.identity(loss) + + train_iterator = strategy.make_dataset_iterator(train_ds) + test_iterator = strategy.make_dataset_iterator(test_ds) + + dist_train = strategy.unwrap( + strategy.experimental_run(train_step, train_iterator)) + dist_test = strategy.unwrap( + strategy.experimental_run(test_step, test_iterator)) + + training_loss_result = training_loss.result() + training_accuracy_result = training_accuracy.result() + test_loss_result = test_loss.result() + test_accuracy_result = test_accuracy.result() + + tf.contrib.distribute.initialize_tpu_system(cluster_resolver) + + train_iterator_init = train_iterator.initialize() + test_iterator_init = test_iterator.initialize() + + all_variables = ( + tf.global_variables() + + training_loss.variables + + training_accuracy.variables + + test_loss.variables + + test_accuracy.variables) + + with tf.Session(cluster_resolver.master()) as session: + session.run([v.initializer for v in all_variables]) + + for epoch in range(0, FLAGS.num_epochs): + # Train + print("Starting epoch {}".format(epoch)) + session.run(train_iterator_init) + while True: + try: + session.run(dist_train) + except tf.errors.OutOfRangeError: + break + print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format( + session.run(training_loss_result), + session.run(training_accuracy_result) * 100)) + training_loss.reset_states() + training_accuracy.reset_states() + + # Test + session.run(test_iterator_init) + while True: + try: + session.run(dist_test) + except tf.errors.OutOfRangeError: + break + print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format( + session.run(test_loss_result), + session.run(test_accuracy_result) * 100)) + test_loss.reset_states() + test_accuracy.reset_states() + + +if __name__ == "__main__": + flags.mark_flag_as_required("tpu") + app.run(main) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 79a9803d75a35445280c006fa023637c9b01fdcc..cfaee03a2003089366a506168be2942c279f45bf 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -91,9 +91,10 @@ def main(_): predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) return predict_features - predictions = estimator.predict(input_fn=predict_input_fn) - # TODO(anjalsridhar): This returns a generator object, figure out how to get - # meaningful results here. + prediction_iterable = estimator.predict(input_fn=predict_input_fn) + # Create a list containing each of the prediction dictionaries that map + # the key 'logits' to an array of model outputs. + predictions = [prediction_iterable.next() for _ in range(10)] print("Prediction results: {}".format(predictions)) diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py index f589cd6ad54ea8f33002cb067ef8d83d3d33036a..80a1c7bae8f34aaf6cfd9357da2b071c200adf8b 100644 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -22,8 +22,6 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib @@ -31,275 +29,60 @@ from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import random_ops from tensorflow.python.util import nest -class PerReplicaDatasetTest(test.TestCase): - - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _test_iterator(self, devices, dataset, expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = input_lib.InputWorkers(device_map) - per_replica_dataset = input_lib.PerReplicaDataset(dataset, input_workers, 0) - if context.executing_eagerly(): - iterator = per_replica_dataset.make_one_shot_iterator() - else: - iterator = per_replica_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next_as_list() - computed_value = self.evaluate(next_element) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next_as_list() - self.evaluate(next_element) - - @test_util.run_in_graph_and_eager_modes - def testOneDevice(self): - devices = ["/device:CPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleDevices(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTupleDataset(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnevenDatasetBatches(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(devices, dataset, expected_values) - - def testInitializableIterator(self): - with context.graph_mode(): - devices = ["/device:CPU:0"] - # Using random input since that is only allowed with initializable - # iterator. - dataset = dataset_ops.Dataset.from_tensor_slices( - random_ops.random_uniform((10,))) - - device_map = values.ReplicaDeviceMap(devices) - input_workers = input_lib.InputWorkers(device_map) - per_replica_dataset = input_lib.PerReplicaDataset( - dataset, input_workers, 0) - iterator = per_replica_dataset.make_initializable_iterator() - - self.evaluate(iterator.initializer) - next_element = iterator.get_next_as_list() - for _ in range(10): - self.evaluate(next_element) - - # Should fail after the input is finished. - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(next_element) - - # After re-initializing the iterator, should be able to iterate again. - self.evaluate(iterator.initializer) - for _ in range(10): - self.evaluate(next_element) - - -class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - - def _test_iterator(self, sess, iterator, devices, expected_values): - next_element = iterator.get_next() - for r, device in enumerate(devices): - v = values.select_replica(r, next_element) - # The `v` here can be a tuple. - for element in nest.flatten(v): - self.assertTrue(element.device in device) - - for expected_value in expected_values: - t = [values.select_replica(r, next_element) for r in range(len(devices))] - actual = sess.run(t) - self.assertEqual(expected_value, actual) - - with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) - - def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = input_lib.InputWorkers(device_map, worker_devices) - multi_worker_dataset = input_lib.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.cached_session() as sess: - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, expected_values) - - def _cpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"]) - ) - devices = [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def _cpu_and_one_gpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", ( - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - )), - ("/job:worker/replica:0/task:1", ( - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - )) - ) - devices = [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def testDataDistributionOneDevicePerWorker(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testDataDistributionTwoDevicePerWorker(self): - if context.num_gpus() < 1: - self.skipTest("A GPU is not available for this test.") - worker_devices, devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) - - def testTupleDataset(self): - worker_devices, devices = self._cpu_devices() - - with context.graph_mode(): - - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(8) - dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] - self._test_dataset(dataset_fn, worker_devices, devices, - expected_values) - - def testInitializableIterator(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(8) - device_map = values.ReplicaDeviceMap(devices) - input_workers = input_lib.InputWorkers(device_map, worker_devices) - multi_worker_dataset = input_lib.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - # After re-initializing the iterator, should be able to iterate again. - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testValueErrorForIterator(self): - # Incompatiable arguments. - d1 = "/device:GPU:0" - d2 = "/device:GPU:1" - device_map = values.ReplicaDeviceMap([d1, d2]) - input_workers = input_lib.InputWorkers( - device_map, (("w1", (d1,)), ("w2", (d2,)))) - with self.assertRaises(ValueError): - input_lib.MultiWorkerDataIterator([("w1", None)], input_workers) - - def testDuplicateDevices(self): - _, devices = self._cpu_devices() - devices.append("/job:worker/replica:0/task:0/device:CPU:0") - with self.assertRaises(ValueError): - _ = values.ReplicaDeviceMap(devices) - - class InputIteratorTestBase(test.TestCase): - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) + def _create_iterator(self, input_type, dataset_fn, worker_device_pairs, + devices, split_batch_by): device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = input_lib.InputFunctionIterator( - input_fn, input_workers, input_contexts) + input_contexts = [] + for i in range(input_workers.num_workers): + input_contexts.append( + distribute_lib.InputContext( + num_input_pipelines=input_workers.num_workers, + input_pipeline_id=i, + num_replicas_in_sync=len(devices))) + + iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers, + input_contexts) else: iterator = input_lib.DatasetIterator( - dataset_fn(), input_workers, split_batch_by) + dataset_fn(distribute_lib.InputContext()), input_workers, + split_batch_by) + return iterator + + def _test_iterator(self, + input_type, + dataset_fn, + worker_device_pairs, + expected_values, + sess=None, + split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + iterator = self._create_iterator( + input_type, dataset_fn, worker_device_pairs, devices, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate([values.select_replica(r, next_element) - for r in range(len(devices))]) + evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) @@ -308,7 +91,9 @@ class InputIteratorTestBase(test.TestCase): next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) class InputIteratorSingleWorkerTest(InputIteratorTestBase, @@ -319,7 +104,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, input_type=["input_fn", "dataset"])) def testOneDeviceCPU(self, input_type): worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] @@ -332,7 +117,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, required_gpus=1)) def testTwoDevicesOneGPUOneCPU(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] @@ -345,7 +130,9 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, required_gpus=1)) def testTupleDataset(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): + + def dataset_fn(ctx): + del ctx dataset1 = dataset_ops.Dataset.range(10) dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -355,15 +142,17 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, self._test_iterator(input_type, dataset_fn, worker_device_pairs, expected_values) - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) + @combinations.generate( + combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) def testUnevenDatasetBatches(self, input_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) - expected_values = [[i, i+1] for i in range(0, 10, 2)] + # The last global batch only contains data for one replica. + expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] self._test_iterator(input_type, dataset_fn, worker_device_pairs, expected_values) @@ -375,7 +164,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, def testBatchSplitting(self, input_type, split_batch_by): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) updated_batch_size = ( batch_size // split_batch_by if split_batch_by else batch_size) @@ -417,7 +206,7 @@ class InputIteratorMultiWorkerTest( def testOneDevicePerWorker(self, input_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) + dataset_fn = lambda _: dataset_ops.Dataset.range(4) self._test_iterator(input_type, dataset_fn, worker_devices, [[0, 0], [1, 1], [2, 2], [3, 3]], sess) @@ -428,7 +217,7 @@ class InputIteratorMultiWorkerTest( def testTwoDevicesPerWorker(self, input_type): worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) + dataset_fn = lambda _: dataset_ops.Dataset.range(4) self._test_iterator(input_type, dataset_fn, worker_devices, [[0, 1, 0, 1], [2, 3, 2, 3]], sess) @@ -438,7 +227,9 @@ class InputIteratorMultiWorkerTest( def testTupleDataset(self, input_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): + + def dataset_fn(ctx): + del ctx dataset1 = dataset_ops.Dataset.range(4) dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) @@ -447,33 +238,35 @@ class InputIteratorMultiWorkerTest( self._test_iterator(input_type, dataset_fn, worker_devices, expected_values, sess) + @combinations.generate( + combinations.combine( + mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) + expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], + [[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + @combinations.generate( + combinations.combine( + mode=["graph"], input_type=["input_fn"], required_gpus=1)) + def testDifferentDatasets(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + + def dataset_fn(ctx): + if ctx.input_pipeline_id == 0: + return dataset_ops.Dataset.range(8).batch(2) + else: + return dataset_ops.Dataset.range(9).batch(2) -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) + expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], + [[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py index 92de8e643e7588365c23dc8513e197c0869c9ecf..2788f342c740f88bf8869fe1b5da7d48b4b31a61 100644 --- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -28,9 +28,11 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_lib from tensorflow.python.eager import test from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop @@ -316,15 +318,19 @@ def all_strategy_combinations(): return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -# TODO(priyag): Add v2 optimizers here. def strategy_and_optimizer_combinations(): return combinations.times( all_strategy_combinations(), - combinations.combine( - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn])) + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + combinations.rmsprop_optimizer_keras_v2_fn + ])) def strategy_and_input_combinations(): @@ -350,6 +356,7 @@ def strategy_for_numpy_input_combinations(): mode=['graph']) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): @@ -458,6 +465,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, self.assertAllEqual([6, 7], outs[1].shape) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): @@ -741,13 +749,16 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - grouped_models = distribution.unwrap(model._distributed_model) + grouped_models = distribution.experimental_local_results( + distributed_training_utils.get_distributed_model( + model, ModeKeys.TRAIN)) with distribution.scope(): for m in grouped_models: self.assertAllClose(0.001, keras.backend.get_value( m.optimizer.lr), atol=1e-05, rtol=1e-05) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine( @@ -787,16 +798,21 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): verbose=0, sample_weight=sample_weight) - # Test with not specifying the `steps` argument. - with self.assertRaisesRegexp( - ValueError, 'the `steps_per_epoch` argument'): + # Test with not specifying the `steps` argument for dataset with + # infinite cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'the `steps` argument'): + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): model.evaluate(dataset, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'the `steps` argument'): + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): model.predict(dataset, verbose=0) @combinations.generate(combinations.combine( @@ -830,6 +846,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): callbacks=[keras.callbacks.ReduceLROnPlateau()]) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithLossMasking(test.TestCase, parameterized.TestCase): @@ -860,6 +877,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, self.assertEqual(hist.history['loss'][0], 0) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): @@ -892,6 +910,7 @@ class TestDistributionStrategyWithNormalizationLayer( np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test_base.py b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py index 08ed933f2986b76bc0f8363a5a81682aa8a24493..c2d840788ca82e05bc8d03eb6d9e7b9f3608d966 100644 --- a/tensorflow/contrib/distribute/python/keras_correctness_test_base.py +++ b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py @@ -17,8 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from absl.testing import parameterized import numpy as np +import six from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy @@ -77,11 +79,7 @@ def strategies_for_embedding_models(): and DefaultStrategy in order to prevent testing timeouts. """ - strategies = [s for s in all_strategies - if not s.required_tpu and s.required_gpus is not None] - strategies.append(combinations.tpu_strategy_loop_on_device) - strategies.append(combinations.tpu_strategy_one_step_loop_on_device) - return strategies + return [s for s in all_strategies if s.required_tpu or s.required_gpus] def test_combinations_for_embedding_model(): @@ -93,6 +91,16 @@ def test_combinations_for_embedding_model(): eager_mode_test_configuration()))) +def test_combinations_with_tpu_strategies(): + tpu_strategies = [combinations.tpu_strategy, + combinations.tpu_strategy_one_step] + + return ( + combinations.times( + combinations.combine(distribution=tpu_strategies), + graph_mode_test_configuration())) + + class MaybeDistributionScope(object): """Provides a context allowing no distribution strategy.""" @@ -133,6 +141,19 @@ def get_batch_size(global_batch_size, distribution): return batch_size +def get_data_size(data): + """Gets the size of data in list, tuple, dict, or a numpy array.""" + assert isinstance(data, (np.ndarray, list, dict, tuple)) + + if isinstance(data, np.ndarray): + return len(data) + + if isinstance(data, (list, tuple)): + return len(data[0]) + + return len(six.next(six.itervalues(data))) + + def get_correctness_test_inputs(use_numpy, use_validation_data, with_distribution, x_train, y_train, x_predict): """Generates the inputs for correctness check when enable Keras with DS.""" @@ -159,11 +180,12 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, 'y': y_train, } predict_inputs = { - 'x': np.array(x_predict, dtype=np.float32), + 'x': x_predict } else: - if len(x_train) < _GLOBAL_BATCH_SIZE * _EVAL_STEPS: - # Currently, we cannot detech the size of a dataset. So, the eval steps is + training_data_size = get_data_size(x_train) + if training_data_size < _GLOBAL_BATCH_SIZE * _EVAL_STEPS: + # Currently, we cannot detect the size of a dataset. So, the eval steps is # hard coded. raise ValueError('x_train must have at least ' '_GLOBAL_BATCH_SIZE * _EVAL_STEPS samples') @@ -179,7 +201,7 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, 'y': None, 'epochs': training_epochs, 'shuffle': False, - 'steps_per_epoch': len(x_train) // global_batch_size, + 'steps_per_epoch': training_data_size // global_batch_size, } if use_validation_data: eval_inputs = None # Remove the eval_inputs @@ -195,7 +217,8 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, 'steps': _EVAL_STEPS, } - predict_batch_size = get_batch_size(len(x_predict), with_distribution) + predict_batch_size = get_batch_size(get_data_size(x_predict), + with_distribution) predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, with_distribution) @@ -207,11 +230,11 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, return training_inputs, eval_inputs, predict_inputs -def fit_eval_and_predict( - initial_weights, input_fn, model_fn, distribution=None): +def fit_eval_and_predict(initial_weights, input_fn, model_fn, + distribution=None, is_stateful_model=False): """Generates results for fit/predict/evaluate for given model.""" model = model_fn(initial_weights=initial_weights, distribution=distribution) - training_inputs, eval_inputs, predict_inputs = input_fn(distribution) + training_inputs, eval_inputs, predict_inputs = input_fn() result = {} result['training_history_1'] = model.fit(**training_inputs).history @@ -222,7 +245,15 @@ def fit_eval_and_predict( result['weights_1'] = model.get_weights() if predict_inputs is not None: - result['predict_result_1'] = model.predict(**predict_inputs) + # Check correctness of the result of predict() invoked + # multiple times -- as for stateful models, result of + # predict may differ for each batch. + predict_length = 1 + if is_stateful_model: + predict_length = 3 + for i in range(predict_length): + result_key = 'predict_result_{}'.format(i) + result[result_key] = model.predict(**predict_inputs) # Train and eval again to mimic user's flow. @@ -241,19 +272,20 @@ def compare_results(results_with_ds, results_without_ds, distribution, """Compares results of model compiled with/without distribution strategy.""" default_tolerance = 1e-5 - tol_table = {} - - if isinstance(distribution, ( - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access - # TODO(b/119257215): Weights are not exactly the same, so use larger - # tolerance for now. Predict should be related to weights. - tol_table = { - 'weights_1': 1e-4, - 'weights_2': 1e-4, - 'predict_result_1': 1e-4, - } + relaxed_tolerance = 1e-4 + + def _get_compare_result_tolerance(key): + """Returns tolerance to compare results.""" + # TODO(b/119257215): For MirroredStrategy, weights are not exactly the same, + # so use larger tolerance for now. Predict should be related to weights. + if (isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access + key.startswith(('weights_1', 'weights_2', 'predict_result'))): + return relaxed_tolerance + + return default_tolerance for key in results_with_ds: if (key.startswith('training_history') and @@ -263,8 +295,7 @@ def compare_results(results_with_ds, results_without_ds, distribution, # underlying bug is fixed. continue - tolerance = tol_table.get(key, default_tolerance) - + tolerance = _get_compare_result_tolerance(key) testcase.assertAllClose( results_with_ds[key], results_without_ds[key], @@ -315,6 +346,22 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, y_train = x_train return (x_train.astype('float32'), y_train.astype('float32'), None) + def get_input_for_correctness_test(self, **kwargs): + """Generates inputs that are dictionaries. + + We only provide a default implementation of this method here. If you need + more customized way of providing input to your model, overwrite this method. + + Arguments: + **kwargs: key word arguments about how to create the input dictionaries + + Returns: + Three dictionaries representing the input for fit(), evalutate() and + predict() + """ + + return get_correctness_test_inputs(**kwargs) + def get_model(self, distribution=None): raise NotImplementedError @@ -334,7 +381,8 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, distribution, use_numpy, use_validation_data, - with_batch_norm=False): + with_batch_norm=False, + is_stateful_model=False): with self.cached_session(): self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) self.skip_unsupported_test_configuration(distribution) @@ -342,23 +390,42 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, # Train, eval, and predict datasets are created with the same input numpy # arrays. x_train, y_train, x_predict = self.get_data() - # The model is built once and the initial weights are saved. # This is used to initialize the model for both the distribution and # non-distribution run. model = self.get_model() initial_weights = model.get_weights() - def input_fn(dist): - return get_correctness_test_inputs( - use_numpy, use_validation_data, dist, x_train, y_train, x_predict) + ds_input_fn = functools.partial( + self.get_input_for_correctness_test, + use_numpy=use_numpy, + use_validation_data=use_validation_data, + with_distribution=distribution, + x_train=x_train, + y_train=y_train, + x_predict=x_predict) + + nods_input_fn = functools.partial( + self.get_input_for_correctness_test, + use_numpy=use_numpy, + use_validation_data=use_validation_data, + with_distribution=None, + x_train=x_train, + y_train=y_train, + x_predict=x_predict) results_with_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, model_fn=self.get_model, - distribution=distribution) + initial_weights, + input_fn=ds_input_fn, + model_fn=self.get_model, + distribution=distribution, + is_stateful_model=is_stateful_model) results_without_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, model_fn=self.get_model, - distribution=None) + initial_weights, + input_fn=nods_input_fn, + model_fn=self.get_model, + distribution=None, + is_stateful_model=is_stateful_model) # First, special case, for multi-replica distributed training, batch norm # is not aggregated globally. So it is expected to have different weights. @@ -371,6 +438,23 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, compare_results(results_with_ds, results_without_ds, distribution, testcase=self) + def get_input_for_dynamic_lr_test(self, **kwargs): + """Generates inputs that are dictionaries. + + We only provide a default implementation of this method here. If you need + more customized way of providing input to your model, overwrite this method. + + Arguments: + **kwargs: key word arguments about how to create the input dictionaries + + Returns: + Three dictionaries representing the input for fit(), evalutate() and + predict() + """ + + training_input = kwargs + return training_input, None, None + def run_dynamic_lr_test(self, distribution): with self.cached_session(): self.set_up_test_config() @@ -388,30 +472,41 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, # same as TPU. update_freq = distribution.extended.steps_per_run - def input_fn(dist): - """Generates training test given test configuration.""" - training_epochs = 2 - global_batch_size = 64 - batch_size = get_batch_size(global_batch_size, dist) - - training_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - 'epochs': training_epochs, - 'shuffle': False, - 'callbacks': [LearningRateBatchScheduler(update_freq)], - 'validation_data': (x_train, y_train) - } - # In this test case, we do not care eval and predict. - eval_inputs, predict_inputs = None, None - return training_inputs, eval_inputs, predict_inputs + training_epochs = 2 + global_batch_size = 64 + + ds_batch_size = get_batch_size(global_batch_size, distribution) + nods_batch_size = get_batch_size(global_batch_size, None) + + ds_input_fn = functools.partial( + self.get_input_for_dynamic_lr_test, + x=x_train, + y=y_train, + batch_size=ds_batch_size, + shuffle=False, + epochs=training_epochs, + callbacks=[LearningRateBatchScheduler(update_freq)], + validation_data=(x_train, y_train)) + + nods_input_fn = functools.partial( + self.get_input_for_dynamic_lr_test, + x=x_train, + y=y_train, + batch_size=nods_batch_size, + shuffle=False, + epochs=training_epochs, + callbacks=[LearningRateBatchScheduler(update_freq)], + validation_data=(x_train, y_train)) results_with_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, model_fn=self.get_model, + initial_weights, + input_fn=ds_input_fn, + model_fn=self.get_model, distribution=distribution) results_without_ds = fit_eval_and_predict( - initial_weights, input_fn=input_fn, model_fn=self.get_model, + initial_weights, + input_fn=nods_input_fn, + model_fn=self.get_model, distribution=None) compare_results(results_with_ds, results_without_ds, distribution, testcase=self) @@ -448,7 +543,7 @@ class TestDistributionStrategyEmbeddingModelCorrectnessBase( features, maxlen=max_words) x_train = np.asarray(features, dtype=np.float32) y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1)) - x_predict = x_train + x_predict = x_train[:_GLOBAL_BATCH_SIZE] return x_train, y_train, x_predict diff --git a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py index 7afacab0ddbed8d5b448c2ed2b983bfa18d11b80..61202e30c4f33892d2675080fae07cc4d7102337 100644 --- a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py +++ b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py @@ -47,7 +47,9 @@ class TestDistributionStrategyDnnCorrectness( # We add few non-linear layers to make it non-trivial. model = keras.Sequential() model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense( + 10, activation='relu', + kernel_regularizer=keras.regularizers.l2(1e-4))) model.add(keras.layers.Dense(10, activation='relu')) model.add(keras.layers.Dense(1)) @@ -68,7 +70,7 @@ class TestDistributionStrategyDnnCorrectness( y_train = 3 * x_train x_train = x_train.astype('float32') y_train = y_train.astype('float32') - x_predict = [[1.], [2.], [3.], [4.]] + x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) return x_train, y_train, x_predict @combinations.generate(keras_correctness_test_base. diff --git a/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py index 3913f9bc0cdfff6b562d5727ec33eb4d83f4a619..e881bb70ecc428e3f972cde5f19c1b61b1dc0f0b 100644 --- a/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py +++ b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py @@ -71,5 +71,80 @@ class DistributionStrategyEmbeddingModelCorrectnessTest( self.run_correctness_test(distribution, use_numpy, use_validation_data) +class DistributionStrategySiameseEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids_a = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_a') + word_ids_b = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_b') + + def submodel(embedding, word_ids): + word_embed = embedding(word_ids) + rep = keras.layers.GlobalAveragePooling1D()(word_embed) + return keras.Model(inputs=[word_ids], outputs=[rep]) + + word_embed = keras.layers.Embedding( + input_dim=20, + output_dim=10, + input_length=max_words, + embeddings_initializer=keras.initializers.RandomUniform(0, 1)) + + a_rep = submodel(word_embed, word_ids_a).outputs[0] + b_rep = submodel(word_embed, word_ids_b).outputs[0] + sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep]) + + model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='mse', + metrics=['mse']) + return model + + def get_data(self, + count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE * + keras_correctness_test_base._EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + features_a, labels_a, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + features_b, labels_b, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + y_train = np.zeros((count, 1), dtype=np.float32) + y_train[labels_a == labels_b] = 1.0 + y_train[labels_a != labels_b] = -1.0 + # TODO(b/123360757): Add tests for using list as inputs for multi-input + # models. + x_train = { + 'words_a': features_a, + 'words_b': features_b, + } + x_predict = x_train + + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_siamese_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py index f625664372dfb6814ccbe9539f6abe018d2a4447..3c2961456b2eede9570ce29f7a8900834f2ccfb7 100644 --- a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py +++ b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py @@ -23,7 +23,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import keras_correctness_test_base from tensorflow.python import keras from tensorflow.python.eager import test -from tensorflow.python.training import gradient_descent +from tensorflow.python.keras.optimizer_v2 import gradient_descent class DistributionStrategyCnnCorrectnessTest( @@ -33,7 +33,8 @@ class DistributionStrategyCnnCorrectnessTest( with keras_correctness_test_base.MaybeDistributionScope(distribution): image = keras.layers.Input(shape=(28, 28, 3), name='image') c1 = keras.layers.Conv2D( - name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4))( + name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4), + kernel_regularizer=keras.regularizers.l2(1e-4))( image) if self.with_batch_norm: c1 = keras.layers.BatchNormalization(name='bn1')(c1) @@ -47,7 +48,7 @@ class DistributionStrategyCnnCorrectnessTest( model.set_weights(initial_weights) model.compile( - optimizer=gradient_descent.GradientDescentOptimizer( + optimizer=gradient_descent.SGD( learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) diff --git a/tensorflow/contrib/distribute/python/keras_multi_worker_test.py b/tensorflow/contrib/distribute/python/keras_multi_worker_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0625a0e4ee928ea49a345b263a4f596b2a3957 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_multi_worker_test.py @@ -0,0 +1,460 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test multi-worker Keras. + +TODO(b/123845258): Move this to tensorflow core. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import copy +import functools +import os +import sys +import threading + +from absl.testing import parameterized + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy as collective_strategy +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base as test_base +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import distribute_coordinator_context as dc_context +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras import backend +from tensorflow.python.keras import callbacks +from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras import models +from tensorflow.python.keras import optimizers +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +def _mnist_synthetic_dataset(batch_size, steps_per_epoch): + # train dataset + x_train = array_ops.ones([batch_size * steps_per_epoch, 28, 28, 1], + dtype=dtypes.float32) + y_train = array_ops.ones([batch_size * steps_per_epoch, 1], + dtype=dtypes.int32) + train_ds = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_ds = train_ds.repeat() + # train_ds = train_ds.shuffle(100) + train_ds = train_ds.batch(64, drop_remainder=True) + + # eval dataset + x_test = random_ops.random_uniform([10000, 28, 28, 1], dtype=dtypes.float32) + y_test = random_ops.random_uniform([10000, 1], + minval=0, + maxval=9, + dtype=dtypes.int32) + eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + eval_ds = eval_ds.repeat() + eval_ds = eval_ds.batch(64, drop_remainder=True) + + return train_ds, eval_ds + + +def _get_model(input_shape): + # Define a deterministically-initialized CNN model to recognize MNIST digits, + # commented out several layers to simplify it. + model = keras.models.Sequential() + model.add( + keras.layers.Conv2D( + 32, + kernel_size=(3, 3), + activation='relu', + input_shape=input_shape, + kernel_initializer=keras.initializers.TruncatedNormal(seed=99))) + # model.add(keras.layers.Conv2D(64, (3, 3), activation='relu')) + # model.add(keras.layers.MaxPooling2D(pool_size=(2, 2))) + # model.add(keras.layers.Dropout(0.25)) + model.add(keras.layers.Flatten()) + # model.add(keras.layers.Dense(128, activation='relu')) + # model.add(keras.layers.Dropout(0.5)) + model.add( + keras.layers.Dense( + 10, + activation='softmax', + kernel_initializer=keras.initializers.TruncatedNormal(seed=99))) + + # TODO(yuefengz): optimizer with slot variables doesn't work because of + # optimizer's bug. + # TODO(yuefengz): we should not allow non-v2 optimizer. + model.compile( + loss=keras.losses.sparse_categorical_crossentropy, + optimizer=gradient_descent.SGD(learning_rate=0.001), + metrics=['accuracy']) + return model + + +def _clone_and_build_model(model, strategy): + # The new "original" model in worker 0. + with strategy.scope(): + cloned_model = models.clone_model(model) + + # Compile and build model. + if isinstance(model.optimizer, optimizers.TFOptimizer): + optimizer = model.optimizer + # TODO(yuefengz): figure out why the optimizer here is still a + # TFOptimizer. + while isinstance(optimizer, optimizers.TFOptimizer): + optimizer = optimizer.optimizer + optimizer = copy.deepcopy(optimizer) + else: + optimizer_config = model.optimizer.get_config() + optimizer = type(model.optimizer).from_config(optimizer_config) + + cloned_model.compile( + optimizer, + model.loss, + metrics=metrics_module.clone_metrics(model._compile_metrics), + loss_weights=model.loss_weights, + sample_weight_mode=model.sample_weight_mode, + weighted_metrics=metrics_module.clone_metrics( + model._compile_weighted_metrics)) + return cloned_model + + +# TODO(b/123918215): Possibly merge this Callback with keras_test.Counter. +class MultiWorkerVerificationCallback(callbacks.Callback): + """MultiWorkerVerificationCallback verifies the callbacks in multi-worker scheme. + + This Callback is intended to be used for verifying the callback is indeed + called the correct number of times in various task types. + + Attributes: + _task_dict: A nested dictionary storing the number of times a callback has + been called in specific task type, task index, and method name. + Look up structure is + task_name -> task_id -> tracking_method_name -> invoke_count + For example, a _task_dict of + { + 'ps': { + 0: { + 'on_epoch_begin': 2 + }, + 1: { + 'on_epoch_begin': 2 + } + }, + 'worker': { + 0: { + 'on_epoch_begin': 2 + }, + 1: { + 'on_epoch_begin': 2 + } + } + } + indicates the ps task has 'on_epoch_begin' called twice on each + of the two indices, and likewise for worker task. + """ + + # TODO(rchao): Add other method calls to verify. + METHODS_TO_VERIFY = ['on_epoch_begin'] + + def __init__(self, num_epoch, num_worker): + """Initialize a MultiWorkerVerificationCallback. + + Args: + num_epoch: Number of epochs this Callback is expected to be called for. + num_worker: Number of workers this Callback is expected to be called from. + """ + super(MultiWorkerVerificationCallback, self).__init__() + self._num_epoch = num_epoch + self._num_worker = num_worker + self._task_dict = { + key: collections.defaultdict(lambda: collections.defaultdict(int)) + for key in ['ps', 'worker'] + } + self._lock = threading.Lock() + self._is_between_graph = None + self.wrap_methods(self.METHODS_TO_VERIFY) + + @property + def is_between_graph(self): + return self._is_between_graph + + @is_between_graph.setter + def is_between_graph(self, is_between_graph): + self._is_between_graph = is_between_graph + + def wrap_methods(self, method_names): + """Wrap methods so that the counts of calls are tracked. + + Args: + method_names: A list of names of methods to track calls. + """ + for method_name in method_names: + method = getattr(self, method_name) + + def wrapped_method(method_to_wrap, name, *arg, **kwargs): + # Use lock to ensure += operation is thread-safe. + with self._lock: + self._task_dict[test_base.get_task_type()][ + test_base.get_task_index()][name] += 1 + method_to_wrap(*arg, **kwargs) + + setattr(self, method_name, + functools.partial(wrapped_method, method, method_name)) + + def verify(self, test_case): + method_count_dict = { + method_name: self._num_epoch for method_name in self.METHODS_TO_VERIFY + } + assert self._is_between_graph is not None + if self._is_between_graph: + # TODO(b/124171024): In between-graph replication, by default only the + # chief calls callback. Fix this test to cover that, as well as the rare + # cases where all workers call. + worker_call_count = { + i: method_count_dict for i in range(0, self._num_worker) + } + else: + # If in-graph, only the first worker calls callback methods. + worker_call_count = {0: method_count_dict} + test_case.assertDictEqual( + self._task_dict, + { + # PS' callback is not supposed to be called. + 'ps': {}, + # Each of the Worker should be called num_epoch of times. + 'worker': worker_call_count + }) + + +# TODO(yuefengz): right now, fit or evaluate has to be called under distribution +# strategy's scope. +def _run_standalone_client(test_obj, strategy, cluster_spec): + input_shape = (28, 28, 1) + with strategy.scope(): + orig_model = _get_model(input_shape) + + def worker_fn(strategy): + with ops.Graph().as_default(): + batch_size = 64 + steps = 10 + + with strategy.scope(): + train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) + model = _clone_and_build_model(orig_model, strategy) + + orig_loss, orig_acc = model.evaluate(train_ds, steps=steps) + + # Workaround for the metrics issue (b/122928955) in async training. This + # can only be used in standalone client mode. + dc_context.get_current_worker_context().wait_for_other_workers() + + model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) + + dc_context.get_current_worker_context().wait_for_other_workers() + + trained_loss, trained_acc = model.evaluate(train_ds, steps=steps) + + test_obj.assertLessEqual(trained_loss, orig_loss) + test_obj.assertGreaterEqual(trained_acc, orig_acc) + + dc.run_distribute_coordinator( + worker_fn, + strategy, + mode=dc.CoordinatorMode.STANDALONE_CLIENT, + cluster_spec=cluster_spec) + + +def get_strategy_object(strategy_cls): + if (strategy_cls == mirrored_strategy.MirroredStrategy or + strategy_cls == mirrored_strategy.CoreMirroredStrategy): + return strategy_cls(mirrored_strategy.all_local_devices()) + else: + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + + +class KerasMultiWorkerTestStandaloneClient(test.TestCase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + super(KerasMultiWorkerTestStandaloneClient, cls).setUpClass() + cls._cluster_spec = test_base.create_in_process_cluster( + num_workers=2, num_ps=1, has_eval=False) + + @combinations.generate( + combinations.combine( + mode=['graph'], + strategy_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + parameter_server_strategy.ParameterServerStrategy, + collective_strategy.CollectiveAllReduceStrategy, + ], + required_gpus=[0, 1])) + def testSimpleModelStandaloneClient(self, strategy_cls): + # With standalone client, training_utils.should_run_multi_worker returns + # False which means the distribute coordinator won't be called again in + # `fit`. This is still correct and intended since session is still + # configured under distribute coordinator's worker context and distribution + # strategy object is already configured by distribute coordinator for + # multi-worker training. + # The logic should be much clearer once standalone client is merged into + # core Keras as well. + strategy = get_strategy_object(strategy_cls) + + _run_standalone_client(self, strategy, self._cluster_spec) + + +class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase, + parameterized.TestCase): + + @combinations.generate( + combinations.combine( + mode=['graph'], + strategy_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + collective_strategy.CollectiveAllReduceStrategy, + ], + required_gpus=[0, 1])) + def testSimpleModelIndependentWorkerSync(self, strategy_cls): + num_workers = 2 + num_epoch = 2 + + cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) + self._barrier = dc._Barrier(2) + + # The verification callback will be shared by multiple threads. + verification_callback = MultiWorkerVerificationCallback( + num_epoch=num_epoch, num_worker=num_workers) + + def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument + """Simulates an Independent Worker inside of a thread.""" + with test.mock.patch.object(dc, '_run_std_server', + self._make_mock_run_std_server()): + strategy = get_strategy_object(strategy_cls) + verification_callback.is_between_graph = \ + strategy.extended.experimental_between_graph + batch_size = 64 + steps = 10 + train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) + with strategy.scope(): + model = _get_model((28, 28, 1)) + orig_loss, _ = model.evaluate(train_ds, steps=steps) + callbacks_for_fit = nest.flatten( + kwargs.get('verification_callback', [])) + history = model.fit( + x=train_ds, + epochs=num_epoch, + steps_per_epoch=steps, + callbacks=callbacks_for_fit) + self.assertIsInstance(history, keras.callbacks.History) + trained_loss, _ = model.evaluate(train_ds, steps=steps) + self.assertLess(trained_loss, orig_loss) + + threads = self.run_multiple_tasks_in_threads( + _independent_worker_fn, + cluster_spec, + verification_callback=verification_callback) + + threads_to_join = [] + strategy = get_strategy_object(strategy_cls) + if strategy.extended.experimental_between_graph: + for ts in threads.values(): + threads_to_join.extend(ts) + else: + threads_to_join = [threads['worker'][0]] + self.join_independent_workers(threads_to_join) + verification_callback.verify(self) + + @combinations.generate( + combinations.combine( + mode=['graph'], + strategy_cls=[parameter_server_strategy.ParameterServerStrategy], + required_gpus=[0, 1])) + def testSimpleModelIndependentWorkerAsync(self, strategy_cls): + num_workers = 2 + num_epoch = 2 + cluster_spec = test_base.create_cluster_spec( + num_workers=num_workers, num_ps=2) + self._barrier = dc._Barrier(4) + + # The verification callback will be shared by multiple threads. + verification_callback = MultiWorkerVerificationCallback( + num_epoch=num_epoch, num_worker=num_workers) + + def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument + """Simulates an Independent Worker inside of a thread.""" + # TODO(rchao/yuefengz): The following is run by both worker and ps + # threads. The distribute coordinator should run std server immediately + # without configuring the session (or building the graph) on PS. + with test.mock.patch.object(dc, '_run_std_server', + self._make_mock_run_std_server()): + batch_size = 64 + steps = 10 + strategy = strategy_cls(num_gpus_per_worker=context.num_gpus()) + verification_callback.is_between_graph = \ + strategy.extended.experimental_between_graph + + train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) + val_ds, _ = _mnist_synthetic_dataset(batch_size, steps) + with strategy.scope(): + model = _get_model((28, 28, 1)) + + # TODO(b/123868066): Verify callback for model.evaluate(). + callbacks_for_fit = nest.flatten( + kwargs.get('verification_callback', [])) + history = model.fit( + x=train_ds, + epochs=num_epoch, + steps_per_epoch=steps, + validation_data=val_ds, + validation_steps=steps, + callbacks=callbacks_for_fit) + self.assertIsInstance(history, keras.callbacks.History) + + threads = self.run_multiple_tasks_in_threads( + _independent_worker_fn, + cluster_spec, + verification_callback=verification_callback) + + threads_to_join = [] + for task_type, ts in threads.items(): + # This test can finish once the worker threads complete, and thus + # the ps threads don't need to be joined. + if task_type == 'ps': + continue + threads_to_join.extend(ts) + self.join_independent_workers(threads_to_join) + verification_callback.verify(self) + + +if __name__ == '__main__': + # Enable manual variable initialization to make sure variables are initialized + # by `init_restore_or_wait_for_variables`. + backend.manual_variable_initialization(True) + with test.mock.patch.object(sys, 'exit', os._exit): + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 5349794334b7f6ea3d718343fa84c693dd3d7a3c..c93d7afa7ceef2c9c272e91997e2871655cea079 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.python import keras from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,108 +45,71 @@ def get_model(): class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + @combinations.generate( + combinations.combine( + distribution=[ + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.parameter_server_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def testKerasOptimizerWithUnequalInput(self, distribution): - def create_fn(): + with distribution.scope(): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) - # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - def loss(): - return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var - optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) - train_op = optimizer.minimize(loss, var_list=[var]) - m = optimizer.get_slot(var, 'm') - v = optimizer.get_slot(var, 'v') - return (var, m, v, train_op, optimizer.iterations) + all_vars = [] - devices = ['/device:GPU:0', '/device:CPU:0'] - with distribution.scope(): - (var, m, v, op, - counter) = distribution.extended.call_for_each_replica(create_fn) + def model_fn(): + + def loss_fn(): + replica_id = _replica_id() + return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * 0.5 * var + + train_op = optimizer.minimize(loss_fn, var_list=[var]) + + return train_op, optimizer + + def train_fn(): + train_op, optimizer = distribution.extended.call_for_each_replica( + model_fn) + if not all_vars: + all_vars.append(var) + all_vars.append(optimizer.get_slot(var, 'm')) + all_vars.append(optimizer.get_slot(var, 'v')) + return distribution.group(train_op) + + if not context.executing_eagerly(): + with self.cached_session() as sess: + train_fn = sess.make_callable(train_fn()) self.evaluate(variables.global_variables_initializer()) - var_val = [2.0, 2.0, 2.0] - self.assertAllClose( - var_val, - self.evaluate( - [distribution.extended.read_var(var), - var.get(devices[0]), - var.get(devices[1])])) - self.assertAllClose([0, 0, 0], - self.evaluate([ - distribution.extended.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) - - train_op = distribution.unwrap(op) - self.evaluate(train_op) - # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 - m_val = [1.2, 1.2, 1.2] - # assert slot variables in both replicas are the same. - self.assertAllClose( - m_val, - self.evaluate( - [distribution.extended.read_var(m), - m.get(devices[0]), - m.get(devices[1])])) - # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 - v_val = [1.8, 1.8, 1.8] - self.assertAllClose( - v_val, - self.evaluate( - [distribution.extended.read_var(v), - v.get(devices[0]), - v.get(devices[1])])) + + # first step. + train_fn() # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) # = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8 - var_val = [1.99, 1.99, 1.99] - self.assertAllClose( - var_val, - self.evaluate( - [distribution.extended.read_var(var), - var.get(devices[0]), - var.get(devices[1])])) - self.assertAllClose([1, 1, 1], - self.evaluate([ - distribution.extended.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) - - self.evaluate(train_op) + self.assertAllClose(1.99, self.evaluate(all_vars[0])) + # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 + self.assertAllClose(1.2, self.evaluate(all_vars[1])) + # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 + self.assertAllClose(1.8, self.evaluate(all_vars[2])) + + # second step. + train_fn() + # var(1) = var(0) - lr * 2 = 1.98 + self.assertAllClose(1.98, self.evaluate(all_vars[0])) # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5 - m_val = [1.44, 1.44, 1.44] - self.assertAllClose( - m_val, - self.evaluate( - [distribution.extended.read_var(m), - m.get(devices[0]), - m.get(devices[1])])) + self.assertAllClose(1.44, self.evaluate(all_vars[1])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 - v_val = [2.16, 2.16, 2.16] - self.assertAllClose( - v_val, - self.evaluate( - [distribution.extended.read_var(v), - v.get(devices[0]), - v.get(devices[1])])) - self.assertAllClose([2, 2, 2], - self.evaluate([ - distribution.extended.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + self.assertAllClose(2.16, self.evaluate(all_vars[2])) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.parameter_server_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): with self.cached_session(): diff --git a/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f5faf6c36b880a72bafc8d082cff2816f3b11a76 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py @@ -0,0 +1,99 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateful tf.keras LSTM models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +def strategies_for_stateful_embedding_model(): + """Returns TPUStrategy with single core device assignment.""" + + return [combinations.tpu_strategy_one_core, + combinations.tpu_strategy_one_step_one_core] + + +def test_combinations_for_stateful_embedding_model(): + return ( + combinations.combine( + distribution=strategies_for_stateful_embedding_model(), + mode='graph', + use_numpy=False, + use_validation_data=False + )) + + +class DistributionStrategyStatefulLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE + + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), + batch_size=batch_size, + dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False, + stateful=True)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(test_combinations_for_stateful_embedding_model()) + def test_stateful_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + @combinations.generate(keras_correctness_test_base. + test_combinations_with_tpu_strategies()) + def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( + self, distribution, use_numpy, use_validation_data): + with self.assertRaisesRegexp(ValueError, + 'Single core must be used for computation ' + 'on stateful models. Consider adding ' + '`device_assignment` parameter to ' + 'TPUStrategy'): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 8a607dd070f859aca69ee857d5c5f091f107e0ca..2b733c2cfb1609998be413254c2ffa377d4acb73 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -25,17 +25,17 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python import keras +from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import values from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -70,6 +70,32 @@ def simple_functional_model(): return model +def simple_subclassed_model(num_labels=_NUM_CLASS): + + class _SimpleMLP(keras.Model): + + def __init__(self, num_labels): + super(_SimpleMLP, self).__init__() + self.dense = keras.layers.Dense(num_labels) + + def call(self, inputs): + return self.dense(inputs) + + return _SimpleMLP(num_labels) + + +def simple_multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + + merged = keras.layers.concatenate([input_a, input_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b], outputs=[output_c, output_d]) + return model + + def multi_inputs_multi_outputs_model(): input_a = keras.layers.Input(shape=(16,), name='input_a') input_b = keras.layers.Input(shape=(16,), name='input_b') @@ -202,6 +228,22 @@ def get_predict_dataset(distribution): return dataset +def convert_numpy_to_dataset_with_unknown_cardinality(inputs, + targets=None): + if targets is not None: + input_slices = (inputs, targets) + dummy_op = (lambda inp, target: True) + else: + input_slices = inputs + dummy_op = (lambda inp: True) + + original_dataset = (dataset_ops.Dataset.from_tensor_slices( + input_slices)) + ds_with_unknown_cardinality = (original_dataset.filter(dummy_op). + batch(10, drop_remainder=True)) + return ds_with_unknown_cardinality + + def multi_input_output_model(): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(5,), name='input_b') @@ -216,9 +258,12 @@ def multi_input_output_model(): return model +# TODO(josh11b): Add combinations.one_device_strategy_gpu once it works with +# TestDistributionStrategyWithCallbacks.test_callbacks_in_predict. strategies_minus_tpu = [ combinations.default_strategy, combinations.one_device_strategy, + combinations.one_device_strategy_gpu, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, combinations.core_mirrored_strategy_with_gpu_and_cpu, @@ -230,54 +275,56 @@ tpu_strategies = [ def strategy_minus_tpu_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu, - mode=['graph', 'eager']) + return combinations.combine(distribution=strategies_minus_tpu, + mode=['graph', 'eager']) def tpu_strategy_combinations(): - return combinations.combine( - distribution=tpu_strategies, - mode=['graph']) + return combinations.combine(distribution=tpu_strategies, + mode=['graph']) def all_strategy_combinations(): return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -def all_strategy_combinations_minus_default(): - strategy_minus_default_combinations = combinations.combine( +def all_strategy_minus_default_and_tpu_combinations(): + return combinations.combine( distribution=[ combinations.one_device_strategy, + combinations.one_device_strategy_gpu, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph', 'eager']) - return strategy_minus_default_combinations + tpu_strategy_combinations() -# TODO(priyag): Add v2 optimizers here. +def all_strategy_combinations_minus_default(): + return (all_strategy_minus_default_and_tpu_combinations() + + tpu_strategy_combinations()) + + def strategy_and_optimizer_combinations(): return combinations.times( all_strategy_combinations(), - combinations.combine( - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn])) - - -def strategy_for_numpy_input_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu + tpu_strategies, - mode=['graph']) + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + combinations.rmsprop_optimizer_keras_v2_fn + ])) class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, parameterized.TestCase): def setUp(self): + super(TestEstimatorDistributionStrategy, self).setUp() self._base_dir = os.path.join(self.get_temp_dir(), 'keras_mirrored_strategy_test') gfile.MakeDirs(self._base_dir) @@ -285,6 +332,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) def tearDown(self): + super(TestEstimatorDistributionStrategy, self).tearDown() writer_cache.FileWriterCache.clear() if os.path.isdir(self._base_dir): gfile.DeleteRecursively(self._base_dir) @@ -429,7 +477,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies # that use per_core_batch_size @@ -460,7 +508,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=None, batch_size=None) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_with_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -506,7 +554,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=1, batch_size=None) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_no_steps_with_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -540,7 +588,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=None, batch_size=3) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_with_steps_with_batch_size(self, distribution): with self.cached_session(): @@ -557,7 +605,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=10, batch_size=13) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): with distribution.scope(): @@ -588,7 +636,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): with distribution.scope(): @@ -640,7 +688,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): with distribution.scope(): @@ -665,6 +713,69 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, self.assertAllEqual([6, 7], outs[0].shape) self.assertAllEqual([6, 7], outs[1].shape) + @combinations.generate(tpu_strategy_combinations()) + def test_predict_with_partial_batch(self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = get_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = get_model() + cpu_model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), dtype=np.float32) + + # As sample size is 10, we batch by 4 so that the last batch is + # a partial batch. Also `predict()` using numpy array as inputs without + # distribution strategy uses entire sample as a single batch. As so, + # we remove parameters `batch_size` and `steps`. + predict_ground_truth = cpu_model.predict(inputs) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + self.assertAllClose( + model_with_ds_strategy.predict(inputs, batch_size=4, steps=3), + predict_ground_truth, + atol=1e-5, + rtol=1e-5) + # Test that `steps` is inferred correctly when final partial batch exists. + self.assertAllClose( + model_with_ds_strategy.predict(inputs, batch_size=4), + predict_ground_truth, + atol=1e-5, + rtol=1e-5) + + @combinations.generate(tpu_strategy_combinations()) + def test_predict_multi_output_model_with_partial_batch( + self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = simple_multi_inputs_multi_outputs_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = simple_multi_inputs_multi_outputs_model() + cpu_model.compile(optimizer, loss) + + input_data, _ = get_multi_inputs_multi_outputs_data() + input_dict = { + 'input_a': input_data['input_a'], + 'input_b': input_data['input_b'], + } + + # As sample size is 200, we batch by 18 so that the last batch is + # a partial batch. Also `fit()` using numpy array as inputs without + # distribution strategy uses entire sample as a single batch. As so, + # we remove parameters `batch_size` and `steps`. + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + self.assertAllClose( + model_with_ds_strategy.predict(input_dict, batch_size=18, steps=12), + cpu_model.predict(input_dict), + atol=1e-4, rtol=1e-4) + class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): @@ -722,7 +833,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, self.assertEqual(interleaved_output.history['val_loss'], [x[0] for x in user_controlled_output]) - self.assertEqual(interleaved_output.history['val_mean_absolute_error'], + val_mean_absolute_error = interleaved_output.history.get( + 'val_mean_absolute_error') + if not val_mean_absolute_error: + # The name of the metric changed in TF2.0 + val_mean_absolute_error = interleaved_output.history['val_mae'] + self.assertEqual(val_mean_absolute_error, [x[1] for x in user_controlled_output]) self.assertEqual(interleaved_output.history['val_categorical_accuracy'], [x[2] for x in user_controlled_output]) @@ -768,6 +884,95 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) + @combinations.generate(all_strategy_combinations()) + def test_fit_eval_and_predict_methods_on_dataset_without_steps( + self, distribution): + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((1000, 3), dtype=np.float32) + targets = np.zeros((1000, 4), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + fit_with_numpy = model.fit(inputs, targets, epochs=1, + batch_size=10).history + eval_with_numpy = model.evaluate(inputs, targets, batch_size=10) + predict_with_numpy = model.predict(inputs, batch_size=10) + + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.batch(10, drop_remainder=True) + fit_with_ds = model.fit(dataset, epochs=1).history + eval_with_ds = model.evaluate(dataset) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + predict_dataset = predict_dataset.batch(10, drop_remainder=True) + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + + @combinations.generate(all_strategy_combinations()) + def test_on_dataset_with_unknown_cardinality_without_steps( + self, distribution): + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((1000, 3), dtype=np.float32) + targets = np.zeros((1000, 4), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + fit_with_numpy = model.fit(inputs, targets, epochs=1, + batch_size=10).history + fit_with_numpy_multiple_epochs = model.fit( + inputs, targets, epochs=2, batch_size=10).history + eval_with_numpy = model.evaluate(inputs, targets, batch_size=10) + predict_with_numpy = model.predict(inputs, batch_size=10) + + dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs, targets) + predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs) + + self.assertEqual(keras.backend.get_value(cardinality.cardinality( + dataset)), cardinality.UNKNOWN) + self.assertEqual(keras.backend.get_value(cardinality.cardinality( + predict_dataset)), cardinality.UNKNOWN) + + eval_with_ds = model.evaluate(dataset) + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + + if (distributed_training_utils.is_tpu_strategy(distribution) and + distribution.extended.steps_per_run != 1): + with self.assertRaisesRegexp(ValueError, '`steps_per_epoch` ' + 'should be specified'): + fit_with_ds = model.fit(dataset, epochs=1) + else: + fit_with_ds = model.fit(dataset, + epochs=1).history + fit_with_ds_multiple_epochs = model.fit(dataset, + epochs=2).history + self.assertAllClose( + fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + fit_with_numpy_multiple_epochs, + fit_with_ds_multiple_epochs, atol=1e-4, rtol=1e-4) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): @@ -955,230 +1160,163 @@ class TestDistributionStrategyWithDatasets(test.TestCase, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) + @combinations.generate(tpu_strategy_combinations()) + def test_predict_with_dataset_with_partial_batch(self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' -class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + with distribution.scope(): + model_with_ds_strategy = get_model() + model_with_ds_strategy.compile(optimizer, loss) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_shape_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2)) - b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) - x = values.DistributedValues(device_map, (a, b)) - y = values.DistributedValues(device_map, (a, a)) - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - with distribution.scope(): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) + cpu_model = get_model() + cpu_model.compile(optimizer, loss) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) - b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) - x = values.DistributedValues(device_map, (a, b)) - y = values.DistributedValues(device_map, (a, a)) - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - with distribution.scope(): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs)) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_unsupported_features(self, distribution): + # As sample size is 10, we batch by 4 so that the last batch is + # a partial batch. + dataset_with_partial_batch = dataset.batch(4) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + + self.assertAllClose( + model_with_ds_strategy.predict(dataset_with_partial_batch, steps=3), + cpu_model.predict(dataset_with_partial_batch, steps=3), + atol=1e-5, rtol=1e-5) + + @combinations.generate(tpu_strategy_combinations()) + def test_predict_multi_output_model_with_dataset_with_partial_batch( + self, distribution): with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + with distribution.scope(): - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics) + model_with_ds_strategy = simple_multi_inputs_multi_outputs_model() + model_with_ds_strategy.compile(optimizer, loss) - dataset = get_dataset(distribution) + cpu_model = simple_multi_inputs_multi_outputs_model() + cpu_model.compile(optimizer, loss) - # Test with validation split - with self.assertRaisesRegexp( - ValueError, '`validation_split` argument is not ' - 'supported when input `x` is a dataset or a ' - 'dataset iterator.+'): - model.fit(dataset, - epochs=1, steps_per_epoch=2, verbose=0, - validation_split=0.5, validation_steps=2) - - # Test with sample weight. - sample_weight = np.random.random((10,)) - with self.assertRaisesRegexp( - ValueError, '`sample_weight` argument is not supported when input ' - '`x` is a dataset or a dataset iterator.'): - model.fit( - dataset, - epochs=1, - steps_per_epoch=2, - verbose=0, - sample_weight=sample_weight) - - # Test with not specifying the `steps` argument. - with self.assertRaisesRegexp( - ValueError, 'the `steps_per_epoch` argument'): - model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, 'the `steps` argument'): - model.evaluate(dataset, verbose=0) - - with self.assertRaisesRegexp(ValueError, 'the `steps` argument'): - model.predict(dataset, verbose=0) + input_data, _ = get_multi_inputs_multi_outputs_data() + input_dict = { + 'input_a': input_data['input_a'], + 'input_b': input_data['input_b'], + } - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_calling_with_unsupported_predefined_callbacks(self, distribution): - with self.cached_session(): - with distribution.scope(): - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics) + dataset = dataset_ops.Dataset.from_tensor_slices(input_dict) - dataset = get_dataset(distribution) + # As sample size is 200, we batch by 18 using 12 steps per epoch so + # that the last batch is a partial batch. + dataset_with_partial_batch = dataset.batch(18) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) - def schedule(_): - return 0.001 - with self.assertRaisesRegexp(ValueError, - 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + self.assertAllClose( + model_with_ds_strategy.predict(dataset_with_partial_batch, steps=12), + cpu_model.predict(dataset_with_partial_batch, steps=12), + atol=1e-4, rtol=1e-4) - with self.assertRaisesRegexp(ValueError, - 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.ReduceLROnPlateau()]) +class TestRegularizerLoss(test.TestCase, parameterized.TestCase): + class IdentityRegularizer(keras.regularizers.Regularizer): -class TestDistributionStrategyWithLossMasking(test.TestCase, - parameterized.TestCase): + def __call__(self, x): + return array_ops.identity(x) - # TODO(priyag): Enable all strategies for this test. Currently it does not - # work for TPU due to some invalid datatype. - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_masking(self, distribution): - with self.cached_session(): - np.random.seed(1337) - x = np.array([[[1], [1]], [[0], [0]]]) - with distribution.scope(): - model = keras.models.Sequential() - model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01)) - y = np.array([[[1], [1]], [[1], [1]]]) - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) - self.assertEqual(hist.history['loss'][0], 0) + class AddLayer(keras.layers.Layer): + def build(self, _): + self.v = self.add_weight( + 'v', (), initializer='ones', + regularizer=TestRegularizerLoss.IdentityRegularizer()) -class TestDistributionStrategyWithNormalizationLayer( - test.TestCase, parameterized.TestCase): + def call(self, inputs): + return inputs + self.v - @combinations.generate(combinations.times( - all_strategy_combinations(), - combinations.combine(fused=[True, False]))) - def test_batchnorm_correctness(self, distribution, fused): - with self.cached_session(): - with distribution.scope(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization( - input_shape=(10,), momentum=0.8, fused=fused) - model.add(norm) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01)) - - # centered on 5.0, variance 10.0 - x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) - x = x.astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 32, distribution) + @staticmethod + def loss_fn(_, y_pred): + return math_ops.reduce_mean(y_pred) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_regularizer_loss(self, distribution): + batch_size = 2 + if not distributed_training_utils.global_batch_size_supported(distribution): + batch_size //= distribution.num_replicas_in_sync + + # Given an input x, which is always 1, and variable v, this model computes + # Loss=x+v+regularizer_loss, where regularizer_loss=v and the variable is + # initialized to 1. Therefore, this model computes Loss=1+2v, and so the + # gradient dLoss/dv = 2. This gradient of 2 is averaged over all examples + # in a batch and then multiplied by the learning rate of 1. As a result, + # the model update for one batch should subtract 2 from v, resulting in v + # being -1. If the regularizer loss is not scaled correctly by number of + # replicas, the variable value will be incorrect when number of replicas + # >1. For e.g. it will be -2 if num replicas = 2. + with distribution.scope(): + x = keras.layers.Input(shape=(), batch_size=batch_size) + y = TestRegularizerLoss.AddLayer()(x) + model = keras.models.Model(inputs=x, outputs=y) + opt = gradient_descent_keras.SGD(1.) + model.compile(opt, loss=TestRegularizerLoss.loss_fn) + model.fit( + x=np.array([[1.], [1.]], dtype=np.float32), + y=np.array([[1.], [1.]], dtype=np.float32), + batch_size=batch_size) + v = model.get_weights()[0] + self.assertEqual(-1.0, v) + + +class TestDistributionStrategyWithKerasModels(test.TestCase, + parameterized.TestCase): - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) - predict_dataset = predict_dataset.repeat(100) - predict_dataset = batch_wrapper(predict_dataset, 32, distribution) + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_sequential_model(self, distribution): + with distribution.scope(): + model = simple_sequential_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) - out = model.predict(predict_dataset, steps=2) - out -= keras.backend.eval(norm.beta) - out /= keras.backend.eval(norm.gamma) - np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) - np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + inputs = np.zeros((20, 10), np.float32) + targets = np.zeros((20, 2), np.float32) + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) -class TestDistributionStrategyValidation(test.TestCase, - parameterized.TestCase): + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_functional_model(self, distribution): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - @combinations.generate(all_strategy_combinations_minus_default()) - def test_layer_outside_scope(self, distribution): - with self.cached_session(): - with self.assertRaisesRegexp( - ValueError, 'was not created in the distribution strategy'): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - with distribution.scope(): - model = keras.Model(x, y) - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) @combinations.generate(all_strategy_combinations_minus_default()) - def test_model_outside_scope(self, distribution): - with self.cached_session(): - with self.assertRaisesRegexp( - ValueError, 'was not created in the distribution strategy'): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) - with distribution.scope(): - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics) + def test_distribution_strategy_one_dimensional(self, distribution): + with distribution.scope(): + inp = keras.layers.Input(shape=(10,)) + out = keras.layers.Dense(3, activation='softmax')(inp) + model = keras.Model(inputs=[inp], outputs=[out]) + model.compile( + optimizer='rmsprop', + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy'], + ) + + x = np.random.random((64, 10)).astype('float32') + y = np.random.randint(3, size=64) + + model.fit(x, y, epochs=1, steps_per_epoch=2) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/keras_utils_test.py b/tensorflow/contrib/distribute/python/keras_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..da17722f4be46479d2614d81fb6b9af2f7e77e77 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_utils_test.py @@ -0,0 +1,572 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.keras models with callbacks, checkpointing with dist strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import tempfile +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_test as keras_test_lib +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import rmsprop as rms_prop_keras +from tensorflow.python.ops import math_ops +from tensorflow.python.training import gradient_descent + + +class Counter(keras.callbacks.Callback): + """Counts the number of times each callback method was run. + + Attributes: + method_counts: dict. Contains the counts of time each callback method was + run. + """ + + def __init__(self): + self.method_counts = collections.defaultdict(int) + methods_to_count = [ + 'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end', + 'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin', + 'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end', + 'on_test_begin', 'on_test_end', 'on_train_batch_begin', + 'on_train_batch_end', 'on_train_begin', 'on_train_end' + ] + for method_name in methods_to_count: + setattr(self, method_name, + self.wrap_with_counts(method_name, getattr(self, method_name))) + + def wrap_with_counts(self, method_name, method): + + def _call_and_count(*args, **kwargs): + self.method_counts[method_name] += 1 + return method(*args, **kwargs) + + return _call_and_count + + +class TestDistributionStrategyWithCallbacks(test.TestCase, + parameterized.TestCase): + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_fit(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + epochs = 2 + steps_per_epoch = 5 + validation_steps = 3 + + model.fit( + dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=0, + validation_data=dataset, + validation_steps=validation_steps, + callbacks=[counter]) + + if isinstance(distribution, tpu_strategy.TPUStrategy): + # TPU Strategy can have multi step training, from extended.steps_per_run + # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch + steps_per_run = distribution.extended.steps_per_run + num_batch_call_per_epoch = steps_per_epoch // steps_per_run + if steps_per_epoch % steps_per_run: + num_batch_call_per_epoch += 1 + else: + num_batch_call_per_epoch = steps_per_epoch + + self.assertDictEqual( + counter.method_counts, { + 'on_batch_begin': epochs * num_batch_call_per_epoch, + 'on_batch_end': epochs * num_batch_call_per_epoch, + 'on_epoch_begin': epochs, + 'on_epoch_end': epochs, + 'on_test_batch_begin': epochs * validation_steps, + 'on_test_batch_end': epochs * validation_steps, + 'on_test_begin': epochs, + 'on_test_end': epochs, + 'on_train_batch_begin': epochs * num_batch_call_per_epoch, + 'on_train_batch_end': epochs * num_batch_call_per_epoch, + 'on_train_begin': 1, + 'on_train_end': 1 + }) + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_eval(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + model.evaluate(dataset, steps=5, callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_test_batch_begin': 5, + 'on_test_batch_end': 5, + 'on_test_begin': 1, + 'on_test_end': 1 + }) + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_predict(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + model.predict( + keras_test_lib.get_predict_dataset(dataset), + steps=5, + callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_predict_batch_begin': 5, + 'on_predict_batch_end': 5, + 'on_predict_begin': 1, + 'on_predict_end': 1 + }) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch( + self, distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp( + ValueError, 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch( + self, distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp( + ValueError, 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras_test_lib.get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + dataset = keras_test_lib.get_dataset(distribution) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + validation_split=0.5, + validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument for dataset with infinite + # cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.predict(dataset, verbose=0) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras_test_lib.get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + dataset = keras_test_lib.get_dataset(distribution) + + def schedule(_): + return 0.001 + + with self.assertRaisesRegexp( + ValueError, 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp( + ValueError, 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + + @combinations.generate( + combinations.combine( + distribution=[combinations.one_device_strategy], mode=['graph'])) + def test_distribution_strategy_with_add_metric_add_loss(self, distribution): + with distribution.scope(): + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + + err_msg = ( + 'We currently do not support compiling the model with distribution ' + r'strategy if `model.add_loss\(tensor\)` or ' + r'`model.add_metric\(tensor\)` has been called.') + + # Test with add_metric. + model = keras.models.Model(x, y) + model.add_metric( + math_ops.reduce_sum(y), name='metric_1', aggregation='mean') + with self.assertRaisesRegex(ValueError, err_msg): + model.compile('sgd',) + + # Test with add_loss. + model = keras.models.Model(x, y) + model.add_loss(math_ops.reduce_mean(y)) + with self.assertRaisesRegex(ValueError, err_msg): + model.compile('sgd',) + + @combinations.generate( + combinations.combine( + distribution=[combinations.one_device_strategy], mode=['eager'])) + def test_distribution_strategy_with_run_eagerly(self, distribution): + with distribution.scope(): + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + model = keras.models.Model(x, y) + + err_msg = ('We currently do not support enabling `run_eagerly` with ' + 'distribution strategy.') + with self.assertRaisesRegex(ValueError, err_msg): + model.compile('sgd', run_eagerly=True) + + # TODO(b/124377929): Remove error assertions once subclassed models + # are supported in DistributedStrategy. + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_distribution_strategy_on_subclassed_model(self, distribution): + with distribution.scope(): + + class _SimpleMLP(keras.Model): + + def __init__(self, num_labels): + super(_SimpleMLP, self).__init__() + self.dense = keras.layers.Dense(num_labels) + + def call(self, inputs): + return self.dense(inputs) + + model = _SimpleMLP(3) + + with self.assertRaisesRegexp( + ValueError, + 'We currently do not support distribution strategy with a ' + '`Sequential` model that is created without ' + '`input_shape`/`input_dim` set in its first layer or ' + 'a subclassed model.'): + model.compile('sgd') + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_distribution_strategy_on_deferred_sequential_model( + self, distribution): + with distribution.scope(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu')) + model.add(keras.layers.Dense(3, activation='softmax')) + + with self.assertRaisesRegexp( + ValueError, + 'We currently do not support distribution strategy with a ' + '`Sequential` model that is created without ' + '`input_shape`/`input_dim` set in its first layer or ' + 'a subclassed model.'): + model.compile('sgd') + + +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): + + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_masking(self, distribution): + with self.cached_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + with distribution.scope(): + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class TestDistributionStrategyWithNormalizationLayer(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + combinations.times(keras_test_lib.all_strategy_combinations(), + combinations.combine(fused=[True, False]))) + def test_batchnorm_correctness(self, distribution, fused): + with self.cached_session(): + with distribution.scope(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization( + input_shape=( + 10, + 20, + 30, + ), momentum=0.8, fused=fused) + model.add(norm) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30)) + x = x.astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution) + + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = keras_test_lib.batch_wrapper(predict_dataset, 32, + distribution) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(predict_dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class TestDistributionStrategySaveLoadWeights(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_save_load_h5(self, distribution): + with self.cached_session(): + dataset = keras_test_lib.get_dataset(distribution) + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp('.h5') + model.save_weights(weights_file) + + model_2 = keras_test_lib.get_model() + model_2.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict( + keras_test_lib.get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_save_load_trackable(self, distribution): + # TODO(b/123533246): Enable the test for TPU once bug is fixed + if (isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + self.skipTest('MultiStep TPU Strategy deadlocks with optimizer restore.') + with self.cached_session(): + dataset = keras_test_lib.get_dataset(distribution) + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp() + model.save_weights(weights_file) + + model_2 = keras_test_lib.get_model() + model_2.compile(rms_prop_keras.RMSprop(learning_rate=0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict( + keras_test_lib.get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + +class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase): + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_layer_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + with distribution.scope(): + model = keras.Model(x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_model_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + with distribution.scope(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index f06c9b75644b2890b7657f75e74e4e20a6f15705..16541c7a1e342f5636e238a301d0946d3e0c4bc4 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -220,7 +220,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused - fetches = distribution.unwrap( + fetches = distribution.experimental_local_results( distribution.extended.call_for_each_replica( model_fn, args=(inputs,))) if update_ops_in_cross_replica_mode: @@ -418,13 +418,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy - # it will be single tensor. Using `broadcast` followed by `unwrap` - # gives us the desired initial value structure. + # it will be single tensor. Using `call_for_each_replica` followed + # by `experimental_local_results` gives us the desired initial + # value structure. + not_reduced = distribution.experimental_local_results( + distribution.extended.call_for_each_replica(initial_loss)) initial_loop_values = { "replica_loss_reduced": initial_loss(), "cross_replica_loss_reduced": initial_loss(), - "cross_replica_loss_not_reduced": - distribution.unwrap(distribution.broadcast(initial_loss())) + "cross_replica_loss_not_reduced": not_reduced, } ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2, @@ -468,11 +470,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def _verify_loss_output(self, initial_loss, loss_output, reduced, distribution): if not reduced: - self.assertLen(distribution.unwrap(loss_output), + self.assertLen(distribution.experimental_local_results(loss_output), distribution.num_replicas_in_sync) loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output) else: - unwrapped_output = distribution.unwrap(loss_output) + unwrapped_output = distribution.experimental_local_results(loss_output) self.assertLen(unwrapped_output, 1) loss_tensor = unwrapped_output[0] self.assertEqual(initial_loss.dtype, loss_tensor.dtype) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 2e23a51ee56ed1388a4387a51342aabce6d24bed..5391e083fc9b3ed99cc64bbed11bdeb8dea07f93 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy @@ -192,16 +190,6 @@ class MirroredExtended(CoreMirroredExtended): """ return input_lib.DatasetIterator(dataset, self._input_workers) - def _distribute_dataset(self, dataset_fn): - if self._local_mode: - return input_lib.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, 0) - else: - return input_lib.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), - self._input_workers, - auto_shard=self._auto_shard_dataset) - # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index d6337d106fced921b8bda0a2faac99c2a77fab8e..8a1772b7f225bce18a96876f2585eb120c71a979 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import json import sys from absl.testing import parameterized @@ -29,17 +30,19 @@ from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.framework import func_graph from tensorflow.python.framework import dtypes +from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core @@ -103,7 +106,7 @@ class MirroredTwoDeviceDistributionTest( expected = sum(range(distribution.num_replicas_in_sync)) self.assertEqual(expected, self.evaluate(reduced)) - def testMakeInputFnIterator(self, distribution): + def testMakeInputFnIteratorWithDataset(self, distribution): dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] @@ -116,12 +119,33 @@ class MirroredTwoDeviceDistributionTest( self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) + # TODO(b/124344198): Re-enable after fixing this flaky test. + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): + def fn(): + dataset = dataset_ops.Dataset.range(2).interleave( + (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [[i, i] for i in range(0, 10)] + + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values, test_reinitialize=False) + def testNumpyIterator(self, distribution): self._test_numpy_iterator(distribution) def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) + def testRun(self, distribution): + self._test_run(distribution) + def testAllReduceSum(self, distribution): self._test_all_reduce_sum(distribution) @@ -140,6 +164,9 @@ class MirroredTwoDeviceDistributionTest( def testAllReduceMeanGradientTape(self, distribution): self._test_all_reduce_mean_gradient_tape(distribution) + def testSummaryForReplicaZeroOnly(self, distribution): + self._test_summary_for_replica_zero_only(distribution) + def one_device_combinations(): return combinations.combine( @@ -169,6 +196,9 @@ class MirroredOneDeviceDistributionTest( def testCallAndMergeExceptions(self, distribution): self._test_call_and_merge_exceptions(distribution) + def testRun(self, distribution): + self._test_run(distribution) + def testAllReduceSum(self, distribution): self._test_all_reduce_sum(distribution) @@ -219,7 +249,7 @@ class MirroredStrategyVariableCreatorStackTest( distribution.scope(), \ variable_scope.variable_creator_scope(main_thread_creator): result = distribution.extended.call_for_each_replica(model_fn) - result = distribution.unwrap(result) + result = distribution.experimental_local_results(result) expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) @@ -239,7 +269,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase): with distribution.scope(): in_scope = ops.executing_eagerly_outside_functions() in_model_fn = distribution.extended.call_for_each_replica(model_fn) - unwrapped = distribution.unwrap(in_model_fn) + unwrapped = distribution.experimental_local_results(in_model_fn) self.assertEqual(in_scope, unwrapped[0]) self.assertEqual(in_scope, originally) @@ -247,10 +277,32 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase): with func_graph.FuncGraph("fg").as_default(), distribution.scope(): in_scope = ops.executing_eagerly_outside_functions() in_model_fn = distribution.extended.call_for_each_replica(model_fn) - unwrapped = distribution.unwrap(in_model_fn) + unwrapped = distribution.experimental_local_results(in_model_fn) self.assertEqual(in_scope, unwrapped[0]) self.assertEqual(in_scope, originally) + def testFunctionInCallForEachReplicaNoMergeCall(self, distribution): + @def_function.function + def model_fn(): + return 0. + + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual((0., 0.), self.evaluate(result.values)) + + def testFunctionInCallForEachReplicaWithMergeCall(self, distribution): + def merge_fn(_): + pass + + @def_function.function + def model_fn(): + ds_context.get_replica_context().merge_call(merge_fn) + return 0. + + with distribution.scope(): + with self.assertRaisesRegexp( + RuntimeError, "`merge_call` called while defining a new graph."): + distribution.extended.call_for_each_replica(model_fn) @combinations.generate(combinations.combine( distribution=[ @@ -412,7 +464,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEqual("var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) self.assertEqual("common/var1:0", v1.name) - self.assertIsInstance(v2, values.ReplicaLocalVariable) + self.assertIsInstance(v2, values.SyncOnReadVariable) self.assertEqual("common/var2:0", v2.name) self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) @@ -449,7 +501,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEqual("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) self.assertEqual("main/common/var1:0", v1.name) - self.assertIsInstance(v2, values.ReplicaLocalVariable) + self.assertIsInstance(v2, values.SyncOnReadVariable) self.assertEqual("main/common/var2:0", v2.name) self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) @@ -589,6 +641,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): aggregation="invalid") def testNonMatchingVariableCreation(self, distribution): + self.skipTest("b/123075960") def model_fn(name): v = variable_scope.variable(1.0, name=name) ds_context.get_replica_context().merge_call(lambda _: _) @@ -600,7 +653,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with self.assertRaises(RuntimeError): _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) - def testReplicaLocalVariable(self, distribution): + def testSyncOnReadVariable(self, distribution): all_v_sum = {} all_v_mean = {} components_sum = {} @@ -616,8 +669,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): 4.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.MEAN) - self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) - self.assertTrue(isinstance(v_mean, values.ReplicaLocalVariable)) + self.assertIsInstance(v_sum, values.SyncOnReadVariable) + self.assertIsInstance(v_mean, values.SyncOnReadVariable) updates = [v_sum.assign_add(2.0 + replica_id), v_mean.assign(6.0 * replica_id)] all_v_sum[replica_id] = v_sum @@ -631,7 +684,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return updates, v_sum, v_mean, c_sum, c_mean with distribution.scope(): - # Create "sum" and "mean" versions of ReplicaLocalVariables. + # Create "sum" and "mean" versions of SyncOnReadVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( distribution.extended.call_for_each_replica(model_fn)) # Should see the same wrapping instance in all replicas. @@ -648,7 +701,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Apply updates self.evaluate(variables.global_variables_initializer()) - self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)]) + self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension + for y in distribution.experimental_local_results(x)]) expected_sum = 0.0 expected_mean = 0.0 for i, d in enumerate(distribution.extended.worker_devices): @@ -694,16 +748,16 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEqual(2, len(result)) for v in result: self.assertIsInstance(v, values.DistributedValues) - _, v1 = distribution.unwrap(v) + _, v1 = distribution.experimental_local_results(v) self.assertStartsWith(v1._op.name, "replica_1/") - def testReplicaLocalVariableUpdate(self, distribution): + def testSyncOnReadVariableUpdate(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + self.assertIsInstance(v_sum, values.SyncOnReadVariable) return v_sum def update(var, value): @@ -714,7 +768,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Initialize variables. self.evaluate(variables.global_variables_initializer()) - # Assert that the aggregated value of the replica local vars is the sum + # Assert that the aggregated value of the sync on read var is the sum # of the individual values before running the update ops. self.assertEqual(1.0, self.evaluate(ret_v_sum.get( distribution.extended.worker_devices[0]).read_value())) @@ -724,7 +778,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): update_ops = distribution.extended.update( ret_v_sum, update, args=(5.0,), group=False) self.evaluate(update_ops) - # Assert that the aggregated value of the replica local vars is the sum + # Assert that the aggregated value of the sync on read vars is the sum # of the individual values after running the update ops. self.assertEqual(5.0, self.evaluate(ret_v_sum.get( distribution.extended.worker_devices[0]).read_value())) @@ -733,11 +787,11 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def testVarDistributeStrategy(self, distribution): with distribution.scope(): mirrored = variable_scope.variable(1.0) - replica_local = variable_scope.variable( + sync_on_read = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ) self.assertIs(distribution, mirrored.distribute_strategy) - self.assertIs(distribution, replica_local.distribute_strategy) + self.assertIs(distribution, sync_on_read.distribute_strategy) @combinations.generate(combinations.combine( @@ -763,7 +817,7 @@ class MirroredStrategyNameScopeTest(test.TestCase): self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = distribution.unwrap(v) + v0, v1 = distribution.experimental_local_results(v) self.assertEqual("main/foo/" + name + ":0", v0.name) self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) @@ -780,7 +834,7 @@ class MirroredStrategyNameScopeTest(test.TestCase): self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = distribution.unwrap(v) + v0, v1 = distribution.experimental_local_results(v) self.assertEqual("foo/" + name + ":0", v0.name) self.assertEqual("replica_1/foo/" + name + ":0", v1.name) @@ -807,9 +861,9 @@ class MirroredStrategyNameScopeTest(test.TestCase): result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = distribution.unwrap(a) - b0, b1 = distribution.unwrap(result_b) - c0, c1 = distribution.unwrap(result_c) + a0, a1 = distribution.experimental_local_results(a) + b0, b1 = distribution.experimental_local_results(result_b) + c0, c1 = distribution.experimental_local_results(result_c) self.assertEqual("main/a:0", a0.name) self.assertEqual("main/a/replica_1:0", a1.name) self.assertEqual("main/b:0", b0.name) @@ -836,9 +890,9 @@ class MirroredStrategyNameScopeTest(test.TestCase): result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = distribution.unwrap(a) - b0, b1 = distribution.unwrap(result_b) - c0, c1 = distribution.unwrap(result_c) + a0, a1 = distribution.experimental_local_results(a) + b0, b1 = distribution.experimental_local_results(result_b) + c0, c1 = distribution.experimental_local_results(result_c) self.assertEqual("a:0", a0.name) self.assertEqual("a/replica_1:0", a1.name) self.assertEqual("b:0", b0.name) @@ -908,7 +962,7 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "You must specify an aggregation method to update a " "MirroredVariable in Replica Context."): - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) def testAssignMirroredVarReplicaContextWithSum(self, distribution): @@ -930,7 +984,7 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non-DistributedValues value 5.0 cannot be reduced " "with the given reduce op ReduceOp.SUM."): - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) def testAssignMirroredVarCrossDeviceContext(self, distribution): @@ -962,7 +1016,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) self.assertEqual(0.5, self.evaluate(mirrored_var)) @@ -980,7 +1034,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) self.assertEqual(5.0, self.evaluate(mirrored_var)) @@ -1023,7 +1077,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) self.assertEqual(1.5, self.evaluate(mirrored_var)) @@ -1041,7 +1095,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) self.assertEqual(6.0, self.evaluate(mirrored_var)) @@ -1076,7 +1130,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) self.assertEqual(4.5, self.evaluate(mirrored_var)) @@ -1094,7 +1148,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(distribution.unwrap( + self.evaluate(distribution.experimental_local_results( distribution.extended.call_for_each_replica(model_fn))) self.assertEqual(4.0, self.evaluate(mirrored_var)) @@ -1104,7 +1158,7 @@ class MirroredVariableUpdateTest(test.TestCase): combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=["graph", "eager"])) -class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): +class MirroredAndSyncOnReadVariableInitializerTest(test.TestCase): def testAssignMirroredVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized @@ -1130,17 +1184,16 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + self.assertIsInstance(v_sum, values.SyncOnReadVariable) return v_sum with distribution.scope(): - replica_local_var = distribution.extended.call_for_each_replica( + sync_on_read_var = distribution.extended.call_for_each_replica( model_fn) - self.assertTrue(isinstance(replica_local_var, - values.ReplicaLocalVariable)) - self.assertFalse(self.evaluate(replica_local_var.is_initialized())) - self.evaluate(replica_local_var.initializer) - self.assertTrue(self.evaluate(replica_local_var.is_initialized())) + self.assertIsInstance(sync_on_read_var, values.SyncOnReadVariable) + self.assertFalse(self.evaluate(sync_on_read_var.is_initialized())) + self.evaluate(sync_on_read_var.initializer) + self.assertTrue(self.evaluate(sync_on_read_var.is_initialized())) @combinations.generate(combinations.combine( @@ -1148,7 +1201,7 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=["graph", "eager"])) -class ReplicaLocalVariableAssignTest(test.TestCase): +class SyncOnReadVariableAssignTest(test.TestCase): def testAssignReplicaLocalVarSumAggregation(self, distribution): def model_fn(): @@ -1159,24 +1212,23 @@ class ReplicaLocalVariableAssignTest(test.TestCase): return v_sum with distribution.scope(): - replica_local_var = distribution.extended.call_for_each_replica(model_fn) - self.assertTrue(isinstance(replica_local_var, - values.ReplicaLocalVariable)) + sync_on_read_var = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(sync_on_read_var, values.SyncOnReadVariable) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. self.assertEqual(2.0, self.evaluate( - distribution.extended.read_var(replica_local_var))) + distribution.extended.read_var(sync_on_read_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. - tlv_ops = replica_local_var.assign(6.0) + tlv_ops = sync_on_read_var.assign(6.0) self.evaluate(tlv_ops) - # On reading the replica local var we should get the assigned value back. + # On reading the sync on read var we should get the assigned value back. # The value on all the replicas are added before being returned by # `read_var`. self.assertEqual(6.0, self.evaluate( - distribution.extended.read_var(replica_local_var))) + distribution.extended.read_var(sync_on_read_var))) def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): @@ -1187,21 +1239,20 @@ class ReplicaLocalVariableAssignTest(test.TestCase): return v_sum with distribution.scope(): - replica_local_var = distribution.extended.call_for_each_replica(model_fn) - self.assertTrue(isinstance(replica_local_var, - values.ReplicaLocalVariable)) + sync_on_read_var = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(sync_on_read_var, values.SyncOnReadVariable) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. self.assertEqual(1.0, self.evaluate( - distribution.extended.read_var(replica_local_var))) - tlv_ops = replica_local_var.assign(6.0) + distribution.extended.read_var(sync_on_read_var))) + tlv_ops = sync_on_read_var.assign(6.0) self.evaluate(tlv_ops) - # On reading the replica local var we should get the MEAN of all values + # On reading the sync on read var we should get the MEAN of all values # which is equal to the value assigned. self.assertEqual(6.0, self.evaluate( - distribution.extended.read_var(replica_local_var))) + distribution.extended.read_var(sync_on_read_var))) class MockModel(object): @@ -1416,7 +1467,7 @@ class MultiWorkerMirroredStrategyTest( self.assertEqual(a.device, "/job:worker/task:0") self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - def testMakeInputFnIterator(self, distribution): + def testMakeInputFnIteratorWithDataset(self, distribution): self._configure_distribution_strategy(distribution) dataset_fn = lambda: dataset_ops.Dataset.range(100) num_gpus = context.num_gpus() @@ -1437,6 +1488,32 @@ class MultiWorkerMirroredStrategyTest( self._test_input_fn_iterator( iterator, distribution.extended.worker_devices, expected_values, sess) + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): + self._configure_distribution_strategy(distribution) + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [] + for i in range(0, 100, num_gpus): + expected_values.append([i+j for j in range(num_gpus)] * num_workers) + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess, + test_reinitialize=False) + def testUpdateConfigProto(self, distribution): distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) @@ -1470,6 +1547,31 @@ class MultiWorkerMirroredStrategyTestWithChief( strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) + def testMinimizeLossGraphCoreMirroredStrategyWithOneNode(self): + cluster_spec = {} + cluster_spec["chief"] = self._cluster_spec["chief"] + tf_config = {"cluster": cluster_spec} + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(tf_config)}): + strategy = mirrored_strategy.CoreMirroredStrategy() + self.assertIsInstance(strategy.extended._inferred_cross_device_ops, + cross_device_ops_lib.NcclAllReduce) + self._test_minimize_loss_graph(strategy, learning_rate=0.05) + + def testInitializeFromTFConfig(self): + tf_config = {"cluster": self._cluster_spec} + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(tf_config)}): + strategy = mirrored_strategy.CoreMirroredStrategy() + self.assertEqual( + max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync) + + def testSummaryForReplicaZeroOnly(self): + strategy = mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_summary_for_replica_zero_only(strategy) + def _replica_id(): replica_id = ds_context.get_replica_context().replica_id_in_sync_group diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 16be839e1d155003b9490fbe3da6ab85b7d2d78a..c0651610cafc06a6d5f4206f4e64d27020fae30b 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -23,9 +23,9 @@ import numpy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib -from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example from tensorflow.python.client import session +from tensorflow.python.distribute import one_device_strategy from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index c4622cdd2af2f6a9c936fe554bcc2eb76f805fdc..23f976f8fa4705578da2be414c05f1f14e13c9fa 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -56,7 +56,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) - sess.run(distribution.unwrap(assign)) + sess.run(distribution.experimental_local_results(assign)) # Mean of val across calls to replica_fn(). average_val = [1.0 + 0.5 * (replica_id[0] - 1), 2.0 - 0.5 * (replica_id[0] - 1)] @@ -82,7 +82,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var, assign_op = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([0.0, 0.0], var.eval()) - sess.run(distribution.unwrap(assign_op)) + sess.run(distribution.experimental_local_results(assign_op)) # Mean of val across calls to replica_fn(). average_val = [1.0 + 0.5 * (replica_id[0] - 1), 2.0 - 0.5 * (replica_id[0] - 1)] @@ -155,7 +155,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) - sess.run(distribution.unwrap(assign)) + sess.run(distribution.experimental_local_results(assign)) self.assertAllClose( [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], var.eval()) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index b05aac431f65b4281d9ed9c2fa95c210d55f4008..ce448840f14e3816f1d40328239256fd5acd51bf 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -37,12 +37,16 @@ except ImportError as _error: # pylint: disable=invalid-name from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session +from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib + +original_run_std_server = dc._run_std_server # pylint: disable=protected-access + ASSIGNED_PORTS = set() lock = threading.Lock() @@ -343,9 +347,9 @@ class MockOsEnv(collections.Mapping): def __iter__(self): if not hasattr(self._thread_local, 'dict'): self._thread_local.dict = dict() - for x in self._thread_local.dict.items(): + for x in self._thread_local.dict: yield x - for x in self._dict.items(): + for x in self._dict: yield x def __len__(self): @@ -357,6 +361,22 @@ class MockOsEnv(collections.Mapping): class IndependentWorkerTestBase(test.TestCase): """Testing infra for independent workers.""" + def _make_mock_run_std_server(self): + thread_local = threading.local() + + def _mock_run_std_server(*args, **kwargs): + ret = original_run_std_server(*args, **kwargs) + # Wait for all std servers to be brought up in order to reduce the chance + # of remote sessions taking local ports that have been assigned to std + # servers. Only call this barrier the first time this function is run for + # each thread. + if not getattr(thread_local, 'server_started', False): + self._barrier.wait() + thread_local.server_started = True + return ret + + return _mock_run_std_server + def setUp(self): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ', @@ -409,3 +429,25 @@ class IndependentWorkerTestBase(test.TestCase): def join_independent_workers(self, worker_threads): self._coord.join(worker_threads) + + +def get_tf_config_task(): + return json.loads(os.environ['TF_CONFIG'])['task'] + + +def get_tf_config_cluster_spec(): + return json.loads(os.environ['TF_CONFIG'])['cluster'] + + +def get_task_type(): + return get_tf_config_task()['type'] + + +def get_task_index(): + return get_tf_config_task()['index'] + + +def is_chief(): + return ('chief' not in get_tf_config_cluster_spec() + and get_task_type() == 'worker' + and get_task_index() == 0) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 836cb7cc41b62352fd69a4a209d483ccf0fc498e..13a501394ee1fec2dfc1427f6d16d3a4624d7747 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -18,199 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import input_lib -from tensorflow.python.distribute import numpy_dataset -from tensorflow.python.distribute import values -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.util import nest +from tensorflow.python.distribute import one_device_strategy - -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -class OneDeviceStrategy(distribute_lib.DistributionStrategy): - """A distribution strategy for running on a single device.""" - # TODO(josh11b): Do we wrap values in types to generate errors if you are - # doing something that won't work with other DistributionStrategy - # implementations? - - def __init__(self, device): - super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) - - -class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of OneDeviceStrategy.""" - - def __init__(self, container_strategy, device): - super(OneDeviceExtended, self).__init__(container_strategy) - self._device = device - self._default_device = device - self._input_device = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(self._input_device, [self._device])] - device_map = values.SingleDeviceMap(device) - self._input_workers = input_lib.InputWorkers( - device_map, worker_device_pairs) - - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - if colocate_with is None: - with ops.device(self._device): - return next_creator(*args, **kwargs) - with ops.colocate_with(colocate_with): - return next_creator(*args, **kwargs) - - def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate(colocate_with_variable, self) - - def _make_dataset_iterator(self, dataset): - """Make iterator from dataset without splitting the batch.""" - return input_lib.DatasetIterator(dataset, self._input_workers) - - def _distribute_dataset(self, dataset_fn): - return input_lib.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, 0) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - return input_lib.InputFunctionIterator( - input_fn, self._input_workers, [distribute_lib.InputContext()]) - - def _experimental_make_numpy_dataset(self, numpy_input, session): - return numpy_dataset.one_host_numpy_dataset( - numpy_input, self._input_device, session) - - def _broadcast_to(self, tensor, destinations): - del destinations - return tensor - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = input_lib.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_result = fn(ctx, iterator.get_next()) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - # TODO(priyag): Use max_iterations instead of an explicit counter. - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _call_for_each_replica(self, fn, args, kwargs): - strategy = self._container_strategy() - with ops.device(self._device), _OneDeviceReplicaContext(strategy): - return fn(*args, **kwargs) - - def _reduce_to(self, reduce_op, value, destinations): - del reduce_op, destinations - return value - - def _update(self, var, fn, args, kwargs, group): - # The implementations of _update() and _update_non_slot() are identical - # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - del colocate_with - with ops.device(self._device), distribute_lib.UpdateContext(self._device): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def read_var(self, replica_local_var): - """Read the aggregate value of a replica-local variable.""" - return array_ops.identity(replica_local_var) - - def _unwrap(self, value): - return (value,) - - def value_container(self, value): - return value - - @property - def _num_replicas_in_sync(self): - return 1 - - @property - def worker_devices(self): - return (self._device,) - - @property - def parameter_devices(self): - return (self._device,) - - def non_slot_devices(self, var_list): - del var_list - return (self._device,) - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - """Global and per-replica batching are equivalent for OneDeviceStrategy.""" - return True - - -class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext for OneDeviceStrategy.""" - - def __init__(self, strategy): - zero = constant_op.constant(0, dtypes.int32) - distribute_lib.ReplicaContext.__init__( - self, strategy, replica_id_in_sync_group=zero) - - @property - def devices(self): - return self._strategy.extended.worker_devices +OneDeviceStrategy = one_device_strategy.OneDeviceStrategy diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index f81466a6c75f1cf287cdb00917872f77383c615e..e7eab1f7d772f66554a9d4646d97bffb0e0d578c 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -18,36 +18,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.eager import test -from tensorflow.python.framework import test_util +@combinations.generate(combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.one_device_strategy_gpu], + mode=["eager", "graph"])) class OneDeviceStrategyTest( strategy_test_lib.DistributionTestBase, strategy_test_lib.OneDeviceDistributionTestBase): - def _get_distribution_strategy(self): - return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) - def testReplicaId(self): - self._test_replica_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testMakeInputFnIterator(self): - d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMakeInputFnIteratorWithDataset(self, distribution): dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] input_fn = self._input_fn_to_test_input_context( @@ -55,31 +54,49 @@ class OneDeviceStrategyTest( expected_num_replicas_in_sync=1, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - iterator = d.make_input_fn_iterator(input_fn) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values) + + def testMakeInputFnIteratorWithCallable(self, distribution): + def fn(): + dataset = dataset_ops.Dataset.range(10) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) self._test_input_fn_iterator( - iterator, d.extended.worker_devices, expected_values) + iterator, distribution.extended.worker_devices, expected_values, + test_reinitialize=False) + + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) - @test_util.run_in_graph_and_eager_modes - def testNumpyIterator(self): - self._test_numpy_iterator(self._get_distribution_strategy()) + def testRun(self, distribution): + self._test_run(distribution) - def testAllReduceSum(self): - self._test_all_reduce_sum(self._get_distribution_strategy()) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) - def testAllReduceSumGradients(self): - self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) - def testAllReduceSumGradientTape(self): - self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) - def testAllReduceMean(self): - self._test_all_reduce_mean(self._get_distribution_strategy()) + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) - def testAllReduceMeanGradients(self): - self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) - def testAllReduceMeanGradientTape(self): - self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index e388061b17a9b92dedbbf9839049b13c8575a22c..01bb7fbedc645cf8892f83445d0635916fd4dcca 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -45,7 +45,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def run_step(): return control_flow_ops.group( - distribution.unwrap( + distribution.experimental_local_results( distribution.extended.call_for_each_replica( model_fn, args=(iterator.get_next(),)))) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 0da269204fcf7d8217867ade7ab42a9690661a42..be863322256f7b5b93d91fa2e7ae1754b2494e3d 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -157,8 +157,8 @@ class ParameterServerExtended(CoreParameterServerExtended): cluster_resolver = SimpleClusterResolver( cluster_spec=tfconfig.cluster_spec(), task_type=tfconfig.task_type, - task_index=tfconfig.task_index, - num_accelerators=num_gpus_per_worker) + task_id=tfconfig.task_id, + num_accelerators={'GPU': num_gpus_per_worker}) super(ParameterServerExtended, self).__init__( container_strategy, cluster_resolver=cluster_resolver) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 9e7e201519e5c08adf3b89973d88244a3d0f07ab..9966f90a84bf986462a008468cd267f801990f73 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -90,12 +90,12 @@ def create_test_objects(cluster_spec=None, cluster_resolver = SimpleClusterResolver( cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), task_type=task_type, - task_index=task_id, - num_accelerators=num_gpus) + task_id=task_id, + num_accelerators={'GPU': num_gpus}) target = 'grpc://' + cluster_spec[WORKER][task_id] else: cluster_resolver = SimpleClusterResolver( - ClusterSpec({}), num_accelerators=num_gpus) + ClusterSpec({}), num_accelerators={'GPU': num_gpus}) target = '' distribution = MockCoreParameterServerStrategy(cluster_resolver) @@ -514,7 +514,7 @@ class ParameterServerStrategyTestBase( def update(v, g): return v.assign_sub(0.05 * g, use_locking=True) - one = d.broadcast(constant_op.constant([[1.]])) + one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" @@ -571,6 +571,7 @@ class ParameterServerStrategyTestBase( num_gpus, input_fn, expected_values, + test_reinitialize=True, use_core_strategy=False): distribution, master_target, config = self._get_test_objects( task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) @@ -594,13 +595,14 @@ class ParameterServerStrategyTestBase( for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - sess.run(iterator.initialize()) + if test_reinitialize: + sess.run(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) class ParameterServerStrategyTest( @@ -694,22 +696,31 @@ class ParameterServerStrategyTest( def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( combinations.combine( mode=['graph'], num_gpus=[1, 2], required_gpus=1, - use_core_strategy=[True, False])) - def testMakeInputFnIteratorDistributed(self, num_gpus, use_core_strategy): + use_core_strategy=[True, False], + use_dataset=[True, False])) + def DISABLED_testMakeInputFnIteratorDistributed( + self, num_gpus, use_core_strategy, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [[i+j for j in range(num_gpus)] for i in range(0, 100, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 @@ -719,23 +730,33 @@ class ParameterServerStrategyTest( num_gpus, input_fn, expected_values, + test_reinitialize=use_dataset, use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. @combinations.generate( combinations.combine( mode=['graph'], num_gpus=[1, 2], required_gpus=1, - use_core_strategy=[True, False])) - def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy): + use_core_strategy=[True, False], + use_dataset=[True, False])) + def DISABLED_testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [[i+j for j in range(num_gpus)] for i in range(0, 100, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) # only one worker and pipeline for local. @@ -745,6 +766,7 @@ class ParameterServerStrategyTest( num_gpus, input_fn, expected_values, + test_reinitialize=use_dataset, use_core_strategy=use_core_strategy) @combinations.generate( @@ -784,31 +806,37 @@ class ParameterServerStrategyTest( # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceSum(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_sum(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceSumGradients(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_sum_gradients(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceSumGradientTape(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_sum_gradient_tape(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceMean(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_mean(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceMeanGradients(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_mean_gradients(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceMeanGradientTape(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 2e2ee92b6e20471f367895ea53c0864bb3d1dae7..a562a1758d4b90b0f8d1f35a79e2c6415534fb33 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile + import numpy as np from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.util import event_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util @@ -33,12 +37,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core +from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.training import optimizer +from tensorflow.python.training import training_util +from tensorflow.python.util import nest class _TestException(Exception): @@ -86,6 +95,20 @@ def _merge_call_merge_raises_fn(): ds_context.get_replica_context().merge_call(_call_merge_raises_fn) +def _events_from_logdir(test_case, logdir): + """Reads summary events from log directory.""" + test_case.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + test_case.assertLen(files, 1) + records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + result = [] + for r in records: + event = event_pb2.Event() + event.ParseFromString(r) + result.append(event) + return result + + class DistributionTestBase(test.TestCase): """Some tests that should work with any DistributionStrategy.""" @@ -94,9 +117,6 @@ class DistributionTestBase(test.TestCase): l = core.Dense(1, use_bias=False) def loss(x): - # TODO(josh11b): What if this constant was instead a captured - # value? Would it need to be a value that has been passed - # through d.broadcast()? y = array_ops.reshape(l(x), []) - constant_op.constant(1.) return y * y # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a @@ -107,7 +127,7 @@ class DistributionTestBase(test.TestCase): def update(v, g): return v.assign_sub(0.2 * g) - one = d.broadcast(constant_op.constant([[1.]])) + one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" @@ -152,9 +172,6 @@ class DistributionTestBase(test.TestCase): l = core.Dense(1, use_bias=False) def loss(x): - # TODO(josh11b): What if this constant was instead a captured - # value? Would it need to be a value that has been passed - # through d.broadcast()? y = array_ops.reshape(l(x), []) - constant_op.constant(1.) return y * y @@ -163,7 +180,7 @@ class DistributionTestBase(test.TestCase): def update(v, g): return v.assign_sub(learning_rate * g) - one = d.broadcast(constant_op.constant([[1.]])) + one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" @@ -197,6 +214,39 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) + def _test_summary_for_replica_zero_only(self, d): + logdir = tempfile.mkdtemp() + + def run_fn(): + """Function executed for each replica.""" + with summary_writer.as_default(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + return summary_ops.scalar("a", replica_id) + + with self.cached_session() as sess, d.scope(), \ + summary_ops.always_record_summaries(): + # We need global_step because summary writing op *always* has global_step + # as input, even when we always record summary or never record summary. + global_step = training_util.get_or_create_global_step() + if not context.executing_eagerly(): + # When executing eagerly, variables are initialized immediately after + # creation, and its initializer will be None. + global_step.initializer.run() + summary_writer = summary_ops.create_file_writer(logdir) + output = d.extended.call_for_each_replica(run_fn) + unwrapped = d.unwrap(output) + if not context.executing_eagerly(): + sess.run(summary_writer.init()) + sess.run(unwrapped) + sess.run(summary_writer.close()) + + events = _events_from_logdir(self, logdir) + # There will be 2 entries: 1 summary file header entry, and 1 entry + # written by replica 0. + self.assertLen(events, 2) + self.assertEqual(events[1].summary.value[0].tag, "a") + self.assertEqual(events[1].summary.value[0].simple_value, 0.0) + def _test_replica_id(self, d): with d.scope(): expected_devices = [False] * len(d.extended.worker_devices) @@ -224,7 +274,7 @@ class DistributionTestBase(test.TestCase): dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) def _input_fn_to_test_input_context(self, - dataset_fn, + dataset_or_callable_fn, expected_num_replicas_in_sync, expected_num_input_pipelines, expected_input_pipeline_id): @@ -248,12 +298,12 @@ class DistributionTestBase(test.TestCase): self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) worker_id_counter[0] += 1 - return dataset_fn() + return dataset_or_callable_fn() return _input_fn def _test_input_fn_iterator(self, iterator, devices, expected_values, - sess=None): + sess=None, test_reinitialize=True): evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initialize()) @@ -269,13 +319,14 @@ class DistributionTestBase(test.TestCase): [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - evaluate(iterator.initialize()) + if test_reinitialize: + evaluate(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) def _test_global_step_update(self, strategy): with strategy.scope(): @@ -295,7 +346,7 @@ class DistributionTestBase(test.TestCase): train_ops, value = strategy.extended.call_for_each_replica(model_fn) self.evaluate(strategy.group(train_ops)) - global_step_tensors = strategy.unwrap(value) + global_step_tensors = strategy.experimental_local_results(value) global_step_values = self.evaluate(global_step_tensors) self.assertEqual((1,) * len(global_step_tensors), global_step_values) @@ -314,7 +365,8 @@ class DistributionTestBase(test.TestCase): def run_and_concatenate(strategy, i): x, y = strategy.experimental_run(lambda z: z, i) - x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y))) + x, y = self.evaluate((strategy.experimental_local_results(x), + strategy.experimental_local_results(y))) return np.concatenate(x), np.concatenate(y) x_1, y_1 = run_and_concatenate(strategy, i) @@ -330,6 +382,19 @@ class DistributionTestBase(test.TestCase): class OneDeviceDistributionTestBase(test.TestCase): """Some tests that should work with any one-device DistributionStrategy.""" + def _test_run(self, strategy): + out1 = strategy.experimental_run_v2(lambda: constant_op.constant(4.)) + self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1))) + + out2 = strategy.experimental_run_v2( + lambda x: {"a": x * 2, "b": x * x}, args=(out1,)) + out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) + self.assertAllEqual([8.], out2_vals["a"]) + self.assertAllEqual([16.], out2_vals["b"]) + + out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2) + self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3))) + def _test_all_reduce_sum(self, strategy): self._test_collective_comms( strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) @@ -360,7 +425,8 @@ class OneDeviceDistributionTestBase(test.TestCase): self.evaluate(inputs.initialize()) outputs = self.evaluate( - list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + list(map(strategy.experimental_local_results, + strategy.experimental_run(comm_fn, inputs)))) self.assertAllEqual([expected[0]], outputs[0]) self.assertAllEqual([expected[1]], outputs[1]) @@ -380,7 +446,8 @@ class OneDeviceDistributionTestBase(test.TestCase): self.evaluate(inputs.initialize()) self.assertAllEqual( expected_grads, - self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + self.evaluate(strategy.experimental_local_results( + strategy.experimental_run(step, inputs)))) def _test_collective_comms_gradient_tape( self, strategy, comm_fn, inputs, expected_grads): @@ -397,12 +464,27 @@ class OneDeviceDistributionTestBase(test.TestCase): self.evaluate(inputs.initialize()) self.assertAllEqual( expected_grads, - self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + self.evaluate(strategy.experimental_local_results( + strategy.experimental_run(step, inputs)))) class TwoDeviceDistributionTestBase(test.TestCase): """Some tests that should work with any two-device DistributionStrategy.""" + def _test_run(self, strategy): + out1 = strategy.experimental_run_v2( + lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1) + self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1))) + + out2 = strategy.experimental_run_v2( + lambda x: {"a": x * 2, "b": x * x}, args=(out1,)) + out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) + self.assertAllEqual([2, 4], out2_vals["a"]) + self.assertAllEqual([1, 4], out2_vals["b"]) + + out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2) + self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3))) + def _test_all_reduce_sum(self, strategy): self._test_collective_comms( strategy, _all_sum, @@ -437,7 +519,8 @@ class TwoDeviceDistributionTestBase(test.TestCase): self.evaluate(inputs.initialize()) outputs = self.evaluate( - list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + list(map(strategy.experimental_local_results, + strategy.experimental_run(comm_fn, inputs)))) self.assertAllEqual([expected[0], expected[0]], outputs[0]) self.assertAllEqual([expected[1], expected[1]], outputs[1]) @@ -457,7 +540,8 @@ class TwoDeviceDistributionTestBase(test.TestCase): self.evaluate(inputs.initialize()) self.assertAllEqual( expected_grads, - self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + self.evaluate(strategy.experimental_local_results( + strategy.experimental_run(step, inputs)))) def _test_collective_comms_gradient_tape( self, strategy, comm_fn, inputs, expected_grads): @@ -474,7 +558,8 @@ class TwoDeviceDistributionTestBase(test.TestCase): self.evaluate(inputs.initialize()) self.assertAllEqual( expected_grads, - self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + self.evaluate(strategy.experimental_local_results( + strategy.experimental_run(step, inputs)))) def _all_sum(value): diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index f55e6f0b2088f24e74400d330951a0471425f0eb..04e0af767bfaf94ed6a53ba9f8ed71ae4f9cdc4a 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,733 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy -import functools - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib -from tensorflow.contrib.tpu.python.tpu import topology -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib -from tensorflow.contrib.tpu.python.tpu import training_loop -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import input_lib -from tensorflow.python.distribute import numpy_dataset -from tensorflow.python.distribute import reduce_util -from tensorflow.python.distribute import values -from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import nest - - -def initialize_tpu_system(cluster_resolver=None): - """Initialize the TPU devices in a separate session and graph. - - Args: - cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, - which provides information about the TPU cluster. - Returns: - The tf.contrib.tpu.Topology object for the topology of the TPU cluster. - """ - if cluster_resolver is None: - cluster_resolver = resolver_lib.TPUClusterResolver("") - master = cluster_resolver.master() - - logging.info("Initializing the TPU system.") - session_config = config_pb2.ConfigProto(allow_soft_placement=True) - - with ops.Graph().as_default(): - with session_lib.Session(config=session_config, target=master) as sess: - serialized_topology = sess.run(tpu.initialize_system()) - logging.info("Finished initializing TPU system.") - return topology.Topology(serialized=serialized_topology) - - -def get_tpu_system_metadata(tpu_cluster_resolver): - """Retrieves TPU system metadata given a TPUClusterResolver.""" - master = tpu_cluster_resolver.master() - - # pylint: disable=protected-access - cluster_spec = tpu_cluster_resolver.cluster_spec() - cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None - tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata( - master, - cluster_def=cluster_def, - query_topology=False)) - - return tpu_system_metadata - - -# TODO(jhseu): Deduplicate with MirroredStrategy? -def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring - strategy, device_map, logical_device, real_mirrored_creator, - *args, **kwargs): - # Figure out what collections this variable should be added to. - # We'll add the TPUMirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # TODO(jhseu): Should we have different behavior for different - # synchronization settings? - - # Get aggregation value - # TODO(jhseu): Support aggregation in a replica context. - aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) - if aggregation not in [ - vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA, - ]: - raise ValueError("Invalid variable aggregation mode: {} for variable: {}" - .format(aggregation, kwargs["name"])) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): - devices = device_map.logical_to_actual_devices(logical_device) - value_list = real_mirrored_creator(devices, *args, **kwargs) - result = values.TPUMirroredVariable( - strategy, device_map, value_list, aggregation, - logical_device=logical_device) - - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in value_list: - l.remove(v) - g.add_to_collections(collections, result) - return result - - -class TPUStrategy(distribute_lib.DistributionStrategy): - """TPU distribution strategy implementation.""" - - def __init__(self, - tpu_cluster_resolver=None, - steps_per_run=None, - device_assignment=None, - **kwargs): - """Initializes the TPUStrategy object. - - Args: - tpu_cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, - which provides information about the TPU cluster. - steps_per_run: Number of steps to run on device before returning to the - host. Note that this can have side-effects on performance, hooks, - metrics, summaries etc. - This parameter is only used when Distribution Strategy is used with - estimator or keras. - device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify - the placement of replicas on the TPU cluster. Currently only supports - the usecase of using a single core within a TPU cluster. - **kwargs: Additional experimental flags. Will be removed in future. - """ - super(TPUStrategy, self).__init__(TPUExtended( - self, tpu_cluster_resolver, steps_per_run, device_assignment)) - - self._disable_training_loop_on_host = False - if len(kwargs) > 1: - raise ValueError("TPUStrategy constructor only takes one experimental " - "flag now") - if len(kwargs) == 1: - if "_disable_training_loop_on_host" not in kwargs: - raise ValueError("TPUStrategy constructor does not support arguments: " - "{}".format(kwargs)) - self._disable_training_loop_on_host = ( - kwargs["_disable_training_loop_on_host"]) - - @property - def steps_per_run(self): - """DEPRECATED: use .extended.steps_per_run instead.""" - return self._extended.steps_per_run - - -class TPUExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of TPUStrategy.""" - - def __init__(self, - container_strategy, - tpu_cluster_resolver=None, - steps_per_run=None, - device_assignment=None): - super(TPUExtended, self).__init__(container_strategy) - - if tpu_cluster_resolver is None: - tpu_cluster_resolver = resolver_lib.TPUClusterResolver("") - - if steps_per_run is None: - # TODO(frankchn): Warn when we are being used by DS/Keras and this is - # not specified. - steps_per_run = 1 - - self._tpu_cluster_resolver = tpu_cluster_resolver - self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) - self._device_assignment = device_assignment - - # Device assignment is currently only supported for 1 core case. - if self._device_assignment: - assert isinstance(self._device_assignment, - device_assignment_lib.DeviceAssignment) - if self._device_assignment.num_replicas != 1: - raise ValueError("Device assignment is only supported for a single " - "core single replica case currently.") - if self._device_assignment.num_cores_per_replica != 1: - raise ValueError("Device assignment is only supported for a single " - "core single replica case currently.") - if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]): - raise ValueError("Device assignment is only supported for a single " - "core single replica case currently.") - - # TODO(jhseu): Switch to DeviceAssignment to support pods and model - # parallelism. - self._device_index = { - d.name: i for i, d in enumerate(self._tpu_metadata.devices) - if "device:TPU:" in d.name - } - self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = tuple(sorted(self._device_index.keys())) - # Only create variables for the number of replicas we're running. - self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] - self._device_map = values.ReplicaDeviceMap(self._tpu_devices) - - # For input: - input_device_map = values.ReplicaDeviceMap(tuple( - self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - self._input_workers = input_lib.InputWorkers( - input_device_map, worker_devices) - - # TODO(sourabhbajaj): Remove this once performance of running one step - # at a time is comparable to multiple steps. - self.steps_per_run = steps_per_run - self._require_static_shapes = True - - def _validate_colocate_with_variable(self, colocate_with_variable): - values.validate_colocate_tpu_variable(colocate_with_variable, self) - - def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, - input_shapes, iterations): - """Create an enqueue op for a single host identified using host_id. - - The while_loop op returned will run `iterations` times and in each run - enqueue batches for each shard. - - Args: - host_id: integer, id of the host to run the enqueue ops on. - multi_worker_iterator: MultiWorkerDataIterator to read the input data. - input_shapes: shape of inputs to be enqueue on the queue. This is same as - the value of `nest.flatten(iterator.output_shapes)`. - iterations: integer, number of iterations to be run; determines the - number of batches to be enqueued. - - Returns: - while_loop_op running `iterations` times; in each run we enqueue a batch - on the infeed queue from the host with id `host_id` for each device shard. - """ - host = self.get_host_cpu_device(host_id) - # TODO(sourabhbajaj): Possibly make changes to MultiWorkerDataset - # to work with TPU Prefetch so clean up this code. - iterator = ( - multi_worker_iterator.get_iterator(self.get_host(host_id))._iterator) # pylint: disable=protected-access - - def _infeed_enqueue_ops_fn(): - """Enqueue ops for one iteration.""" - control_deps = [] - sharded_inputs = [] - enqueue_ops = [] - - with ops.device(host): - for _ in range(self.num_replicas_per_host): - # Use control dependencies to ensure a deterministic ordering. - with ops.control_dependencies(control_deps): - inputs = nest.flatten(iterator.get_next()) - control_deps.extend(inputs) - sharded_inputs.append(inputs) - - for core_id, shard_input in enumerate(sharded_inputs): - enqueue_ops.append( - tpu_ops.infeed_enqueue_tuple( - inputs=shard_input, - shapes=input_shapes, - device_ordinal=core_id)) - return enqueue_ops - - def enqueue_ops_loop_body(i): - """Callable for the loop body of the while_loop instantiated below.""" - with ops.control_dependencies(_infeed_enqueue_ops_fn()): - return i + 1 - - with ops.device(host): - enqueue_op_per_host = control_flow_ops.while_loop( - lambda i: i < iterations, - enqueue_ops_loop_body, - [constant_op.constant(0)], - parallel_iterations=1) - - return enqueue_op_per_host - - def _make_dataset_iterator(self, dataset): - """Make iterators for each of the TPU hosts.""" - return input_lib.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) - - def _distribute_dataset(self, dataset_fn): - return input_lib.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), - self._input_workers) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - input_contexts = [] - num_workers = self._input_workers.num_workers - for i in range(num_workers): - input_contexts.append(distribute_lib.InputContext( - num_input_pipelines=num_workers, - input_pipeline_id=i, - num_replicas_in_sync=self._num_replicas_in_sync)) - return input_lib.InputFunctionIterator( - input_fn, self._input_workers, input_contexts) - - def _experimental_make_numpy_dataset(self, numpy_input, session): - return numpy_dataset.one_host_numpy_dataset( - numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)), - session) - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have - # a mechanism to infer the outputs of `fn`. Pending b/110550782. - def _experimental_run_steps_on_iterator( - self, fn, multi_worker_iterator, iterations, initial_loop_values=None): - output_shapes = multi_worker_iterator.output_shapes - shapes = nest.flatten(output_shapes) - if any(not s.is_fully_defined() for s in shapes): - raise ValueError( - "TPU currently requires fully defined shapes. Either use " - "set_shape() on the input tensors or use " - "dataset.batch(..., drop_remainder=True).") - types = nest.flatten(multi_worker_iterator.output_types) - - enqueue_ops = [ - self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, - iterations) - for host_id in range(self.num_hosts)] - - def dequeue_fn(): - dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - return nest.pack_sequence_as(output_shapes, dequeued) - - # Wrap `fn` for repeat. - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - ctx = input_lib.MultiStepContext() - - def run_fn(*args, **kwargs): - """Single step on the TPU device.""" - del args, kwargs - fn_result = fn(ctx, dequeue_fn()) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - if flat_last_step_outputs: - with ops.control_dependencies([fn_result]): - return [array_ops.identity(f) for f in flat_last_step_outputs] - else: - return fn_result - - def iterate_on_tpu(): - return training_loop.repeat(iterations, run_fn, initial_loop_values) - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop and TPU replicate context. This is useful in cases - # where we might need to exit these contexts and get back to the outer - # context to do some things, for e.g. create an op which should be - # evaluated only once at the end of the loop on the host. One such usage - # is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - # pylint: disable=protected-access - if self._container_strategy()._disable_training_loop_on_host: - replicate_inputs = [[]] * self._num_replicas_in_sync - replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) - else: - def rewrite_fn(*args): - """The rewritten step fn running on TPU.""" - del args - replicate_inputs = [[]] * self._num_replicas_in_sync - replicate_outputs = tpu.replicate(run_fn, replicate_inputs) - - # If run_fn has tensor outputs, tpu.replicate returns a list of list. We - # will flatten it in this case. If run_fn has no tensor outputs, - # tpu.replicate returns a list of no_ops, we will keep the output as it - # is. - if isinstance(replicate_outputs[0], list): - replicate_outputs = nest.flatten(replicate_outputs) - - return replicate_outputs - - # TODO(sourabhbajaj): The input to while loop should be based on the - # output type of the step_fn - assert isinstance(initial_loop_values, list) - initial_loop_values = initial_loop_values * self._num_replicas_in_sync - - # Put the while loop op on host 0. - with ops.device(self.get_host_cpu_device(0)): - replicate_outputs = training_loop.repeat(iterations, rewrite_fn, - initial_loop_values) - - del self._outer_control_flow_context - ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - - if self._container_strategy()._disable_training_loop_on_host: - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [x for x in replicate_outputs - if not isinstance(x, ops.Operation)] - - # Outputs are currently of the structure (grouped by device) - # [[output0_device0, output1_device0, output2_device0], - # [output0_device1, output1_device1, output2_device1]] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - last_step_tensor_outputs = [list(x) for x in - zip(*last_step_tensor_outputs)] - else: - if isinstance(replicate_outputs, list): - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [ - x for x in replicate_outputs if not isinstance(x, ops.Operation) - ] - - # Outputs are currently of the structure (flattened) - # [output0_device0, output1_device0, output2_device0, - # output0_device1, output1_device1, output2_device1, - # ...] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync - last_step_tensor_outputs = [ - last_step_tensor_outputs[i::output_num] for i in range(output_num) - ] - else: - # no tensors returned. - last_step_tensor_outputs = [] - - # Convert replicate_outputs to the original dict structure of - # last_step_outputs. - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been reduced, take the first value - # from the list as each value should be the same. Else return the full - # list of values. - # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica - # value. - if reduce_op is not None: - # TODO(priyag): Should this return the element or a list with 1 element - last_step_tensor_outputs_dict[name] = output[0] - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - - return ctx - - def _call_for_each_replica(self, fn, args, kwargs): - # TODO(jhseu): Consider making it so call_for_each_replica implies that - # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - with _TPUReplicaContext(self._container_strategy()): - return fn(*args, **kwargs) - - def _experimental_initialize_system(self): - """Experimental method added to be used by Estimator. - - This is a private method only to be used by Estimator. Other frameworks - should directly be calling `tf.contrib.distribute.initialize_tpu_system` - """ - initialize_tpu_system(self._tpu_cluster_resolver) - - def _create_variable(self, next_creator, *args, **kwargs): - """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" - colocate_with = kwargs.pop("colocate_with", None) - if colocate_with is None: - device_map = self._device_map - logical_device = 0 # TODO(josh11b): Get logical device from scope here. - elif isinstance(colocate_with, numpy_dataset.SingleDevice): - with ops.device(colocate_with.device): - return next_creator(*args, **kwargs) - else: - device_map = colocate_with.device_map - logical_device = colocate_with.logical_device - - def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - value_list = [] - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = value_list[0].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - value_list[0].value()) - else: - def initial_value_fn(device=d): - with ops.device(device): - return array_ops.identity(value_list[0].initial_value) - kwargs["initial_value"] = initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) - assert not isinstance(v, values.TPUMirroredVariable) - value_list.append(v) - return value_list - - return _create_tpu_mirrored_variable( - self._container_strategy(), device_map, logical_device, - _real_mirrored_creator, *args, **kwargs) - - def _reduce_to(self, reduce_op, value, destinations): - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if reduce_op == reduce_util.ReduceOp.MEAN: - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self._num_replicas_in_sync) - elif reduce_op != reduce_util.ReduceOp.SUM: - raise NotImplementedError( - "Currently only support sum & mean in TPUStrategy.") - return tpu_ops.cross_replica_sum(value) - - if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerReplica or - # Mirrored values. For example, the same value could be present on all - # replicas in which case `value` would be a single value or value could - # be 0. - return cross_device_ops_lib.reduce_non_distributed_value( - reduce_op, self._device_map, value, destinations) - - # Validate that the destination is same as the host device - # Note we don't do this when in replicate context as the reduction is - # performed on the TPU device itself. - devices = cross_device_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - assert device_util.canonicalize(devices[0]) == device_util.canonicalize( - self._host_device) - else: - raise ValueError("Multiple devices are not supported for TPUStrategy") - - output = math_ops.add_n(value) - if reduce_op == reduce_util.ReduceOp.MEAN: - return output * (1. / len(value)) - return output - - def _update(self, var, fn, args, kwargs, group): - assert isinstance(var, values.TPUMirroredVariable) - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if group: - return fn(var, *args, **kwargs) - else: - return (fn(var, *args, **kwargs),) - - # Otherwise, we revert to MirroredStrategy behavior and update each variable - # directly. - updates = [] - for i, (d, v) in enumerate(zip(var.devices, var.values)): - name = "update_%d" % i - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - # If args and kwargs are not mirrored, the value is returned as is. - updates.append(fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs))) - return values.update_regroup(self, self._device_map, updates, group) - - def read_var(self, var): - assert isinstance(var, values.TPUMirroredVariable) - return var.read_value() - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - return tuple(val.get(device=d) for d in sorted(val.devices)) - elif isinstance(val, list): - # TODO(josh11b): We need to remove this case; per device values should - # be represented using a PerReplica wrapper instead of a list with - # one entry per device. - return tuple(val) - elif isinstance(val, values.TPUMirroredVariable): - # pylint: disable=protected-access - if values._enclosing_tpu_context() is not None: - return (val,) - return val.values - return (val,) - - def value_container(self, value): - return value - - def _broadcast_to(self, tensor, destinations): - del destinations - return tensor - - @property - def num_hosts(self): - if self._device_assignment is None: - return self._tpu_metadata.num_hosts - - return len(set([self._device_assignment.host_device(r) - for r in range(self._device_assignment.num_replicas)])) - - @property - def num_replicas_per_host(self): - if self._device_assignment is None: - return self._tpu_metadata.num_of_cores_per_host - - # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed - # as the computation of num_replicas_per_host is not a constant - # when using device_assignment. This is a temporary workaround to support - # StatefulRNN as everything is 1 in that case. - # This method needs to take host_id as input for correct computation. - max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // - self._device_assignment.num_cores_per_replica) - models_per_host = min(self._device_assignment.num_replicas, - max_models_per_host) - return models_per_host * self._device_assignment.num_cores_per_replica - - @property - def _num_replicas_in_sync(self): - if self._device_assignment is None: - return self._tpu_metadata.num_cores - return (self._device_assignment.num_replicas * - self._device_assignment.num_cores_per_replica) - - @property - def experimental_between_graph(self): - return False - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - @property - def worker_devices(self): - return self._tpu_devices - - @property - def parameter_devices(self): - return self._tpu_devices - - def non_slot_devices(self, var_list): - return self._host_device - - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - del colocate_with - with ops.device(self._host_device), distribute_lib.UpdateContext( - self._host_device): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def get_host(self, host_id): - if self._tpu_cluster_resolver.get_master() in ("", "local"): - return "/replica:0/task:0" - job_name = self._tpu_cluster_resolver.get_job_name() or "tpu_worker" - return "/job:%s/task:%d" % (job_name, host_id) - - def get_host_cpu_device(self, host_id): - return self.get_host(host_id) + "/device:CPU:0" - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - del cluster_spec, task_type, task_id - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - updated_config.isolate_session_state = True - cluster_spec = self._tpu_cluster_resolver.cluster_spec() - if cluster_spec: - updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) - return updated_config - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. - - `distribute_dataset` and `make_input_fn_iterator` assume per-replica - batching. - - Returns: - Boolean. - """ - return True - - -class _TPUReplicaContext(distribute_lib.ReplicaContext): - """Replication Context class for TPU Strategy.""" - - # TODO(sourabhbajaj): Call for each replica should be updating this. - def __init__(self, strategy): - # TODO(b/118385803): properly initialize replica_id, instead of always 0 - replica_id = constant_op.constant(0, dtypes.int32) - distribute_lib.ReplicaContext.__init__( - self, strategy, replica_id_in_sync_group=replica_id) - - @property - def devices(self): - distribute_lib.require_replica_context(self) - ds = self._strategy - replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return (ds.extended.worker_devices[replica_id],) +# pylint: disable=unused-import +from tensorflow.python.distribute.tpu_strategy import TPUStrategy +from tensorflow.python.tpu.tpu_strategy_util import initialize_tpu_system diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 51c58b0b2f3dc2ab63e22718825a471b8657f892..101c76ed0af4f178a98204b661f8f0a0f62a52fb 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -511,6 +511,34 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_cpu, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.tpu_strategy, + ], + mode=["graph", "eager"])) + def testAssignOutOfScope_mirrored(self, distribution): + with distribution.scope(): + mirrored = variables_lib.Variable(1.) + if not isinstance(mirrored, values.MirroredVariable): + self.assertIsInstance(mirrored, values.TPUMirroredVariable) + self.evaluate(mirrored.assign(3.)) + self.assertEqual(self.evaluate(mirrored.read_value()), 3.) + for component in mirrored.values: + self.assertEqual(self.evaluate(component.read_value()), 3.) + + @combinations.generate(combinations.combine( + distribution=[combinations.parameter_server_strategy_with_two_gpus], + mode=["graph", "eager"])) + def testAssignOutOfScope_aggregating(self, distribution): + with distribution.scope(): + aggregating = variables_lib.Variable(1.) + self.assertIsInstance(aggregating, values.AggregatingVariable) + self.evaluate(aggregating.assign(3.)) + self.assertEqual(self.evaluate(aggregating.read_value()), 3.) + self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.) _devices = ("/device:GPU:0", "/device:CPU:0") @@ -522,11 +550,11 @@ def _make_replica_local(method, strategy=None): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - replica_local = values.ReplicaLocalVariable(strategy, device_map, v, method) + replica_local = values.SyncOnReadVariable(strategy, device_map, v, method) return v, replica_local -class ReplicaLocalVariablePropertiesTest(test.TestCase): +class SyncOnReadVariablePropertiesTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -549,7 +577,7 @@ class ReplicaLocalVariablePropertiesTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) - replica_local = values.ReplicaLocalVariable( + replica_local = values.SyncOnReadVariable( None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, replica_local.name) @@ -577,7 +605,7 @@ class ReplicaLocalVariablePropertiesTest(test.TestCase): combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=["graph", "eager"])) -class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): +class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): def _assign_replica_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -656,7 +684,8 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, replica_local = _make_replica_local("sum", distribution) + v, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3079175015a9aee1625404902070df8f13b2089c..c2300286d3be4bb757dac588623c47044a1a9db5 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -822,7 +822,7 @@ cuda_py_test( cuda_py_test( name = "affine_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/bijectors/affine_test.py"], additional_deps = [ ":bijectors_py", @@ -837,7 +837,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - shard_count = 5, + shard_count = 10, tags = ["noasan"], # times out b/63678675 ) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py index 7ae98878986eb10570b5e93a4a57d6bad6b38c0c..daab24e4333e1f0524f3016cfe367c20c7d1470a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -158,10 +159,13 @@ def vector_size_to_square_matrix_size(d, validate_args, name=None): return int(n) else: with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: - n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.cast(d, dtypes.float32))) / 2. if validate_args: - with ops.control_dependencies([check_ops.assert_equal( - math_ops.to_float(math_ops.to_int32(n)), n, - message="Vector length is not a triangular number")]): + with ops.control_dependencies([ + check_ops.assert_equal( + math_ops.cast(math_ops.cast(n, dtypes.int32), dtypes.float32), + n, + message="Vector length is not a triangular number") + ]): n = array_ops.identity(n) return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 452628257ea96713453bf2aa32b5baa9d6d0cb86..1006dfac49f36baa7cf5136f6f2982e3fd965298 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -249,9 +249,9 @@ class InverseGamma(distribution.Distribution): `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`.""") def _variance(self): - var = (math_ops.square(self.rate) - / math_ops.square(self.concentration - 1.) - / (self.concentration - 2.)) + var = ( + math_ops.square(self.rate) / math_ops.squared_difference( + self.concentration, 1.) / (self.concentration - 2.)) if self.allow_nan_stats: nan = array_ops.fill( self.batch_shape_tensor(), diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 19e99e03803e7f4cdfdb023feb04daaba68eceed..ad0f2317c99fc482d52d1bb7a3b3f4779d2ca439 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -304,14 +304,14 @@ def percentile(x, x = ops.convert_to_tensor(x, name="x") # Double is needed here and below, else we get the wrong index if the array # is huge along axis. - q = math_ops.to_double(q, name="q") + q = math_ops.cast(q, dtypes.float64, name="q") _get_static_ndims(q, expect_ndims=0) if validate_args: q = control_flow_ops.with_dependencies([ check_ops.assert_rank(q, 0), - check_ops.assert_greater_equal(q, math_ops.to_double(0.)), - check_ops.assert_less_equal(q, math_ops.to_double(100.)) + check_ops.assert_greater_equal(q, math_ops.cast(0., dtypes.float64)), + check_ops.assert_less_equal(q, math_ops.cast(100., dtypes.float64)) ], q) if axis is None: @@ -336,7 +336,7 @@ def percentile(x, y = _move_dims_to_flat_end(x, axis, x_ndims) frac_at_q_or_above = 1. - q / 100. - d = math_ops.to_double(array_ops.shape(y)[-1]) + d = math_ops.cast(array_ops.shape(y)[-1], dtypes.float64) if interpolation == "lower": index = math_ops.ceil((d - 1) * frac_at_q_or_above) @@ -349,7 +349,7 @@ def percentile(x, # let's use max/min to avoid out of bounds errors. d = array_ops.shape(y)[-1] # d - 1 will be distinct from d in int32. - index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1) + index = clip_ops.clip_by_value(math_ops.cast(index, dtypes.int32), 0, d - 1) # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index 15b0820cbdf560e04a304c40a47e541006523b6d..b22ae1eb1543f86319dfcf2d841eb722516b1b34 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import histogram_ops @@ -125,7 +126,7 @@ class DiscreteScalarDistributionTestHelpers(object): atol: Python `float`-type indicating the admissible absolute error between analytical and sample statistics. """ - x = math_ops.to_float(dist.sample(num_samples, seed=seed)) + x = math_ops.cast(dist.sample(num_samples, seed=seed), dtypes.float32) sample_mean = math_ops.reduce_mean(x, axis=0) sample_variance = math_ops.reduce_mean( math_ops.square(x - sample_mean), axis=0) @@ -180,7 +181,7 @@ class DiscreteScalarDistributionTestHelpers(object): lo = value_range[0] hi = value_range[1] if nbins is None: - nbins = math_ops.to_int32(hi - lo) + nbins = math_ops.cast(hi - lo, dtypes.int32) delta = (hi - lo) / math_ops.cast( nbins, dtype=value_range.dtype.base_dtype) edges = math_ops.range( diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 8966a9befcd3db4a3f397b319e80f37f84ad236b..d441e4735b64fe1176e77a978d281d46a7b287ab 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -144,7 +144,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 78ab155896cfeda4dd259a8529f4b1f77a12cf0b..48925b1bfacc6b59c210b2fb4b53a9a1a851673f 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class IteratorTest(test.TestCase): @@ -238,7 +238,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual([1, 4], iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([9, 16], iterator.get_next().numpy()) @@ -257,7 +257,7 @@ class IteratorTest(test.TestCase): dataset_2 = Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = trackable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) self.assertEqual(0, iterator_3.get_next().numpy()) @@ -279,7 +279,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(3) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) @@ -293,7 +293,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for j in range(2): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 97c299a911c9180bf69faa0fa46527e80eada790..3e0881754c750f4d36e2e4dd8b80835b031c658c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,16 +6,16 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ - "//tensorflow/contrib/eager/python/examples/densenet", - "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/densenet:densenet_lib", + "//tensorflow/contrib/eager/python/examples/gan:mnist_lib", "//tensorflow/contrib/eager/python/examples/l2hmc", "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", - "//tensorflow/contrib/eager/python/examples/linear_regression", + "//tensorflow/contrib/eager/python/examples/linear_regression:linear_regression_lib", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/revnet", "//tensorflow/contrib/eager/python/examples/revnet:config", - "//tensorflow/contrib/eager/python/examples/rnn_colorbot", - "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/rnn_colorbot:rnn_colorbot_lib", + "//tensorflow/contrib/eager/python/examples/rnn_ptb:rnn_ptb_lib", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index 56a682ec55a0bb07a16fe4d7d65a62f3c67c4292..fbb5daf230bb79f08a3d071062ddc0e8507ab324 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -9,6 +9,13 @@ py_binary( name = "densenet", srcs = ["densenet.py"], srcs_version = "PY2AND3", + deps = [":densenet_lib"], +) + +py_library( + name = "densenet_lib", + srcs = ["densenet.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -17,13 +24,14 @@ py_binary( cuda_py_test( name = "densenet_test", - size = "large", + size = "medium", srcs = ["densenet_test.py"], additional_deps = [ ":densenet", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "optonly", @@ -33,13 +41,14 @@ cuda_py_test( cuda_py_test( name = "densenet_graph_test", - size = "large", + size = "medium", srcs = ["densenet_graph_test.py"], additional_deps = [ ":densenet", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "noasan", diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index d64c8eb9ce122fa277567b2fbc632abfbc72df64..d99a519112787bad664232983208279cfb4d0036 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -9,6 +9,13 @@ py_binary( name = "mnist", srcs = ["mnist.py"], srcs_version = "PY2AND3", + deps = [":mnist_lib"], +) + +py_library( + name = "mnist_lib", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,7 +27,7 @@ cuda_py_test( name = "mnist_test", srcs = ["mnist_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], @@ -30,7 +37,7 @@ cuda_py_test( name = "mnist_graph_test", srcs = ["mnist_graph_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index 1a08cc0fd06516be4af5c2b0b46a3ffcf9101e95..e1a02db76f705414a34d232022f50124a5a6a3ed 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -13,11 +13,13 @@ "\n", "# Convolutional VAE: An example with tf.keras and eager\n", "\n", + "This example has moved:\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { @@ -28,604 +30,14 @@ }, "source": [ "![evolution of output during training](https://tensorflow.org/images/autoencoders/cvae.gif)\n", - "\n", - "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", "\n" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "P-JuIu2N_SQf" - }, - "outputs": [], - "source": [ - "# to generate gifs\n", - "!pip install imageio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "source": [ - "## Import TensorFlow and enable Eager execution" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "YfIk2es3hJEd" - }, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", - "import tensorflow as tf\n", - "tfe = tf.contrib.eager\n", - "tf.enable_eager_execution()\n", - "\n", - "import os\n", - "import time\n", - "import numpy as np\n", - "import glob\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", - "import imageio\n", - "from IPython import display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "source": [ - "## Load the MNIST dataset\n", - "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "a4fYMGxGhrna" - }, - "outputs": [], - "source": [ - "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NFC2ghIdiZYE" - }, - "outputs": [], - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", - "\n", - "# Normalizing the images to the range of [0., 1.]\n", - "train_images /= 255.\n", - "test_images /= 255.\n", - "\n", - "# Binarization\n", - "train_images[train_images \u003e= .5] = 1.\n", - "train_images[train_images \u003c .5] = 0.\n", - "test_images[test_images \u003e= .5] = 1.\n", - "test_images[test_images \u003c .5] = 0." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "S4PIDhoDLbsZ" - }, - "outputs": [], - "source": [ - "TRAIN_BUF = 60000\n", - "BATCH_SIZE = 100\n", - "\n", - "TEST_BUF = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "source": [ - "## Use *tf.data* to create batches and shuffle the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-yKCCQOoJ7cn" - }, - "outputs": [], - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", - "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "source": [ - "## Wire up the generative and inference network with *tf.keras.Sequential*\n", - "\n", - "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", - "\n", - "### Generative Network\n", - "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", - "\n", - "### Inference Network\n", - "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", - "\n", - "### Reparameterization Trick\n", - "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", - "\n", - "### Network architecture\n", - "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "VGLbvBEmjK0a" - }, - "outputs": [], - "source": [ - "class CVAE(tf.keras.Model):\n", - " def __init__(self, latent_dim):\n", - " super(CVAE, self).__init__()\n", - " self.latent_dim = latent_dim\n", - " self.inference_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", - " tf.keras.layers.Conv2D(\n", - " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2D(\n", - " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Flatten(),\n", - " # No activation\n", - " tf.keras.layers.Dense(latent_dim + latent_dim),\n", - " ]\n", - " )\n", - "\n", - " self.generative_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", - " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", - " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=64,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=32,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " # No activation\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", - " ]\n", - " )\n", - "\n", - " def sample(self, eps=None):\n", - " if eps is None:\n", - " eps = tf.random_normal(shape=(100, self.latent_dim))\n", - " return self.decode(eps, apply_sigmoid=True)\n", - "\n", - " def encode(self, x):\n", - " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", - " return mean, logvar\n", - "\n", - " def reparameterize(self, mean, logvar):\n", - " eps = tf.random_normal(shape=mean.shape)\n", - " return eps * tf.exp(logvar * .5) + mean\n", - "\n", - " def decode(self, z, apply_sigmoid=False):\n", - " logits = self.generative_net(z)\n", - " if apply_sigmoid:\n", - " probs = tf.sigmoid(logits)\n", - " return probs\n", - "\n", - " return logits" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "source": [ - "## Define the loss function and the optimizer\n", - "\n", - "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", - "\n", - "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", - "\n", - "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", - "\n", - "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", - "where $z$ is sampled from $q(z|x)$.\n", - "\n", - "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" - }, - "outputs": [], - "source": [ - "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", - " log2pi = tf.log(2. * np.pi)\n", - " return tf.reduce_sum(\n", - " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", - " axis=raxis)\n", - "\n", - "def compute_loss(model, x):\n", - " mean, logvar = model.encode(x)\n", - " z = model.reparameterize(mean, logvar)\n", - " x_logit = model.decode(z)\n", - "\n", - " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", - " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", - " logpz = log_normal_pdf(z, 0., 0.)\n", - " logqz_x = log_normal_pdf(z, mean, logvar)\n", - " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", - "\n", - "def compute_gradients(model, x):\n", - " with tf.GradientTape() as tape:\n", - " loss = compute_loss(model, x)\n", - " return tape.gradient(loss, model.trainable_variables), loss\n", - "\n", - "optimizer = tf.train.AdamOptimizer(1e-4)\n", - "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", - " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", - "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", - "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", - "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, it is time to generate some images\n", - "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", - "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", - "* Here we plot the probabilities of Bernoulli distributions\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NS2GWywBbAWo" - }, - "outputs": [], - "source": [ - "epochs = 100\n", - "latent_dim = 50\n", - "num_examples_to_generate = 16\n", - "\n", - "# keeping the random vector constant for generation (prediction) so\n", - "# it will be easier to see the improvement.\n", - "random_vector_for_generation = tf.random_normal(\n", - " shape=[num_examples_to_generate, latent_dim])\n", - "model = CVAE(latent_dim)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RmdVsmvhPxyy" - }, - "outputs": [], - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " predictions = model.sample(test_input)\n", - " fig = plt.figure(figsize=(4,4))\n", - "\n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", - " plt.axis('off')\n", - "\n", - " # tight_layout minimizes the overlap between 2 sub-plots\n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "generate_and_save_images(model, 0, random_vector_for_generation)\n", - "\n", - "for epoch in range(1, epochs + 1):\n", - " start_time = time.time()\n", - " for train_x in train_dataset:\n", - " gradients, loss = compute_gradients(model, train_x)\n", - " apply_gradients(optimizer, gradients, model.trainable_variables)\n", - " end_time = time.time()\n", - "\n", - " if epoch % 1 == 0:\n", - " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset:\n", - " loss(compute_loss(model, test_x))\n", - " elbo = -loss.result()\n", - " display.clear_output(wait=False)\n", - " print('Epoch: {}, Test set ELBO: {}, '\n", - " 'time elapse for current epoch {}'.format(epoch,\n", - " elbo,\n", - " end_time - start_time))\n", - " generate_and_save_images(\n", - " model, epoch, random_vector_for_generation)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "source": [ - "### Display an image using the epoch number" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WfO5wCdclHGL" - }, - "outputs": [], - "source": [ - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "5x3q9_Oe5q0A" - }, - "outputs": [], - "source": [ - "display_image(epochs) # Display images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "source": [ - "### Generate a GIF of all the saved images." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "IGKQgENQ8lEI" - }, - "outputs": [], - "source": [ - "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) \u003e round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp cvae.gif cvae.gif.png')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "uV0yiKpzNP1b" - }, - "outputs": [], - "source": [ - "display.Image(filename=\"cvae.gif.png\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "yQXO_dlXkKsT" - }, - "source": [ - "To downlod the animation from Colab uncomment the code below:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4fSJS3m5HLFM" - }, - "outputs": [], - "source": [ - "#from google.colab import files\n", - "#files.download('cvae.gif')" - ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], - "default_view": {}, "name": "cvae.ipynb", "private_outputs": true, "provenance": [ @@ -635,8 +47,7 @@ } ], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 78fcd397087fd1fd64aebed7ac3b5c6b2f45c450..53767058838459e56215d286e9f8f8eb66287147 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -1,26 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "dcgan.ipynb", - "version": "0.3.2", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python2", - "display_name": "Python 2" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0TD5ZrvEMbhZ" }, - "cell_type": "markdown", "source": [ "**Copyright 2018 The TensorFlow Authors**.\n", "\n", @@ -28,851 +13,39 @@ "\n", "# Generating Handwritten Digits with DCGAN\n", "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "ITZuApL56Mny" - }, - "cell_type": "markdown", - "source": [ - "This tutorial demonstrates how to generate images of handwritten digits using a Deep Convolutional Generative Adversarial Network ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)). The code is written in [tf.keras](https://www.tensorflow.org/programmers_guide/keras) with [eager execution](https://www.tensorflow.org/programmers_guide/eager) enabled. " - ] - }, - { - "metadata": { - "colab_type": "toc", - "id": "x2McrO9bMyLN" - }, - "cell_type": "markdown", - "source": [ - ">[Generating Handwritten Digits with DCGAN](#scrollTo=0TD5ZrvEMbhZ)\n", - "\n", - ">>[What are GANs?](#scrollTo=2MbKJY38Puy9)\n", - "\n", - ">>>[Import TensorFlow and enable eager execution](#scrollTo=e1_Y75QXJS6h)\n", - "\n", - ">>>[Load the dataset](#scrollTo=iYn4MdZnKCey)\n", - "\n", - ">>>[Use tf.data to create batches and shuffle the dataset](#scrollTo=PIGN6ouoQxt3)\n", - "\n", - ">>[Create the models](#scrollTo=THY-sZMiQ4UV)\n", - "\n", - ">>>[The Generator Model](#scrollTo=-tEyxE-GMC48)\n", - "\n", - ">>>[The Discriminator model](#scrollTo=D0IKnaCtg6WE)\n", - "\n", - ">>[Define the loss functions and the optimizer](#scrollTo=0FMYgY_mPfTi)\n", - "\n", - ">>>[Generator loss](#scrollTo=Jd-3GCUEiKtv)\n", - "\n", - ">>>[Discriminator loss](#scrollTo=PKY_iPSPNWoj)\n", - "\n", - ">>[Set up GANs for Training](#scrollTo=Rw1fkAczTQYh)\n", - "\n", - ">>[Train the GANs](#scrollTo=dZrd4CdjR-Fp)\n", - "\n", - ">>[Generated images](#scrollTo=P4M_vIbUi7c0)\n", + "This example has moved.\n", "\n", - ">>[Learn more about GANs](#scrollTo=k6qC-SbjK0yW)\n", - "\n" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2MbKJY38Puy9" }, - "cell_type": "markdown", "source": [ - "## What are GANs?\n", - "GANs, or [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661), are a framework for estimating generative models. Two models are trained simultaneously by an adversarial process: a Generator, which is responsible for generating data (say, images), and a Discriminator, which is responsible for estimating the probability that an image was drawn from the training data (the image is real), or was produced by the Generator (the image is fake). During training, the Generator becomes progressively better at generating images, until the Discriminator is no longer able to distinguish real images from fake. \n", - "\n", - "![alt text](https://github.com/margaretmz/tensorflow/blob/margaret-dcgan/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png?raw=1)\n", - "\n", - "We will demonstrate this process end-to-end on MNIST. Below is an animation that shows a series of images produced by the Generator as it was trained for 50 epochs. Overtime, the generated images become increasingly difficult to distinguish from the training set.\n", - "\n", - "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course, which includes a lecture on Deep Generative Models ([video](https://youtu.be/JVb54xhEw6Y) | [slides](http://introtodeeplearning.com/materials/2018_6S191_Lecture4.pdf)). Now, let's head to the code!\n", - "\n", "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "dcgan.ipynb", + "provenance": [], + "version": "0.3.2" }, - { - "metadata": { - "colab_type": "code", - "id": "u_2z-B3piVsw", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Install imgeio in order to generate an animated gif showing the image generating process\n", - "!pip install imageio" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "cell_type": "markdown", - "source": [ - "### Import TensorFlow and enable eager execution" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "YfIk2es3hJEd", - "colab": {} - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "import glob\n", - "import imageio\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "import PIL\n", - "import time\n", - "\n", - "from IPython import display" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "cell_type": "markdown", - "source": [ - "### Load the dataset\n", - "\n", - "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "a4fYMGxGhrna", - "colab": {} - }, - "cell_type": "code", - "source": [ - "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "NFC2ghIdiZYE", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "S4PIDhoDLbsZ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "BUFFER_SIZE = 60000\n", - "BATCH_SIZE = 256" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "cell_type": "markdown", - "source": [ - "### Use tf.data to create batches and shuffle the dataset" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "-yKCCQOoJ7cn", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "cell_type": "markdown", - "source": [ - "## Create the models\n", - "\n", - "We will use tf.keras [Sequential API](https://www.tensorflow.org/guide/keras#sequential_model) to define the generator and discriminator models." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "-tEyxE-GMC48" - }, - "cell_type": "markdown", - "source": [ - "### The Generator Model\n", - "\n", - "The generator is responsible for creating convincing images that are good enough to fool the discriminator. The network architecture for the generator consists of [Conv2DTranspose](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose) (Upsampling) layers. We start with a fully connected layer and upsample the image two times in order to reach the desired image size of 28x28x1. We increase the width and height, and reduce the depth as we move through the layers in the network. We use [Leaky ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LeakyReLU) activation for each layer except for the last one where we use a tanh activation." - ] - }, - { - "metadata": { - "id": "6bpTcDqoLWjY", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_generator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " \n", - " model.add(tf.keras.layers.Reshape((7, 7, 256)))\n", - " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", - " \n", - " model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 7, 7, 128) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 14, 14, 64) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", - " assert model.output_shape == (None, 28, 28, 1)\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "D0IKnaCtg6WE" - }, - "cell_type": "markdown", - "source": [ - "### The Discriminator model\n", - "\n", - "The discriminator is responsible for distinguishing fake images from real images. It's similar to a regular CNN-based image classifier." - ] - }, - { - "metadata": { - "id": "dw2tPLmk2pEP", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_discriminator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Flatten())\n", - " model.add(tf.keras.layers.Dense(1))\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "gDkA05NE6QMs", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator = make_generator_model()\n", - "discriminator = make_discriminator_model()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "cell_type": "markdown", - "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "Let's define the loss functions and the optimizers for the generator and the discriminator.\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Jd-3GCUEiKtv" - }, - "cell_type": "markdown", - "source": [ - "### Generator loss\n", - "The generator loss is a sigmoid cross entropy loss of the generated images and an array of ones, since the generator is trying to generate fake images that resemble the real images." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "90BIcCKcDMxz", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generator_loss(generated_output):\n", - " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PKY_iPSPNWoj" - }, - "cell_type": "markdown", - "source": [ - "### Discriminator loss\n", - "\n", - "The discriminator loss function takes two inputs: real images, and generated images. Here is how to calculate the discriminator loss:\n", - "1. Calculate real_loss which is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images).\n", - "2. Calculate generated_loss which is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images).\n", - "3. Calculate the total_loss as the sum of real_loss and generated_loss." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "wkMNfBWlT-PV", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def discriminator_loss(real_output, generated_output):\n", - " # [1,1,...,1] with real output since it is true and we want our generated examples to look like it\n", - " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", - "\n", - " # [0,0,...,0] with generated images since they are fake\n", - " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", - "\n", - " total_loss = real_loss + generated_loss\n", - "\n", - " return total_loss" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "MgIc7i0th_Iu" - }, - "cell_type": "markdown", - "source": [ - "The discriminator and the generator optimizers are different since we will train two networks separately." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "iWCn_PVdEJZ7", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator_optimizer = tf.train.AdamOptimizer(1e-4)\n", - "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "mWtinsGDPJlV" - }, - "cell_type": "markdown", - "source": [ - "**Checkpoints (Object-based saving)**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "CA1w-7s2POEy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", - " discriminator_optimizer=discriminator_optimizer,\n", - " generator=generator,\n", - " discriminator=discriminator)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "cell_type": "markdown", - "source": [ - "## Set up GANs for Training\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "5QC5BABamh_c" - }, - "cell_type": "markdown", - "source": [ - "Now it's time to put together the generator and discriminator to set up the Generative Adversarial Networks, as you see in the diagam at the beginning of the tutorial." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Ff6oN6PZX27n" - }, - "cell_type": "markdown", - "source": [ - "**Define training parameters**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "NS2GWywBbAWo", - "colab": {} - }, - "cell_type": "code", - "source": [ - "EPOCHS = 50\n", - "noise_dim = 100\n", - "num_examples_to_generate = 16\n", - "\n", - "# We'll re-use this random vector used to seed the generator so\n", - "# it will be easier to see the improvement over time.\n", - "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", - " noise_dim])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "jylSonrqSWfi" - }, - "cell_type": "markdown", - "source": [ - "**Define training method**\n", - "\n", - "We start by iterating over the dataset. The generator is given a random vector as an input which is processed to output an image looking like a handwritten digit. The discriminator is then shown the real MNIST images as well as the generated images.\n", - "\n", - "Next, we calculate the generator and the discriminator loss. Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables." - ] - }, - { - "metadata": { - "id": "3t5ibNo05jCB", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train_step(images):\n", - " # generating noise from a normal distribution\n", - " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", - " \n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " generated_images = generator(noise, training=True)\n", - " \n", - " real_output = discriminator(images, training=True)\n", - " generated_output = discriminator(generated_images, training=True)\n", - " \n", - " gen_loss = generator_loss(generated_output)\n", - " disc_loss = discriminator_loss(real_output, generated_output)\n", - " \n", - " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", - " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", - " \n", - " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6TSZgwc2BUQ-" - }, - "cell_type": "markdown", - "source": [ - "\n", - "This model takes about ~30 seconds per epoch to train on a single Tesla K80 on Colab, as of October 2018. \n", - "\n", - "Eager execution can be slower than executing the equivalent graph as it can't benefit from whole-program optimizations on the graph, and also incurs overheads of interpreting Python code. By using [tf.contrib.eager.defun](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/defun) to create graph functions, we get a ~20 secs/epoch performance boost (from ~50 secs/epoch down to ~30 secs/epoch). This way we get the best of both eager execution (easier for debugging) and graph mode (better performance)." - ] - }, - { - "metadata": { - "id": "Iwya07_j5p2A", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_step = tf.contrib.eager.defun(train_step)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "2M7LmLtGEMQJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train(dataset, epochs): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - " \n", - " for images in dataset:\n", - " train_step(images)\n", - "\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epoch + 1,\n", - " random_vector_for_generation)\n", - " \n", - " # saving (checkpoint) the model every 15 epochs\n", - " if (epoch + 1) % 15 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - " \n", - " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", - " time.time()-start))\n", - " # generating after the final epoch\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epochs,\n", - " random_vector_for_generation)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "2aFF7Hk3XdeW" - }, - "cell_type": "markdown", - "source": [ - "**Generate and save images**\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "RmdVsmvhPxyy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " # make sure the training parameter is set to False because we\n", - " # don't want to train the batchnorm layer when doing inference.\n", - " predictions = model(test_input, training=False)\n", - "\n", - " fig = plt.figure(figsize=(4,4))\n", - " \n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", - " plt.axis('off')\n", - " \n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "dZrd4CdjR-Fp" - }, - "cell_type": "markdown", - "source": [ - "## Train the GANs\n", - "We will call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", - "\n", - "At the beginning of the training, the generated images look like random noise. As training progresses, you can see the generated digits look increasingly real. After 50 epochs, they look very much like the MNIST digits." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "Ly3UN0SLLY2l", - "colab": {} - }, - "cell_type": "code", - "source": [ - "%%time\n", - "train(train_dataset, EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "rfM4YcPVPkNO" - }, - "cell_type": "markdown", - "source": [ - "**Restore the latest checkpoint**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "XhXsd0srPo8c", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "cell_type": "markdown", - "source": [ - "## Generated images \n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "mLskt7EfXAjr" - }, - "cell_type": "markdown", - "source": [ - "\n", - "After training, its time to generate some images! \n", - "The last step is to plot the generated images and voila!\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "WfO5wCdclHGL", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Display a single image using the epoch number\n", - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "5x3q9_Oe5q0A", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display_image(EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "cell_type": "markdown", - "source": [ - "**Generate a GIF of all the saved images**\n", - "\n", - "We will use imageio to create an animated gif using all the images saved during training." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "IGKQgENQ8lEI", - "colab": {} - }, - "cell_type": "code", - "source": [ - "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) > round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp dcgan.gif dcgan.gif.png')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "cGhC3-fMWSwl" - }, - "cell_type": "markdown", - "source": [ - "Display the animated gif with all the mages generated during the training of GANs." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "uV0yiKpzNP1b", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display.Image(filename=\"dcgan.gif.png\")" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6EEG-wePkmJQ" - }, - "cell_type": "markdown", - "source": [ - "**Download the animated gif**\n", - "\n", - "Uncomment the code below to download an animated gif from Colab." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "4UJjSnIMOzOJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "#from google.colab import files\n", - "#files.download('dcgan.gif')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "k6qC-SbjK0yW" - }, - "cell_type": "markdown", - "source": [ - "## Learn more about GANs\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "xjjkT9KAK6H7" - }, - "cell_type": "markdown", - "source": [ - "We hope this tutorial was helpful! As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home).\n", - "\n", - "To learn more about GANs:\n", - "\n", - "* Check out MIT's lecture (linked above), or [this](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture12.pdf) lecture form Stanford's CS231n. \n", - "\n", - "* We also recommend the [CVPR 2018 Tutorial on GANs](https://sites.google.com/view/cvpr2018tutorialongans/), and the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" - ] + "kernelspec": { + "display_name": "Python 2", + "name": "python2" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png deleted file mode 100644 index b715bd83ef117641c6429e0ac173dbe9b8d5fd88..0000000000000000000000000000000000000000 Binary files a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png and /dev/null differ diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 12c5eff2b4aa901bdab52bf545e95b1e4dce7468..979772acd3f823a8cc53ab5e026946ad3bb19353 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1174 +1,71 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "K2s1A9eLRPEj" - }, - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Cffg2i257iMS" - }, - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "QASbY_HGo4Lq" - }, - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "U8l4RJ0XRPEm" - }, - "outputs": [], - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "b6qbGw8MRPE5" - }, - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "krQuPYTtRPE7" - }, - "outputs": [], - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aANEzb5WwSzg" - }, - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4G3b8x8_RPFD" - }, - "outputs": [], - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "mPBMgK34RPFL" - }, - "outputs": [], - "source": [ - "len(train_captions), len(all_captions)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "8cSW4u-ORPFQ" - }, - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "zXR0217aRPFR" - }, - "outputs": [], - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MDvIu4sXRPFV" - }, - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RD3vW4SsRPFW" - }, - "outputs": [], - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "rERqlR3WRPGO" - }, - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Dx_fvbVgRPGQ" - }, - "outputs": [], - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "nyqH3zFwRPFi" - }, - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "HZfK8RhQRPFj" - }, - "outputs": [], - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "oJGE34aiRPFo" - }, - "outputs": [], - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "8Q44tNQVRPFt" - }, - "outputs": [], - "source": [ - "tokenizer.word_index[''] = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "0fpJb5ojRPFv" - }, - "outputs": [], - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AidglIZVRPF4" - }, - "outputs": [], - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "gL0wkttkRPGA" - }, - "outputs": [], - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "M3CD75nDpvTI" - }, - "source": [ - "## Split the data into training and testing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iS7DDMszRPGF" - }, - "outputs": [], - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "XmViPkRFRPGH" - }, - "outputs": [], - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "uEWM9xrYcg45" - }, - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Q3TnZ1ToRPGV" - }, - "outputs": [], - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "SmZS2N0bXG3T" - }, - "outputs": [], - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "FDF_Nm3tRPGZ" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "nrvoDphgRPGd" - }, - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AAppCGLKRPGd" - }, - "outputs": [], - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "ja2LFTMSdeV3" - }, - "outputs": [], - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AZ7R1RxHRPGf" - }, - "outputs": [], - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "V9UbGQmERPGi" - }, - "outputs": [], - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Qs_Sr03wRPGk" - }, - "outputs": [], - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-bYN7xA0RPGl" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PHod7t72RPGn" - }, - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Vt4WZ5mhJE-E" - }, - "outputs": [], - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "UlA4VIQpRPGo" - }, - "outputs": [], - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "1Wm83G-ZBPcC" - }, - "outputs": [], - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "xGvOcLQKghXN" - }, - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RCWpDtyNRPGs" - }, - "outputs": [], - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(tokenizer.index_word[predicted_id])\n", - "\n", - " if tokenizer.index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "fD_y7PD6RPGt" - }, - "outputs": [], - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] }, - "colab_type": "code", - "id": "io7ws3ReRPGv" - }, - "outputs": [], - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rprk3HEvZuxb" - }, - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "This example has moved:\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] }, - "colab_type": "code", - "id": "9Psd1quzaAWg" - }, - "outputs": [], - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "VJZXyJco6uLO" - }, - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "image_captioning_with_attention.ipynb", - "private_outputs": true, - "provenance": [ { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" } - ], - "toc_visible": true, - "version": "0.3.2", - "views": {} - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index bda9e77085e45ae31a228142135425e22a1c6780..c945c753b3ba36d16aa6985d23a5849f8f552304 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -13,633 +13,13 @@ "\n", "# Text Generation using a RNN\n", "\n", + "This example has moved.\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BwpJ5IffzRG6" - }, - "source": [ - "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n", - " \n", - "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n", - "\n", - "```\n", - "were to the death of him\n", - "And nothing of the field in the view of hell,\n", - "When I said, banish him, I will not burn thee that would live.\n", - "\n", - "HENRY BOLINGBROKE:\n", - "My gracious uncle--\n", - "\n", - "DUKE OF YORK:\n", - "As much disgraced to the court, the gods them speak,\n", - "And now in peace himself excuse thee in the world.\n", - "\n", - "HORTENSIO:\n", - "Madam, 'tis not the cause of the counterfeit of the earth,\n", - "And leave me to the sun that set them on the earth\n", - "And leave the world and are revenged for thee.\n", - "\n", - "GLOUCESTER:\n", - "I would they were talking with the very name of means\n", - "To make a puppet of a guest, and therefore, good Grumio,\n", - "Nor arm'd to prison, o' the clouds, of the whole field,\n", - "With the admire\n", - "With the feeding of thy chair, and we have heard it so,\n", - "I thank you, sir, he is a visor friendship with your silly your bed.\n", - "\n", - "SAMPSON:\n", - "I do desire to live, I pray: some stand of the minds, make thee remedies\n", - "With the enemies of my soul.\n", - "\n", - "MENENIUS:\n", - "I'll keep the cause of my mistress.\n", - "\n", - "POLIXENES:\n", - "My brother Marcius!\n", - "\n", - "Second Servant:\n", - "Will't ple\n", - "```\n", - "\n", - "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n", - "\n", - "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n", - "\n", - "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n", - "\n", - "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "R3p22DBDsaCA" - }, - "source": [ - "## Install unidecode library\n", - "A helpful library to convert unicode to ASCII." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wZ6LOM12wKGH" - }, - "outputs": [], - "source": [ - "!pip install unidecode" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WGyKZj3bzf9p" - }, - "source": [ - "## Import tensorflow and enable eager execution." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "yG_n40gFzf9s" - }, - "outputs": [], - "source": [ - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", - "import tensorflow as tf\n", - "\n", - "# Note: Once you enable eager execution, it cannot be disabled. \n", - "tf.enable_eager_execution()\n", - "\n", - "import numpy as np\n", - "import os\n", - "import re\n", - "import random\n", - "import unidecode\n", - "import time" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "EHDoRoc5PKWz" - }, - "source": [ - "## Download the dataset\n", - "\n", - "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "pD_55cOxLkAb" - }, - "outputs": [], - "source": [ - "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "UHjdCjDuSvX_" - }, - "source": [ - "## Read the dataset\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "-E5JvY3wzf94" - }, - "outputs": [], - "source": [ - "text = unidecode.unidecode(open(path_to_file).read())\n", - "# length of text is the number of characters in it\n", - "print (len(text))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Il9ww98izf-D" - }, - "source": [ - "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "IalZLbvOzf-F" - }, - "outputs": [], - "source": [ - "# unique contains all the unique characters in the file\n", - "unique = sorted(set(text))\n", - "\n", - "# creating a mapping from unique characters to indices\n", - "char2idx = {u:i for i, u in enumerate(unique)}\n", - "idx2char = {i:u for i, u in enumerate(unique)}" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "1v_qUYfAzf-I" - }, - "outputs": [], - "source": [ - "# setting the maximum length sentence we want for a single input in characters\n", - "max_length = 100\n", - "\n", - "# length of the vocabulary in chars\n", - "vocab_size = len(unique)\n", - "\n", - "# the embedding dimension \n", - "embedding_dim = 256\n", - "\n", - "# number of RNN (here GRU) units\n", - "units = 1024\n", - "\n", - "# batch size \n", - "BATCH_SIZE = 64\n", - "\n", - "# buffer size to shuffle our dataset\n", - "BUFFER_SIZE = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "LFjSVAlWzf-N" - }, - "source": [ - "## Creating the input and output tensors\n", - "\n", - "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n", - "\n", - "But first, we need to create the input and output vectors.\n", - "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n", - "\n", - "For example, consider that the string = 'tensorflow' and the max_length is 9\n", - "\n", - "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n", - "\n", - "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "0UHJDA39zf-O" - }, - "outputs": [], - "source": [ - "input_text = []\n", - "target_text = []\n", - "\n", - "for f in range(0, len(text)-max_length, max_length):\n", - " inps = text[f:f+max_length]\n", - " targ = text[f+1:f+1+max_length]\n", - "\n", - " input_text.append([char2idx[i] for i in inps])\n", - " target_text.append([char2idx[t] for t in targ])\n", - " \n", - "print (np.array(input_text).shape)\n", - "print (np.array(target_text).shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MJdfPmdqzf-R" - }, - "source": [ - "## Creating batches and shuffling them using tf.data" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "p2pGotuNzf-S" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "m8gPwEjRzf-Z" - }, - "source": [ - "## Creating the model\n", - "\n", - "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n", - "\n", - "* Embedding layer\n", - "* GRU layer (you can use an LSTM layer here)\n", - "* Fully connected layer" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "P3KTiiInzf-a" - }, - "outputs": [], - "source": [ - "class Model(tf.keras.Model):\n", - " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", - " super(Model, self).__init__()\n", - " self.units = units\n", - " self.batch_sz = batch_size\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - "\n", - " if tf.test.is_gpu_available():\n", - " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " self.gru = tf.keras.layers.GRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')\n", - "\n", - " self.fc = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " def call(self, x, hidden):\n", - " x = self.embedding(x)\n", - "\n", - " # output shape == (batch_size, max_length, hidden_size) \n", - " # states shape == (batch_size, hidden_size)\n", - "\n", - " # states variable to preserve the state of the model\n", - " # this will be used to pass at every step to the model while training\n", - " output, states = self.gru(x, initial_state=hidden)\n", - "\n", - "\n", - " # reshaping the output so that we can pass it to the Dense layer\n", - " # after reshaping the shape is (batch_size * max_length, hidden_size)\n", - " output = tf.reshape(output, (-1, output.shape[2]))\n", - "\n", - " # The dense layer will output predictions for every time_steps(max_length)\n", - " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n", - " x = self.fc(output)\n", - "\n", - " return x, states" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "trpqTWyvk0nr" - }, - "source": [ - "## Call the model and set the optimizer and the loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "7t2XrzEOzf-e" - }, - "outputs": [], - "source": [ - "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "dkjWIATszf-h" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", - "def loss_function(real, preds):\n", - " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3K6s6F79P7za" - }, - "source": [ - "## Checkpoints (Object-based saving)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oAGisDdfP9rL" - }, - "outputs": [], - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", - " model=model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lPrP0XMUzf-p" - }, - "source": [ - "## Train the model\n", - "\n", - "Here we will use a custom training loop with the help of GradientTape()\n", - "\n", - "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n", - "\n", - "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n", - "\n", - "* There are a lot of interesting things happening here.\n", - " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n", - " * The model then returns the predictions **P1** and **H1**.\n", - " * For the next batch of input, the model receives **I1** and **H1**.\n", - " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n", - " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n", - "\n", - "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n", - "\n", - "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n", - "\n", - "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "d4tSNwymzf-q" - }, - "outputs": [], - "source": [ - "# Training step\n", - "\n", - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " \n", - " # initializing the hidden state at the start of every epoch\n", - " hidden = model.reset_states()\n", - " \n", - " for (batch, (inp, target)) in enumerate(dataset):\n", - " with tf.GradientTape() as tape:\n", - " # feeding the hidden state back into the model\n", - " # This is the interesting step\n", - " predictions, hidden = model(inp, hidden)\n", - " \n", - " # reshaping the target because that's how the \n", - " # loss function expects it\n", - " target = tf.reshape(target, (-1,))\n", - " loss = loss_function(target, predictions)\n", - " \n", - " grads = tape.gradient(loss, model.variables)\n", - " optimizer.apply_gradients(zip(grads, model.variables))\n", - "\n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", - " batch,\n", - " loss))\n", - " # saving (checkpoint) the model every 5 epochs\n", - " if (epoch + 1) % 5 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - "\n", - " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", - " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "01AR9vpNQMFF" - }, - "source": [ - "## Restore the latest checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tyvpYomYQQkF" - }, - "outputs": [], - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "DjGz1tDkzf-u" - }, - "source": [ - "## Predicting using our trained model\n", - "\n", - "The below code block is used to generated the text\n", - "\n", - "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n", - "\n", - "* We get predictions using the start_string and the hidden state\n", - "\n", - "* Then we use argmax to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", - "\n", - "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", - "\n", - "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WvuwZBX5Ogfd" - }, - "outputs": [], - "source": [ - "# Evaluation step(generating text using the model learned)\n", - "\n", - "# number of characters to generate\n", - "num_generate = 1000\n", - "\n", - "# You can change the start string to experiment\n", - "start_string = 'Q'\n", - "# converting our start string to numbers(vectorizing!) \n", - "input_eval = [char2idx[s] for s in start_string]\n", - "input_eval = tf.expand_dims(input_eval, 0)\n", - "\n", - "# empty string to store our results\n", - "text_generated = ''\n", - "\n", - "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", - "hidden = [tf.zeros((1, units))]\n", - "for i in range(num_generate):\n", - " predictions, hidden = model(input_eval, hidden)\n", - "\n", - " # using argmax to predict the word returned by the model\n", - " predicted_id = tf.argmax(predictions[-1]).numpy()\n", - " \n", - " # We pass the predicted word as the next input to the model\n", - " # along with the previous hidden state\n", - " input_eval = tf.expand_dims([predicted_id], 0)\n", - " \n", - " text_generated += idx2char[predicted_id]\n", - "\n", - "print (start_string + text_generated)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "AM2Uma_-yVIq" - }, - "source": [ - "## Next steps\n", - "\n", - "* Change the start string to a different character, or the start of a sentence.\n", - "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", - "* Add another RNN layer.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gtEd86sX5cB2" - }, - "outputs": [], - "source": [ - "" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD index 78548c51c90912ee5094aa88904d818916bcd688..35d509904211d98f124d2555fc48166e75cb0dd9 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -28,7 +28,7 @@ py_library( cuda_py_test( name = "l2hmc_test", - size = "large", + size = "medium", srcs = ["l2hmc_test.py"], additional_deps = [ ":l2hmc", @@ -36,6 +36,7 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//third_party/py/numpy", ], + shard_count = 4, tags = [ "oss_serial", ], diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 6a178ddcecc6b271fa3f074cb6768165b1d1177d..30afef83bc5c6c164c8456ed472f4d6064068a25 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -9,6 +9,13 @@ py_binary( name = "linear_regression", srcs = ["linear_regression.py"], srcs_version = "PY2AND3", + deps = [":linear_regression_lib"], +) + +py_library( + name = "linear_regression_lib", + srcs = ["linear_regression.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,7 +27,7 @@ cuda_py_test( size = "small", srcs = ["linear_regression_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], tags = [ @@ -34,7 +41,7 @@ cuda_py_test( size = "small", srcs = ["linear_regression_graph_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 099b712fc06d1d3eb9ab4095f8db7283690bda76..206ef9409df7b1dc21de42ba919d2ba97f334a8c 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -56,7 +56,7 @@ class LinearModel(tf.keras.Model): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) + return tf.reduce_mean(tf.squared_difference(model(xs), ys)) def fit(model, dataset, optimizer, verbose=False, logdir=None): diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 66d52a74943d0d81fde05ce51b019558b327978d..512605a17eb77a85a5ec98197f4ed8fda6863932 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -1,11 +1,28 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nmt_with_attention.ipynb", + "version": "0.3.2", + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "accelerator": "GPU" + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AOpGoE2T-YXS" }, + "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -13,19 +30,19 @@ "\n", "# Neural Machine Translation with Attention\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CiwtNgENbx2g" }, + "cell_type": "markdown", "source": [ "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", "\n", @@ -33,24 +50,22 @@ "\n", "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", "\n", - "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n", + "\"spanish-english\n", "\n", "Note: This example takes approximately 10 mintues to run on a single P100 GPU." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "tnxXKDjq3jEL" + "id": "tnxXKDjq3jEL", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "# Import TensorFlow >= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()\n", @@ -65,14 +80,16 @@ "import time\n", "\n", "print(tf.__version__)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wfodePkj3jEa" }, + "cell_type": "markdown", "source": [ "## Download and prepare the dataset\n", "\n", @@ -91,14 +108,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "kRVATYOgJs1b" + "id": "kRVATYOgJs1b", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Download the file\n", "path_to_zip = tf.keras.utils.get_file(\n", @@ -106,17 +121,17 @@ " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "rd0jw-eC3jEh" + "id": "rd0jw-eC3jEh", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Converts the unicode file to ascii\n", "def unicode_to_ascii(s):\n", @@ -128,7 +143,7 @@ " w = unicode_to_ascii(w.lower().strip())\n", " \n", " # creating a space between a word and the punctuation following it\n", - " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n", + " # eg: \"he is a boy.\" => \"he is a boy .\" \n", " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", @@ -140,19 +155,19 @@ " \n", " # adding a start and an end token to the sentence\n", " # so that the model know when to start and stop predicting.\n", - " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n", + " w = ' ' + w + ' '\n", " return w" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "OHn4Dct23jEm" + "id": "OHn4Dct23jEm", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# 1. Remove the accents\n", "# 2. Clean the sentences\n", @@ -163,20 +178,20 @@ " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", " \n", " return word_pairs" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "9xbqO7Iie9bb" + "id": "9xbqO7Iie9bb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n", - "# (e.g., 5 -\u003e \"dad\") for each language,\n", + "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", + "# (e.g., 5 -> \"dad\") for each language,\n", "class LanguageIndex():\n", " def __init__(self, lang):\n", " self.lang = lang\n", @@ -192,23 +207,23 @@ " \n", " self.vocab = sorted(self.vocab)\n", " \n", - " self.word2idx['\u003cpad\u003e'] = 0\n", + " self.word2idx[''] = 0\n", " for index, word in enumerate(self.vocab):\n", " self.word2idx[word] = index + 1\n", " \n", " for word, index in self.word2idx.items():\n", " self.idx2word[index] = word" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "eAY9k49G3jE_" + "id": "eAY9k49G3jE_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)\n", @@ -244,71 +259,71 @@ " padding='post')\n", " \n", " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GOi42V79Ydlr" }, + "cell_type": "markdown", "source": [ "### Limit the size of the dataset to experiment faster (optional)\n", "\n", - "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "cnxC7q-j3jFD" + "id": "cnxC7q-j3jFD", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Try experimenting with the size of that dataset\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "4QILQkOs3jFG" + "id": "4QILQkOs3jFG", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Creating training and validation sets using an 80-20 split\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# Show length\n", "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rgCLkfv5uO3d" }, + "cell_type": "markdown", "source": [ "### Create a tf.data dataset" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "TqHsArVZ3jFS" + "id": "TqHsArVZ3jFS", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", @@ -320,27 +335,29 @@ "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TNfHIF71ulLu" }, + "cell_type": "markdown", "source": [ "## Write the encoder and decoder model\n", "\n", - "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) from the seq2seq tutorial. The following diagram shows that each input word is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n", + "\"attention\n", "\n", "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", "\n", "Here are the equations that are implemented:\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n", + "\"attention\n", + "\"attention\n", "\n", "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", "\n", @@ -362,14 +379,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "avyJ_4VIUoHb" + "id": "avyJ_4VIUoHb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def gru(units):\n", " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", @@ -385,17 +400,17 @@ " return_state=True, \n", " recurrent_activation='sigmoid', \n", " recurrent_initializer='glorot_uniform')" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "nZ2rI24i3jFg" + "id": "nZ2rI24i3jFg", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", @@ -412,17 +427,17 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "yJ_B3mhW3jFk" + "id": "yJ_B3mhW3jFk", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", @@ -476,41 +491,41 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.dec_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "P5UY8wko3jFp" + "id": "P5UY8wko3jFp", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_ch_71VbIRfK" }, + "cell_type": "markdown", "source": [ "## Define the optimizer and the loss function" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WmTHr5iV3jFr" + "id": "WmTHr5iV3jFr", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "optimizer = tf.train.AdamOptimizer()\n", "\n", @@ -519,41 +534,43 @@ " mask = 1 - np.equal(real, 0)\n", " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", " return tf.reduce_mean(loss_)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DMVWzzsfNl4e" }, + "cell_type": "markdown", "source": [ "## Checkpoints (Object-based saving)" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "Zj8bXQTgNwrF" + "id": "Zj8bXQTgNwrF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", " encoder=encoder,\n", " decoder=decoder)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hpObfY22IddU" }, + "cell_type": "markdown", "source": [ "## Training\n", "\n", @@ -567,14 +584,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ddefjBMa3jF0" + "id": "ddefjBMa3jF0", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "EPOCHS = 10\n", "\n", @@ -592,7 +607,7 @@ " \n", " dec_hidden = enc_hidden\n", " \n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", " \n", " # Teacher forcing - feeding the target as the next input\n", " for t in range(1, targ.shape[1]):\n", @@ -625,14 +640,16 @@ " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mU3Ce8M6I3rz" }, + "cell_type": "markdown", "source": [ "## Translate\n", "\n", @@ -644,14 +661,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "EbQpyYs13jF_" + "id": "EbQpyYs13jF_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", @@ -668,12 +683,12 @@ " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", " \n", - " # storing the attention weigths to plot later on\n", + " # storing the attention weights to plot later on\n", " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", @@ -681,24 +696,24 @@ "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", - " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n", + " if targ_lang.idx2word[predicted_id] == '':\n", " return result, sentence, attention_plot\n", " \n", " # the predicted ID is fed back into the model\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "s5hQWlbN3jGF" + "id": "s5hQWlbN3jGF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# function for plotting the attention weights\n", "def plot_attention(attention, sentence, predicted_sentence):\n", @@ -712,17 +727,17 @@ " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " plt.show()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "sl9zUHzg3jGI" + "id": "sl9zUHzg3jGI", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", @@ -732,91 +747,93 @@ " \n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "n250XbnjOaqP" }, + "cell_type": "markdown", "source": [ "## Restore the latest checkpoint and test" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "UJpT9D5_OgP6" + "id": "UJpT9D5_OgP6", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WrAM0FDomq3E" + "id": "WrAM0FDomq3E", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "zSx2iM36EZQZ" + "id": "zSx2iM36EZQZ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "A3LLCx3ZE0Ls" + "id": "A3LLCx3ZE0Ls", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "DUQVLVqUE1YW" + "id": "DUQVLVqUE1YW", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# wrong translation\n", "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RTe5P5ioMJwN" }, + "cell_type": "markdown", "source": [ "## Next steps\n", "\n", @@ -824,31 +841,5 @@ "* Experiment with training on a larger dataset, or using more epochs\n" ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "nmt_with_attention.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", - "timestamp": 1527858391290 - }, - { - "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", - "timestamp": 1527776041613 - } - ], - "toc_visible": true, - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + ] } diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb index d60ee18586196614c9c0f73fc88dfb8b758725ea..57bd18d7529d28e0914d6c32b93881336002b9a5 100644 --- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb +++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb @@ -11,777 +11,17 @@ "\n", "Licensed under the Apache License, Version 2.0 (the \"License\").\n", "\n", - "# Pix2Pix: An example with tf.keras and eager\n", - "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ITZuApL56Mny" - }, - "source": [ - "This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n", - "\n", - "In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n", - "\n", - "Each epoch takes around 58 seconds on a single P100 GPU.\n", - "\n", - "Below is the output generated after training the model for 200 epochs.\n", - "\n", - "\n", - "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n", - "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "source": [ - "## Import TensorFlow and enable eager execution" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "YfIk2es3hJEd" - }, - "outputs": [], - "source": [ - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "import os\n", - "import time\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", - "from IPython.display import clear_output" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "source": [ - "## Load the dataset\n", - "\n", - "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n", - "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n", - "* In random mirroring, the image is randomly flipped horizontally i.e left to right." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Kn-k8kTXuAlv" - }, - "outputs": [], - "source": [ - "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n", - " cache_subdir=os.path.abspath('.'),\n", - " origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n", - " extract=True)\n", - "\n", - "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "2CbTEt448b4R" - }, - "outputs": [], - "source": [ - "BUFFER_SIZE = 400\n", - "BATCH_SIZE = 1\n", - "IMG_WIDTH = 256\n", - "IMG_HEIGHT = 256" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tyaP4hLJ8b4W" - }, - "outputs": [], - "source": [ - "def load_image(image_file, is_train):\n", - " image = tf.read_file(image_file)\n", - " image = tf.image.decode_jpeg(image)\n", - "\n", - " w = tf.shape(image)[1]\n", - "\n", - " w = w // 2\n", - " real_image = image[:, :w, :]\n", - " input_image = image[:, w:, :]\n", - "\n", - " input_image = tf.cast(input_image, tf.float32)\n", - " real_image = tf.cast(real_image, tf.float32)\n", - "\n", - " if is_train:\n", - " # random jittering\n", - " \n", - " # resizing to 286 x 286 x 3\n", - " input_image = tf.image.resize_images(input_image, [286, 286], \n", - " align_corners=True, \n", - " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", - " real_image = tf.image.resize_images(real_image, [286, 286], \n", - " align_corners=True, \n", - " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", - " \n", - " # randomly cropping to 256 x 256 x 3\n", - " stacked_image = tf.stack([input_image, real_image], axis=0)\n", - " cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", - " input_image, real_image = cropped_image[0], cropped_image[1]\n", - "\n", - " if np.random.random() \u003e 0.5:\n", - " # random mirroring\n", - " input_image = tf.image.flip_left_right(input_image)\n", - " real_image = tf.image.flip_left_right(real_image)\n", - " else:\n", - " input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", - " align_corners=True, method=2)\n", - " real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", - " align_corners=True, method=2)\n", - " \n", - " # normalizing the images to [-1, 1]\n", - " input_image = (input_image / 127.5) - 1\n", - " real_image = (real_image / 127.5) - 1\n", - "\n", - " return input_image, real_image" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "source": [ - "## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "SQHmYSmk8b4b" - }, - "outputs": [], - "source": [ - "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n", - "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", - "train_dataset = train_dataset.map(lambda x: load_image(x, True))\n", - "train_dataset = train_dataset.batch(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "MS9J0yA58b4g" - }, - "outputs": [], - "source": [ - "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n", - "test_dataset = test_dataset.map(lambda x: load_image(x, False))\n", - "test_dataset = test_dataset.batch(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "source": [ - "## Write the generator and discriminator models\n", - "\n", - "* **Generator** \n", - " * The architecture of generator is a modified U-Net.\n", - " * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n", - " * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n", - " * There are skip connections between the encoder and decoder (as in U-Net).\n", - " \n", - "* **Discriminator**\n", - " * The Discriminator is a PatchGAN.\n", - " * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n", - " * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n", - " * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n", - " * Discriminator receives 2 inputs.\n", - " * Input image and the target image, which it should classify as real.\n", - " * Input image and the generated image (output of generator), which it should classify as fake. \n", - " * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n", - "\n", - "* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n", - "\n", - "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tqqvWxlw8b4l" - }, - "outputs": [], - "source": [ - "OUTPUT_CHANNELS = 3" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "lFPI4Nu-8b4q" - }, - "outputs": [], - "source": [ - "class Downsample(tf.keras.Model):\n", - " \n", - " def __init__(self, filters, size, apply_batchnorm=True):\n", - " super(Downsample, self).__init__()\n", - " self.apply_batchnorm = apply_batchnorm\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - "\n", - " self.conv1 = tf.keras.layers.Conv2D(filters, \n", - " (size, size), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer,\n", - " use_bias=False)\n", - " if self.apply_batchnorm:\n", - " self.batchnorm = tf.keras.layers.BatchNormalization()\n", - " \n", - " def call(self, x, training):\n", - " x = self.conv1(x)\n", - " if self.apply_batchnorm:\n", - " x = self.batchnorm(x, training=training)\n", - " x = tf.nn.leaky_relu(x)\n", - " return x \n", - "\n", - "\n", - "class Upsample(tf.keras.Model):\n", - " \n", - " def __init__(self, filters, size, apply_dropout=False):\n", - " super(Upsample, self).__init__()\n", - " self.apply_dropout = apply_dropout\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - "\n", - " self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n", - " (size, size), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer,\n", - " use_bias=False)\n", - " self.batchnorm = tf.keras.layers.BatchNormalization()\n", - " if self.apply_dropout:\n", - " self.dropout = tf.keras.layers.Dropout(0.5)\n", - "\n", - " def call(self, x1, x2, training):\n", - " x = self.up_conv(x1)\n", - " x = self.batchnorm(x, training=training)\n", - " if self.apply_dropout:\n", - " x = self.dropout(x, training=training)\n", - " x = tf.nn.relu(x)\n", - " x = tf.concat([x, x2], axis=-1)\n", - " return x\n", - "\n", - "\n", - "class Generator(tf.keras.Model):\n", - " \n", - " def __init__(self):\n", - " super(Generator, self).__init__()\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - " \n", - " self.down1 = Downsample(64, 4, apply_batchnorm=False)\n", - " self.down2 = Downsample(128, 4)\n", - " self.down3 = Downsample(256, 4)\n", - " self.down4 = Downsample(512, 4)\n", - " self.down5 = Downsample(512, 4)\n", - " self.down6 = Downsample(512, 4)\n", - " self.down7 = Downsample(512, 4)\n", - " self.down8 = Downsample(512, 4)\n", - "\n", - " self.up1 = Upsample(512, 4, apply_dropout=True)\n", - " self.up2 = Upsample(512, 4, apply_dropout=True)\n", - " self.up3 = Upsample(512, 4, apply_dropout=True)\n", - " self.up4 = Upsample(512, 4)\n", - " self.up5 = Upsample(256, 4)\n", - " self.up6 = Upsample(128, 4)\n", - " self.up7 = Upsample(64, 4)\n", - "\n", - " self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n", - " (4, 4), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer)\n", - " \n", - " @tf.contrib.eager.defun\n", - " def call(self, x, training):\n", - " # x shape == (bs, 256, 256, 3) \n", - " x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n", - " x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n", - " x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n", - " x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n", - " x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n", - " x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n", - " x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n", - " x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n", - "\n", - " x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n", - " x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n", - " x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n", - " x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n", - " x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n", - " x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n", - " x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n", - "\n", - " x16 = self.last(x15) # (bs, 256, 256, 3)\n", - " x16 = tf.nn.tanh(x16)\n", - "\n", - " return x16" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "ll6aNeQx8b4v" - }, - "outputs": [], - "source": [ - "class DiscDownsample(tf.keras.Model):\n", - " \n", - " def __init__(self, filters, size, apply_batchnorm=True):\n", - " super(DiscDownsample, self).__init__()\n", - " self.apply_batchnorm = apply_batchnorm\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - "\n", - " self.conv1 = tf.keras.layers.Conv2D(filters, \n", - " (size, size), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer,\n", - " use_bias=False)\n", - " if self.apply_batchnorm:\n", - " self.batchnorm = tf.keras.layers.BatchNormalization()\n", - " \n", - " def call(self, x, training):\n", - " x = self.conv1(x)\n", - " if self.apply_batchnorm:\n", - " x = self.batchnorm(x, training=training)\n", - " x = tf.nn.leaky_relu(x)\n", - " return x \n", - "\n", - "class Discriminator(tf.keras.Model):\n", - " \n", - " def __init__(self):\n", - " super(Discriminator, self).__init__()\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - " \n", - " self.down1 = DiscDownsample(64, 4, False)\n", - " self.down2 = DiscDownsample(128, 4)\n", - " self.down3 = DiscDownsample(256, 4)\n", - " \n", - " # we are zero padding here with 1 because we need our shape to \n", - " # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n", - " self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n", - " self.conv = tf.keras.layers.Conv2D(512, \n", - " (4, 4), \n", - " strides=1, \n", - " kernel_initializer=initializer, \n", - " use_bias=False)\n", - " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", - " \n", - " # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n", - " self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n", - " self.last = tf.keras.layers.Conv2D(1, \n", - " (4, 4), \n", - " strides=1,\n", - " kernel_initializer=initializer)\n", - " \n", - " @tf.contrib.eager.defun\n", - " def call(self, inp, tar, training):\n", - " # concatenating the input and the target\n", - " x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n", - " x = self.down1(x, training=training) # (bs, 128, 128, 64)\n", - " x = self.down2(x, training=training) # (bs, 64, 64, 128)\n", - " x = self.down3(x, training=training) # (bs, 32, 32, 256)\n", - "\n", - " x = self.zero_pad1(x) # (bs, 34, 34, 256)\n", - " x = self.conv(x) # (bs, 31, 31, 512)\n", - " x = self.batchnorm1(x, training=training)\n", - " x = tf.nn.leaky_relu(x)\n", - " \n", - " x = self.zero_pad2(x) # (bs, 33, 33, 512)\n", - " # don't add a sigmoid activation here since\n", - " # the loss function expects raw logits.\n", - " x = self.last(x) # (bs, 30, 30, 1)\n", - "\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gDkA05NE6QMs" - }, - "outputs": [], - "source": [ - "# The call function of Generator and Discriminator have been decorated\n", - "# with tf.contrib.eager.defun()\n", - "# We get a performance speedup if defun is used (~25 seconds per epoch)\n", - "generator = Generator()\n", - "discriminator = Discriminator()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "* **Discriminator loss**\n", - " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", - " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n", - " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n", - " * Then the total_loss is the sum of real_loss and the generated_loss\n", - " \n", - "* **Generator loss**\n", - " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n", - " * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n", - " * This allows the generated image to become structurally similar to the target image.\n", - " * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "cyhxTuvJyIHV" - }, - "outputs": [], - "source": [ - "LAMBDA = 100" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wkMNfBWlT-PV" - }, - "outputs": [], - "source": [ - "def discriminator_loss(disc_real_output, disc_generated_output):\n", - " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n", - " logits = disc_real_output)\n", - " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n", - " logits = disc_generated_output)\n", - "\n", - " total_disc_loss = real_loss + generated_loss\n", - "\n", - " return total_disc_loss" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "90BIcCKcDMxz" - }, - "outputs": [], - "source": [ - "def generator_loss(disc_generated_output, gen_output, target):\n", - " gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n", - " logits = disc_generated_output) \n", - " # mean absolute error\n", - " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", - "\n", - " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", - "\n", - " return total_gen_loss" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" - }, - "outputs": [], - "source": [ - "generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n", - "discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aKUZnDiqQrAh" - }, - "source": [ - "## Checkpoints (Object-based saving)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WJnftd5sQsv6" - }, - "outputs": [], - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", - " discriminator_optimizer=discriminator_optimizer,\n", - " generator=generator,\n", - " discriminator=discriminator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* The generator gets the input image and we get a generated output.\n", - "* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n", - "* Next, we calculate the generator and the discriminator loss.\n", - "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, its time to generate some images!\n", - "* We pass images from the test dataset to the generator.\n", - "* The generator will then translate the input image into the output we expect.\n", - "* Last step is to plot the predictions and **voila!**" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "NS2GWywBbAWo" - }, - "outputs": [], - "source": [ - "EPOCHS = 200" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "RmdVsmvhPxyy" - }, - "outputs": [], - "source": [ - "def generate_images(model, test_input, tar):\n", - " # the training=True is intentional here since\n", - " # we want the batch statistics while running the model\n", - " # on the test dataset. If we use training=False, we will get \n", - " # the accumulated statistics learned from the training dataset\n", - " # (which we don't want)\n", - " prediction = model(test_input, training=True)\n", - " plt.figure(figsize=(15,15))\n", - "\n", - " display_list = [test_input[0], tar[0], prediction[0]]\n", - " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", - "\n", - " for i in range(3):\n", - " plt.subplot(1, 3, i+1)\n", - " plt.title(title[i])\n", - " # getting the pixel values between [0, 1] to plot it.\n", - " plt.imshow(display_list[i] * 0.5 + 0.5)\n", - " plt.axis('off')\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "def train(dataset, epochs): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - "\n", - " for input_image, target in dataset:\n", - "\n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " gen_output = generator(input_image, training=True)\n", - "\n", - " disc_real_output = discriminator(input_image, target, training=True)\n", - " disc_generated_output = discriminator(input_image, gen_output, training=True)\n", - "\n", - " gen_loss = generator_loss(disc_generated_output, gen_output, target)\n", - " disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n", - "\n", - " generator_gradients = gen_tape.gradient(gen_loss, \n", - " generator.variables)\n", - " discriminator_gradients = disc_tape.gradient(disc_loss, \n", - " discriminator.variables)\n", - "\n", - " generator_optimizer.apply_gradients(zip(generator_gradients, \n", - " generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n", - " discriminator.variables))\n", - "\n", - " if epoch % 1 == 0:\n", - " clear_output(wait=True)\n", - " for inp, tar in test_dataset.take(1):\n", - " generate_images(generator, inp, tar)\n", - " \n", - " # saving (checkpoint) the model every 20 epochs\n", - " if (epoch + 1) % 20 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - "\n", - " print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", - " time.time()-start))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "a1zZmKmvOH85" - }, - "outputs": [], - "source": [ - "train(train_dataset, EPOCHS)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "kz80bY3aQ1VZ" - }, - "source": [ - "## Restore the latest checkpoint and test" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "4t4x69adQ5xb" - }, - "outputs": [], - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + "# Pix2Pix" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "1RGysMU_BZhx" - }, - "source": [ - "## Testing on the entire test dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "KUgSnmy2nqSP" - }, - "outputs": [], - "source": [ - "# Run the trained model on the entire test dataset\n", - "for inp, tar in test_dataset:\n", - " generate_images(generator, inp, tar)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "3AJXOByaZVOf" + "id": "c7W3j96p219v" }, - "outputs": [], "source": [ - "" + "This notebook has been moved to [https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb](https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb)" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 04ac78a2d3753a07f067f6f2abe0cfc02f5245ba..f2851d97223e483da11120f1fe3f0a2f641dfb81 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -27,7 +27,7 @@ py_library( cuda_py_test( name = "resnet50_test", - size = "large", + size = "medium", srcs = ["resnet50_test.py"], additional_deps = [ ":resnet50", @@ -35,6 +35,7 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", # Fix b/118130911 "nomsan", # Fix b/118130911 @@ -46,7 +47,7 @@ cuda_py_test( cuda_py_test( name = "resnet50_graph_test", - size = "large", + size = "medium", srcs = ["resnet50_graph_test.py"], additional_deps = [ ":resnet50", @@ -54,6 +55,7 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", "nomsan", diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 0d85bf63ad28d6ad15b93c86f8236bbdc7c2f4be..cb207b8ddf3641a68a114386f6a95a26ce2b74d6 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -67,12 +67,13 @@ py_library( # Tests cuda_py_test( name = "ops_test", - size = "large", + size = "medium", srcs = ["ops_test.py"], additional_deps = [ ":ops", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "oss_serial", ], @@ -80,20 +81,22 @@ cuda_py_test( cuda_py_test( name = "blocks_test", - size = "large", + size = "medium", srcs = ["blocks_test.py"], additional_deps = [ ":blocks", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ + "no_oss", # b/123045964 "optonly", ], ) cuda_py_test( name = "revnet_test", - size = "large", + size = "medium", srcs = ["revnet_test.py"], additional_deps = [ ":blocks_test", @@ -101,6 +104,7 @@ cuda_py_test( ":revnet", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", # depends on blocks_test, which is not available in pip package "optonly", @@ -131,6 +135,13 @@ py_binary( name = "main", srcs = ["main.py"], srcs_version = "PY2AND3", + deps = [":main_lib"], +) + +py_library( + name = "main_lib", + srcs = ["main.py"], + srcs_version = "PY2AND3", deps = [ ":cifar_input", ":config", @@ -145,7 +156,7 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -157,7 +168,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -169,7 +180,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index 576f60396ef3de21655920f1376ac96b3fbf5c5f..f4dbe7ac16f734f7bee045bc71e9559b630adf81 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_colorbot", srcs = ["rnn_colorbot.py"], srcs_version = "PY2AND3", + deps = [":rnn_colorbot_lib"], +) + +py_library( + name = "rnn_colorbot_lib", + srcs = ["rnn_colorbot.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -21,7 +28,7 @@ cuda_py_test( name = "rnn_colorbot_test", srcs = ["rnn_colorbot_test.py"], additional_deps = [ - ":rnn_colorbot", + ":rnn_colorbot_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 74ebb1ec77131a560b1ebfd062c690920c35e261..1c718a5ce3d8e1541656d92fd5e8dad6c6683c4c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -207,7 +207,7 @@ class RNNColorbot(tf.keras.Model): def loss(labels, predictions): """Computes mean squared loss.""" - return tf.reduce_mean(tf.square(predictions - labels)) + return tf.reduce_mean(tf.squared_difference(predictions, labels)) def test(model, eval_data): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index f9bf82a7d88c46e13bd12ce3ee90f0334d473f10..43a6ca526d3a0aecda2c8df865a0487ac28758ab 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_ptb", srcs = ["rnn_ptb.py"], srcs_version = "PY2AND3", + deps = [":rnn_ptb_lib"], +) + +py_library( + name = "rnn_ptb_lib", + srcs = ["rnn_ptb.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -21,17 +28,18 @@ cuda_py_test( name = "rnn_ptb_test", srcs = ["rnn_ptb_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = ["no_oss"], # b/123045964 ) cuda_py_test( name = "rnn_ptb_graph_test", srcs = ["rnn_ptb_graph_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index d18a097063c7d25947af3e2e2959ce574edd553f..3143270ccfe4f670428c80bdc1e09fa452584207 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: enable=g-bad-import-order @@ -421,7 +421,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph = checkpointable_utils.object_metadata( + object_graph = trackable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c8d9266672a8b87d32338ea7c4f74fb40d41c767..b32501c2e804838af9d4c77663be131b77bd30b4 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -32,12 +32,12 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(checkpointable.Checkpointable): +class Metric(trackable.Trackable): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -269,7 +269,7 @@ class Metric(checkpointable.Checkpointable): else: collections = [ops.GraphKeys.LOCAL_VARIABLES] collections += [ops.GraphKeys.METRIC_VARIABLES] - # Variables are Checkpointable dependencies of Metrics regardless of the + # Variables are Trackable dependencies of Metrics regardless of the # global/local distinction. Users can avoid saving variables by not adding a # dependency on the Metric. v = self._add_variable_with_custom_getter( @@ -282,7 +282,7 @@ class Metric(checkpointable.Checkpointable): use_resource=True, getter=variable_scope.get_variable, # Raise duplicate variable exceptions from get_variable rather than - # Checkpointable. + # Trackable. overwrite=True) self._vars.append(v) if context.executing_eagerly(): diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 39e5957f5d1760613f2c33607c0bdb163040efb4..c56d1956fde35b562e60496015e666efe9ebc8f6 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -35,7 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class MetricsTest(test.TestCase): @@ -314,7 +314,7 @@ class MetricsTest(test.TestCase): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") mean = metrics.Mean() - checkpoint = checkpointable_utils.Checkpoint(mean=mean) + checkpoint = trackable_utils.Checkpoint(mean=mean) mean.build() mean._built = True self.evaluate(mean.init_variables()) @@ -327,7 +327,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(200., self.evaluate(mean.value())) restore_mean = metrics.Mean() - restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean) status = restore_checkpoint.restore(save_path) restore_update = restore_mean(300.) status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 240f213c602395b8589d39c3ecd90b602ffa9848..b3e8daddaf2369e9e33179fde2aab1469e97ea47 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: disable=not-callable @@ -65,7 +65,7 @@ class NetworkTest(test.TestCase): def test_checkpointing_not_implemented(self): checkpoint_directory = self.get_temp_dir() - checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + checkpoint = trackable_utils.Checkpoint(net=MyNetwork()) with self.assertRaises(NotImplementedError): checkpoint.save(checkpoint_directory) diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index 7803a6799bb64441fab881bf6ca986d5cf3851a8..258f0a19309235dcd99b31b4de3d35ef8d89b15b 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): @@ -129,8 +129,8 @@ class SharedVariable(resource_variable_ops.ResourceVariable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") - if isinstance(initial_value, checkpointable.CheckpointInitialValue): - self._maybe_initialize_checkpointable() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index b82e1bb71bce9a28d7bbbf961cc6d5e25dd18acf..05830c9c1c3ebab82a545938d62094772ccf11b1 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -62,7 +62,6 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@Checkpoint @@Checkpointable -@@CheckpointableSaver @@executing_eagerly @@in_eager_mode @@ -138,9 +137,8 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable -from tensorflow.python.training.checkpointable.util import CheckpointableSaver -from tensorflow.python.training.checkpointable.util import Checkpoint +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import CheckpointV1 as Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 8c35dddb5a515aa09cc70c173a9f0605e8567e82..6881fabdc09e3275c29f3013283999c96e283770 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import tempfile from tensorflow.contrib.eager.python import tfe +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -40,6 +41,9 @@ class TFETest(test_util.TensorFlowTestCase): self.assertAllEqual([[4.]], y.numpy()) def testInstantError(self): + if context.num_gpus(): + # TODO(nareshmodi): make this test better + self.skipTest("Gather doesn't do index checking on GPUs") with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py index 11856ece38bf08dfdf16e8b0d9890bbfb0033216..47f568ed3d3e1b94e74c1423f774352df5c30f45 100644 --- a/tensorflow/contrib/estimator/python/estimator/early_stopping.py +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py @@ -23,7 +23,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow_estimator.contrib.estimator.python.estimator import early_stopping +from tensorflow_estimator.python.estimator import early_stopping # Include attrs that start with single underscore. _HAS_DYNAMIC_ATTRIBUTES = True @@ -31,4 +31,4 @@ early_stopping.__all__ = [ s for s in dir(early_stopping) if not s.startswith('__') ] -from tensorflow_estimator.contrib.estimator.python.estimator.early_stopping import * +from tensorflow_estimator.python.estimator.early_stopping import * diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index cb86efb8da72f168b54f04773289a6fe421282b1..da2479a0b7b029561136903c82cabed9aae622b8 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -109,7 +109,7 @@ tf_gen_op_wrapper_py( # Ops tests tf_py_test( name = "gmm_test", - size = "large", + size = "medium", srcs = [ "python/ops/gmm_test.py", ], @@ -130,6 +130,7 @@ tf_py_test( "//tensorflow/python:random_seed", "//tensorflow/python:training", ], + shard_count = 4, tags = [ "no_pip", # b/38283730 "notsan", # Flaky: b/30756419 @@ -202,10 +203,7 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = [ - "nomac", # b/73741358 - "notsan", # b/67512932 - ], + tags = ["notsan"], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", @@ -227,7 +225,7 @@ py_test( tf_py_test( name = "wals_test", - size = "large", + size = "medium", srcs = ["python/ops/wals_test.py"], additional_deps = [ ":factorization_py", @@ -250,8 +248,8 @@ tf_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], + shard_count = 4, tags = [ - "manual", "noasan", # times out b/63678675 "nomsan", ], diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index a8c5d0763c28ba2b54f217405f0da65533f26b91..68078ba8bbb07b4344c19d554012d214229f9c4f 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -19,12 +19,12 @@ #include #include +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index d48b89cbacce34781819010addbcbd0ba66f9873..505d8d731fa9f3d0b004ccacd724576b7ac0ceee 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -613,7 +613,8 @@ class _InitializeClustersOpFactory(object): inp = nn_impl.l2_normalize(inp, dim=1) return gen_clustering_ops.kmeans_plus_plus_initialization( inp, - math_ops.to_int64(self._num_remaining), self._random_seed, + math_ops.cast(self._num_remaining, dtypes.int64), + self._random_seed, self._kmeans_plus_plus_num_retries) def _kmc2_multiple_centers(self): diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index d365ad111760247fc18b730657390f07ba6b865e..000b9832aa4d9ec645a86a6946fbf2665f9fd71d 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -53,7 +53,7 @@ def _covariance(x, diag): A Tensor representing the covariance of x. In the case of diagonal matrix just the diagonal is returned. """ - num_points = math_ops.to_float(array_ops.shape(x)[0]) + num_points = math_ops.cast(array_ops.shape(x)[0], dtypes.float32) x -= math_ops.reduce_mean(x, 0, keepdims=True) if diag: cov = math_ops.reduce_sum( @@ -297,8 +297,9 @@ class GmmAlgorithm(object): cholesky, array_ops.transpose( diff, perm=[0, 2, 1]), lower=True)) diag_m = array_ops.transpose(math_ops.reduce_sum(x_mu_cov, 1)) - self._probs[shard_id] = -0.5 * (diag_m + math_ops.to_float(self._dimensions) - * math_ops.log(2 * np.pi) + log_det_covs) + self._probs[shard_id] = ( + -0.5 * (diag_m + math_ops.cast(self._dimensions, dtypes.float32) * + math_ops.log(2 * np.pi) + log_det_covs)) def _define_diag_covariance_probs(self, shard_id, shard): """Defines the diagonal covariance probabilities per example in a class. @@ -314,14 +315,14 @@ class GmmAlgorithm(object): # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( math_ops.log(self._covs + 1e-3), 1, keepdims=True) - diff = shard - self._means - x2 = math_ops.square(diff) + x2 = math_ops.squared_difference(shard, self._means) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) # num_classes X num_examples x2_cov = math_ops.matmul(x2, cov_expanded) x2_cov = array_ops.transpose(array_ops.squeeze(x2_cov, [2])) self._probs[shard_id] = -0.5 * ( - math_ops.to_float(self._dimensions) * math_ops.log(2.0 * np.pi) + + math_ops.cast(self._dimensions, dtypes.float32) * + math_ops.log(2.0 * np.pi) + array_ops.transpose(det_expanded) + x2_cov) def _define_log_prob_operation(self, shard_id, shard): @@ -401,7 +402,8 @@ class GmmAlgorithm(object): # Update alpha. if 'w' in self._params: final_points_in_k = points_in_k / num_batches - num_examples = math_ops.to_float(math_ops.reduce_sum(final_points_in_k)) + num_examples = math_ops.cast(math_ops.reduce_sum(final_points_in_k), + dtypes.float32) self._alpha_op = self._alpha.assign(final_points_in_k / (num_examples + MEPS)) else: diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 4e29e2559986012d8eeeaec807f14181226363aa..edd6f36e07c246eb4d8a5176a74943b461830cc3 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -14,7 +14,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", - ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -34,6 +33,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/feature_column:utils", ], ) @@ -72,44 +72,3 @@ tf_py_test( ], tags = ["no_pip"], ) - -py_library( - name = "sequence_feature_column_v2", - srcs = ["python/feature_column/sequence_feature_column_v2.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - ], -) - -tf_py_test( - name = "sequence_feature_column_v2_test", - srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - additional_deps = [ - ":sequence_feature_column", - ":sequence_feature_column_v2", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/feature_column:feature_column_py", - "//tensorflow/python/feature_column:feature_column_v2_test", - ], - tags = ["no_pip"], -) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 9b3a5c58aaa9498257fc971ac60b97f31d5185d8..64df44fe4360cb30de89f06e6e88d85e6dbaf182 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -23,6 +23,7 @@ import collections from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import utils as fc_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -506,7 +507,7 @@ class _SequenceNumericColumn( # sequence length is not affected. num_elements = (self._variable_shape.num_elements() if sp_tensor.shape.ndims == 2 else 1) - seq_length = fc._sequence_length_from_sparse_tensor( + seq_length = fc_utils.sequence_length_from_sparse_tensor( sp_tensor, num_elements=num_elements) return fc._SequenceDenseColumn.TensorSequenceLengthPair( diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 88a14a2a94cc683f021d032ea11358e0cfb63faa..8fd2b5f39bc88b76fe5583f8d18389e232ea9f40 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -32,7 +32,6 @@ tf_custom_op_py_library( "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", - "python/ops/critical_section_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", "python/ops/script_ops.py", @@ -51,6 +50,7 @@ tf_custom_op_py_library( "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", "//tensorflow_estimator:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ @@ -171,26 +171,6 @@ py_test( ], ) -cuda_py_test( - name = "critical_section_test", - size = "medium", - srcs = ["python/ops/critical_section_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - ], -) - py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index e72e50585a3861d4527b66f89e1659d76c85960a..063717f08aa88f4de9470d8392db2b7c95b3e4bf 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -94,8 +94,6 @@ @@smart_constant_value @@smart_case -@@CriticalSection - @@BoundedTensorSpec @@TensorSpec @@ -129,18 +127,24 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', + 'is_nested', 'is_sequence', + 'is_sequence_or_composite', 'flatten', 'flatten_dict_items', 'pack_sequence_as', 'map_structure', 'map_structure_with_paths', + 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', + 'flatten_with_tuple_paths_up_to', 'map_structure_up_to', + 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', 'yield_flat_paths', 'flatten_with_joined_string_paths', + 'flatten_with_tuple_paths', ] remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index c4976497f5fa95d82e492153b117681f693eaa13..8113bf7c095bd0817e40cfd08bdf1ef7275ba55b 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -22,7 +22,6 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * -from tensorflow.contrib.framework.python.ops.critical_section_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * from tensorflow.contrib.framework.python.ops.script_ops import * diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 57a5bfbf43c915775c6b0ef05baac19581213a09..5c254436e630393341945cb4546c20f2b24be031 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -68,6 +68,8 @@ tf_kernel_library( prefix = "fused_conv2d_bias_activation_op", visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core:conv_autotuning_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -92,6 +94,8 @@ tf_custom_op_library( "ops/fused_conv2d_bias_activation_op.cc", ], deps = [ + "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core:conv_autotuning_proto_cc", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:conv_2d_hdrs", @@ -171,6 +175,7 @@ cuda_py_test( main = "python/ops/fused_conv2d_bias_activation_benchmark.py", tags = [ "manual", # TODO(b/117128481): re-enable after fixing OSS build + "nogpu", "requires-gpu-sm70", ], ) diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index c541c71f996c7a1b36cf28ae9a1783f8dca0a72c..0b3e37107213281abca1f7abdb58f820c85ac1b0 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -34,9 +34,16 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA +#include "google/protobuf/duration.pb.h" +#include "absl/time/time.h" #include "cuda/include/cudnn.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/protobuf/autotuning.pb.h" +#include "tensorflow/core/protobuf/conv_autotuning.pb.h" #include "tensorflow/core/util/activation_mode.h" #endif // GOOGLE_CUDA @@ -252,6 +259,131 @@ class FusedConv2DBiasActivationOp : public OpKernel { #if GOOGLE_CUDA namespace dnn = se::dnn; +// Several functions are copyed over from tensorflow/core/kernels/gpu_utils, +// since this file may be compiled down to a tf_custom_op_library .so file, +// which can't depend on basic dependencies like tensorflow/core:lib. Instead, +// the code has to depend on whatever is the same in libtensorflow_framework.so. +// +// In theory, we can lift the dependencies of gpu_utils by turning it into a +// template library that provides duck typing, but I think duplication is the +// lesser of two evils. +namespace internal { +namespace { + +tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + tensorflow::CudnnVersion cudnn_version; + if (auto* dnn = stream_executor->AsDnn()) { + se::port::StatusOr version_or = dnn->GetVersion(); + if (version_or.ok()) { + const auto& version = version_or.ValueOrDie(); + cudnn_version.set_major(version.major_version()); + cudnn_version.set_minor(version.minor_version()); + cudnn_version.set_patch(version.patch()); + } + } + return cudnn_version; +} + +// Converts an absl::Duration to a google::protobuf::Duration. +inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + return proto; +} + +// Converts a google::protobuf::Duration to an absl::Duration. +inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + +tensorflow::ComputeCapability GetComputeCapability( + se::StreamExecutor* stream_executor) { + tensorflow::ComputeCapability cc; + int cc_major, cc_minor; + stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + cc.set_major(cc_major); + cc.set_minor(cc_minor); + return cc; +} + +void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input, + const Tensor& filter, const Tensor& output, + const Tensor& bias, const Tensor* side_input, + se::StreamExecutor* stream_exec, + absl::Span results) { + AutotuningLog log; + ConvNodeDef instr; + *instr.mutable_conv() = node; + input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape()); + instr.mutable_input()->set_dtype(input.dtype()); + filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape()); + instr.mutable_filter()->set_dtype(filter.dtype()); + output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape()); + instr.mutable_output()->set_dtype(output.dtype()); + bias.shape().AsProto(instr.mutable_bias()->mutable_tensor_shape()); + instr.mutable_bias()->set_dtype(bias.dtype()); + if (side_input) { + side_input->shape().AsProto( + instr.mutable_side_input()->mutable_tensor_shape()); + instr.mutable_side_input()->set_dtype(side_input->dtype()); + } + log.mutable_instr()->PackFrom(std::move(instr)); + *log.mutable_cudnn_version() = internal::GetCudnnVersion(stream_exec); + *log.mutable_compute_capability() = + internal::GetComputeCapability(stream_exec); + for (const auto& result : results) { + *log.add_results() = result; + } + Logger::Singleton()->LogProto(log); +} + +Status BestCudnnConvAlgorithm(absl::Span results, + se::dnn::AlgorithmConfig* algo) { + // For the "!xhs.has_success()" below, this is because we want successful ones + // to order first, therefore they need a smaller key per "min_element". + const AutotuneResult* best_result = std::min_element( + results.begin(), results.end(), + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return std::make_tuple( + !lhs.has_success(), + internal::FromDurationProto(lhs.success().run_time())) < + std::make_tuple( + !rhs.has_success(), + internal::FromDurationProto(rhs.success().run_time())); + }); + + const AutotuneResult* best_result_no_scratch = std::min_element( + results.begin(), results.end(), + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return std::make_tuple( + !lhs.has_success(), lhs.success().scratch_bytes(), + internal::FromDurationProto(lhs.success().run_time())) < + std::make_tuple( + !rhs.has_success(), rhs.success().scratch_bytes(), + internal::FromDurationProto(rhs.success().run_time())); + }); + + if (best_result == results.end() || !best_result->has_success()) { + return errors::NotFound("No algorithm worked!"); + } + algo->set_algorithm({best_result->conv().algorithm(), + best_result->conv().tensor_ops_enabled()}); + if (best_result_no_scratch != results.end() && + best_result_no_scratch->has_success() && + best_result_no_scratch->success().scratch_bytes() == 0) { + algo->set_algorithm_no_scratch( + {best_result_no_scratch->conv().algorithm(), + best_result_no_scratch->conv().tensor_ops_enabled()}); + } + return Status::OK(); +} + +} // namespace +} // namespace internal + // A dummy type to group forward convolution autotune results together. struct ConvBiasActivationAutoTuneGroup { static string name() { return "ConvBiasActivation"; } @@ -565,8 +697,21 @@ void LaunchFusedConv2DBiasActivationOp:: fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo( stream->parent()), &algorithms)); - dnn::ProfileResult best_result; - dnn::ProfileResult best_result_no_scratch; + if (activation_mode == ActivationMode::NONE) { + // Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is supported for + // identity activation, other algs seem to quietly do Relu. + // See + // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward + algorithms.erase( + std::remove_if( + algorithms.begin(), algorithms.end(), + [](dnn::AlgorithmDesc alg) { + return alg.algo_id() != + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + }), + algorithms.end()); + } + std::vector results; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. @@ -583,28 +728,24 @@ void LaunchFusedConv2DBiasActivationOp:: .ok(); if (cudnn_launch_status) { if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + results.emplace_back(); + auto& result = results.back(); + result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled( + profile_algorithm.tensor_ops_enabled()); + result.mutable_success()->set_scratch_bytes( + scratch_allocator.TotalByteSize()); + *result.mutable_success()->mutable_run_time() = + internal::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } } - OP_REQUIRES(ctx, - best_result.is_valid() || best_result_no_scratch.is_valid(), - errors::NotFound("No algorithm worked!")); - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - if (best_result_no_scratch.is_valid()) { - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - } + internal::LogFusedConvAutotuneResults(ctx->op_kernel().def(), *conv_input, + *filter, *output, bias, side_input, + stream->parent(), results); + OP_REQUIRES_OK( + ctx, internal::BestCudnnConvAlgorithm(results, &algorithm_config)); AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters, algorithm_config); } diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index db0868fb2c43464a811b3d6dfcd96480ba2463ee..386e4cf69b7aa118a85fb25bcb809a879c5c1bd8 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -377,7 +377,10 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_pip", + "no_windows", + ], deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index db7dc51daa78ecee12ecb7f6d33df4511e068243..3c1d814e70f7fdad4083583c9d89450a60bc2e20 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -9,8 +9,9 @@ explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an -introduction. +Goodfellow et al. See +[tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) +for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction. #### Usage ```python @@ -57,11 +58,10 @@ These include the following main pieces (explained in detail below). generative models. * [examples](https://github.com/tensorflow/models/tree/master/research/gan/) - and [tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN - to make GAN training easier, or use the more complicated examples to - jumpstart your own project. These include unconditional and conditional - GANs, InfoGANs, adversarial losses on existing networks, and image-to-image - translation. + and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make + GAN training easier, or use the more complicated examples to jump-start your + own project. These include unconditional and conditional GANs, InfoGANs, + adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 5b9c54e43a16adf457d5ed0e7e73dcd168ab0d67..66af79d1e81bbc450141673dd54d865e5c7932d5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -238,10 +237,10 @@ class GANEstimatorIntegrationTest(test.TestCase): # Evaluate. scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # Predict. predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py index c00ff4399748a77f88d9753df7592bf3859d754e..0fcd1b7924eb02f5d617b45af16852baf2e2bb48 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -235,10 +234,10 @@ class StarGANEstimatorIntegrationTest(test.TestCase): # EVALUTE scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py index 9fdcc08334d50b4ddf3a0bc9bc755e55d51b0bd8..baf2c28df4b63cff525dcf3ff880730768ad000a 100644 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -184,12 +183,11 @@ class TPUGANEstimatorIntegrationTest(test.TestCase, parameterized.TestCase): # Evaluate. num_steps_eval = 2 scores = est.evaluate(eval_input_fn, steps=num_steps_eval) - self.assertEqual(num_steps_train + num_steps_eval, - scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn(ops.GraphKeys.GLOBAL_STEP, scores) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # Predict. predictions = np.array([x['generated_data'] for x in diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 31f0d34ed68a6adc25cca102236079d0f66615cb..efbdb1152d665509ae8b4444097dd5091a0f5312 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -41,9 +41,9 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops @@ -140,7 +140,7 @@ def preprocess_image(images, is_single = images.shape.ndims == 3 with ops.name_scope(scope, 'preprocess', [images, height, width]): if not images.dtype.is_floating: - images = math_ops.to_float(images) + images = math_ops.cast(images, dtypes.float32) if is_single: images = array_ops.expand_dims(images, axis=0) resized = image_ops.resize_bilinear(images, [height, width]) @@ -346,7 +346,7 @@ def classifier_score(images, classifier_fn, num_batches=1): images, num_or_size_splits=num_batches) # Compute the classifier splits using the memory-efficient `map_fn`. - logits = functional_ops.map_fn( + logits = map_fn.map_fn( fn=classifier_fn, elems=array_ops.stack(generated_images_list), parallel_iterations=1, @@ -505,12 +505,12 @@ def frechet_classifier_distance(real_images, # Compute the activations using the memory-efficient `map_fn`. def compute_activations(elems): - return functional_ops.map_fn(fn=classifier_fn, - elems=elems, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') + return map_fn.map_fn(fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') real_a = compute_activations(real_imgs) gen_a = compute_activations(generated_imgs) @@ -895,7 +895,7 @@ def kernel_classifier_distance_and_std(real_images, # Compute the activations using the memory-efficient `map_fn`. def compute_activations(elems): - return functional_ops.map_fn( + return map_fn.map_fn( fn=classifier_fn, elems=elems, parallel_iterations=1, @@ -1057,7 +1057,8 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, n_g = array_ops.shape(generated_activations)[0] n_bigger = math_ops.maximum(n_r, n_g) - n_blocks = math_ops.to_int32(math_ops.ceil(n_bigger / max_block_size)) + n_blocks = math_ops.cast(math_ops.ceil(n_bigger / max_block_size), + dtypes.int32) v_r = n_r // n_blocks v_g = n_g // n_blocks @@ -1099,7 +1100,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) + (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1))) - ests = functional_ops.map_fn( + ests = map_fn.map_fn( compute_kid_block, math_ops.range(n_blocks), dtype=dtype, back_prop=False) mn = math_ops.reduce_mean(ests) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index bd17571a0535a3c8e9dfee24a8da16eb2e72f165..bc7c1057b478fe2656898e68c1a14013b5a71d12 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -365,7 +365,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): unused_image = array_ops.zeros([2, 299, 299, 3]) incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: incscore_np = sess.run(incscore, {'concat:0': logits}) self.assertAllClose(_expected_inception_score(logits), incscore_np) @@ -473,7 +473,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_fn=lambda x: x, max_block_size=600) - with self.test_session() as sess: + with self.cached_session() as sess: actual_kid, actual_std = sess.run(kid_op) expected_kid, expected_std = _expected_kid_and_std(test_pool_real_a, @@ -500,7 +500,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): max_block_size=max_block_size) for block_size in [50, 512, 1000]: - with self.test_session() as sess: + with self.cached_session() as sess: actual_kid, actual_std = sess.run(kid_op, {max_block_size: block_size}) expected_kid, expected_std = _expected_kid_and_std( diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 4b1105f6bd4f21a0da02338b0fc9db87a41b145f..9657d4e3d0cb60376b3f1dd23d0138a200ce5e5c 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -28,6 +28,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -74,7 +75,7 @@ def _laplacian_pyramid(batch, num_levels): res = spatial_conv(res, 4) return res - pyramid = [math_ops.to_float(batch)] + pyramid = [math_ops.cast(batch, dtypes.float32)] for _ in range(1, num_levels): pyramid.append(pyr_down(pyramid[-1])) pyramid[-2] -= pyr_up(pyramid[-1]) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 9f448d3a1602c503093214201bdc96fc9bee85b5..3eb4f5db0c841af584a672f100509a3d455a8b75 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -20,9 +20,10 @@ from __future__ import print_function from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import util as loss_util @@ -171,8 +172,10 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, gan_model.generated_data[:num_comparisons]) real_list = array_ops.unstack(gan_model.real_data[:num_comparisons]) diffs = [ - math_ops.abs(math_ops.to_float(generated) - math_ops.to_float(real)) for - generated, real in zip(generated_list, real_list)] + math_ops.abs(math_ops.cast(generated, dtypes.float32) - + math_ops.cast(real, dtypes.float32)) + for generated, real in zip(generated_list, real_list) + ] image_list.extend(diffs) # Reshape image and display. @@ -261,7 +264,7 @@ def add_stargan_image_summaries(stargan_model, summary.image( 'stargan_image_generation', - functional_ops.map_fn( + map_fn.map_fn( _build_image, stargan_model.input_data[:num_images], parallel_iterations=num_images, diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py index 0cc653f0a7907f407e66add5537d1e0a5adb6d8b..3764c43cdfc8f6515e0376cd6aa1d244b21e2e89 100644 --- a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -53,7 +53,7 @@ def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): Args: w_tensor: The weight matrix whose spectral norm should be computed. power_iteration_rounds: The number of iterations of the power method to - perform. A higher number yeilds a better approximation. + perform. A higher number yields a better approximation. name: An optional scope name. Returns: @@ -105,7 +105,7 @@ def spectral_normalize(w, power_iteration_rounds=1, name=None): Args: w: The weight matrix to be normalized. power_iteration_rounds: The number of iterations of the power method to - perform. A higher number yeilds a better approximation. + perform. A higher number yields a better approximation. name: An optional scope name. Returns: @@ -126,7 +126,7 @@ def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): Args: scale: A scalar multiplier. 0.0 disables the regularizer. power_iteration_rounds: The number of iterations of the power method to - perform. A higher number yeilds a better approximation. + perform. A higher number yields a better approximation. scope: An optional scope name. Returns: @@ -221,7 +221,7 @@ def spectral_normalization_custom_getter(name_filter=_default_name_filter, name_filter: Optionally, a method that takes a Variable name as input and returns whether this Variable should be normalized. power_iteration_rounds: The number of iterations of the power method to - perform per step. A higher number yeilds a better approximation of the + perform per step. A higher number yields a better approximation of the true spectral norm. Returns: @@ -294,7 +294,7 @@ def keras_spectral_normalization(name_filter=_default_name_filter, name_filter: Optionally, a method that takes a Variable name as input and returns whether this Variable should be normalized. power_iteration_rounds: The number of iterations of the power method to - perform per step. A higher number yeilds a better approximation of the + perform per step. A higher number yields a better approximation of the true spectral norm. Yields: diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f5c448db41c67adb4edd2634dd63a1840180df70..80fae0356f66f9d98969171cdfe18110bd21f7bd 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -224,7 +224,8 @@ class VBN(object): # statistics and the reference batch statistics. ref_batch_size = _static_or_dynamic_batch_size( self._reference_batch, self._batch_axis) - self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.) + self._example_weight = 1. / ( + math_ops.cast(ref_batch_size, dtypes.float32) + 1.) self._ref_weight = 1. - self._example_weight # Make the variables, if necessary. diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py index ecfbb8a432e3308863edd6f1343be55c1fe5753c..9848f654badafea3f08ef48207eb84973c62cb16 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -112,7 +112,7 @@ class VirtualBatchnormTest(test.TestCase): batch, axis, training=True) # Get VBN's batch normalization on reference batch. - batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same + batch_axis = 0 if axis != 0 else 1 # axis and batch_axis can't same vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis) vbn_normalized = vbn.reference_batch_normalization() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index e3c780ac1a0f0ef15ff993bd3a9bf9730dcb45b8..44ee0f52696dc1cdcd91286a80b2d4b42be93a4d 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -403,7 +403,9 @@ class _PenaltyTest(object): def test_all_correct(self): loss = self._penalty_fn(**self._kwargs) self.assertEqual(self._expected_dtype, loss.dtype) - self.assertEqual(self._expected_op_name, loss.op.name) + # NOTE: Op names will change, it is inappropriate to include them in tests. + # See go/tf-breaking-change. + # self.assertEqual(self._expected_op_name, loss.op.name) with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index f36a5d346e0f27fbbc480e876380db51ed559c09..9bff8090d93d3ad7def69726073accfb234ef301 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -757,7 +757,9 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) - +# Begin google-internal +# The four major parts can be found here: http://screen/tMRMBAohDYG. +# End google-internal def stargan_loss( model, generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( @@ -776,8 +778,6 @@ def stargan_loss( add_summaries=True): """StarGAN Loss. - The four major part can be found here: http://screen/tMRMBAohDYG. - Args: model: (StarGAN) Model output of the stargan_model() function call. generator_loss_fn: The loss function on the generator. Takes a diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 704be917b3680a1b5712f4f1dc5059b354db8610..bf8b66dcfa5e44a03107cdf1ef8b04e1dbff4a9c 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -17,11 +17,6 @@ filegroup( ]), ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cuda_library", -) - # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", @@ -66,7 +61,6 @@ cc_library( ":gdr_memory_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -100,15 +94,37 @@ cc_library( ], ) +cc_library( + name = "gdr_collective_executor_mgr", + srcs = ["gdr_collective_executor_mgr.cc"], + hdrs = ["gdr_collective_executor_mgr.h"], + deps = [ + ":gdr_memory_manager", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:cancellable_call", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", + "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", + "//tensorflow/core/distributed_runtime:worker_cache", + ], +) + cc_library( name = "gdr_server_lib", srcs = ["gdr_server_lib.cc"], hdrs = ["gdr_server_lib.h"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ + ":gdr_collective_executor_mgr", ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 8242d93f129904828a11b61d48f2df8fb0f88bc3..711adc865f37fc84550e4b45d9f0c7fff421a0dc 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -114,7 +114,16 @@ Caveats In current implementation, only tensors that reside in host memory or in GPU memory such that the GPU is adjacent to an RDMA capable NIC will use direct RDMA as its transport. When RDMA is available but not GDR, a temporary tensor copy on host memory will be used as RDMA source/destination (and copied from/to the target device). When there is no RDMA device present, it can even fallback to the original gRPC runtime. While it is theoretically possible to mix GDR enabled TF with non-GDR deployments in the same job, make sure the environment is properly setup so the GDR mode is enabled whenever possible (i.e. do not fall back to gRPC when it is not absolutely necessary). -In the original design (as in the reference), tensor buffers are only registered to NIC when we could determine that the tensor will be either a source of Send or a sink of Recv across physical machine boundary. However, to implement the precise allocations, we need to change all the devices to possibly return a NIC compatible allocator. As GDR is currently in contrib, we would like to avoid the unnecessary code disruption to the TF core, so we allocate all tensors from NIC-registered buffers using a BFC allocator. This behaviour is similar to the effect of enabling the extra GPU option `force_gpu_compatible`, which allocate all host tensors in GPU-registered buffers no matter they will be transferred from/to GPUs or not. +In the original design (as in the reference), tensor buffers are only registered +to NIC when we could determine that the tensor will be either a source of Send +or a sink of Recv across physical machine boundary. However, to implement the +precise allocations, we need to change all the devices to possibly return a NIC +compatible allocator. As GDR is currently in contrib, we would like to avoid the +unnecessary code disruption to the TF core, so we allocate all tensors from +NIC-registered buffers using a BFC allocator. This behavior is similar to the +effect of enabling the extra GPU option `force_gpu_compatible`, which allocate +all host tensors in GPU-registered buffers no matter they will be transferred +from/to GPUs or not. Reference === diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc new file mode 100644 index 0000000000000000000000000000000000000000..755cbdff31cd7ca31579e0d64399d681dc24ad81 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +class WorkerCacheInterface; + +namespace { + +class RecvBufCall : public CancellableCall { + public: + RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, + const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const DeviceLocality& server_locality, + CancellationManager* cancel_mgr, WorkerCacheInterface* wc) + : CancellableCall(cancel_mgr, peer_task, wc) { + req_.set_step_id(step_id); + req_.set_buf_rendezvous_key(key); + *req_.mutable_client_locality() = client_locality; + *req_.mutable_server_locality() = server_locality; + req_.set_num_bytes(to_tensor->TotalBytes()); + req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); + req_.set_src_device(peer_device); + req_.set_dst_device(to_device->name()); + req_.set_request_id(GetUniqueRequestId()); + } + + ~RecvBufCall() override {} + + void IssueCall(const StatusCallback& done) override { + wi_->RecvBufAsync(&opts_, &req_, &resp_, done); + } + + RecvBufRequest req_; + RecvBufResponse resp_; +}; + +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + WorkerCacheInterface* worker_cache, + int64 step_id, + RemoteMemoryManager* remote_memory_manager) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache), + remote_memory_manager_(remote_memory_manager) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + const StatusCallback& done) override { + if (peer_is_local) { + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, + done); + return; + } + + // State that needs to be threaded through a couple of async calls + // in order to make this function completely non-blocking. + struct State { + DeviceLocality server_locality; + std::unique_ptr call; + }; + State* state = new State; + + // Logic to be executed on the RecvBufAsync callback. + auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, + to_device_ctx, to_tensor, done](const Status& s) { + if (s.ok()) { + remote_memory_manager_->TensorFromTransportOptions( + to_tensor, state->call->resp_.transport_options(), to_device, + to_device_ctx, to_alloc_attr.on_host(), done); + } + if (!s.ok() && errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearTask(peer_task); + } + + delete state; + }; + + // Logic to execute once we have the device locality for the server-side + // device. + auto dev_locality_callback = [this, state, peer_device, peer_task, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, + recv_buf_callback](const Status& s) { + if (!s.ok()) { + recv_buf_callback(s); + } else { + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_locality, + &cancel_mgr_, worker_cache_)); + state->call->Start(recv_buf_callback); + } + }; + + dev_resolver_->GetLocalityAsync( + peer_device, peer_task, &state->server_locality, dev_locality_callback); + } + + void StartAbort(const Status& s) override { + CollectiveRemoteAccessLocal::StartAbort(s); + cancel_mgr_.StartCancel(); + } + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + CancellationManager cancel_mgr_; + RemoteMemoryManager* remote_memory_manager_; +}; + +} // namespace + +CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id, + remote_memory_manager_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..1417e51e82c31035f058e8e9b546e04fb0ad97b8 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class ConfigProto; +class DeviceMgr; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA. +class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr { + public: + GdrCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name, + RemoteMemoryManager* remote_memory_manager) + : RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, + task_name), + remote_memory_manager_(remote_memory_manager) {} + + ~GdrCollectiveExecutorMgr() override {} + + protected: + virtual CollectiveExecutor* Create(int64 step_id) override; + + private: + RemoteMemoryManager* remote_memory_manager_; // Not owned. +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index ce1875151597f926aeb6392e7fc8307312da123f..9b8e832fd96c898d11fe817dcf3472b92293270f 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -73,7 +73,10 @@ int TryToReadNumaNode(ibv_device* device) { std::ifstream ifs(filename.c_str()); string content; - CHECK(std::getline(ifs, content)); + const auto& ret = std::getline(ifs, content); + if (!ret) { + return port::kNUMANoAffinity; + } int32 value; if (strings::safe_strto32(content, &value)) { @@ -247,10 +250,9 @@ Status GdrMemoryManager::Init() { LOG(INFO) << "Instrumenting CPU allocator(s)"; for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, - alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, - free_visitor); + GPUProcessState::singleton()->AddGpuHostAllocVisitor(numa_idx, + alloc_visitor); + GPUProcessState::singleton()->AddGpuHostFreeVisitor(numa_idx, free_visitor); } if (IsGDRAvailable()) { diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 5f8c300155770ed03ad12a9fa5ac74456edaf024..1124dff741309d8fd04954e70c5ebaaf164b940a 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -167,8 +167,11 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // RendezvousMgr already aborted, shouldn't send RPC call any more if (!call->status().ok()) { - done(call->status(), Args(), Args(), Tensor(), false); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. session()->worker_cache->ReleaseWorker(src_worker, rwi); + done(call->status(), Args(), Args(), Tensor(), false); delete call; return; } @@ -181,8 +184,11 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // If StartAbort was called prior to DeregisterCall, then the // current status should be bad. Status s = call->status(); - done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. session()->worker_cache->ReleaseWorker(src_worker, rwi); + done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); delete call; Unref(); }); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index dc0d5d548b80d36409778ef34e63171441f10142..c39cc0f9bcecc26aedfaf9707113210acf670244 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_server_lib.h" #include "grpc/support/alloc.h" +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" #include "tensorflow/contrib/gdr/gdr_memory_manager.h" #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h" #include "tensorflow/contrib/gdr/gdr_worker.h" - -#include "grpc/support/alloc.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" namespace tensorflow { @@ -57,10 +59,34 @@ Status GdrServer::Init() { return std::unique_ptr( new GdrWorker(env, config, remote_memory_manager_.get())); }; - + CollectiveMgrCreationFunction collective_mgr_func = + [this](const ConfigProto& config, const WorkerEnv* env, + WorkerCacheInterface* worker_cache) { + string unused; + string default_worker_name; + DeviceNameUtils::SplitDeviceName( + env->device_mgr->ListDevices()[0]->name(), &default_worker_name, + &unused); + + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(env->device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed( + config, env->device_mgr, dev_resolver.get(), worker_cache, + default_worker_name)); + return new GdrCollectiveExecutorMgr( + config, env->device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name, + remote_memory_manager_.get()); + }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); - return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func); + GrpcServerOptions opts; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + opts.collective_mgr_func = collective_mgr_func; + opts.worker_func = worker_func; + return GrpcServer::Init(opts); } Status GdrServer::Start() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 016e5ea27b397830c69b6e1761b5994ebcfa9c3d..1204b8ca501a8f99ea6abd6c047ab2d91350bae1 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_worker.h" +#include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -40,13 +42,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), - recv_tensor_recent_request_ids_(100000) {} + recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - Status s = recv_tensor_recent_request_ids_.TrackUnique( + Status s = recent_request_ids_.TrackUnique( request->request_id(), "RecvTensor (GdrWorker)", *request); if (!s.ok()) { done(s); @@ -145,4 +147,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // This is an RDMA enabled implementation augmenting grpc. + Status s = recent_request_ids_.TrackUnique(request->request_id(), + "RecvBuf (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + CollectiveExecutor::Handle ce_handle( + env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); + CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); + rma->buf_rendezvous()->ConsumeBuf( + request->buf_rendezvous_key(), + [this, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { + Status s = status; + if (s.ok()) { + if (!DMAHelper::CanUseDMA(hook->prod_value)) { + s = errors::Internal("Tensor value for key ", + request->buf_rendezvous_key(), + " is not of a type supported by RecvBuf"); + } + } + if (s.ok()) { + remote_memory_manager_->TransportOptionsFromTensor( + response->mutable_transport_options(), *hook->prod_value, + hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(), + [this, response, done, hook](const Status& s) { + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + }); + } + }); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 39f11e6bde5a1ca7ae91ead02279d22d70af027b..9a85cfd4263ad86f6579eedce95969c2829ff62c 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker { ::grpc::ByteBuffer* response, StatusCallback done) override; + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, + StatusCallback done) override; + private: RemoteMemoryManager* remote_memory_manager_; // Not owned - RecentRequestIds recv_tensor_recent_request_ids_; + RecentRequestIds recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index e79ccd8da1f8952758ae322d3a92dec34910a9db..5b37239665d46db38fc249e9004d2200abb3d610 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -22,7 +22,6 @@ from __future__ import print_function from copy import deepcopy from functools import partial from six import iteritems -from six import iterkeys from six import string_types from six import StringIO from tensorflow.contrib.graph_editor import reroute @@ -735,9 +734,8 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", # control dependencies. graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor)) control_ios = util.ControlOutputs(graph) - ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)), - flatten_target_ts, - control_ios=control_ios) + ops = select.get_walks_intersection_ops( + list(replacement_ts), flatten_target_ts, control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 0081fb61770075a2c36e92f65e01126f657edeb4..92016e6a83975a9b15a39a15125e0eabc111912e 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -16,9 +16,31 @@ tf_cc_binary( srcs = ["hvx_ops_support_checker_main.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:candidate_sampling_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:io_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:logging_ops_op_lib", + "//tensorflow/core:lookup_ops_op_lib", + "//tensorflow/core:manip_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:parsing_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", + "//tensorflow/core:string_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", + "//tensorflow/core:user_ops_op_lib", "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/tools/graph_transforms:file_utils", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index 7b30e1776ffc4e8dc552bdadcb3d7017ff77bf57..c1f6cac4942436d32f9867d4b5557c6b9e376c69 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -98,6 +98,7 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) >>> @@ -117,7 +118,15 @@ Using this ability we can calculate gradients on the nodes the data is stored on Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition. -Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset. +Ignite Dataset allows using these two aspects of distributed neural network +training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a +computation graph operation that can be performed on a remote worker. The remote +worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) +by setting correspondent environment variables for worker process (such as +`IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using +this overriding approach, we can assign a specific partition to every worker so +that one worker handles one partition and, at the same time, transparently work +with single dataset. ```python >>> import tensorflow as tf @@ -150,23 +159,31 @@ system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in addition to its own APIs, IGFS implements Hadoop FileSystem API and can be transparently plugged into Hadoop or Spark deployments. This contrib package -contains an integration between IGFS and TensorFlow. The integration is based -on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) -from TensorFlow side and +contains an integration between IGFS and TensorFlow. The integration is based on +[custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) from +TensorFlow side and [IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache -Ignite side. It has numerous uses, for example: * Checkpoints of state can be -saved to IGFS for reliability and fault-tolerance. * Training processes -communicate with TensorBoard by writing event files to a directory, which -TensorBoard watches. IGFS allows this communication to work even when -TensorBoard runs in a different process or machine. +Ignite side. It has numerous uses, for example: + +* Checkpoints of state can be saved to IGFS for reliability and + fault-tolerance. +* Training processes communicate with TensorBoard by writing event files to a + directory, which TensorBoard watches. IGFS allows this communication to work + even when TensorBoard runs in a different process or machine. ### SSL Connection -Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. +Apache Ignite allows to protect data transfer channels by +[SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and +authentication. Ignite Dataset supports both SSL connection with and without +authentication. For more information, please refer to the +[Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) +documentation. ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", @@ -187,7 +204,7 @@ Following examples will help you to easily start working with this module. The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded -[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interrupt with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: @@ -198,13 +215,13 @@ docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist After that you will be able to work with it following way: -![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist-2.png "Ignite Dataset Mnist") ### IGFS The simplest way to try IGFS with TensorFlow is to run [Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS -and then interruct with it using TensorFlow +and then interrupt with it using TensorFlow [tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container is available on Docker Hub: [dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 66e654ca636a5a051c6f9cd35bf9001dfbcbf7f4..3ffceef8070e0fc3b3cebae2522f89fe98ce4413 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -735,8 +735,6 @@ class IgniteDataset(dataset_ops.DatasetSource): cert_password: Password to be used if the private key is encrypted and a password is necessary. """ - super(IgniteDataset, self).__init__() - with IgniteClient(host, port, username, password, certfile, keyfile, cert_password) as client: client.handshake() @@ -760,6 +758,8 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), self.cache_type.to_output_classes()) + super(IgniteDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, self.local, self.part, self.page_size, diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ff5d4c458c859fd8e5e3ae65ee41a454d55d6538..89b74fbfdc38c9f42795d5c778889210baf6387f 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -19,9 +19,9 @@ from __future__ import print_function import os +from tensorflow import compat from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -66,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset_ops.make_one_shot_iterator(dataset) + it = compat.v1.data.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc index bbb3a3b18fd7bfdc68e8b8532568985245154794..f97e790b56c511ffb7859b4120b7a4220b75c506 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -55,9 +55,10 @@ void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, &tranformation_matrix)); // TODO(huangyp): It takes about 3.5 us to compute tranformation_matrix // with one thread. Improve its performance if necessary. - internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( - delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), - tranformation_matrix.flat().size()); + TF_CHECK_OK(CudaLaunchKernel(internal::compute_tranformation_matrix_cuda, 1, + 1, 0, cu_stream, delta_h, scale_s, scale_v, + tranformation_matrix.flat().data(), + tranformation_matrix.flat().size())); // Call cuBlas C = A * B directly. auto no_transpose = se::blas::Transpose::kNoTranspose; auto a_ptr = diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index b25a6f7b5742917a032946fe03a0dab20e7dc1ad..05ba9155c401b538a130958504ee919574480d75 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -518,7 +518,7 @@ def connected_components(images): def has_zero(): # Insert a zero in the consecutive ids where zero appears in unique_ids. # id_is_zero has length 1. - zero_id_ind = math_ops.to_int32(id_is_zero[0]) + zero_id_ind = math_ops.cast(id_is_zero[0], dtypes.int32) ids_before = nonzero_consecutive_ids[:zero_id_ind] ids_after = nonzero_consecutive_ids[zero_id_ind:] return array_ops.concat([ids_before, [0], ids_after], axis=0) diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc deleted file mode 100644 index 8cdf16103bab2b22d51c144d21a589e1e39f2f0b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kafka/ops/kafka_ops.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("KafkaDataset") - .Input("topics: string") - .Input("servers: string") - .Input("group: string") - .Input("eof: bool") - .Input("timeout: int64") - .Output("handle: variant") - .SetIsStateful() - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that emits the messages of one or more Kafka topics. - -topics: A `tf.string` tensor containing one or more subscriptions, - in the format of [topic:partition:offset:length], - by default length is -1 for unlimited. -servers: A list of bootstrap servers. -group: The consumer group id. -eof: If True, the kafka reader will stop on EOF. -timeout: The timeout value for the Kafka Consumer to wait - (in millisecond). -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py index 08ebcdb544645d3585a1af25c86c6182a1589dcb..3651275f935b50ac9d21bb831fd257eb22a6b793 100644 --- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -49,7 +50,8 @@ class KafkaDatasetTest(test.TestCase): topics, group="test", eof=True).repeat(num_epochs) batch_dataset = repeat_dataset.batch(batch_size) - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure( + dataset_ops.get_legacy_output_types(batch_dataset)) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh old mode 100644 new mode 100755 diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index b399e1b6c2ac47db205b5d8bbc81875ef5c08a31..5591c3b0cc8c8bf196bb4821c018cbf155cba4ce 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -52,7 +52,6 @@ class KafkaDataset(dataset_ops.DatasetSource): timeout: The timeout value for the Kafka Consumer to wait (in millisecond). """ - super(KafkaDataset, self).__init__() self._topics = ops.convert_to_tensor( topics, dtype=dtypes.string, name="topics") self._servers = ops.convert_to_tensor( @@ -63,6 +62,8 @@ class KafkaDataset(dataset_ops.DatasetSource): self._timeout = ops.convert_to_tensor( timeout, dtype=dtypes.int64, name="timeout") + super(KafkaDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.kafka_dataset(self._topics, self._servers, self._group, self._eof, self._timeout) diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index c4476a7bbd5056fa898468a46031bf3d8b1e44cf..b12832d2e2a3cccb4948d9e3bf3d226030121ac2 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -22,7 +22,7 @@ from __future__ import print_function from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy from tensorflow.python.keras.losses import categorical_hinge -from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import cosine_similarity from tensorflow.python.keras.losses import hinge from tensorflow.python.keras.losses import kullback_leibler_divergence from tensorflow.python.keras.losses import logcosh diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 7317fdb52c5b79e787a49d71be49f5261d6b1fff..095b5d798df9ac9038fa1088cdd402dff304e87e 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -23,7 +23,7 @@ from tensorflow.python.keras.metrics import binary_accuracy from tensorflow.python.keras.metrics import binary_crossentropy from tensorflow.python.keras.metrics import categorical_accuracy from tensorflow.python.keras.metrics import categorical_crossentropy -from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import cosine_similarity from tensorflow.python.keras.metrics import hinge from tensorflow.python.keras.metrics import kullback_leibler_divergence from tensorflow.python.keras.metrics import mean_absolute_error diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 294a7d69a704b3c06ab9e30489af116929ab6c2a..0d43bc2101bc3d189bb48f35b0a801f4e77030bb 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -80,7 +80,7 @@ def sparse_multiclass_hinge_loss( ' {}'.format(logits_rank)) logits_shape = array_ops.shape(logits) batch_size, num_classes = logits_shape[0], logits_shape[1] - logits = math_ops.to_float(logits) + logits = math_ops.cast(logits, dtypes.float32) # Check labels have valid type. if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64: diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py index bf89922318b9b9a569e4bd1d71fe6283810cadda..af7018f8368116172511b3f78c42caf3fc215632 100644 --- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py +++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py @@ -29,6 +29,7 @@ from __future__ import print_function import boto3 from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -59,7 +60,8 @@ class KinesisDatasetTest(test.TestCase): stream, read_indefinitely=False).repeat(num_epochs) batch_dataset = repeat_dataset.batch(batch_size) - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure( + dataset_ops.get_legacy_output_types(batch_dataset)) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() @@ -102,7 +104,8 @@ class KinesisDatasetTest(test.TestCase): stream, shard, read_indefinitely=False).repeat(num_epochs) batch_dataset = repeat_dataset.batch(batch_size) - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure( + dataset_ops.get_legacy_output_types(batch_dataset)) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 588f15b867c1fedbadd5a5d945d870a356549468..7e19ae7c13df421ec5bb9cb0e07dff0d00fb9548 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -155,7 +155,7 @@ py_library( ":core", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:functional_ops", + "//tensorflow/python:map_fn", "//tensorflow/python:math_ops", "//tensorflow/python:numerics", "//tensorflow/python:random_ops", diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py index 0c6bba758b429a8c4112bc6abb2fae542b5dfc14..8ee554ffa7ab6bbcc2d36c525ad68e03bacb594b 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -321,8 +321,8 @@ class LabeledTensor(object): for (d, axis) in zip(shape, unvalidated_axes.values()): if d != axis.size: raise ValueError( - 'Provided axis size %d does not match tensor dimension size %d' % - (axis.size, d)) + 'Provided axis size %d does not match tensor dimension size %d' + 'in tensor %r' % (axis.size, d, tensor)) self._axes = unvalidated_axes diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 2ede5daee74223e812cc29e9708b1989b698fb4e..a65f045cc886f4d4f351423858d92412baa3a622 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn as map_fn_lib from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics from tensorflow.python.ops import random_ops @@ -629,7 +630,7 @@ def map_fn(fn, labeled_tensor, name=None): # TODO(ericmc): Fix this upstream. if labeled_tensor.dtype == dtypes.string: - # We must construct the full graph here, because functional_ops.map_fn + # We must construct the full graph here, because map_fn_lib.map_fn # doesn't work for string-valued tensors. # Constructing the full graph may be slow. map_lts = [fn(t) for t in unpack_lts] @@ -652,7 +653,7 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn( + map_op = map_fn_lib.map_fn( tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 9ca6f8df5dbe3c236c4cd85095176ce69ad9deaa..69d5496f8aebb9b89c5d79f80a1a439f556093d7 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -81,6 +81,7 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 00d819ed0e9fe3a5644105a571beda100204631e..f52aaaf7b7983b4ee0f779a8ed8e163781643222 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -840,7 +840,7 @@ class _WeightedSparseColumn( # The weight tensor can be a regular Tensor. In such case, sparsify it. weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor) if not self.dtype.is_floating: - weight_tensor = math_ops.to_float(weight_tensor) + weight_tensor = math_ops.cast(weight_tensor, dtypes.float32) return tuple([id_tensor, weight_tensor]) def insert_transformed_feature(self, columns_to_tensors): @@ -1731,7 +1731,7 @@ class _RealValuedVarLenColumn(_FeatureColumn, collections.namedtuple( """ # Transform the input tensor according to the normalizer function. input_tensor = self._normalized_input_tensor(columns_to_tensors[self.name]) - columns_to_tensors[self] = math_ops.to_float(input_tensor) + columns_to_tensors[self] = math_ops.cast(input_tensor, dtypes.float32) # pylint: disable=unused-argument def _to_dnn_input_layer(self, @@ -1871,7 +1871,7 @@ class _RealValuedColumn( """ # Transform the input tensor according to the normalizer function. input_tensor = self._normalized_input_tensor(columns_to_tensors[self.name]) - columns_to_tensors[self] = math_ops.to_float(input_tensor) + columns_to_tensors[self] = math_ops.cast(input_tensor, dtypes.float32) # pylint: disable=unused-argument def _to_dnn_input_layer(self, @@ -1881,7 +1881,7 @@ class _RealValuedColumn( output_rank=2): input_tensor = self._to_dense_tensor(input_tensor) if input_tensor.dtype != dtypes.float32: - input_tensor = math_ops.to_float(input_tensor) + input_tensor = math_ops.cast(input_tensor, dtypes.float32) return _reshape_real_valued_tensor(input_tensor, output_rank, self.name) def _to_dense_tensor(self, input_tensor): @@ -1897,8 +1897,8 @@ class _RealValuedColumn( return inputs.get(self) def _transform_feature(self, inputs): - return math_ops.to_float( - self._normalized_input_tensor(inputs.get(self.name))) + return math_ops.cast( + self._normalized_input_tensor(inputs.get(self.name)), dtypes.float32) @property def _parse_example_spec(self): @@ -2104,7 +2104,7 @@ class _BucketizedColumn( raise ValueError("BucketizedColumn currently only supports output_rank=2") return array_ops.reshape( array_ops.one_hot( - math_ops.to_int64(input_tensor), + math_ops.cast(input_tensor, dtypes.int64), self.length, 1., 0., @@ -2136,8 +2136,10 @@ class _BucketizedColumn( i2 = array_ops.zeros([batch_size], dtype=dtypes.int32, name="zeros") bucket_indices = array_ops.reshape(input_tensor, [-1], name="reshape") - indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2)))) - shape = math_ops.to_int64(array_ops.stack([batch_size, dimension])) + indices = math_ops.cast(array_ops.transpose(array_ops.stack((i1, i2))), + dtypes.int64) + shape = math_ops.cast(array_ops.stack([batch_size, dimension]), + dtypes.int64) sparse_id_values = sparse_tensor_py.SparseTensor( indices, bucket_indices, shape) @@ -2527,7 +2529,7 @@ class DataFrameColumn(_FeatureColumn, trainable=True, output_rank=2): if input_tensor.dtype != dtypes.float32: - input_tensor = math_ops.to_float(input_tensor) + input_tensor = math_ops.cast(input_tensor, dtypes.float32) return _reshape_real_valued_tensor(input_tensor, output_rank, self.name) def _to_dense_tensor(self, input_tensor): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 7e6eafaa0d6f60cfc28a4c422abac0b6d5a991fb..00e41026d0038409ace178e6affd2c1cdc812122 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -1757,7 +1757,7 @@ class WeightedSumTest(test.TestCase): logits_core = fc_core.linear_model(features, [movies]) with self.cached_session() as sess: - variables_lib.initialize_all_variables().run() + variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 403b522ce45ac6ad98a321378626b87aaa7738aa..1d959b3c78445977b4fe74ee6c20c86aaf7f86da 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2308,11 +2308,13 @@ def layer_norm(inputs, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) - # Calculate the moments on the last axis (layer activations). + # By default, compute the moments across all the dimensions except the one with index 0. norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. - variance_epsilon = 1e-12 + # Note that epsilon must be increased for float16 due to the limited + # representable range. + variance_epsilon = 1e-12 if dtype != dtypes.float16 else 1e-3 outputs = nn.batch_normalization( inputs, mean, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 1c0088186c030437454c0f764decab9e5a276adc..90fd55cf3898586dec3313c238df8b3952b8b349 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1399,9 +1399,10 @@ class DropoutTest(test.TestCase): with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') - num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) + num_elem_initial = math_ops.reduce_mean( + math_ops.cast(images > 0, dtypes.float32)) output = _layers.dropout(images) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) @@ -1421,9 +1422,10 @@ class DropoutTest(test.TestCase): with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') - num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) + num_elem_initial = math_ops.reduce_mean( + math_ops.cast(images > 0, dtypes.float32)) output = _layers.dropout(images, is_training=False) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertEqual(num_elem, num_elem_initial) outputs, inputs = sess.run([output, images]) @@ -1435,9 +1437,10 @@ class DropoutTest(test.TestCase): images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.fully_connected(images, 50) - num_elem_initial = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem_initial = math_ops.reduce_mean( + math_ops.cast(output > 0, dtypes.float32)) output = _layers.dropout(output) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) @@ -1450,7 +1453,7 @@ class DropoutTest(test.TestCase): (5, height, width, 3), seed=1, name='images') output = _layers.fully_connected( images, 50, normalizer_fn=_layers.dropout) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) sess.run(variables_lib.global_variables_initializer()) num_elem = sess.run(num_elem) self.assertLess(num_elem, 0.5) @@ -2869,10 +2872,19 @@ class LayerNormTest(test.TestCase): tol=1e-5, begin_norm_axis=1, dtype=dtypes.float64): + eps = 1e-12 if dtype != dtypes.float16 else 1e-3 expected_mean = np.zeros(input_shape[:begin_norm_axis]) - expected_var = np.ones(input_shape[:begin_norm_axis]) - for mu in [0.0, 1e2]: - for sigma in [1.0, 0.1]: + expected_var_uncorrected = np.ones(input_shape[:begin_norm_axis]) + sigma_list = [1.0, 0.1] + if dtype == dtypes.float16: + # This causes the variance to underflow in float16, and requires that + # variance_epsilon be set appropriately to avoid NaNs in the output. + sigma_list.append(1e-4) + # Note that the mean:variance ratio must be limited to the representable + # range for float16. + for mu in [0.0, 1e2 if dtype != dtypes.float16 else 1e1]: + for sigma in sigma_list: + expected_var = expected_var_uncorrected / (1.0 + eps / sigma**2) input_values = np.random.randn(*input_shape) * sigma + mu with ops.Graph().as_default() as g: with self.session(graph=g) as sess: @@ -2893,10 +2905,13 @@ class LayerNormTest(test.TestCase): outputs, beta, gamma = sess.run((output_t, beta_var, gamma_var)) # Make sure that there are no NaNs self.assertFalse(np.isnan(outputs).any()) + if outputs.dtype != np.float64: + # Cast to float64 before computing mean/variance to avoid + # overflow and precision issues. + outputs = outputs.astype(np.float64) mean = np.mean(outputs, axis=moments_axis) var = np.var(outputs, axis=moments_axis) # Layer-norm implemented in numpy - eps = 1e-12 expected_out = ( (gamma * (input_values - np.mean( input_values, axis=moments_axis, keepdims=True)) / @@ -2933,6 +2948,12 @@ class LayerNormTest(test.TestCase): def testOutputBigInput(self): self.doOutputTest((1, 100, 100, 1)) + def testOutputBigInputFloat32(self): + self.doOutputTest((1, 100, 1000, 1), tol=1e-4, dtype=dtypes.float32) + + def testOutputBigInputFloat16(self): + self.doOutputTest((1, 100, 1000, 1), tol=5e-2, dtype=dtypes.float16) + class GDNTest(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 2fdcd849b026d52ed4aff724838f6c71e3a315d0..1ccc8f012f90e3240f1156ff1970321c7c4510f0 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -21,6 +21,7 @@ from __future__ import print_function import six from tensorflow.contrib import framework as contrib_framework +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -109,11 +110,12 @@ def optimize_loss(loss, gradient_multipliers: dict of variables or variable names to floats. If present, gradients for specified variables will be multiplied by given constant. - clip_gradients: float, callable or `None`. If float, is provided, a global - clipping is applied to prevent the norm of the gradient to exceed this - value. Alternatively, a callable can be provided e.g.: adaptive_clipping. - This callable takes a `list` of `(gradients, variables)` `tuple`s and - returns the same thing with the gradients modified. + clip_gradients: float, callable or `None`. If a float is provided, a global + clipping is applied to prevent the norm of the gradient from exceeding + this value. Alternatively, a callable can be provided, e.g., + `adaptive_clipping_fn()`. This callable takes a list of + `(gradients, variables)` tuples and returns the same thing with the + gradients modified. learning_rate_decay_fn: function, takes `learning_rate` and `global_step` `Tensor`s, returns `Tensor`. Can be used to implement any learning rate decay @@ -324,7 +326,7 @@ def _adaptive_max_norm(norm, std_factor, decay, global_step, epsilon, name): # quicker adaptation at the beginning if global_step is not None: - n = math_ops.to_float(global_step) + n = math_ops.cast(global_step, dtypes.float32) decay = math_ops.minimum(decay, n / (n + 1.)) # update averages diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 8a6b4f68a8b33d497ddb16614a7e3cdf32f2c422..131b1e0dba28f2498cd11254dad1d5790f5b7c04 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -23,6 +23,7 @@ import six from tensorflow.contrib.framework import deprecated from tensorflow.contrib.losses.python.losses import loss_ops from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -185,7 +186,8 @@ class _TargetColumn(object): return None else: return array_ops.reshape( - math_ops.to_float(features[self._weight_column_name]), shape=(-1,)) + math_ops.cast(features[self._weight_column_name], dtypes.float32), + shape=(-1,)) @property def problem_type(self): @@ -252,9 +254,10 @@ class _TargetColumn(object): if weight_tensor is None: return math_ops.reduce_mean(loss_unweighted, name="loss") loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor) - return math_ops.div(math_ops.reduce_sum(loss_weighted), - math_ops.to_float(math_ops.reduce_sum(weight_tensor)), - name="loss") + return math_ops.div( + math_ops.reduce_sum(loss_weighted), + math_ops.cast(math_ops.reduce_sum(weight_tensor), dtypes.float32), + name="loss") class _RegressionTargetColumn(_TargetColumn): @@ -323,7 +326,7 @@ class _MultiClassTargetColumn(_TargetColumn): metrics = {("accuracy", "classes"): metric_ops.streaming_accuracy} predictions = math_ops.sigmoid(logits) - labels_float = math_ops.to_float(labels) + labels_float = math_ops.cast(labels, dtypes.float32) default_metrics = self._default_eval_metrics() for metric_name, metric_op in default_metrics.items(): @@ -399,7 +402,8 @@ def _mean_squared_loss(logits, target): target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) - return math_ops.square(logits - math_ops.to_float(target)) + return math_ops.squared_difference(logits, + math_ops.cast(target, dtypes.float32)) def _log_loss_with_two_classes(logits, target): @@ -407,7 +411,7 @@ def _log_loss_with_two_classes(logits, target): if len(target.get_shape()) == 1: target = array_ops.expand_dims(target, axis=1) loss_vec = nn.sigmoid_cross_entropy_with_logits( - labels=math_ops.to_float(target), logits=logits) + labels=math_ops.cast(target, dtypes.float32), logits=logits) return loss_vec @@ -475,7 +479,7 @@ def get_default_binary_metrics_for_eval(thresholds): def _float_weights_or_none(weights): if weights is None: return None - return math_ops.to_float(weights) + return math_ops.cast(weights, dtypes.float32) def _labels_streaming_mean(unused_predictions, labels, weights=None): @@ -494,8 +498,8 @@ def _streaming_auc(predictions, labels, weights=None): def _accuracy_at_threshold(threshold): def _accuracy_metric(predictions, labels, weights=None): - threshold_predictions = math_ops.to_float( - math_ops.greater_equal(predictions, threshold)) + threshold_predictions = math_ops.cast( + math_ops.greater_equal(predictions, threshold), dtypes.float32) return metric_ops.streaming_accuracy( predictions=threshold_predictions, labels=labels, weights=weights) diff --git a/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py b/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py index 91684dc61e40efe3f2408c3d5f1f4eb2764bc558..934a7f06069cb85c74ded8237bcfefce0771aa02 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py +++ b/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py @@ -86,11 +86,11 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0, internal_type = dtypes.string for i in range(len(values)): if values[i].dtype != dtypes.string: - values[i] = math_ops.to_int64(values[i]) + values[i] = math_ops.cast(values[i], dtypes.int64) internal_type = dtypes.int64 for i in range(len(dense_inputs)): if dense_inputs[i].dtype != dtypes.string: - dense_inputs[i] = math_ops.to_int64(dense_inputs[i]) + dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64) internal_type = dtypes.int64 if hash_key: diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 14065fcee51c014a1af227504eaaca1fa39941e1..3f0a91ccdc91ed0d8b3e383cf167180fbeb5c8bf 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -357,9 +357,9 @@ py_test( py_test( name = "dnn_linear_combined_test", - size = "large", + size = "medium", srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], - shard_count = 4, + shard_count = 8, srcs_version = "PY2AND3", tags = ["no_oss"], # flaky b/70524820 deps = [ @@ -387,6 +387,13 @@ py_test( shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 + deps = [":head_test_lib"], +) + +py_library( + name = "head_test_lib", + srcs = ["python/learn/estimators/head_test.py"], + srcs_version = "PY2AND3", deps = [ ":learn", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md index b0bff915a993c9a01e2e6d9ef9f71c14d2f29a73..b2d3a6273abba7e3a893f30bbdd4f8b2662bd54a 100644 --- a/tensorflow/contrib/learn/README.md +++ b/tensorflow/contrib/learn/README.md @@ -111,18 +111,17 @@ Some arguments are renamed, please refer to documentation. In addition: Switch to `tf.estimator.train_and_evaluate`. Some differences: -* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, - should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. -* Remove the `experiment_fn`. Instead, create the `Estimator`, - `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` - directly. -* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement - for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the - replacement for `tf.contrib.learn.make_export_strategy`. If you want to export - only at the end of training use `tf.estimator.FinalExporter`. -* If the `TF_CONFIG` environment variable is constructed manually, please read - the `train_and_evaluate` documentation for the new requirementds (in - particular, the chief node and evaluator node). +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, `train_spec` + and `eval_spec`, then call `tf.estimator.train_and_evaluate` directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement for + `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to + export only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirements (in + particular, the chief node and evaluator node). ## Others Classes and Functions diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py index b968aeed1b7a11d522b531783f04f0104b37904f..ab0ce6d581a9d65f91ace0f8453911e2ddf3e8b8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py @@ -474,7 +474,7 @@ class DebugClassifierTest(test.TestCase): def _my_metric_op(predictions, labels): # For the case of binary classification, the 2nd column of "predictions" # denotes the model predictions. - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) predictions = array_ops.strided_slice( predictions, [0, 1], [-1, 2], end_mask=1) labels = math_ops.cast(labels, predictions.dtype) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index d46a873bfaa297e7f6242aa56e9d0bf0eb551867..4f636ce69ddc7dbd3bb083e5eaead384b5875e8b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -807,7 +807,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase): def _my_metric_op(predictions, labels): # For the case of binary classification, the 2nd column of "predictions" # denotes the model predictions. - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) predictions = array_ops.strided_slice( predictions, [0, 1], [-1, 2], end_mask=1) return math_ops.reduce_sum(math_ops.multiply(predictions, labels)) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index ee25cebd484f1e831fe8b6d3aa7290da7558adee..d779495720b1c8fd4c3c7f4b631dab8b49e4f3bd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -815,7 +815,7 @@ class DNNClassifierTest(test.TestCase): def _my_metric_op(predictions, labels): # For the case of binary classification, the 2nd column of "predictions" # denotes the model predictions. - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) predictions = array_ops.strided_slice( predictions, [0, 1], [-1, 2], end_mask=1) labels = math_ops.cast(labels, predictions.dtype) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 28c4964527bb034c8c6b1642366c6c82c1a72201..7a96f6d3ea41ce7a672c3ba8b8a818500012945e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -37,8 +37,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell @@ -372,9 +372,10 @@ class DynamicRnnEstimatorTest(test.TestCase): labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( - math_ops.to_float( + math_ops.cast( array_ops.slice(random_sequence, [0, 1], - [batch_size, sequence_length])), 2) + [batch_size, sequence_length]), + dtypes.float32), 2) input_dict = { dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform( [batch_size, cell_size], seed=((i + 1) * seed)) @@ -430,9 +431,10 @@ class DynamicRnnEstimatorTest(test.TestCase): labels = array_ops.slice(sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( - math_ops.to_float( + math_ops.cast( array_ops.slice(sequence, [0, 1], [batch_size, sequence_length - ])), 2) + ]), + dtypes.float32), 2) input_dict = state_dict input_dict['inputs'] = inputs return input_dict, labels @@ -524,7 +526,7 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): def input_fn(): starts = random_ops.random_uniform( [batch_size], maxval=(2 * np.pi), seed=seed) - sin_curves = functional_ops.map_fn( + sin_curves = map_fn.map_fn( _sin_fn, (starts,), dtype=dtypes.float32) inputs = array_ops.expand_dims( array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]), @@ -587,9 +589,11 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( - math_ops.to_float( + math_ops.cast( array_ops.slice(random_sequence, [0, 1], - [batch_size, sequence_length])), 2) + [batch_size, sequence_length]), + dtypes.float32), + 2) return {'inputs': inputs}, labels return input_fn @@ -719,11 +723,13 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed) - inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2) - labels = math_ops.to_int32( + inputs = array_ops.expand_dims( + math_ops.cast(random_sequence, dtypes.float32), 2) + labels = math_ops.cast( array_ops.squeeze( math_ops.reduce_sum(inputs, axis=[1]) > ( - sequence_length / 2.0))) + sequence_length / 2.0)), + dtypes.int32) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 8a461a0bd7ba457fcf830769f23c6ca2860a2732..153d4867961ae5115157bb5a246b5819387d91cf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -220,7 +220,7 @@ def _build_estimator_for_export_tests(tmpdir): hashtable = lookup.HashTable( lookup.TextFileStringTableInitializer(vocab_file_name), 'x') features['bogus_lookup'] = hashtable.lookup( - math_ops.to_int64(features['feature'])) + math_ops.cast(features['feature'], dtypes.int64)) return input_fn_utils.InputFnOps(features, labels, inputs) @@ -1181,14 +1181,14 @@ class EstimatorTest(test.TestCase): ] self.assertItemsEqual([expected_vocab_file], assets) graph_ops = [x.name for x in graph.get_operations()] - self.assertTrue('input_example_tensor' in graph_ops) - self.assertTrue('ParseExample/ParseExample' in graph_ops) - self.assertTrue('linear/linear/feature/matmul' in graph_ops) + self.assertIn('input_example_tensor', graph_ops) + self.assertIn('ParseExample/ParseExample', graph_ops) + self.assertIn('linear/linear/feature/matmul', graph_ops) # Since there were no transforms, both save ops are still present. - self.assertTrue('save/SaveV2/tensor_names' in graph_ops) - self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops) + self.assertIn('save/SaveV2/tensor_names', graph_ops) + self.assertIn('save_1/SaveV2/tensor_names', graph_ops) # Since there were no transforms, the hash table lookup is still there. - self.assertTrue('hash_table_Lookup' in graph_ops) + self.assertIn('hash_table_Lookup/LookupTableFindV2', graph_ops) # Restore, to validate that the export was well-formed. # tag_2, tag_3 was subjected to strip_unused_nodes. diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c1b97d8b49613ea49d9813954da3b7a63d3ba04c..2458652f8eb966d1ff5578ac5231249c572579aa 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -567,7 +567,8 @@ def _mean_squared_loss(labels, logits, weights=None): if len(logits.get_shape()) == 1: logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) - loss = math_ops.square(logits - math_ops.to_float(labels), name=name) + loss = math_ops.squared_difference( + logits, math_ops.cast(labels, dtypes.float32), name=name) return _compute_weighted_loss(loss, weights) @@ -792,7 +793,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None): with ops.name_scope(None, "log_loss_with_two_classes", (logits, labels)) as name: logits = ops.convert_to_tensor(logits) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) # TODO(ptucker): This will break for dynamic shapes. # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: @@ -1213,8 +1214,8 @@ def _sparse_labels_to_indicator(labels, num_classes): if num_classes < 2: raise ValueError("Must set num_classes >= 2 when passing labels as a " "SparseTensor.") - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(labels, num_classes)) + return math_ops.cast( + sparse_ops.sparse_to_indicator(labels, num_classes), dtypes.int64) return labels @@ -1399,8 +1400,9 @@ class _MultiLabelHead(_SingleHead): math_ops.sigmoid( logits, name=prediction_key.PredictionKey.PROBABILITIES), prediction_key.PredictionKey.CLASSES: - math_ops.to_int64( + math_ops.cast( math_ops.greater(logits, 0), + dtypes.int64, name=prediction_key.PredictionKey.CLASSES) } @@ -1782,7 +1784,7 @@ def _weight_tensor(features, weight_column_name): raise ValueError("Weights {} missing from features.".format( weight_column_name)) with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))): - weight_tensor = math_ops.to_float(features[weight_column_name]) + weight_tensor = math_ops.cast(features[weight_column_name], dtypes.float32) shape = weight_tensor.get_shape() rank = shape.ndims # We don't bother with expanding dims of non-staticly shaped tensors or @@ -1832,7 +1834,7 @@ def _compute_weighted_loss(loss_unweighted, weight, name="loss"): weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope) weighted_loss_normalized = math_ops.div( math_ops.reduce_sum(weighted_loss), - math_ops.to_float(math_ops.reduce_sum(weight)), + math_ops.cast(math_ops.reduce_sum(weight), dtypes.float32), name="weighted_average_loss") return weighted_loss_mean, weighted_loss_normalized @@ -1951,7 +1953,7 @@ def _sigmoid_cross_entropy_loss(labels, logits, weights=None): (logits, labels)) as name: # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. loss = nn.sigmoid_cross_entropy_with_logits( - labels=math_ops.to_float(labels), logits=logits, name=name) + labels=math_ops.cast(labels, dtypes.float32), logits=logits, name=name) return _compute_weighted_loss(loss, weights) @@ -1959,11 +1961,11 @@ def _float_weights_or_none(weights): if weights is None: return None with ops.name_scope(None, "float_weights", (weights,)) as name: - return math_ops.to_float(weights, name=name) + return math_ops.cast(weights, dtypes.float32, name=name) def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) weights = _float_weights_or_none(weights) if weights is not None: weights = weights_broadcast_ops.broadcast_weights(weights, labels) @@ -1977,7 +1979,7 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): def _predictions_streaming_mean(predictions, weights=None, class_id=None): - predictions = math_ops.to_float(predictions) + predictions = math_ops.cast(predictions, dtypes.float32) weights = _float_weights_or_none(weights) if weights is not None: weights = weights_broadcast_ops.broadcast_weights(weights, predictions) @@ -2001,9 +2003,9 @@ def _class_predictions_streaming_mean(predictions, weights, class_id): return metrics_lib.mean( array_ops.where( math_ops.equal( - math_ops.to_int32(class_id), math_ops.to_int32(predictions)), - array_ops.ones_like(predictions), - array_ops.zeros_like(predictions)), + math_ops.cast(class_id, dtypes.int32), + math_ops.cast(predictions, dtypes.int32)), + array_ops.ones_like(predictions), array_ops.zeros_like(predictions)), weights=weights) @@ -2011,15 +2013,16 @@ def _class_labels_streaming_mean(labels, weights, class_id): return metrics_lib.mean( array_ops.where( math_ops.equal( - math_ops.to_int32(class_id), math_ops.to_int32(labels)), - array_ops.ones_like(labels), array_ops.zeros_like(labels)), + math_ops.cast(class_id, dtypes.int32), + math_ops.cast(labels, dtypes.int32)), array_ops.ones_like(labels), + array_ops.zeros_like(labels)), weights=weights) def _streaming_auc(predictions, labels, weights=None, class_id=None, curve="ROC"): # pylint: disable=missing-docstring - predictions = math_ops.to_float(predictions) + predictions = math_ops.cast(predictions, dtypes.float32) if labels.dtype.base_dtype != dtypes.bool: logging.warning("Casting %s labels to bool.", labels.dtype) labels = math_ops.cast(labels, dtypes.bool) @@ -2046,8 +2049,8 @@ def _assert_class_id(class_id, num_classes=None): def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): - threshold_predictions = math_ops.to_float( - math_ops.greater_equal(predictions, threshold)) + threshold_predictions = math_ops.cast( + math_ops.greater_equal(predictions, threshold), dtypes.float32) return metrics_lib.accuracy(labels, threshold_predictions, weights) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 7c2d9bb0767cb979dae9c84b5342d129225677ed..a52d25acf402bdda46771e9146a40cfb71e99d53 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -62,8 +62,8 @@ def _assert_no_variables(test_case): def _assert_metrics(test_case, expected_loss, expected_eval_metrics, model_fn_ops): test_case.assertAlmostEqual(expected_loss, model_fn_ops.loss.eval(), places=4) - for k in six.iterkeys(expected_eval_metrics): - test_case.assertIn(k, six.iterkeys(model_fn_ops.eval_metric_ops)) + for k in expected_eval_metrics: + test_case.assertIn(k, model_fn_ops.eval_metric_ops) variables.initialize_local_variables().run() for key, expected_value in six.iteritems(expected_eval_metrics): value_tensor, update_tensor = model_fn_ops.eval_metric_ops[key] @@ -545,19 +545,19 @@ class MultiLabelHeadTest(test.TestCase): with session.Session(): self.assertListEqual( [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0]) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.731059, 0.5, 0.5], [0.5, 0.5, 0.731059,]], @@ -850,18 +850,18 @@ class BinaryClassificationHeadTest(test.TestCase): with session.Session(): self.assertListEqual( [1, 1], list(model_fn_ops.predictions["classes"].eval())) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.LOGISTIC_REGRESSION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) predicted_classes = predictions_for_serving["classes"].eval().tolist() self.assertListEqual( [b"0", b"1"], predicted_classes[0]) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) def testBinaryClassificationInferMode_withWeightColumn(self): n_classes = 2 @@ -1349,18 +1349,18 @@ class MultiClassHeadTest(test.TestCase): self.assertAllEqual( [0, 2], model_fn_ops.predictions["classes"].eval()) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.576117, 0.2119416, 0.2119416], [0.2119416, 0.2119416, 0.576117]], @@ -1401,18 +1401,18 @@ class MultiClassHeadTest(test.TestCase): self.assertAllEqual( [b"key0", b"key2"], model_fn_ops.predictions["classes"].eval()) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.576117, 0.2119416, 0.2119416], [0.2119416, 0.2119416, 0.576117]], diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py index 3cbcc6e98de1c915c302617e4591c9baa33adeaf..8981432f7f27f4c773824efed48d2b3320c1b340 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py @@ -31,6 +31,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import metric_key from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.ops import math_ops @@ -160,8 +161,9 @@ def _make_logistic_eval_metric_ops(labels, predictions, thresholds): labels=labels_tensor, predictions=predictions) for threshold in thresholds: - predictions_at_threshold = math_ops.to_float( + predictions_at_threshold = math_ops.cast( math_ops.greater_equal(predictions, threshold), + dtypes.float32, name='predictions_at_threshold_%f' % threshold) metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = ( metrics_lib.streaming_accuracy(labels=labels_tensor, diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index dcb161180c99ce71195c820217e8bdaf79d70901..96adc8b83b5bec912460dbb54899ce5f168b8f25 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -219,7 +219,7 @@ class ModelFnOps( used if a Servo request does not explicitly mention which head to infer on. Pass the key of the output alternative here that you want to designate as default. A separate ExportOutpout for this default head - wil be added to the export_outputs dict with the special key + will be added to the export_outputs dict with the special key signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is already an enry in output_alternatives with this special key. diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index 06c61554fa2fa9b563652e7555fbe436ee102638..0689be88c5efa3f85ad981bb630f2f897e8663ff 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -396,8 +396,9 @@ class StateSavingRnnEstimatorTest(test.TestCase): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) - inputs = math_ops.to_float( - array_ops.slice(random_sequence, [1], [sequence_length])) + inputs = math_ops.cast( + array_ops.slice(random_sequence, [1], [sequence_length]), + dtypes.float32) features = {'inputs': inputs} if mode == model_fn_lib.ModeKeys.INFER: @@ -450,8 +451,9 @@ class LegacyConstructorTest(test.TestCase): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) - inputs = math_ops.to_float( - array_ops.slice(random_sequence, [1], [sequence_length])) + inputs = math_ops.cast( + array_ops.slice(random_sequence, [1], [sequence_length]), + dtypes.float32) return {'inputs': inputs}, labels return input_fn @@ -537,8 +539,9 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) - inputs = math_ops.to_float( - array_ops.slice(random_sequence, [1], [sequence_length])) + inputs = math_ops.cast( + array_ops.slice(random_sequence, [1], [sequence_length]), + dtypes.float32) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index c056a12fa5307a7e9ac4cf30e1386ddfd5cd7d75..950840c6b77d771a5a051870d7986b00de3e2902 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -624,7 +624,7 @@ class SdcaModel(object): # Note that we need double precision to get accurate results. with ops.control_dependencies(shard_sums): shard_sums.append( - math_ops.reduce_sum(math_ops.to_double(values), 0)) + math_ops.reduce_sum(math_ops.cast(values, dtypes.float64), 0)) summed_values = math_ops.add_n(shard_sums) primal_loss = summed_values[1] diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a28394964a12013c43d85701b5a0ab5c559afd62..8fda828e994bc2436eaba4475077020436703631 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation -# TODO(rohanj): This should subclass Checkpointable and implement +# TODO(rohanj): This should subclass Trackable and implement # _gather_saveables_for_checkpoint. class ShardedMutableDenseHashTable(object): """A sharded version of MutableDenseHashTable. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 8ebe45d8510f4b78cded997916dd9d6b96d22579..58ab3aec6648ca61db996be458178e90d6c47353 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -135,7 +135,7 @@ class SDCAOptimizer(object): array_ops.reshape( array_ops.split( value=sparse_indices, num_or_size_splits=2, axis=1)[1], [-1]), - array_ops.reshape(math_ops.to_float(sparse_values), [-1])) + array_ops.reshape(math_ops.cast(sparse_values, dtypes.float32), [-1])) def _training_examples_and_variables(): """Returns dictionaries for training examples and variables.""" @@ -254,8 +254,8 @@ class SDCAOptimizer(object): examples = dict( sparse_features=sparse_feature_with_values, dense_features=dense_features, - example_labels=math_ops.to_float( - array_ops.reshape(targets, shape=[-1])), + example_labels=math_ops.cast( + array_ops.reshape(targets, shape=[-1]), dtypes.float32), example_weights=example_weights, example_ids=example_ids) sdca_variables = dict( diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD deleted file mode 100644 index 893ddd78231c8a0d819cbe5776e6873bdab57355..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/python/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -licenses(["notice"]) - -# DO NOT USE THIS TARGET. TensorFlow Lite has moved to tensorflow/lite. -py_library( - name = "lite", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/lite/python:lite", - ], -) diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 229a72a780d5ccce8263444ffeae7700f6ac8613..20e86e56bbe911eca2bba661aff7165e53fa159e 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,11 +26,11 @@ from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import from tensorflow.python.ops.lookup_ops import FastHashSpec from tensorflow.python.ops.lookup_ops import HasherSpec -from tensorflow.python.ops.lookup_ops import HashTable from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets from tensorflow.python.ops.lookup_ops import index_table_from_file from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase +from tensorflow.python.ops.lookup_ops import InitializableLookupTableBaseV1 from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer from tensorflow.python.ops.lookup_ops import LookupInterface from tensorflow.python.ops.lookup_ops import StrongHashSpec @@ -42,7 +40,6 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex from tensorflow.python.ops.lookup_ops import TextFileInitializer from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer # pylint: enable=unused-import -from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.deprecation import deprecated @@ -288,353 +285,52 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): return table.lookup(tensor) -class MutableHashTable(LookupInterface): - """A generic mutable hash table implementation. - - Data can be inserted by calling the insert method and removed by calling the - remove method. It does not support initialization via the init method. +class HashTable(InitializableLookupTableBaseV1): + """A generic hash table implementation. Example usage: ```python - table = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, - value_dtype=tf.int64, - default_value=-1) - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) + table = tf.HashTable( + tf.KeyValueTensorInitializer(keys, values), -1) + out = table.lookup(input_tensor) + table.init.run() print(out.eval()) ``` """ - def __init__(self, - key_dtype, - value_dtype, - default_value, - shared_name=None, - name="MutableHashTable", - checkpoint=True): - """Creates an empty `MutableHashTable` object. + def __init__(self, initializer, default_value, shared_name=None, name=None): + """Creates a non-initialized `HashTable` object. - Creates a table, the type of its keys and values are specified by key_dtype - and value_dtype, respectively. + Creates a table, the type of its keys and values are specified by the + initializer. + Before using the table you will have to initialize it. After initialization + the table will be immutable. Args: - key_dtype: the type of the key tensors. - value_dtype: the type of the value tensors. + initializer: The table initializer to use. See `HashTable` kernel for + supported key and value types. default_value: The value to use if a key is missing in the table. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. + shared_name: If non-empty, this table will be shared under the given name + across multiple sessions. name: A name for the operation (optional). - checkpoint: if True, the contents of the table are saved to and restored - from checkpoints. If `shared_name` is empty for a checkpointed table, it - is shared using the table node name. Returns: - A `MutableHashTable` object. - - Raises: - ValueError: If checkpoint is True and no name was specified. + A `HashTable` object. """ - self._default_value = ops.convert_to_tensor(default_value, - dtype=value_dtype) - self._value_shape = self._default_value.get_shape() - self._checkpoint = checkpoint - self._key_dtype = key_dtype - self._value_dtype = value_dtype - self._name = name - - if context.executing_eagerly() and shared_name is None: - # TODO(allenl): This will leak memory due to kernel caching by the - # shared_name attribute value (but is better than the alternative of - # sharing everything by default when executing eagerly; hopefully creating - # tables in a loop is uncommon). - shared_name = "table_%d" % (ops.uid(),) + self._initializer = initializer + self._default_value = default_value self._shared_name = shared_name - super(MutableHashTable, self).__init__(key_dtype, value_dtype) - - self._resource_handle = self.create_resource() - if checkpoint: - saveable = MutableHashTable._Saveable(self, name) - if not context.executing_eagerly(): - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - - def create_resource(self): - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = self._checkpoint and self._shared_name is None - if self._default_value.get_shape().ndims == 0: - table_ref = gen_lookup_ops.mutable_hash_table_v2( - shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - name=self._name) - else: - table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( - shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - value_shape=self._default_value.get_shape(), - name=self._name) - - if context.executing_eagerly(): - self._table_name = None - else: - self._table_name = table_ref.op.name.split("/")[-1] - return table_ref - - @property - def name(self): - return self._table_name - - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - return gen_lookup_ops.lookup_table_size_v2( - self.resource_handle, name=name) - - def remove(self, keys, name=None): - """Removes `keys` and its associated values from the table. - - If a key is not present in the table, it is silently ignored. - - Args: - keys: Keys to remove. Can be a tensor of any shape. Must match the table's - key type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_lookup_table_remove" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_remove_v2( - self.resource_handle, keys, name=name) - - return op - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. Can be a tensor of any shape. Must match the - table's key_dtype. - name: A name for the operation (optional). - - Returns: - A tensor containing the values in the same shape as `keys` using the - table's value type. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - with ops.name_scope( - name, "%s_lookup_table_find" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2( - self.resource_handle, keys, self._default_value, name=name) - return values - - def insert(self, keys, values, name=None): - """Associates `keys` with `values`. - - Args: - keys: Keys to insert. Can be a tensor of any shape. Must match the - table's key type. - values: Values to be associated with keys. Must be a tensor of the same - shape as `keys` and match the table's value type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` or `values` doesn't match the table data - types. - """ - with ops.name_scope(name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values]) as name: - keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") - values = ops.convert_to_tensor(values, self._value_dtype, name="values") - with ops.colocate_with(self.resource_handle): - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_insert_v2( - self.resource_handle, keys, values, name=name) - return op - - def export(self, name=None): - """Returns tensors of all keys and values in the table. - - Args: - name: A name for the operation (optional). - - Returns: - A pair of tensors with the first tensor containing all keys and the - second tensors containing all values in the table. - """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self.resource_handle, self._key_dtype, self._value_dtype, name=name) - return exported_keys, exported_values - - def _gather_saveables_for_checkpoint(self): - """For object-based checkpointing.""" - return {"table": functools.partial(MutableHashTable._Saveable, table=self)} - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableHashTable.""" - - def __init__(self, table, name): - tensors = table.export() - specs = [ - BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), - BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") - ] - # pylint: disable=protected-access - super(MutableHashTable._Saveable, self).__init__(table, specs, name) - - def restore(self, restored_tensors, restored_shapes): - del restored_shapes # unused - # pylint: disable=protected-access - with ops.colocate_with(self.op.resource_handle): - return gen_lookup_ops.lookup_table_import_v2( - self.op.resource_handle, restored_tensors[0], restored_tensors[1]) - - -class MutableDenseHashTable(LookupInterface): - """A generic mutable hash table implementation using tensors as backing store. - - Data can be inserted by calling the insert method and removed by calling the - remove method. It does not support initialization via the init method. - - It uses "open addressing" with quadratic reprobing to resolve collisions. - Compared to `MutableHashTable` the insert, remove and lookup operations in a - `MutableDenseHashTable` are typically faster, but memory usage can be higher. - However, `MutableDenseHashTable` does not require additional memory for - temporary tensors created during checkpointing and restore operations. - - Example usage: - - ```python - table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, - value_dtype=tf.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) - print(out.eval()) - ``` - """ - - # TODO(andreasst): consider extracting common code with MutableHashTable into - # a common superclass. - def __init__(self, - key_dtype, - value_dtype, - default_value, - empty_key, - deleted_key, - initial_num_buckets=None, - shared_name=None, - name="MutableDenseHashTable", - checkpoint=True): - """Creates an empty `MutableDenseHashTable` object. - - Creates a table, the type of its keys and values are specified by key_dtype - and value_dtype, respectively. - - Args: - key_dtype: the type of the key tensors. - value_dtype: the type of the value tensors. - default_value: The value to use if a key is missing in the table. - empty_key: the key to use to represent empty buckets internally. Must not - be used in insert, remove or lookup operations. - initial_num_buckets: the initial number of buckets. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. - name: A name for the operation (optional). - checkpoint: if True, the contents of the table are saved to and restored - from checkpoints. If `shared_name` is empty for a checkpointed table, it - is shared using the table node name. - deleted_key: the key to use to represent deleted buckets internally. Must - not be used in insert, remove or lookup operations and be different from - the empty_key. - - Returns: - A `MutableDenseHashTable` object. - - Raises: - ValueError: If checkpoint is True and no name was specified. - """ - self._default_value = ops.convert_to_tensor( - default_value, dtype=value_dtype, name="default_value") - self._key_dtype = key_dtype - self._value_dtype = value_dtype - self._initial_num_buckets = initial_num_buckets + self._name = name or "hash_table" + self._table_name = None + super(HashTable, self).__init__(default_value, initializer) self._value_shape = self._default_value.get_shape() - self._checkpoint = checkpoint - self._name = name - - self._empty_key = ops.convert_to_tensor( - empty_key, dtype=key_dtype, name="empty_key") - self._deleted_key = ops.convert_to_tensor( - deleted_key, dtype=key_dtype, name="deleted_key") - if context.executing_eagerly() and shared_name is None: - # TODO(allenl): This will leak memory due to kernel caching by the - # shared_name attribute value (but is better than the alternative of - # sharing everything by default when executing eagerly; hopefully creating - # tables in a loop is uncommon). - shared_name = "table_%d" % (ops.uid(),) - self._shared_name = shared_name - super(MutableDenseHashTable, self).__init__(key_dtype, value_dtype) - - self._resource_handle = self.create_resource() - if checkpoint: - saveable = MutableDenseHashTable._Saveable(self, name) - if not context.executing_eagerly(): - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - - def create_resource(self): - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = self._checkpoint and self._shared_name is None - table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( - empty_key=self._empty_key, - deleted_key=self._deleted_key, + + def _create_resource(self): + table_ref = gen_lookup_ops.hash_table_v2( shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - value_dtype=self._value_dtype, - value_shape=self._value_shape, - initial_num_buckets=self._initial_num_buckets, + key_dtype=self._initializer.key_dtype, + value_dtype=self._initializer.value_dtype, name=self._name) if context.executing_eagerly(): self._table_name = None @@ -642,107 +338,14 @@ class MutableDenseHashTable(LookupInterface): self._table_name = table_ref.op.name.split("/")[-1] return table_ref + @property + def init(self): + return self.initializer + @property def name(self): return self._table_name - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - return gen_lookup_ops.lookup_table_size_v2( - self.resource_handle, name=name) - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. Can be a tensor of any shape. Must match the - table's key_dtype. - name: A name for the operation (optional). - - Returns: - A tensor containing the values in the same shape as `keys` using the - table's value type. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - with ops.name_scope(name, "%s_lookup_table_find" % self.name, - [self.resource_handle, keys]) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2( - self.resource_handle, keys, self._default_value, name=name) - - return values - - def insert(self, keys, values, name=None): - """Associates `keys` with `values`. - - Args: - keys: Keys to insert. Can be a tensor of any shape. Must match the - table's key type. - values: Values to be associated with keys. Must be a tensor of the same - shape as `keys` and match the table's value type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` or `values` doesn't match the table data - types. - """ - with ops.name_scope(name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values]) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - values = ops.convert_to_tensor( - values, dtype=self._value_dtype, name="values") - with ops.colocate_with(self.resource_handle): - op = gen_lookup_ops.lookup_table_insert_v2( - self.resource_handle, keys, values, name=name) - return op - - def remove(self, keys, name=None): - """Removes `keys` and its associated values from the table. - - If a key is not present in the table, it is silently ignored. - - Args: - keys: Keys to remove. Can be a tensor of any shape. Must match the table's - key type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_lookup_table_remove" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_remove_v2( - self.resource_handle, keys, name=name) - - return op - def export(self, name=None): """Returns tensors of all keys and values in the table. @@ -753,34 +356,15 @@ class MutableDenseHashTable(LookupInterface): A pair of tensors with the first tensor containing all keys and the second tensors containing all values in the table. """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, + with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self.resource_handle, self._key_dtype, self._value_dtype, name=name) + exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( + self.resource_handle, self._key_dtype, self._value_dtype, name=name) + exported_values.set_shape(exported_keys.get_shape().concatenate( + self._value_shape)) return exported_keys, exported_values - def _gather_saveables_for_checkpoint(self): - """For object-based checkpointing.""" - return {"table": functools.partial( - MutableDenseHashTable._Saveable, table=self)} - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableDenseHashTable.""" - - def __init__(self, table, name): - tensors = table.export() - specs = [ - BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), - BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") - ] - # pylint: disable=protected-access - super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name) - - def restore(self, restored_tensors, restored_shapes): - del restored_shapes # unused - # pylint: disable=protected-access - with ops.colocate_with(self.op.resource_handle): - return gen_lookup_ops.lookup_table_import_v2( - self.op.resource_handle, restored_tensors[0], restored_tensors[1]) + +MutableHashTable = lookup_ops.MutableHashTable +MutableDenseHashTable = lookup_ops.DenseHashTable diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 9b2c2dd87cc8a92fbb6b45504939be3788b60839..9fe8dafcc8edd6b80625c61a4a0e783e65b44720 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -18,14 +18,10 @@ from __future__ import division from __future__ import print_function import os -import tempfile import numpy as np -import six from tensorflow.contrib import lookup from tensorflow.python.client import session -from tensorflow.python.data.experimental.ops import counter -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,9 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import saver from tensorflow.python.training import server_lib -from tensorflow.python.training.checkpointable import util as checkpointable class HashTableOpTest(test.TestCase): @@ -299,1240 +293,6 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([b"brain", b"salad", b"n/a"], result) -class MutableHashTableOpTest(test.TestCase): - - def testMutableHashTable(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["tarkus", "tank"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None], exported_values.get_shape().as_list()) - - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(exported_keys.eval()) - sorted_values = np.sort(exported_values.eval()) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([0, 1, 2], sorted_values) - - def testSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - v0 = variables.Variable(10.0, name="v0") - v1 = variables.Variable(20.0, name="v1") - - default_val = -1 - keys = constant_op.constant(["b", "c", "d"], dtypes.string) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - - save = saver.Saver() - variables.global_variables_initializer().run() - - # Check that the parameter nodes have been initialized. - self.assertEqual(10.0, v0.eval()) - self.assertEqual(20.0, v1.eval()) - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - v0 = variables.Variable(-1.0, name="v0") - v1 = variables.Variable(-1.0, name="v1") - default_val = -1 - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - table.insert( - constant_op.constant(["a", "c"], dtypes.string), - constant_op.constant([12, 24], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - # Check that the parameter nodes have been restored. - self.assertEqual(10.0, v0.eval()) - self.assertEqual(20.0, v1.eval()) - - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["a", "b", "c", "d", "e"], - dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) - - @test_util.run_in_graph_and_eager_modes - def testObjectSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - v0 = variables.Variable(10.0, name="v0") - v1 = variables.Variable(20.0, name="v1") - - default_val = -1 - keys = constant_op.constant(["b", "c", "d"], dtypes.string) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - - checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) - self.evaluate([v0.initializer, v1.initializer]) - - # Check that the parameter nodes have been initialized. - self.assertEqual(10.0, self.evaluate(v0)) - self.assertEqual(20.0, self.evaluate(v1)) - - self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) - - save_path = checkpoint.save(save_prefix) - del table, checkpoint, v0, v1 - - v0 = variables.Variable(-1.0, name="v0") - v1 = variables.Variable(-1.0, name="v1") - default_val = -1 - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - self.evaluate(table.insert( - constant_op.constant(["a", "c"], dtypes.string), - constant_op.constant([12, 24], dtypes.int64))) - self.assertAllEqual(2, self.evaluate(table.size())) - - checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) - - # Restore the saved values in the parameter nodes. - checkpoint.restore(save_path).run_restore_ops() - # Check that the parameter nodes have been restored. - self.assertEqual(10.0, self.evaluate(v0)) - self.assertEqual(20.0, self.evaluate(v1)) - - self.assertAllEqual(3, self.evaluate(table.size())) - - input_string = constant_op.constant(["a", "b", "c", "d", "e"], - dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) - - def testSharing(self): - # Start a server to store the table state - server = server_lib.Server( - { - "local0": ["localhost:0"] - }, protocol="grpc", start=True) - # Create two sessions sharing the same state - session1 = session.Session(server.target) - session2 = session.Session(server.target) - - table = lookup.MutableHashTable( - dtypes.int64, dtypes.string, "-", name="t1") - - # Populate the table in the first session - with session1: - self.assertAllEqual(0, table.size().eval()) - - keys = constant_op.constant([11, 12], dtypes.int64) - values = constant_op.constant(["a", "b"]) - table.insert(keys, values).run() - self.assertAllEqual(2, table.size().eval()) - - output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64)) - self.assertAllEqual([b"a", b"b", b"-"], output.eval()) - - # Verify that we can access the shared data from the second session - with session2: - self.assertAllEqual(2, table.size().eval()) - - output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64)) - self.assertAllEqual([b"-", b"a", b"b"], output.eval()) - - def testMutableHashTableOfTensors(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["tarkus", "tank"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3, 2], output.get_shape()) - - result = output.eval() - self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list(), - msg="Saw shape %s" % exported_keys.shape) - self.assertAllEqual([None, 2], exported_values.get_shape().as_list(), - msg="Saw shape %s" % exported_values.shape) - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(exported_keys.eval()) - sorted_values = np.sort(exported_values.eval()) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) - - def testMutableHashTableExportInsert(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table1.size().eval()) - table1.insert(keys, values).run() - self.assertAllEqual(3, table1.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - expected_output = [[0, 1], [2, 3], [-1, -1]] - output1 = table1.lookup(input_string) - self.assertAllEqual(expected_output, output1.eval()) - - exported_keys, exported_values = table1.export() - self.assertAllEqual(3, exported_keys.eval().size) - self.assertAllEqual(6, exported_values.eval().size) - - # Populate a second table from the exported data - table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table2.size().eval()) - table2.insert(exported_keys, exported_values).run() - self.assertAllEqual(3, table2.size().eval()) - - # Verify lookup result is still the same - output2 = table2.lookup(input_string) - self.assertAllEqual(expected_output, output2.eval()) - - def testMutableHashTableOfTensorsInvalidShape(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - # Shape [6] instead of [3, 2] - values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [2,3] instead of [3, 2] - values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [2, 2] instead of [3, 2] - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [3, 1] instead of [3, 2] - values = constant_op.constant([[0], [2], [4]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Valid Insert - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - def testMutableHashTableInvalidDefaultValue(self): - with self.cached_session(): - default_val = constant_op.constant([[-1, -1]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - with self.assertRaisesOpError("Default value must be a vector"): - self.assertAllEqual(0, table.size().eval()) - - def testMutableHashTableDuplicateInsert(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([3, 1, -1], result) - - def testMutableHashTableFindHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant( - [["brain", "salad"], ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2], output.get_shape()) - - result = output.eval() - self.assertAllEqual([[0, 1], [-1, -1]], result) - - def testMutableHashTableInsertHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, 3, -1], result) - - def testMutableHashTableRemoveHighRank(self): - with self.test_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["salad", "tarkus"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, -1, 3, -1], result) - - def testMutableHashTableOfTensorsFindHighRank(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant( - [["brain", "salad"], ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) - - result = output.eval() - self.assertAllEqual( - [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) - - def testMutableHashTableOfTensorsRemoveHighRank(self): - with self.test_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - remove_string = constant_op.constant([["brain", "tank"]]) - table.remove(remove_string).run() - self.assertAllEqual(2, table.size().eval()) - - input_string = constant_op.constant([["brain", "salad"], - ["surgery", "tank"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) - - result = output.eval() - self.assertAllEqual( - [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) - - def testMultipleMutableHashTables(self): - with self.cached_session() as sess: - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - - table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table3 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table1.insert(keys, values).run() - table2.insert(keys, values).run() - table3.insert(keys, values).run() - - self.assertAllEqual(3, table1.size().eval()) - self.assertAllEqual(3, table2.size().eval()) - self.assertAllEqual(3, table3.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output1 = table1.lookup(input_string) - output2 = table2.lookup(input_string) - output3 = table3.lookup(input_string) - - out1, out2, out3 = sess.run([output1, output2, output3]) - self.assertAllEqual([0, 1, -1], out1) - self.assertAllEqual([0, 1, -1], out2) - self.assertAllEqual([0, 1, -1], out3) - - def testMutableHashTableWithTensorDefault(self): - with self.cached_session(): - default_val = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testSignatureMismatch(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - # insert with keys of the wrong type - with self.assertRaises(ValueError): - table.insert(constant_op.constant([4, 5, 6]), values).run() - - # insert with values of the wrong type - with self.assertRaises(ValueError): - table.insert(keys, constant_op.constant(["a", "b", "c"])).run() - - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string_ref = variables.Variable("brain") - input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) - variables.global_variables_initializer().run() - - # Ref types do not produce an insert signature mismatch. - table.insert(input_string_ref, input_int64_ref).run() - self.assertAllEqual(3, table.size().eval()) - - # Ref types do not produce a lookup signature mismatch. - self.assertEqual(-1, table.lookup(input_string_ref).eval()) - - # lookup with keys of the wrong type - input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(ValueError): - table.lookup(input_string).eval() - - # default value of the wrong type - with self.assertRaises(TypeError): - lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") - - def testMutableHashTableStringFloat(self): - with self.cached_session(): - default_val = -1.5 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) - table = lookup.MutableHashTable(dtypes.string, dtypes.float32, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllClose([0, 1.1, default_val], result) - - def testMutableHashTableIntFloat(self): - with self.cached_session(): - default_val = -1.0 - keys = constant_op.constant([3, 7, 0], dtypes.int64) - values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) - table = lookup.MutableHashTable(dtypes.int64, dtypes.float32, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([7, 0, 11], dtypes.int64) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllClose([-1.2, 9.9, default_val], result) - - def testMutableHashTableInt64String(self): - with self.cached_session(): - default_val = "n/a" - keys = constant_op.constant([0, 1, 2], dtypes.int64) - values = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup.MutableHashTable(dtypes.int64, dtypes.string, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([0, 1, 3], dtypes.int64) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual((b"brain", b"salad", b"n/a"), result) - - -class MutableDenseHashTableOpTest(test.TestCase): - - def testBasic(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, -1, -1], result) - - def testBasicBool(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([True, True, True, True], dtypes.bool) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.bool, - default_value=False, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([11, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([False, True, False], result) - - def testSameEmptyAndDeletedKey(self): - with self.cached_session(): - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, table.size().eval()) - - def testLookupUnknownShape(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - placeholder_keys = array_ops.placeholder(dtypes.int64) - output = table.lookup(placeholder_keys) - self.assertAllEqual(None, output.get_shape()) - result = output.eval({placeholder_keys: [11, 12, 15]}) - self.assertAllEqual([0, 1, -1], result) - - def testMapStringToFloat(self): - with self.cached_session(): - - keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) - default_value = constant_op.constant(-1.5, dtypes.float32) - table = lookup.MutableDenseHashTable( - dtypes.string, - dtypes.float32, - default_value=default_value, - empty_key="", - deleted_key="$") - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["b", "e"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - - def testMapInt64ToFloat(self): - for float_dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) - default_value = constant_op.constant(-1.5, float_dtype) - table = lookup.MutableDenseHashTable( - dtypes.int64, - float_dtype, - default_value=default_value, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - - def testVectorValues(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], - dtypes.int64) - default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - table.insert( - constant_op.constant([14], dtypes.int64), - constant_op.constant([[2, 3, 4, 5]], dtypes.int64)).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - remove_string = constant_op.constant([12, 16], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4, 4], - output.shape, - msg="Saw shape: %s" % output.shape) - - result = output.eval() - self.assertAllEqual( - [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], - result) - - def testVectorKeys(self): - with self.cached_session(): - keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) - values = constant_op.constant([10, 11, 12], dtypes.int64) - empty_key = constant_op.constant([0, 3], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - table.insert( - constant_op.constant([[0, 0]], dtypes.int64), - constant_op.constant([13], dtypes.int64)).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllEqual([10, -1, 12, -1], result) - - def testResize(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - keys2 = constant_op.constant([12, 99], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) - values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) - - table.insert(keys3, values3).run() - self.assertAllEqual(6, table.size().eval()) - self.assertAllEqual(16, len(table.export()[0].eval())) - - keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], - dtypes.int64) - output = table.lookup(keys4) - self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval()) - - def testExport(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([1, 2, 3, 4], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=100, - deleted_key=200, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - keys2 = constant_op.constant([12, 15], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None], exported_values.get_shape().as_list()) - - np_keys = exported_keys.eval() - np_values = exported_values.eval() - - self.assertAllEqual(8, len(np_keys)) - self.assertAllEqual(8, len(np_values)) - - # pair up keys and values, drop extra added dimension - pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] - # sort by key - pairs = pairs[pairs[:, 0].argsort()] - self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], - [100, 0], [100, 0], [200, 2]], pairs) - - def testSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - default_value = -1 - empty_key = 0 - deleted_key = -1 - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([12, 15], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([11, 14], dtypes.int64), - constant_op.constant([12, 24], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, -1, 2, 3], output.eval()) - - @test_util.run_in_graph_and_eager_modes - def testObjectSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - default_value = -1 - empty_key = 0 - deleted_key = -1 - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - save_table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save_checkpoint = checkpointable.Checkpoint(table=save_table) - - self.assertAllEqual(0, self.evaluate(save_table.size())) - self.evaluate(save_table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(save_table.size())) - self.assertAllEqual(32, len(self.evaluate(save_table.export()[0]))) - - save_path = save_checkpoint.save(save_prefix) - del save_table, save_checkpoint - - load_table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - self.evaluate(load_table.insert( - constant_op.constant([11, 14], dtypes.int64), - constant_op.constant([12, 24], dtypes.int64))) - self.assertAllEqual(2, self.evaluate(load_table.size())) - self.assertAllEqual(64, len(self.evaluate(load_table.export()[0]))) - - restore_checkpoint = checkpointable.Checkpoint(table=load_table) - - # Restore the saved values in the parameter nodes. - restore_checkpoint.restore(save_path).run_restore_ops() - - self.assertAllEqual(3, self.evaluate(load_table.size())) - self.assertAllEqual(32, len(self.evaluate(load_table.export()[0]))) - - input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) - output = load_table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) - - def testVectorSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-2, -3], dtypes.int64) - default_value = constant_op.constant([-1, -2], dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], - dtypes.int64) - values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]], - dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-2, -3], dtypes.int64) - default_value = constant_op.constant([-1, -2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([[11, 12], [13, 15]], dtypes.int64), - constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant( - [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]], - output.eval()) - - def testVectorScalarSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], - dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t2", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t2", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([[11, 12], [13, 15]], dtypes.int64), - constant_op.constant([3, 4], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant( - [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([0, 1, -1, 3, -1], output.eval()) - - def testReprobe(self): - with self.cached_session(): - # Insert 6 keys into a table with 8 buckets. - # The values are chosen to make sure collisions occur when using GCC STL - keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) - values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(6, table.size().eval()) - - input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([9], output.get_shape()) - - result = output.eval() - self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) - - def testCustomEmptyKey(self): - with self.cached_session(): - keys = constant_op.constant([11, 0, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=12, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 0, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testErrors(self): - with self.cached_session(): - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - # Inserting the empty key returns an error - keys1 = constant_op.constant([11, 0], dtypes.int64) - values1 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "empty_key"): - table.insert(keys1, values1).run() - - # Looking up the empty key returns an error - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "empty_key"): - table.lookup(keys1).eval() - - # Inserting the deleted key returns an error - keys2 = constant_op.constant([11, -1], dtypes.int64) - values2 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table.insert(keys2, values2).run() - - # Looking up the empty key returns an error - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table.lookup(keys2).eval() - - # Arbitrary tensors of keys are not supported - keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Expected key shape"): - table.lookup(keys).eval() - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Expected key shape"): - table.insert(keys, values).run() - - table2 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=17, - deleted_key=-1, - initial_num_buckets=12) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Number of buckets must be"): - self.assertAllEqual(0, table2.size().eval()) - - with self.assertRaisesRegexp( - errors_impl.InvalidArgumentError, - "Empty and deleted keys must have same shape"): - table3 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=[1, 2]) - self.assertAllEqual(0, table3.size().eval()) - - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table4 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, table4.size().eval()) - - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table5 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=[1, 2, 3], - deleted_key=[1, 2, 3]) - self.assertAllEqual(0, table5.size().eval()) - - class IndexTableFromFile(test.TestCase): def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): @@ -2721,64 +1481,6 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec([None, 2])) -class MutableHashTableBenchmark(test.Benchmark): - - def _create_table(self): - return lookup.MutableHashTable(dtypes.int64, dtypes.float32, 0.0) - - def benchmark_single_repeated_scalar_insert_scalar(self): - table = self._create_table() - value = variables.Variable(1.0) - insert = table.insert(0, value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) - assert sess.run(size) == 1 - - def benchmark_many_repeated_scalar_insert_scalar(self): - table = self._create_table() - c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() - value = variables.Variable(1.0) - insert = table.insert(c, value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) - assert sess.run(size) >= 10000 - - def benchmark_single_repeated_batch_32_insert_scalar(self): - table = self._create_table() - value = variables.Variable([1.0] * 32) - insert = table.insert(list(range(32)), value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) - assert sess.run(size) == 32 - - def benchmark_many_repeated_batch_32_insert_scalar(self): - table = self._create_table() - c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() - value = variables.Variable([1.0] * 32) - insert = table.insert(32 * c + list(range(32)), value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) - assert sess.run(size) >= 1000*32 - - -class MutableDenseHashTableBenchmark(MutableHashTableBenchmark): - - def _create_table(self): - return lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.float32, - default_value=0.0, - empty_key=-1, - deleted_key=-2) - - if __name__ == "__main__": test.main() + diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index 728f75f8ef1eb3b107dbd0ab4ffbecd63787bf3e..f4ebbdeee883ddeef0d47cb561901c16e2195bb2 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -82,10 +82,11 @@ py_library( py_test( name = "metric_loss_ops_test", - size = "large", + size = "medium", srcs = [ "python/metric_learning/metric_loss_ops_test.py", ], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":metric_learning_py", diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 709a042bbcefb89125f7e4cd14a0d7ecd2b53281..dea111f9a0f734a19758a59fb5838f742573560d 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -100,8 +101,8 @@ def compute_weighted_loss(losses, weights=1.0, scope=None): with ops.name_scope(scope, "weighted_loss", [losses, weights]): losses = ops.convert_to_tensor(losses) input_dtype = losses.dtype - losses = math_ops.to_float(losses) - weights = math_ops.to_float(ops.convert_to_tensor(weights)) + losses = math_ops.cast(losses, dtypes.float32) + weights = math_ops.cast(ops.convert_to_tensor(weights), dtypes.float32) if losses.get_shape().ndims is None: raise ValueError("losses.get_shape().ndims cannot be None") @@ -147,8 +148,8 @@ def _num_present(losses, weights, per_batch=False): batch_size = array_ops.reshape( array_ops.slice(array_ops.shape(losses), [0], [1]), []) num_per_batch = math_ops.div( - math_ops.to_float(array_ops.size(losses)), - math_ops.to_float(batch_size)) + math_ops.cast(array_ops.size(losses), dtypes.float32), + math_ops.cast(batch_size, dtypes.float32)) num_per_batch = array_ops.where( math_ops.equal(weights, 0), 0.0, num_per_batch) num_per_batch = math_ops.multiply( @@ -159,12 +160,14 @@ def _num_present(losses, weights, per_batch=False): if weights.get_shape().ndims >= 1: axis = list(range(1, weights.get_shape().ndims)) num_nonzero_per_batch = math_ops.reduce_sum( - math_ops.to_float(math_ops.not_equal(weights, 0)), axis=axis) + math_ops.cast(math_ops.not_equal(weights, 0), dtypes.float32), + axis=axis) # Next, determine the number of elements that weights would broadcast to: broadcast_dims = array_ops.slice( array_ops.shape(losses), [weights.get_shape().ndims], [-1]) - num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims)) + num_to_broadcast = math_ops.cast(math_ops.reduce_prod(broadcast_dims), + dtypes.float32) num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) @@ -262,8 +265,8 @@ def absolute_difference(predictions, labels=None, weights=1.0, scope=None): with ops.name_scope(scope, "absolute_difference", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) losses = math_ops.abs(math_ops.subtract(predictions, labels)) return compute_weighted_loss(losses, weights, scope=scope) @@ -438,8 +441,8 @@ def log_loss(predictions, labels=None, weights=1.0, epsilon=1e-7, scope=None): with ops.name_scope(scope, "log_loss", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) losses = -math_ops.multiply( labels, math_ops.log(predictions + epsilon)) - math_ops.multiply( (1 - labels), math_ops.log(1 - predictions + epsilon)) @@ -473,7 +476,7 @@ def hinge_loss(logits, labels=None, scope=None): with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope: logits.get_shape().assert_is_compatible_with(labels.get_shape()) # We first need to convert binary labels to -1/1 labels (as floats). - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) all_ones = array_ops.ones_like(labels) labels = math_ops.subtract(2 * labels, all_ones) return nn_ops.relu( @@ -509,9 +512,9 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): with ops.name_scope(scope, "mean_squared_error", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) - losses = math_ops.square(math_ops.subtract(predictions, labels)) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) + losses = math_ops.squared_difference(predictions, labels) return compute_weighted_loss(losses, weights, scope=scope) @@ -563,9 +566,9 @@ def mean_pairwise_squared_error(predictions, with ops.name_scope(scope, "mean_pairwise_squared_error", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) - weights = math_ops.to_float(ops.convert_to_tensor(weights)) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) + weights = math_ops.cast(ops.convert_to_tensor(weights), dtypes.float32) diffs = math_ops.subtract(predictions, labels) @@ -638,8 +641,8 @@ def cosine_distance(predictions, [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum( diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index de76acb51ffe985162a66c617b266f47c5216b19..226527a49c73834cb3ccb0cc3255f981fda64e84 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -67,11 +67,13 @@ def pairwise_distance(feature, squared=False): pairwise_distances = pairwise_distances_squared else: pairwise_distances = math_ops.sqrt( - pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) + pairwise_distances_squared + + math_ops.cast(error_mask, dtypes.float32) * 1e-16) # Undo conditionally adding 1e-16. pairwise_distances = math_ops.multiply( - pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) + pairwise_distances, + math_ops.cast(math_ops.logical_not(error_mask), dtypes.float32)) num_data = array_ops.shape(feature)[0] # Explicitly set diagonals to zero. @@ -105,13 +107,14 @@ def contrastive_loss(labels, embeddings_anchor, embeddings_positive, # Get per pair distances distances = math_ops.sqrt( math_ops.reduce_sum( - math_ops.square(embeddings_anchor - embeddings_positive), 1)) + math_ops.squared_difference(embeddings_anchor, embeddings_positive), + 1)) # Add contrastive loss for the siamese network. # label here is {0,1} for neg, pos. return math_ops.reduce_mean( - math_ops.to_float(labels) * math_ops.square(distances) + - (1. - math_ops.to_float(labels)) * + math_ops.cast(labels, dtypes.float32) * math_ops.square(distances) + + (1. - math_ops.cast(labels, dtypes.float32)) * math_ops.square(math_ops.maximum(margin - distances, 0.)), name='contrastive_loss') @@ -283,8 +286,8 @@ def npairs_loss(labels, embeddings_anchor, embeddings_positive, assert lshape.shape == 1 labels = array_ops.reshape(labels, [lshape[0], 1]) - labels_remapped = math_ops.to_float( - math_ops.equal(labels, array_ops.transpose(labels))) + labels_remapped = math_ops.cast( + math_ops.equal(labels, array_ops.transpose(labels)), dtypes.float32) labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True) # Add the softmax loss. @@ -317,9 +320,10 @@ def _build_multilabel_adjacency(sparse_labels): adjacency_matrix = array_ops.zeros([num_pairs, num_pairs]) for i in range(num_pairs): for j in range(num_pairs): - sparse_dot_product = math_ops.to_float( + sparse_dot_product = math_ops.cast( sparse_ops.sparse_reduce_sum(sparse_ops.sparse_minimum( - sparse_labels[i], sparse_labels[j]))) + sparse_labels[i], sparse_labels[j])), + dtypes.float32) sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 0) sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 1) one_hot_matrix = array_ops.pad(sparse_dot_product, @@ -389,7 +393,7 @@ def npairs_loss_multilabel(sparse_labels, embeddings_anchor, # TODO(coreylynch): are composed only of 0's and 1's. multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels) - labels_remapped = math_ops.to_float(multilabel_adjacency_matrix) + labels_remapped = math_ops.cast(multilabel_adjacency_matrix, dtypes.float32) labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True) # Add the softmax loss. @@ -541,7 +545,8 @@ def get_cluster_assignment(pairwise_distances, centroid_ids): array_ops.constant(0, dtype=dtypes.int64), axis=0, dtype=dtypes.int64), - math_ops.to_int64(math_ops.range(array_ops.shape(centroid_ids)[0]))) + math_ops.cast(math_ops.range(array_ops.shape(centroid_ids)[0]), + dtypes.int64)) constraint_vect = math_ops.reduce_sum( array_ops.transpose(constraint_one_hot), axis=0) @@ -605,46 +610,51 @@ def compute_clustering_score(labels, predictions, margin_type): def _compute_nmi_score(labels, predictions): - return math_ops.to_float( + return math_ops.cast( script_ops.py_func( metrics.normalized_mutual_info_score, [labels, predictions], [dtypes.float64], - name='nmi')) + name='nmi'), + dtypes.float32) def _compute_ami_score(labels, predictions): - ami_score = math_ops.to_float( + ami_score = math_ops.cast( script_ops.py_func( metrics.adjusted_mutual_info_score, [labels, predictions], [dtypes.float64], - name='ami')) + name='ami'), + dtypes.float32) return math_ops.maximum(0.0, ami_score) def _compute_ari_score(labels, predictions): - ari_score = math_ops.to_float( + ari_score = math_ops.cast( script_ops.py_func( metrics.adjusted_rand_score, [labels, predictions], [dtypes.float64], - name='ari')) + name='ari'), + dtypes.float32) # ari score can go below 0 # http://scikit-learn.org/stable/modules/clustering.html#adjusted-rand-score return math_ops.maximum(0.0, ari_score) def _compute_vmeasure_score(labels, predictions): - vmeasure_score = math_ops.to_float( + vmeasure_score = math_ops.cast( script_ops.py_func( metrics.v_measure_score, [labels, predictions], [dtypes.float64], - name='vmeasure')) + name='vmeasure'), + dtypes.float32) return math_ops.maximum(0.0, vmeasure_score) def _compute_zeroone_score(labels, predictions): - zeroone_score = math_ops.to_float( + zeroone_score = math_ops.cast( math_ops.equal( math_ops.reduce_sum( - math_ops.to_int32(math_ops.equal(labels, predictions))), - array_ops.shape(labels)[0])) + math_ops.cast(math_ops.equal(labels, predictions), dtypes.int32)), + array_ops.shape(labels)[0]), + dtypes.float32) return zeroone_score @@ -710,8 +720,8 @@ def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids, candidate_scores = math_ops.add( candidate_scores, margin_multiplier * nmi_scores) - argmax_index = math_ops.to_int32( - math_ops.argmax(candidate_scores, axis=0)) + argmax_index = math_ops.cast( + math_ops.argmax(candidate_scores, axis=0), dtypes.int32) return candidate_ids[argmax_index] @@ -786,7 +796,7 @@ def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, def func_body(iteration, scores_margin): # swap the current medoid with the candidate cluster member - candidate_medoid = math_ops.to_int32(cluster_member_ids[iteration]) + candidate_medoid = math_ops.cast(cluster_member_ids[iteration], dtypes.int32) tmp_chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, candidate_medoid) predictions = get_cluster_assignment(pairwise_distances, tmp_chosen_ids) metric_score = compute_clustering_score(labels, predictions, margin_type) @@ -810,10 +820,10 @@ def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, [iteration, scores_margin]) candidate_scores = math_ops.add(scores_fac, margin_multiplier * scores_margin) - argmax_index = math_ops.to_int32( - math_ops.argmax(candidate_scores, axis=0)) + argmax_index = math_ops.cast( + math_ops.argmax(candidate_scores, axis=0), dtypes.int32) - best_medoid = math_ops.to_int32(cluster_member_ids[argmax_index]) + best_medoid = math_ops.cast(cluster_member_ids[argmax_index], dtypes.int32) chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, best_medoid) return chosen_ids @@ -841,7 +851,8 @@ def update_all_medoids(pairwise_distances, predictions, labels, chosen_ids, def func_body_augmented_pam(iteration, chosen_ids): """Call the update_medoid_per_cluster subroutine.""" mask = math_ops.equal( - math_ops.to_int64(predictions), math_ops.to_int64(iteration)) + math_ops.cast(predictions, dtypes.int64), + math_ops.cast(iteration, dtypes.int64)) this_cluster_ids = array_ops.where(mask) pairwise_distances_subset = array_ops.transpose( diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 7ea6e34cf50ed8e292f11314550d992c3dde34c0..37cb198e22166a14d2a65704f563d1d530f7f322 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -439,7 +439,6 @@ $(MARCH_OPTION) \ -Itensorflow/core/kernels \ -I$(MAKEFILE_DIR)/downloads/cub \ -I$(MAKEFILE_DIR)/downloads/cub/cub_archive/cub/device \ --Ithird_party/toolchains/gpus/cuda \ -I$(JETPACK)/cuda/include \ -I$(JETPACK) \ -I$(JETPACK)/cuDNN/aarch64 \ @@ -630,6 +629,9 @@ BENCHMARK_NAME := $(BINDIR)benchmark CORE_CC_ALL_SRCS := \ $(ABSL_CC_SRCS) \ +tensorflow/c/c_api.cc \ +tensorflow/c/kernels.cc \ +tensorflow/c/tf_status_helper.cc \ $(wildcard tensorflow/core/*.cc) \ $(wildcard tensorflow/core/common_runtime/*.cc) \ $(wildcard tensorflow/core/framework/*.cc) \ diff --git a/tensorflow/contrib/makefile/compile_ios_protobuf.sh b/tensorflow/contrib/makefile/compile_ios_protobuf.sh index 8fa20213633414d134d6c6a50e151cce2ac8a368..d2fbf696f8f3bc0031db132b021b3da3591d5ce6 100755 --- a/tensorflow/contrib/makefile/compile_ios_protobuf.sh +++ b/tensorflow/contrib/makefile/compile_ios_protobuf.sh @@ -24,11 +24,11 @@ fi usage() { echo "Usage: $(basename "$0") [-a]" echo "-a [build_arch] build for specified arch comma separate for multiple archs (eg: x86_64,arm64)" - echo "default arch i386, x86_64, armv7, armv7s, arm64" + echo "default arch x86_64, armv7, armv7s, arm64" exit 1 } -BUILD_TARGET="i386 x86_64 armv7 armv7s arm64" +BUILD_TARGET="x86_64 armv7 armv7s arm64" while getopts "a:" opt_name; do case "$opt_name" in a) BUILD_TARGET="${OPTARG}";; @@ -115,39 +115,6 @@ package_pb_library() { build_target() { case "$1" in - i386) make distclean - ./configure \ - --host=i386-apple-${OSX_VERSION} \ - --disable-shared \ - --enable-cross-compile \ - --with-protoc="${PROTOC_PATH}" \ - --prefix=${LIBDIR}/iossim_386 \ - --exec-prefix=${LIBDIR}/iossim_386 \ - "CFLAGS=${CFLAGS} \ - -mios-simulator-version-min=${MIN_SDK_VERSION} \ - -arch i386 \ - -fembed-bitcode \ - -isysroot ${IPHONESIMULATOR_SYSROOT}" \ - "CXX=${CXX}" \ - "CXXFLAGS=${CXXFLAGS} \ - -mios-simulator-version-min=${MIN_SDK_VERSION} \ - -arch i386 \ - -fembed-bitcode \ - -isysroot \ - ${IPHONESIMULATOR_SYSROOT}" \ - LDFLAGS="-arch i386 \ - -fembed-bitcode \ - -mios-simulator-version-min=${MIN_SDK_VERSION} \ - ${LDFLAGS} \ - -L${IPHONESIMULATOR_SYSROOT}/usr/lib/ \ - -L${IPHONESIMULATOR_SYSROOT}/usr/lib/system" \ - "LIBS=${LIBS}" - make -j"${JOB_COUNT}" - make install - - package_pb_library "iossim_386" - ;; - x86_64) make distclean ./configure \ --host=x86_64-apple-${OSX_VERSION} \ diff --git a/tensorflow/contrib/makefile/compile_ios_tensorflow.sh b/tensorflow/contrib/makefile/compile_ios_tensorflow.sh index ae82163e1178216fc22aad37cd07fd1734c2bedb..3822f0d7da78afbf67fa1fa0389730b0196b637b 100755 --- a/tensorflow/contrib/makefile/compile_ios_tensorflow.sh +++ b/tensorflow/contrib/makefile/compile_ios_tensorflow.sh @@ -46,11 +46,11 @@ fi usage() { echo "Usage: $(basename "$0") [-a]" echo "-a [build_arch] build for specified arch comma separate for multiple archs (eg: x86_64,arm64)" - echo "default is [i386, x86_64, armv7, armv7s, arm64]" + echo "default is [x86_64, armv7, armv7s, arm64]" exit 1 } -BUILD_TARGET="i386 x86_64 armv7 armv7s arm64" +BUILD_TARGET="x86_64 armv7 armv7s arm64" while getopts "a:f:h:n:" opt_name; do case "$opt_name" in a) BUILD_TARGET="${OPTARG}";; @@ -126,18 +126,6 @@ case "$1" in fi package_tf_library "arm64" ;; - i386) - make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ - TARGET=IOS IOS_ARCH=I386 LIB_NAME=${LIB_PREFIX}-i386.a \ - OPTFLAGS="${BUILD_OPT}" HOST_NSYNC_LIB="${NSYNC_HOST}" \ - TARGET_NSYNC_LIB="${NSYNC_TARGET}" - if [ $? -ne 0 ] - then - echo "i386 compilation failed." - exit 1 - fi - package_tf_library "i386" - ;; x86_64) make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ TARGET=IOS IOS_ARCH=X86_64 LIB_NAME=${LIB_PREFIX}-x86_64.a \ diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index cb4c94d92fc630c1ce4158c618cd82be80de6741..e154b8223c64fd0b42eb4cc6ae7da1188b92ac36 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -22,7 +22,7 @@ set -e prog=compile_nsync.sh android_api_version=21 default_android_arch=armeabi-v7a -default_ios_arch="i386 x86_64 armv7 armv7s arm64" +default_ios_arch="x86_64 armv7 armv7s arm64" usage="usage: $prog [-t linux|ios|android|macos|native] [-a architecture] [-v android_api_version] @@ -130,7 +130,7 @@ for arch in $archs; do ios) arch_flags= case "$arch" in - i386|x86_64) + x86_64) arch_flags="$arch_flags -mios-simulator-version-min=8.0" arch_flags="$arch_flags -isysroot $(xcrun --sdk iphonesimulator --show-sdk-path)" ;; diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 2a5232b476712a96f84be0f4725beb78bc138297..af3c541dc214c30e9e59fdcca995ffc53b028df4 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -142,5 +142,6 @@ replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DAT # TODO(satok): Remove this once protobuf/autogen.sh is fixed. replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \ "${DOWNLOADS_DIR}/protobuf/autogen.sh" +cat "third_party/eigen3/gebp_neon.patch" | patch "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h" echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 9ea94c74330e3e49414a6a84cd5bc0db3778114a..0a0ba36232075460b561bc54a95fc24973017571 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -40,7 +40,6 @@ tensorflow/core/lib/wav/wav_io.cc tensorflow/core/platform/cpu_info.cc tensorflow/core/platform/default/logging.cc tensorflow/core/platform/default/mutex.cc -tensorflow/core/platform/default/protobuf.cc tensorflow/core/platform/default/tracing.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc @@ -53,6 +52,7 @@ tensorflow/core/platform/posix/error.cc tensorflow/core/platform/posix/load_library.cc tensorflow/core/platform/posix/port.cc tensorflow/core/platform/posix/posix_file_system.cc +tensorflow/core/platform/protobuf.cc tensorflow/core/platform/protobuf_util.cc tensorflow/core/platform/setround.cc tensorflow/core/platform/tensor_coding.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 87c73ec1ca610cac6d63468887bc350bada5910b..1c1460ce77c99d29785c7e8b8a8e9f770a45b59f 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.cc tensorflow/core/framework/versions.pb.cc tensorflow/core/grappler/costs/op_performance_data.pb.cc tensorflow/core/lib/core/error_codes.pb.cc +tensorflow/core/protobuf/trackable_object_graph.pb.cc tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/eager_service.pb.cc @@ -34,8 +35,11 @@ tensorflow/core/protobuf/meta_graph.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc +tensorflow/core/protobuf/saved_object_graph.pb.cc tensorflow/core/protobuf/saver.pb.cc +tensorflow/core/protobuf/struct.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc +tensorflow/core/protobuf/verifier_config.pb.cc tensorflow/core/util/event.pb.cc tensorflow/core/util/memmapped_file_system.pb.cc tensorflow/core/util/saved_tensor_slice.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 4120ea52ec5255b1efce7a6ce6890fc79c1e4831..5def632e8a7b65272a1339bdacd92c1fa23012d2 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.h tensorflow/core/framework/versions.pb.h tensorflow/core/grappler/costs/op_performance_data.pb.h tensorflow/core/lib/core/error_codes.pb.h +tensorflow/core/protobuf/trackable_object_graph.pb.h tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h @@ -34,9 +35,12 @@ tensorflow/core/protobuf/meta_graph.pb.h tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/rewriter_config.pb.h +tensorflow/core/protobuf/saved_object_graph.pb.h tensorflow/core/protobuf/saver.pb.h +tensorflow/core/protobuf/struct.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h +tensorflow/core/protobuf/verifier_config.pb.h tensorflow/core/util/event.pb.h tensorflow/core/util/memmapped_file_system.pb.h tensorflow/core/util/saved_tensor_slice.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 2cd7d6d519a55423a96526b541845392d9ec6bc2..ea5f5913c66e85ed6ea84e7127c8eb85a51d609e 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -1,3 +1,4 @@ +tensorflow/c/kernels/bitcast_op.cc tensorflow/contrib/boosted_trees/ops/model_ops.cc tensorflow/contrib/boosted_trees/ops/prediction_ops.cc tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -43,7 +44,9 @@ tensorflow/core/kernels/conv_grad_input_ops.cc tensorflow/core/kernels/conv_grad_ops.cc tensorflow/core/kernels/conv_ops.cc tensorflow/core/kernels/conv_ops_3d.cc -tensorflow/core/kernels/conv_ops_fused.cc +tensorflow/core/kernels/conv_ops_fused_double.cc +tensorflow/core/kernels/conv_ops_fused_float.cc +tensorflow/core/kernels/conv_ops_fused_half.cc tensorflow/core/kernels/conv_ops_using_gemm.cc tensorflow/core/kernels/crop_and_resize_op.cc tensorflow/core/kernels/ctc_decoder_ops.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f94d70db9046cec43073ab1406762aea1f28c8e3..13e3b6422d1989b0d499d8d20901d919554c630e 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -29,5 +29,6 @@ tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/tensor_bundle.pb_text.cc +tensorflow/core/protobuf/verifier_config.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2712e906d719e72dacb60e213205ad68895f905f..deb6a5b94020a02b878bdd68a33b3737a97fcf2b 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -31,6 +31,7 @@ tensorflow/core/framework/versions.proto tensorflow/core/grappler/costs/op_performance_data.proto tensorflow/core/kernels/boosted_trees/boosted_trees.proto tensorflow/core/lib/core/error_codes.proto +tensorflow/core/protobuf/trackable_object_graph.proto tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto @@ -40,9 +41,12 @@ tensorflow/core/protobuf/meta_graph.proto tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/rewriter_config.proto +tensorflow/core/protobuf/saved_object_graph.proto tensorflow/core/protobuf/saver.proto +tensorflow/core/protobuf/struct.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensorflow_server.proto +tensorflow/core/protobuf/verifier_config.proto tensorflow/core/util/event.proto tensorflow/core/util/memmapped_file_system.proto tensorflow/core/util/saved_tensor_slice.proto diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 63843b993c16363a80b64622af665aaa64e05830..93701249cc8bf722c8c8558e91e0b700ca1c4a04 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -10,6 +10,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -45,6 +46,28 @@ tf_gen_op_wrapper_py( deps = [":memory_stats_ops_op_lib"], ) +tf_gen_op_wrapper_cc( + name = "memory_stats_ops", + out_ops_file = "memory_stats_ops", +) + +cc_library( + name = "memory_stats_cc", + srcs = ["memory_stats_ops.cc"], + hdrs = ["memory_stats_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":memory_stats_kernels", + ":memory_stats_ops_op_lib", + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "memory_stats_py", srcs = [ diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 974fb537499c5ea4591a0a128f53d2dea67b9e57..7ae1dbeaa2d04d7846e7fada117f3941319cc1c1 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -24,13 +24,15 @@ class MemoryStatsOp : public OpKernel { void Compute(OpKernelContext* context) override { Allocator* allocator = context->device()->GetAllocator(AllocatorAttributes()); - AllocatorStats allocator_stats; - allocator->GetStats(&allocator_stats); + absl::optional allocator_stats = allocator->GetStats(); + if (!allocator_stats) { + *allocator_stats = AllocatorStats(); + } Tensor* output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape({}), &output_tensor)); - output_tensor->scalar()() = ExtractAllocatorStats(allocator_stats); + output_tensor->scalar()() = ExtractAllocatorStats(*allocator_stats); } protected: @@ -71,7 +73,7 @@ class BytesLimitOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.bytes_limit; + return allocator_stats.bytes_limit ? *allocator_stats.bytes_limit : -1; } }; @@ -93,7 +95,7 @@ class MaxBytesInUseOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.max_bytes_in_use; + return allocator_stats.peak_bytes_in_use; } }; diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 7b432f8bd20989c6d95310bcaca88d44ce3e0d1f..c7316229cf00d0b8db11aa6c37a1e42dd67dcfe1 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -772,7 +772,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if weights is not None: broadcast_weights = weights_broadcast_ops.broadcast_weights( - math_ops.to_float(weights), predictions) + math_ops.cast(weights, dtypes.float32), predictions) weights_tiled = array_ops.tile( array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( @@ -786,8 +786,8 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'tp' in includes: true_positives = metrics_impl.metric_variable( [num_thresholds], dtypes.float32, name='true_positives') - is_true_positive = math_ops.to_float( - math_ops.logical_and(label_is_pos, pred_is_pos)) + is_true_positive = math_ops.cast( + math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) if weights_tiled is not None: is_true_positive *= weights_tiled update_ops['tp'] = state_ops.assign_add(true_positives, @@ -798,8 +798,8 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'fn' in includes: false_negatives = metrics_impl.metric_variable( [num_thresholds], dtypes.float32, name='false_negatives') - is_false_negative = math_ops.to_float( - math_ops.logical_and(label_is_pos, pred_is_neg)) + is_false_negative = math_ops.cast( + math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) if weights_tiled is not None: is_false_negative *= weights_tiled update_ops['fn'] = state_ops.assign_add(false_negatives, @@ -810,8 +810,8 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'tn' in includes: true_negatives = metrics_impl.metric_variable( [num_thresholds], dtypes.float32, name='true_negatives') - is_true_negative = math_ops.to_float( - math_ops.logical_and(label_is_neg, pred_is_neg)) + is_true_negative = math_ops.cast( + math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) if weights_tiled is not None: is_true_negative *= weights_tiled update_ops['tn'] = state_ops.assign_add(true_negatives, @@ -822,8 +822,8 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if 'fp' in includes: false_positives = metrics_impl.metric_variable( [num_thresholds], dtypes.float32, name='false_positives') - is_false_positive = math_ops.to_float( - math_ops.logical_and(label_is_neg, pred_is_pos)) + is_false_positive = math_ops.cast( + math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) if weights_tiled is not None: is_false_positive *= weights_tiled update_ops['fp'] = state_ops.assign_add(false_positives, @@ -1356,9 +1356,8 @@ def _compute_placement_auc(labels, predictions, weights, alpha, weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / (total_0 - 1. + _EPSILON)) var_1 = ( - math_ops.reduce_sum( - weights_1 * math_ops.square(placement_values_1 - auc_1)) / - (total_1 - 1. + _EPSILON)) + math_ops.reduce_sum(weights_1 * math_ops.squared_difference( + placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON)) auc_std_err = math_ops.sqrt( (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) @@ -2165,7 +2164,7 @@ def streaming_recall_at_k(predictions, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) + in_top_k = math_ops.cast(nn.in_top_k(predictions, labels, k), dtypes.float32) return streaming_mean(in_top_k, weights, metrics_collections, updates_collections, name or _at_k_name('recall', k)) @@ -3206,7 +3205,8 @@ def streaming_covariance(predictions, [], dtypes.float32, name='comoment') if weights is None: - batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn + batch_count = math_ops.cast( + array_ops.size(labels), dtypes.float32) # n_B in eqn weighted_predictions = predictions weighted_labels = labels else: @@ -3766,15 +3766,15 @@ def count(values, count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') if weights is None: - num_values = math_ops.to_float(array_ops.size(values)) + num_values = math_ops.cast(array_ops.size(values), dtypes.float32) else: - values = math_ops.to_float(values) + values = math_ops.cast(values, dtypes.float32) values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=values, labels=None, weights=weights) weights = weights_broadcast_ops.broadcast_weights( - math_ops.to_float(weights), values) + math_ops.cast(weights, dtypes.float32), values) num_values = math_ops.reduce_sum(weights) with ops.control_dependencies([values]): @@ -3896,10 +3896,11 @@ def cohen_kappa(labels, total = math_ops.reduce_sum(pe_row) pe_sum = math_ops.reduce_sum( math_ops.div_no_nan( - math_ops.to_double(pe_row * pe_col), math_ops.to_double(total))) - po_sum, pe_sum, total = (math_ops.to_double(po_sum), - math_ops.to_double(pe_sum), - math_ops.to_double(total)) + math_ops.cast(pe_row * pe_col, dtypes.float64), + math_ops.cast(total, dtypes.float64))) + po_sum, pe_sum, total = (math_ops.cast(po_sum, dtypes.float64), + math_ops.cast(pe_sum, dtypes.float64), + math_ops.cast(total, dtypes.float64)) # kappa = (po - pe) / (N - pe) k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access po_sum - pe_sum, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index fc64f343ab4add17f04eabeccab922e8be51a692..aec07241e7a23410870de2285ab0470b1650cc7b 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -5810,9 +5810,10 @@ class StreamingCovarianceTest(test.TestCase): def testVars(self): metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10]))) _assert_metric_variables(self, ( 'covariance/comoment:0', 'covariance/count:0', @@ -5823,18 +5824,20 @@ class StreamingCovarianceTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' cov, _ = metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [cov]) def testUpdatesCollection(self): my_collection_name = '__updates__' _, update_op = metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -5857,8 +5860,8 @@ class StreamingCovarianceTest(test.TestCase): def testSingleUpdateIdentical(self): with self.cached_session() as sess: - predictions = math_ops.to_float(math_ops.range(10)) - labels = math_ops.to_float(math_ops.range(10)) + predictions = math_ops.cast(math_ops.range(10), dtypes_lib.float32) + labels = math_ops.cast(math_ops.range(10), dtypes_lib.float32) cov, update_op = metrics.streaming_covariance(predictions, labels) @@ -5982,9 +5985,10 @@ class StreamingPearsonRTest(test.TestCase): def testVars(self): metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10]))) _assert_metric_variables(self, ( 'pearson_r/covariance/comoment:0', 'pearson_r/covariance/count:0', @@ -6003,18 +6007,20 @@ class StreamingPearsonRTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' pearson_r, _ = metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [pearson_r]) def testUpdatesCollection(self): my_collection_name = '__updates__' _, update_op = metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -6038,8 +6044,8 @@ class StreamingPearsonRTest(test.TestCase): def testSingleUpdateIdentical(self): with self.cached_session() as sess: - predictions = math_ops.to_float(math_ops.range(10)) - labels = math_ops.to_float(math_ops.range(10)) + predictions = math_ops.cast(math_ops.range(10), dtypes_lib.float32) + labels = math_ops.cast(math_ops.range(10), dtypes_lib.float32) pearson_r, update_op = metrics.streaming_pearson_correlation( predictions, labels) diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc index a31fa9ce0b3110d875689d74a41ca9f9cc85f532..e44e10af0814ba8d6d964dfc34a0470ce45c0b40 100644 --- a/tensorflow/contrib/mpi/mpi_server_lib.cc +++ b/tensorflow/contrib/mpi/mpi_server_lib.cc @@ -54,7 +54,10 @@ MPIServer::~MPIServer() { Status MPIServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { - Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); + GrpcServerOptions opts; + opts.service_func = service_func; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + Status s = GrpcServer::Init(opts); return s; } diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc index b04abde4694199d827a1738850bded9bf696d56c..ca3ddfa721d45a2de3ea51c80d6adfa2371c3c94 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc @@ -96,13 +96,14 @@ __global__ void elemwise_accum(T* out, const T* in, const size_t N) { // Synchronously accumulate tensors on the GPU, using a different stream than // the default and than TensorFlow to avoid synchronizing on operations // unrelated to the allreduce. -#define GENERATE_ACCUMULATE(type) \ - template <> \ - void AccumulateTensorData(type * dst, type * src, \ - size_t size) { \ - auto stream = CudaStreamForMPI(); \ - elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ - cudaStreamSynchronize(stream); \ +#define GENERATE_ACCUMULATE(type) \ + template <> \ + void AccumulateTensorData(type * dst, type * src, \ + size_t size) { \ + auto stream = CudaStreamForMPI(); \ + TF_CHECK_OK(CudaLaunchKernel(elemwise_accum, 32, 256, 0, stream, \ + dst, src, size)); \ + cudaStreamSynchronize(stream); \ }; GENERATE_ACCUMULATE(int); GENERATE_ACCUMULATE(long long); diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc index 2f3eef366a9a3c10e59cd5298fc1626e1094dff8..c73156d230820e8f89d88d8d4c8599fd1a5f68d8 100644 --- a/tensorflow/contrib/mpi_collectives/ring.cu.cc +++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc @@ -96,13 +96,14 @@ __global__ void elemwise_accum(T* out, const T* in, const size_t N) { // Synchronously accumulate tensors on the GPU, using a different stream than // the default and than TensorFlow to avoid synchronizing on operations // unrelated to the allreduce. -#define GENERATE_ACCUMULATE(type) \ - template <> \ - void AccumulateTensorData(type * dst, type * src, \ - size_t size) { \ - auto stream = CudaStreamForMPI(); \ - elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ - cudaStreamSynchronize(stream); \ +#define GENERATE_ACCUMULATE(type) \ + template <> \ + void AccumulateTensorData(type * dst, type * src, \ + size_t size) { \ + auto stream = CudaStreamForMPI(); \ + TF_CHECK_OK(CudaLaunchKernel(elemwise_accum, 32, 256, 0, stream, \ + dst, src, size)); \ + cudaStreamSynchronize(stream); \ }; GENERATE_ACCUMULATE(int); GENERATE_ACCUMULATE(long long); diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h index cae57ce60eb09509af69f8ccab9eacedea361548..9b5d52e1b648e62af93d5420885e4f22796e3ea1 100644 --- a/tensorflow/contrib/mpi_collectives/ring.h +++ b/tensorflow/contrib/mpi_collectives/ring.h @@ -129,7 +129,7 @@ cudaStream_t CudaStreamForMPI(); * has the fully accumulated Segment 1; and so on. The scatter-reduce is * complete. * - * Next, the allgather distributes these fully accumululated chunks across all + * Next, the allgather distributes these fully accumulated chunks across all * nodes. Communication proceeds in the same ring, once again in N-1 steps. At * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). * For example, at the first iteration, the following transfers will occur: diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 12320d9e456ae93cbf95639a0c9e0c7f414f3518..f30643cf3059754daaeee4093938ac47b26f76ea 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -413,8 +413,9 @@ py_test( py_test( name = "shampoo_test", - size = "large", + size = "medium", srcs = ["python/training/shampoo_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":opt_py", diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py index c5c9fc74deaf0171a33d0eb1b5c6f60b3aa5e533..0b149ed17533adff3bd7cd8fd8ff94d171f72911 100644 --- a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" from __future__ import absolute_import from __future__ import division @@ -38,9 +37,14 @@ class AdamGSOptimizer(optimizer.Optimizer): ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). """ - def __init__(self, global_step=0, learning_rate=0.001, - beta1=0.9, beta2=0.999, epsilon=1e-8, - use_locking=False, name="Adam"): + def __init__(self, + global_step=0, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + use_locking=False, + name="Adam"): r"""Construct a new Adam optimizer. Branched from tf.train.AdamOptimizer. The only difference is to pass @@ -112,9 +116,6 @@ class AdamGSOptimizer(optimizer.Optimizer): self._beta2_t = None self._epsilon_t = None - # Created in SparseApply if needed. - self._updated_lr = None - def _get_beta_accumulators(self): return (math_ops.pow(self._beta1_t, self._global_step_on_worker), math_ops.pow(self._beta2_t, self._global_step_on_worker)) @@ -146,28 +147,34 @@ class AdamGSOptimizer(optimizer.Optimizer): v = self.get_slot(var, "v") beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.apply_adam( - var, m, v, + var, + m, + v, math_ops.cast(beta1_power, var.dtype.base_dtype), math_ops.cast(beta2_power, var.dtype.base_dtype), math_ops.cast(self._lr_t, var.dtype.base_dtype), math_ops.cast(self._beta1_t, var.dtype.base_dtype), math_ops.cast(self._beta2_t, var.dtype.base_dtype), math_ops.cast(self._epsilon_t, var.dtype.base_dtype), - grad, use_locking=self._use_locking).op + grad, + use_locking=self._use_locking).op def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.resource_apply_adam( - var.handle, m.handle, v.handle, + var.handle, + m.handle, + v.handle, math_ops.cast(beta1_power, grad.dtype.base_dtype), math_ops.cast(beta2_power, grad.dtype.base_dtype), math_ops.cast(self._lr_t, grad.dtype.base_dtype), math_ops.cast(self._beta1_t, grad.dtype.base_dtype), math_ops.cast(self._beta2_t, grad.dtype.base_dtype), math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), - grad, use_locking=self._use_locking) + grad, + use_locking=self._use_locking) def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power, beta2_power = self._get_beta_accumulators() @@ -181,8 +188,7 @@ class AdamGSOptimizer(optimizer.Optimizer): # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) - m_t = state_ops.assign(m, m * beta1_t, - use_locking=self._use_locking) + m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) @@ -192,23 +198,26 @@ class AdamGSOptimizer(optimizer.Optimizer): with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) v_sqrt = math_ops.sqrt(v_t) - var_update = state_ops.assign_sub(var, - lr * m_t / (v_sqrt + epsilon_t), - use_locking=self._use_locking) + var_update = state_ops.assign_sub( + var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t]) def _apply_sparse(self, grad, var): return self._apply_sparse_shared( - grad.values, var, grad.indices, + grad.values, + var, + grad.indices, lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda - x, i, v, use_locking=self._use_locking)) + x, + i, + v, + use_locking=self._use_locking)) def _resource_scatter_add(self, x, i, v): with ops.control_dependencies( - [resource_variable_ops.resource_scatter_add( - x.handle, i, v)]): + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): return x.value() def _resource_apply_sparse(self, grad, var, indices): - return self._apply_sparse_shared( - grad, var, indices, self._resource_scatter_add) + return self._apply_sparse_shared(grad, var, indices, + self._resource_scatter_add) diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py index 6dc17fe5a5210fa1700e1382016e40fa0a792917..df0cb2b0071b932418492e83aebca9fe70027162 100644 --- a/tensorflow/contrib/opt/python/training/ggt.py +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import numpy as np from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -224,7 +225,7 @@ class GGTOptimizer(optimizer_v2.OptimizerV2): window = state.get_hyper("window") grad_buffer = self._get_grad_buffer(state) next_grad_index = math_ops.floormod( - math_ops.to_int32(update_global_step - 1.), window) + math_ops.cast(update_global_step - 1., dtypes.int32), window) # grad_buffer[(t-1) % window] := moment1_t update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, update_moment1) diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py index baab577638626fb39bfbd9b60d98b5848d481a1c..1c5d2fe17876cf6dda45194b445c3a12b65b1210 100644 --- a/tensorflow/contrib/opt/python/training/matrix_functions.py +++ b/tensorflow/contrib/opt/python/training/matrix_functions.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -57,7 +58,7 @@ def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4): current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err - identity = linalg_ops.eye(math_ops.to_int32(mat_a_size)) + identity = linalg_ops.eye(math_ops.cast(mat_a_size, dtypes.int32)) mat_a = mat_a + ridge_epsilon * identity norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a)) mat_init_y = mat_a / norm @@ -100,7 +101,7 @@ def matrix_inverse_pth_root(mat_g, mat_g^alpha """ - identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + identity = linalg_ops.eye(math_ops.cast(mat_g_size, dtypes.int32)) def mat_power(mat_m, p): """Computes mat_m^p, for p a positive integer. diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py index e542f46892a3cea60b758a1a95ce2f20d5f29a67..efbafac662b78fe2cde9e50e6778bb787af29e31 100644 --- a/tensorflow/contrib/opt/python/training/shampoo.py +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -24,6 +24,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.opt.python.training import matrix_functions +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -120,7 +121,7 @@ class ShampooOptimizer(optimizer.Optimizer): super(ShampooOptimizer, self).__init__(use_locking, name) - self._global_step = math_ops.to_float(global_step) + self._global_step = math_ops.cast(global_step, dtypes.float32) self._max_matrix_size = max_matrix_size self._gbar_decay = gbar_decay self._gbar_weight = gbar_weight @@ -246,7 +247,8 @@ class ShampooOptimizer(optimizer.Optimizer): if mat_g_size == 1: mat_h = math_ops.pow(mat_g + self._epsilon, alpha) else: - damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + damping = self._epsilon * linalg_ops.eye( + math_ops.cast(mat_g_size, dtypes.int32)) diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) mat_h = math_ops.matmul( mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), diff --git a/tensorflow/contrib/opt/python/training/sign_decay.py b/tensorflow/contrib/opt/python/training/sign_decay.py index e8870c072110da145c0bb78e20c3584083438ea0..99cd0f6e60e1d2fda14060c571c9aab8c7d32da2 100644 --- a/tensorflow/contrib/opt/python/training/sign_decay.py +++ b/tensorflow/contrib/opt/python/training/sign_decay.py @@ -23,7 +23,9 @@ from __future__ import division from __future__ import print_function import math + from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -51,10 +53,10 @@ def get_linear_decay_fn(decay_steps): if global_step is None: raise ValueError("global_step is required for linear_decay.") global_step = math_ops.minimum(global_step, decay_steps) - remaining_steps = math_ops.to_int32(decay_steps) - math_ops.to_int32( - global_step) - decayed = math_ops.to_float(remaining_steps) / math_ops.to_float( - decay_steps) + remaining_steps = math_ops.cast( + decay_steps, dtypes.int32) - math_ops.cast(global_step, dtypes.int32) + decayed = (math_ops.cast(remaining_steps, dtypes.float32) / + math_ops.cast(decay_steps, dtypes.float32)) return math_ops.maximum(0.0, decayed) # pylint:enable=missing-docstring return linear_decay_fn @@ -92,8 +94,8 @@ def get_cosine_decay_fn(decay_steps, num_periods=0.5, zero_after=None): if global_step is None: raise ValueError("global_step is required for cosine_decay.") global_step = math_ops.minimum(global_step, decay_steps) - completed_fraction = math_ops.to_float(global_step) / math_ops.to_float( - decay_steps) + completed_fraction = (math_ops.cast(global_step, dtypes.float32) / + math_ops.cast(decay_steps, dtypes.float32)) fraction = 2.0 * num_periods * completed_fraction decayed = 0.5 * ( 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) @@ -143,14 +145,14 @@ def get_restart_decay_fn(decay_steps, num_periods=1, zero_after=None): if global_step is None: raise ValueError("global_step is required for cosine_decay.") global_step = math_ops.minimum(global_step, decay_steps) - num = math_ops.mod(num_periods * math_ops.to_float(global_step), + num = math_ops.mod(num_periods * math_ops.cast(global_step, dtypes.float32), decay_steps) - fraction = num / math_ops.to_float(decay_steps) + fraction = num / math_ops.cast(decay_steps, dtypes.float32) decayed = 0.5 * ( 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) if zero_after is not None: - tmp = math_ops.to_float( - num_periods * global_step) / math_ops.to_float(decay_steps) + tmp = (math_ops.cast(num_periods * global_step, dtypes.float32) / + math_ops.cast(decay_steps, dtypes.float32)) decayed = array_ops.where( math_ops.greater_equal(tmp, zero_after), 0.0, decayed) return decayed diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 0243927ce44aec626973744507e75b20a42253e9..b469ebff25fafc5d97a3e457732954d238cbb5af 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -44,14 +44,15 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import graph_view +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util -class NonLayerCheckpointable(tracking.AutoCheckpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() + super(NonLayerTrackable, self).__init__() self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -64,8 +65,8 @@ class MyModel(training.Model): super(MyModel, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -100,7 +101,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -116,11 +117,10 @@ class CheckpointingTests(test.TestCase): other_model(input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) - named_variables, serialized_graph, _ = ( - util._serialize_object_graph( - root_checkpointable, saveables_cache=None)) + named_variables, serialized_graph, _ = graph_view.ObjectGraphView( + root_trackable).serialize_object_graph() expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -208,7 +208,7 @@ class CheckpointingTests(test.TestCase): def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -217,24 +217,24 @@ class CheckpointingTests(test.TestCase): else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. - root_checkpointable.save_counter # pylint: disable=pointless-statement + root_trackable.save_counter # pylint: disable=pointless-statement self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - save_path = root_checkpointable.save(file_prefix=prefix) + save_path = root_trackable.save(file_prefix=prefix) self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + self.evaluate(state_ops.assign(root_trackable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration - status = root_checkpointable.restore(save_path=save_path).assert_consumed() + status = root_trackable.restore(save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual(1, self.evaluate(root_trackable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly @@ -302,7 +302,7 @@ class CheckpointingTests(test.TestCase): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = util.Checkpoint( + root = util.CheckpointV1( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) @@ -440,7 +440,7 @@ class CheckpointingTests(test.TestCase): def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = tracking.AutoCheckpointable() + root = util.Checkpoint() root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -455,21 +455,17 @@ class CheckpointingTests(test.TestCase): util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = util.CheckpointableSaver(root).save( - os.path.join(checkpoint_directory, "no_slots")) + no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = util.CheckpointableSaver(root).save( - os.path.join(checkpoint_directory, "with_slots")) - new_root = tracking.AutoCheckpointable() + slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) + new_root = util.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = util.CheckpointableSaver( - new_root).restore(slots_path) - no_slot_status = util.CheckpointableSaver( - new_root).restore(no_slots_path) + slot_status = new_root.restore(slots_path) + no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = util.add_variable( @@ -508,15 +504,14 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.AutoCheckpointable() + obj = util.Checkpoint() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(util.gather_initializers(obj)) - saver = util.CheckpointableSaver(obj) - saver.save(checkpoint_prefix) + obj.save(checkpoint_prefix) before_ops = graph.get_operations() - saver.save(checkpoint_prefix) + obj.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) def testManyRestoresGraph(self): @@ -526,16 +521,15 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.AutoCheckpointable() + obj = util.Checkpoint() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(util.gather_initializers(obj)) - saver = util.CheckpointableSaver(obj) - save_path = saver.save(checkpoint_prefix) - saver.restore(save_path) + save_path = obj.save(checkpoint_prefix) + obj.restore(save_path) before_ops = graph.get_operations() - saver.restore(save_path) + obj.restore(save_path) self.assertEqual(before_ops, graph.get_operations()) def testMultipleGraphsNonSlotVariables(self): @@ -548,11 +542,11 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = util.Checkpoint( + first_root_trackable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate(util.gather_initializers( - first_root_checkpointable)) + first_root_trackable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate(optimizer.get_slot( @@ -564,23 +558,23 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = util.Checkpoint( + second_root_trackable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) - second_root_checkpointable.restore(None).initialize_or_restore() + second_root_trackable.restore(None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([5.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(6.)) - save_path = second_root_checkpointable.save(checkpoint_prefix) + save_path = second_root_trackable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([8.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta_1_power)) - status = second_root_checkpointable.restore(save_path) + status = second_root_trackable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( @@ -600,7 +594,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): @test_util.run_in_graph_and_eager_modes - def test_checkpointable_save_restore(self): + def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( @@ -647,13 +641,13 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable # with known values to check when loading. @@ -662,24 +656,24 @@ class CheckpointCompatibilityTests(test.TestCase): var=model._named_dense.bias, name="m").assign([2.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(3.)) - return root_checkpointable + return root_trackable - def _set_sentinels(self, root_checkpointable): - self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + def _set_sentinels(self, root_trackable): + self.evaluate(root_trackable.model._named_dense.bias.assign([101.])) self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m") + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m") .assign([102.])) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(103.)) - def _check_sentinels(self, root_checkpointable): + def _check_sentinels(self, root_trackable): self.assertAllEqual( - [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + [1.], self.evaluate(root_trackable.model._named_dense.bias)) self.assertAllEqual([2.], self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m"))) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m"))) + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta_1_power)) def _write_name_based_checkpoint(self): @@ -704,14 +698,13 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = util.CheckpointableSaver(root) + object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root)) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): self._check_sentinels(root) if context.executing_eagerly(): - with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): - status.assert_consumed() + status.assert_consumed() else: # When graph building, we haven't read any keys, so we don't know # whether the restore will be complete. @@ -733,10 +726,9 @@ class CheckpointCompatibilityTests(test.TestCase): with context.graph_mode(): save_graph = ops.Graph() with save_graph.as_default(), self.test_session( - graph=save_graph) as session: + graph=save_graph): root = self._initialized_model() - save_path = root.save( - session=session, file_prefix=checkpoint_prefix) + save_path = root.save(file_prefix=checkpoint_prefix) with context.eager_mode(): root = self._initialized_model() self._set_sentinels(root) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 1323ed014c9e51e273491694fa44a8e36cc723d0..436ece79a79810d4688e259523a4f86a1ca7f5a5 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,7 +24,6 @@ import abc import six -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop @@ -39,7 +38,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest @@ -224,7 +223,7 @@ class _OptimizerV2State(object): } self._slots = {} self._non_slot_dict = {} - # Extra state to help Optimizers implement Checkpointable. Holds information + # Extra state to help Optimizers implement Trackable. Holds information # about variables which will be restored as soon as they're created. self._deferred_dependencies = {} # Non-slot variables self._deferred_slot_restorations = {} # Slot variables @@ -367,8 +366,8 @@ class _OptimizerV2State(object): slot variable needs to be restored). Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. optional_op_name: Name to use when scoping the Variable that needs to be @@ -386,7 +385,7 @@ class _OptimizerV2State(object): # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access - initializer = checkpointable.CheckpointInitialValue( + initializer = trackable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self.create_slot( var=variable, @@ -661,7 +660,7 @@ class OptimizerV2(optimizer_v1.Optimizer): name=None, grad_loss=None, stop_gradients=None, - scale_loss_by_num_replicas=None): + scale_loss_by_num_replicas=False): """Add operations to minimize `loss` by updating `var_list`. This method simply combines calls `compute_gradients()` and @@ -685,8 +684,7 @@ class OptimizerV2(optimizer_v1.Optimizer): stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down - by the number of replicas. By default, auto-detects whether this is - needed. + by the number of replicas. DEPRECATED and generally no longer needed. Returns: An Operation that updates the variables in `var_list`. If `global_step` @@ -732,7 +730,7 @@ class OptimizerV2(optimizer_v1.Optimizer): aggregation_method=None, grad_loss=None, stop_gradients=None, - scale_loss_by_num_replicas=None): + scale_loss_by_num_replicas=False): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list @@ -756,8 +754,7 @@ class OptimizerV2(optimizer_v1.Optimizer): stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down - by the number of replicas. By default, auto-detects whether this is - needed. + by the number of replicas. DEPRECATED and generally no longer needed. Returns: A list of (gradient, variable) pairs. Variable is always present, but @@ -781,9 +778,7 @@ class OptimizerV2(optimizer_v1.Optimizer): tape.watch(var_list) loss_value = loss() - # Scale loss for number of replicas (callable-loss case). In this case, - # we have to be careful to call distribute_lib.get_loss_reduction() - # *after* loss() is evaluated, so we know what loss reduction it uses. + # Scale loss for number of replicas (callable-loss case). loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) if var_list is None: @@ -839,9 +834,6 @@ class OptimizerV2(optimizer_v1.Optimizer): @staticmethod def _scale_loss(loss_value, scale_loss_by_num_replicas): """Scale loss for the number of replicas.""" - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: @@ -973,7 +965,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # `update_op`. # TODO(josh11b): Make different state objects for each device to # avoid needing to set the device_policy. - device_policy = context.context().device_policy( + device_policy = context.device_policy( context.DEVICE_PLACEMENT_SILENT) with ops.name_scope("update_" + scope_name), device_policy: return processor.update_op(self, g, state) @@ -989,7 +981,7 @@ class OptimizerV2(optimizer_v1.Optimizer): def finish(): # TODO(josh11b): Make different state objects for each device to # avoid needing to set the device_policy. - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + with context.device_policy(context.DEVICE_PLACEMENT_SILENT): return self._finish(state) update_ops = control_flow_ops.group(update_ops) @@ -1267,10 +1259,10 @@ class OptimizerV2(optimizer_v1.Optimizer): return self._per_graph_state.get(var._graph_key, None) # -------------- - # Overridden methods from Checkpointable. + # Overridden methods from Trackable. # -------------- - def _track_checkpointable(self, *args, **kwargs): + def _track_trackable(self, *args, **kwargs): """Optimizers may not track dependencies. Raises an error.""" raise NotImplementedError( "Optimizers may not have dependencies. File a feature request if this " @@ -1278,7 +1270,7 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _checkpoint_dependencies(self): - """From Checkpointable. Gather graph-specific non-slot variables to save.""" + """From Trackable. Gather graph-specific non-slot variables to save.""" current_graph_non_slot_variables = [] state = self._get_per_graph_state() if state is not None: @@ -1287,14 +1279,14 @@ class OptimizerV2(optimizer_v1.Optimizer): # Avoid comparing variables key=lambda item: item[0]): current_graph_non_slot_variables.append( - checkpointable.CheckpointableReference( + trackable.TrackableReference( name=name, ref=variable_object)) # Note: ignores super(); Optimizers may not have any dependencies outside of # state objects. return current_graph_non_slot_variables def _lookup_dependency(self, name): - """From Checkpointable. Find a non-slot variable in the current graph.""" + """From Trackable. Find a non-slot variable in the current graph.""" state = self._get_per_graph_state() if state is None: return None @@ -1303,10 +1295,10 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _deferred_dependencies(self): - """Lets Checkpointable know where non-slot variables are created. + """Lets Trackable know where non-slot variables are created. If necessary, creates a new state object for the current default graph. - Checkpointable will then add entries to that state's deferred dependency + Trackable will then add entries to that state's deferred dependency dictionary. The state object will check that dictionary when creating non-slot variables, restoring their value if an entry is found. @@ -1319,14 +1311,14 @@ class OptimizerV2(optimizer_v1.Optimizer): def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, variable): - """Checkpointable: Restore a slot variable's value, possibly creating it. + """Trackable: Restore a slot variable's value, possibly creating it. Called when a variable which has an associated slot variable is created or restored. Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. """ diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index dd7f2f44055a2e48e8a48d01c1da3a8e7513255d..2fc0b5ea4de2332ff3bf32f9a12a15eee566d5c4 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -71,7 +71,7 @@ class OptimizerTest(test.TestCase): opt_op = sgd_op.minimize( cost, global_step, [var0, var1], - aggregation_method=gradients_impl.AggregationMethod. + aggregation_method=gradients_util.AggregationMethod. EXPERIMENTAL_ACCUMULATE_N) variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py index 202c1e9afc0623a5837aa82480f1b406834007ee..ab47b74c65a6ddce3ace4b56d12ecd416cc74e54 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -25,10 +25,12 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.optimizer_v2 import rmsprop +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -448,5 +450,56 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): ]), var1.eval()) +class SlotColocationTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([True, False]) + @test_util.run_in_graph_and_eager_modes + def testRunMinimizeOnGPUForCPUVariables(self, use_resource): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + + with ops.device("/device:CPU:0"): + if use_resource: + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], + dtype=dtypes.float32) + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step") + else: + var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) + var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name="global_step") + + def loss(): + return 5 * var0 + 3 * var1 + + opt = rmsprop.RMSPropOptimizer( + learning_rate=1.0, decay=0.9, momentum=0.5, epsilon=1.0) + + # Fetch params to validate initial values + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 1 step through optimizer on GPU. + # Slot variables are created the first time optimizer is used on some + # variable. This tests that slot variables will be colocated with the base + # variable. + with ops.device("/device:GPU:0"): + # Note that for eager execution, minimize expects a function instead of a + # Tensor. + opt_op = opt.minimize(loss, global_step, [var0, var1]) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt_op) + + # Validate updated params, All variables should have decreased. + self.assertTrue(all(v < 0.0 for v in self.evaluate(var0)), + msg="updated variables: %s" % self.evaluate(var0)) + self.assertTrue(all(v < 2.0 for v in self.evaluate(var1)), + msg="updated variables: %s" % self.evaluate(var1)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py index 17b69c7b35dce130c45ab0aadb28be330b4bfb88..13749837e0cd2ed9dec3748ad5209088c1b3fdd9 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -84,7 +84,10 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): values = field_dict[field.name] self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - fd = field.value.DESCRIPTOR.fields_by_name[field.name] + if 'ext_value' in field.name: + fd = test_example_pb2.PrimitiveValue() + else: + fd = field.value.DESCRIPTOR.fields_by_name[field.name] # Values has the same shape as the input plus an extra # dimension for repeats. @@ -92,13 +95,16 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # Nested messages are represented as TF strings, requiring # some special handling. - if field.name == 'message_value': + if field.name == 'message_value' or 'ext_value' in field.name: vs = [] for buf in values.flat: msg = test_example_pb2.PrimitiveValue() msg.ParseFromString(buf) vs.append(msg) - evs = getattr(field.value, field.name) + if 'ext_value' in field.name: + evs = field.value.Extensions[test_example_pb2.ext_value] + else: + evs = getattr(field.value, field.name) if len(vs) != len(evs): self.fail('Field %s decoded %d outputs, expected %d' % (fd.name, len(vs), len(evs))) @@ -223,7 +229,8 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): sanitize=False, force_disordered=True) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testPacked(self, case): # Now try with the packed serialization. # @@ -235,8 +242,7 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # Note: float_format='.17g' is necessary to ensure preservation of # doubles and floats in text format. text_format.Parse( - text_format.MessageToString( - value, float_format='.17g'), + text_format.MessageToString(value, float_format='.17g'), test_example_pb2.PackedTestValue()).SerializeToString() for value in case.values ] @@ -290,14 +296,13 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): field_names = ['sizes'] field_types = [dtypes.int32] - with self.cached_session() as sess: - ctensor, vtensor = self._decode_module.decode_proto( - batch, - message_type=msg_type, - field_names=field_names, - output_types=field_types, - sanitize=sanitize) - with self.assertRaisesRegexp(errors.DataLossError, - 'Unable to parse binary protobuf' - '|Failed to consume entire buffer'): - _ = sess.run([ctensor] + vtensor) + with self.assertRaisesRegexp( + errors.DataLossError, 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + self.evaluate( + self._decode_module.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize)) diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py index 01b3ccc7fd3918c4ff910281289e31177e5a8097..fac2453527dde46fe9ee065f7112f02a285823ea 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -15,9 +15,6 @@ # ============================================================================= """Table-driven test for encode_proto op. -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - It tests that encode_proto is a lossless inverse of decode_proto (for the specified fields). """ @@ -33,7 +30,9 @@ from google.protobuf import text_format from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -53,56 +52,86 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): self._decode_module = decode_module self._encode_module = encode_module + def testBadSizesShape(self): + if context.executing_eagerly(): + expected_error = (errors.InvalidArgumentError, + r'Invalid shape for field double_value.') + else: + expected_error = (ValueError, + r'Shape must be at least rank 2 but is rank 0') + with self.assertRaisesRegexp(*expected_error): + self.evaluate( + self._encode_module.encode_proto( + sizes=1, + values=[np.double(1.0)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value'])) + def testBadInputs(self): # Invalid field name - with self.cached_session(): - with self.assertRaisesOpError('Unknown field: non_existent_field'): - self._encode_module.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['non_existent_field']).eval() + with self.assertRaisesOpError('Unknown field: non_existent_field'): + self.evaluate( + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['non_existent_field'])) # Incorrect types. - with self.cached_session(): - with self.assertRaisesOpError( - 'Incompatible type for field double_value.'): - self._encode_module.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['double_value']).eval() + with self.assertRaisesOpError('Incompatible type for field double_value.'): + self.evaluate( + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value'])) # Incorrect shapes of sizes. - with self.cached_session(): + for sizes_value in 1, np.array([[[0, 0]]]): with self.assertRaisesOpError( r'sizes should be batch_size \+ \[len\(field_names\)\]'): - sizes = array_ops.placeholder(dtypes.int32) - values = array_ops.placeholder(dtypes.float64) - self._encode_module.encode_proto( - sizes=sizes, - values=[values], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['double_value']).eval(feed_dict={ - sizes: [[[0, 0]]], - values: [[0.0]] - }) + if context.executing_eagerly(): + self.evaluate( + self._encode_module.encode_proto( + sizes=sizes_value, + values=[np.array([[0.0]])], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value'])) + else: + with self.cached_session(): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + self._encode_module.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval(feed_dict={ + sizes: sizes_value, + values: [[0.0]] + }) # Inconsistent shapes of values. - with self.cached_session(): - with self.assertRaisesOpError( - 'Values must match up to the last dimension'): - sizes = array_ops.placeholder(dtypes.int32) - values1 = array_ops.placeholder(dtypes.float64) - values2 = array_ops.placeholder(dtypes.int32) - (self._encode_module.encode_proto( - sizes=[[1, 1]], - values=[values1, values2], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['double_value', 'int32_value']).eval(feed_dict={ - values1: [[0.0]], - values2: [[0], [0]] - })) + with self.assertRaisesOpError('Values must match up to the last dimension'): + if context.executing_eagerly(): + self.evaluate( + self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[np.array([[0.0]]), + np.array([[0], [0]])], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value'])) + else: + with self.cached_session(): + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) def _testRoundtrip(self, in_bufs, message_type, fields): @@ -145,7 +174,8 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # loss of packing in the encoding). self.assertEqual(in_buf, out_buf) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testRoundtrip(self, case): in_bufs = [value.SerializeToString() for value in case.values] @@ -154,7 +184,8 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): return self._testRoundtrip( in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testRoundtripPacked(self, case): # Now try with the packed serialization. # We test the packed representations by loading the same test cases using diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py index 2950c7dfdc59a11ba7d2c07d8406bd4af26b5bd9..1a636486a1765ad9544b5cb5e52961cc47f92950 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -38,17 +38,18 @@ class ProtoOpTestBase(test.TestCase): ct.cdll.LoadLibrary(lib) @staticmethod - def named_parameters(): - return ( - ("defaults", ProtoOpTestBase.defaults_test_case()), - ("minmax", ProtoOpTestBase.minmax_test_case()), - ("nested", ProtoOpTestBase.nested_test_case()), - ("optional", ProtoOpTestBase.optional_test_case()), - ("promote", ProtoOpTestBase.promote_test_case()), - ("ragged", ProtoOpTestBase.ragged_test_case()), - ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), - ("simple", ProtoOpTestBase.simple_test_case()), - ) + def named_parameters(extension=True): + parameters = [("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case())] + if extension: + parameters.append(("extension", ProtoOpTestBase.extension_test_case())) + return parameters @staticmethod def defaults_test_case(): @@ -399,6 +400,21 @@ class ProtoOpTestBase(test.TestCase): field.value.bool_value.append(True) return test_case + @staticmethod + def extension_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.Extensions[test_example_pb2.ext_value].add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = test_example_pb2.ext_value.full_name + field.dtype = types_pb2.DT_STRING + message_value = field.value.Extensions[test_example_pb2.ext_value].add() + message_value.double_value = 23.5 + return test_case + @staticmethod def simple_test_case(): test_case = test_example_pb2.TestCase() diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index 674d881220a1113631def47c5111e3ef401b99f3..b1ce66de4feb9c6666ca9ccf39403b4e12840fcf 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -61,6 +61,8 @@ message TestValue { optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; optional sint32 sint32_value_with_default = 33 [default = 12]; optional sint64 sint64_value_with_default = 34 [default = 13]; + + extensions 100 to 199; } // A PackedTestValue looks exactly the same as a TestValue in the text format, @@ -68,7 +70,7 @@ message TestValue { // by loading the same test cases using this definition instead of TestValue. // // NOTE: This definition must be kept in sync with TestValue in every way except -// the packed=true declaration. +// the packed=true declaration and the lack of extensions. message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; @@ -132,6 +134,10 @@ message ExtraFields { optional bool bool_value = 1777; } +extend TestValue { + repeated PrimitiveValue ext_value = 100; +} + // The messages below are for yet-to-be created tests. message EnumValue { diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b35c4fde1a2c704880e023a0c3ac1e0766493514..b67e68ea96a15f94e62050c92405eec4fe4be70f 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -202,8 +202,9 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "large", + size = "medium", srcs = ["python/quantize_parameterized_test.py"], + shard_count = 4, srcs_version = "PY2AND3", # TODO(b/118839526): Re-enable msan test. tags = [ diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 5b8da92491fb747c5a37dcfe03bcb21b5b903560..b335e1af69b7b2e6020f8e745c43bb1bdc95a62d 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -8,9 +8,9 @@ for both training and inference. There are two aspects to this: For efficient inference, TensorFlow combines batch normalization with the preceding convolutional and fully-connected layers prior to quantization by -[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. +[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. -The quantization error is modeled using [fake quantization](../api_guides/python/array_ops.md#Fake_quantization) +The quantization error is modeled using [fake quantization](../../api_guides/python/array_ops.md#Fake_quantization) nodes to simulate the effect of quantization in the forward and backward passes. The forward-pass models quantization, while the backward-pass models quantization as a straight-through estimator. Both the forward- and backward-pass simulate the quantization @@ -105,7 +105,7 @@ toco \ --std_value=127.5 --mean_value=127.5 ``` -See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/). +See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../../lite/). ## Quantized accuracy results diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index e0c6da00d86fe4c5f881bcab7b444182da092b8f..a70f748fad60c6467946225ad5035caaf89c2aaf 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -454,7 +454,7 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, strides=layer_op.get_attr('strides'), padding=layer_op.get_attr('padding'), use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), - data_format=layer_op.get_attr('data_format'), + data_format=layer_op.get_attr('data_format').decode(), name=new_layer_name) elif layer_op.type == 'MatMul': return math_ops.matmul( @@ -867,7 +867,7 @@ class _OpCloner(object): strides=op.get_attr('strides'), padding=op.get_attr('padding'), use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'), - data_format=op.get_attr('data_format'), + data_format=op.get_attr('data_format').decode(), name=new_name).op def _CloneDepthwiseConv2d(self, op, inputs, new_name): diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 8619708cdaecd78bcc7de0e8e0cbf2baa11bf6a2..39082cacf9770619cf5fb529ac9a0aad6e955c6d 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -224,8 +224,8 @@ def MovingAvgQuantize(inputs, None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: scope.set_partitioner(None) input_shape = inputs.get_shape() - input_dim = len(input_shape) if per_channel: + input_dim = len(input_shape) # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' ' scope: %s' % (input_shape, name_prefix)) diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index 36d2af94e059cdc75b758bbf607d26c4e1ee73e9..c636c90d23a0f5a6de9d14085c824283cb41f6ca 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -63,6 +63,12 @@ class QuantOpsTest(googletest.TestCase): self.assertAlmostEqual(min_value, -0.5, delta=1e-3) self.assertAlmostEqual(max_value, 0.5, delta=1e-3) + def testMovingAvgQuantizeTrainingAssignNoShape(self): + min_value, max_value = self._GetMinMaxValues( + quant_ops.MovingAvgQuantize, [[-1, 1], [0, 0]], shape=None) + self.assertAlmostEqual(min_value, -0.5, delta=1e-3) + self.assertAlmostEqual(max_value, 0.5, delta=1e-3) + def testMovingAvgSymmetricQuantizeTrainingAssign(self): min_value, max_value = self._GetMinMaxValues( quant_ops.MovingAvgQuantize, [[-1, 0.5], [0, 0]], symmetric=True) @@ -109,10 +115,10 @@ class QuantOpsTest(googletest.TestCase): is_training=True, vars_collection=_MIN_MAX_VARS) - def _GetMinMaxValues(self, quantize_fn, input_values, **kwds): + def _GetMinMaxValues(self, quantize_fn, input_values, shape=(2), **kwds): g = ops.Graph() with session.Session(graph=g) as sess: - x = array_ops.placeholder(dtypes.float32, shape=[2]) + x = array_ops.placeholder(dtypes.float32, shape=shape) y = quantize_fn( x, init_min=0.0, diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py index f51de755d81b74e39a26551a282a1f2a47557ebc..b8540258e63016bce6f5d7c0cd02d9ceef1255ed 100644 --- a/tensorflow/contrib/recurrent/python/ops/recurrent.py +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -100,7 +100,7 @@ def _Update(struct_acc, struct_x, t): to_skip_update = set() acc_lst = nest.flatten(struct_acc) x_lst = nest.flatten(struct_x) - t = math_ops.to_int32([t]) # tf.to_int32 casts on-device tensors. + t = math_ops.cast([t], dtypes.int32) # tf.to_int32 casts on-device tensors. lst = [] for acc, x in zip(acc_lst, x_lst): if acc in to_skip_update: @@ -429,7 +429,8 @@ class _Recurrent(object): acc_extras = _EmptyAcc(slen_dim, extras) t = slen_dim - max_input_length if self._aligned_end else 0 - dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) + dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( + t, dtypes.int64) run = functional_ops.For( start=t, limit=slen_dim if self._aligned_end else max_input_length, @@ -568,7 +569,8 @@ class _Recurrent(object): # Loop backwards. Note the loop's limit is open-ended, so goes through # t=0. t = slen_dim - 1 if self._aligned_end else max_input_length - 1 - dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) + dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( + t, dtypes.int64) limit = slen_dim - max_input_length - 1 if self._aligned_end else -1 run = functional_ops.For( start=t, diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc index 204b83f7f5f118f418815edb6c482b1c06673845..13fbd974e9ce6a680a31507f7f49df17d121535f 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc @@ -77,10 +77,10 @@ namespace functor { sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop, \ 0, 0); \ \ - ReduceSliceDeviceKernel##reduceop \ - <<>>( \ - config, indices_width, bound, beginning(), indices.data(), \ - data.data(), output.data()); \ + TF_CHECK_OK(CudaLaunchKernel( \ + ReduceSliceDeviceKernel##reduceop, config.block_count, \ + config.thread_per_block, 0, d.stream(), config, indices_width, \ + bound, beginning(), indices.data(), data.data(), output.data())); \ } \ }; diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py index 2054367f0d1461c8868e3332d82322a8a3dd38af..7e79785d2867de586f0730373d4864602ef770ae 100644 --- a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py @@ -50,13 +50,13 @@ def remote_fused_graph_execute(inputs, if default_graph_input_tensor_type_shapes: for type_shape in default_graph_input_tensor_type_shapes: type_shape_proto = info_proto.default_graph_input_tensor_shape.add() - type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum for dim in type_shape[1]: type_shape_proto.shape.dim.add().size = dim if default_graph_output_tensor_type_shapes: for type_shape in default_graph_output_tensor_type_shapes: type_shape_proto = info_proto.default_graph_output_tensor_shape.add() - type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum for dim in type_shape[1]: type_shape_proto.shape.dim.add().size = dim diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc index 3c07051f685c74b6e45fb782c80871f38dffbbf4..3b2ee098b3e24287298273a04f80e41f6d9dcd86 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc @@ -119,10 +119,10 @@ struct Resampler2DFunctor { batch_size * num_sampling_points * data_channels; ::tensorflow::CudaLaunchConfig config = ::tensorflow::GetCudaLaunchConfig(output_data_size, d); - Resampler2DKernel - <<>>( - data, warp, output, batch_size, data_height, data_width, - data_channels, num_sampling_points); + TF_CHECK_OK(CudaLaunchKernel( + Resampler2DKernel, config.block_count, config.thread_per_block, 0, + d.stream(), data, warp, output, batch_size, data_height, data_width, + data_channels, num_sampling_points)); } }; @@ -254,22 +254,23 @@ struct ResamplerGrad2DFunctor { ::tensorflow::CudaLaunchConfig config = ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); - ::tensorflow:: - SetZero<<>>( - grad_warp_size, grad_warp); + TF_CHECK_OK(::tensorflow::CudaLaunchKernel( + SetZero, config.block_count, config.thread_per_block, 0, d.stream(), + grad_warp_size, grad_warp)); config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); - ::tensorflow:: - SetZero<<>>( - grad_data_size, grad_data); + TF_CHECK_OK(::tensorflow::CudaLaunchKernel( + SetZero, config.block_count, config.thread_per_block, 0, d.stream(), + grad_data_size, grad_data)); const int resampler_output_size = batch_size * num_sampling_points * data_channels; config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); - ResamplerGrad2DKernel - <<>>( - data, warp, grad_output, grad_data, grad_warp, batch_size, - data_height, data_width, data_channels, num_sampling_points); + TF_CHECK_OK(CudaLaunchKernel(ResamplerGrad2DKernel, config.block_count, + config.thread_per_block, 0, d.stream(), data, + warp, grad_output, grad_data, grad_warp, + batch_size, data_height, data_width, + data_channels, num_sampling_points)); } }; diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 39b688596875ab1b208d97a5d6f9a5ee811674cb..24fa740d24502a28cb42c994715d09180ee99899 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -102,26 +102,6 @@ cuda_py_tests( xla_enabled = True, ) -cuda_py_tests( - name = "core_rnn_cell_test", - size = "medium", - srcs = ["python/kernel_tests/core_rnn_cell_test.py"], - additional_deps = [ - ":rnn_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@absl_py//absl/testing:parameterized", - ], -) - cuda_py_tests( name = "rnn_test", size = "medium", @@ -144,32 +124,6 @@ cuda_py_tests( ], ) -cuda_py_tests( - name = "core_rnn_test", - size = "medium", - srcs = ["python/kernel_tests/core_rnn_test.py"], - additional_deps = [ - ":rnn_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:rnn", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - ], - shard_count = 10, -) - tf_py_test( name = "fused_rnn_cell_test", size = "medium", @@ -388,6 +342,13 @@ py_binary( name = "checkpoint_convert", srcs = ["python/tools/checkpoint_convert.py"], srcs_version = "PY2AND3", + deps = [":checkpoint_convert_lib"], +) + +py_library( + name = "checkpoint_convert_lib", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_ops", @@ -406,7 +367,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":checkpoint_convert", + ":checkpoint_convert_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:session", diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index 15ae95f13cffa5d1469d737b23f2a83b9e5a694f..81beb2942c183e6a831b64e946fea89c050b88db 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -242,8 +242,9 @@ void LSTMBlockCellFpropWithCUDA( const int block_dim = 128; const int grid_dim = Eigen::divup(batch_size * (cell_size + input_size), block_dim); - concat_xh<<>>( - xh.data(), x.data(), h_prev.data(), batch_size, cell_size, input_size); + TF_CHECK_OK(CudaLaunchKernel(concat_xh, grid_dim, block_dim, 0, cu_stream, + xh.data(), x.data(), h_prev.data(), batch_size, + cell_size, input_size)); // states1 = xh * w typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); @@ -261,15 +262,17 @@ void LSTMBlockCellFpropWithCUDA( Eigen::divup(cell_size, static_cast(block_dim_2d.y))); if (use_peephole) { - lstm_gates<<>>( + TF_CHECK_OK(CudaLaunchKernel( + lstm_gates, grid_dim_2d, block_dim_2d, 0, cu_stream, icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), - i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size)); } else { - lstm_gates<<>>( + TF_CHECK_OK(CudaLaunchKernel( + lstm_gates, grid_dim_2d, block_dim_2d, 0, cu_stream, icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), - i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size)); } } @@ -374,12 +377,13 @@ void LSTMBlockCellBpropWithCUDA( dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), Eigen::divup(cell_size, static_cast(block_dim_2d.y))); - lstm_gates_bprop<<>>( + TF_CHECK_OK(CudaLaunchKernel( + lstm_gates_bprop, grid_dim_2d, block_dim_2d, 0, cu_stream, cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(), wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(), co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(), dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(), - batch_size, cell_size, use_peephole); + batch_size, cell_size, use_peephole)); if (use_peephole) { Eigen::array p_shape({1, cell_size}); diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py deleted file mode 100644 index 7bad4a60a149011d5b8d745f45359fd25473e54e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ /dev/null @@ -1,1210 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for RNN cells.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib import rnn as contrib_rnn -from tensorflow.contrib.rnn.python.ops import core_rnn_cell -from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util -from tensorflow.python.keras import layers as keras_layers -from tensorflow.python.layers import base as base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell_impl -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import util as checkpointable_utils - -# pylint: enable=protected-access -Linear = core_rnn_cell._Linear # pylint: disable=invalid-name - - -class RNNCellTest(test.TestCase): - - def testLinear(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(1.0)): - x = array_ops.zeros([1, 2]) - l = Linear([x], 2, False)([x]) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([l], {x.name: np.array([[1., 2.]])}) - self.assertAllClose(res[0], [[3.0, 3.0]]) - - # Checks prevent you from accidentally creating a shared function. - with self.assertRaises(ValueError): - l1 = Linear([x], 2, False)([x]) - - # But you can create a new one in a new scope and share the variables. - with variable_scope.variable_scope("l1") as new_scope: - l1 = Linear([x], 2, False)([x]) - with variable_scope.variable_scope(new_scope, reuse=True): - Linear([l1], 2, False)([l1]) - self.assertEqual(len(variables_lib.trainable_variables()), 2) - - def testBasicRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = rnn_cell_impl.BasicRNNCell(2) - g, _ = cell(x, m) - self.assertEqual([ - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testBasicRNNCellNotTrainable(self): - with self.cached_session() as sess: - - def not_trainable_getter(getter, *args, **kwargs): - kwargs["trainable"] = False - return getter(*args, **kwargs) - - with variable_scope.variable_scope( - "root", - initializer=init_ops.constant_initializer(0.5), - custom_getter=not_trainable_getter): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = rnn_cell_impl.BasicRNNCell(2) - g, _ = cell(x, m) - self.assertFalse(cell.trainable_variables) - self.assertEqual([ - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.non_trainable_variables]) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testIndRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = contrib_rnn_cell.IndRNNCell(2) - g, _ = cell(x, m) - self.assertEqual([ - "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testGRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = rnn_cell_impl.GRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.175991, 0.175991]]) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test GRUCell with input_size != num_units. - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = rnn_cell_impl.GRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.156736, 0.156736]]) - - def testIndyGRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.185265, 0.17704]]) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test IndyGRUCell with input_size != num_units. - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.155127, 0.157328]]) - - def testSRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.SRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.509682, 0.509682]]) - - def testSRUCellKerasRNN(self): - """Tests that SRUCell works with keras RNN layer.""" - cell = contrib_rnn_cell.SRUCell(10) - seq_input = ops.convert_to_tensor( - np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) - rnn_layer = keras_layers.RNN(cell=cell) - rnn_outputs_keras = rnn_layer(seq_input) - with self.cached_session() as sess: - sess.run([variables_lib.global_variables_initializer()]) - self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) - - def testSRUCellBiasType(self): - """Tests that the bias' dtype is properly set.""" - cell = contrib_rnn_cell.SRUCell(10) - cell.build((2, 3, 5)) - self.assertEqual(cell._bias.dtype, dtypes.float32_ref) - - cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) - cell.build((2, 3, 5)) - self.assertEqual(cell._bias.dtype, dtypes.int32_ref) - - cell_input = ops.convert_to_tensor( - np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) - cell_state = ops.convert_to_tensor( - np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) - cell = contrib_rnn_cell.SRUCell(10) - cell(cell_input, [cell_state]) - self.assertEqual(cell._bias.dtype, dtypes.float16_ref) - - def testSRUCellWithDiffSize(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.SRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) - - def testBasicLSTMCell(self): - for dtype in [dtypes.float16, dtypes.float32]: - np_dtype = dtype.as_numpy_dtype - with self.session(graph=ops.Graph()) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2], dtype=dtype) - m = array_ops.zeros([1, 8], dtype=dtype) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=False) - self.assertEqual(cell.dtype, None) - self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) - self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) - cell.get_config() # Should not throw an error - g, out_m = cell(x, m) - # Layer infers the input type. - self.assertEqual(cell.dtype, dtype.name) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual(expected_variable_names, - [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8]) - }) - self.assertEqual(len(res), 2) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # The numbers in results were not calculated, this is just a - # smoke test. - self.assertAllClose(res[0], np.array( - [[0.240, 0.240]], dtype=np_dtype), 1e-2) - expected_mem = np.array( - [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], - dtype=np_dtype) - self.assertAllClose(res[1], expected_mem, 1e-2) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test BasicLSTMCell with input_size != num_units. - x = array_ops.zeros([1, 3], dtype=dtype) - m = array_ops.zeros([1, 4], dtype=dtype) - g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], { - x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - m.name: 0.1 * np.ones([1, 4], dtype=np_dtype) - }) - self.assertEqual(len(res), 2) - - def testBasicLSTMCellDimension0Error(self): - """Tests that dimension 0 in both(x and m) shape must be equal.""" - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - num_units = 2 - state_size = num_units * 2 - batch_size = 3 - input_size = 4 - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size - 1, state_size]) - with self.assertRaises(ValueError): - g, out_m = rnn_cell_impl.BasicLSTMCell( - num_units, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - sess.run( - [g, out_m], { - x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size - 1, state_size]) - }) - - def testBasicLSTMCellStateSizeError(self): - """Tests that state_size must be num_units * 2.""" - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - num_units = 2 - state_size = num_units * 3 # state_size must be num_units * 2 - batch_size = 3 - input_size = 4 - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - with self.assertRaises(ValueError): - g, out_m = rnn_cell_impl.BasicLSTMCell( - num_units, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - sess.run( - [g, out_m], { - x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size, state_size]) - }) - - def testBasicLSTMCellStateTupleType(self): - with self.cached_session(): - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m0 = (array_ops.zeros([1, 2]),) * 2 - m1 = (array_ops.zeros([1, 2]),) * 2 - cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], - state_is_tuple=True) - self.assertTrue(isinstance(cell.state_size, tuple)) - self.assertTrue( - isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple)) - self.assertTrue( - isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple)) - - # Pass in regular tuples - _, (out_m0, out_m1) = cell(x, (m0, m1)) - self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) - - # Pass in LSTMStateTuples - variable_scope.get_variable_scope().reuse_variables() - zero_state = cell.zero_state(1, dtypes.float32) - self.assertTrue(isinstance(zero_state, tuple)) - self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple)) - _, (out_m0, out_m1) = cell(x, zero_state) - self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) - - def testBasicLSTMCellWithStateTuple(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m0 = array_ops.zeros([1, 4]) - m1 = array_ops.zeros([1, 4]) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=True) - g, (out_m0, out_m1) = cell(x, (m0, m1)) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m0, out_m1], { - x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([1, 4]), - m1.name: 0.1 * np.ones([1, 4]) - }) - self.assertEqual(len(res), 3) - # The numbers in results were not calculated, this is just a smoke test. - # Note, however, these values should match the original - # version having state_is_tuple=False. - self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) - expected_mem0 = np.array( - [[0.68967271, 0.68967271, 0.44848421, 0.44848421]]) - expected_mem1 = np.array( - [[0.39897051, 0.39897051, 0.24024698, 0.24024698]]) - self.assertAllClose(res[1], expected_mem0) - self.assertAllClose(res[2], expected_mem1) - - def testIndyLSTMCell(self): - for dtype in [dtypes.float16, dtypes.float32]: - np_dtype = dtype.as_numpy_dtype - with self.session(graph=ops.Graph()) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2], dtype=dtype) - state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - cell = rnn_cell_impl.MultiRNNCell( - [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) - self.assertEqual(cell.dtype, None) - self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) - self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) - cell.get_config() # Should not throw an error - g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) - # Layer infers the input type. - self.assertEqual(cell.dtype, dtype.name) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual(expected_variable_names, - [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_state_0, out_state_1], { - x.name: np.array([[1., 1.]]), - state_0[0].name: 0.1 * np.ones([1, 2]), - state_0[1].name: 0.1 * np.ones([1, 2]), - state_1[0].name: 0.1 * np.ones([1, 2]), - state_1[1].name: 0.1 * np.ones([1, 2]), - }) - self.assertEqual(len(res), 3) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # Only check the range of outputs as this is just a smoke test. - self.assertAllInRange(res[0], -1.0, 1.0) - self.assertAllInRange(res[1], -1.0, 1.0) - self.assertAllInRange(res[2], -1.0, 1.0) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test IndyLSTMCell with input_size != num_units. - x = array_ops.zeros([1, 3], dtype=dtype) - state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_state], { - x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), - state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), - }) - self.assertEqual(len(res), 2) - - def testLSTMCell(self): - with self.cached_session() as sess: - num_units = 8 - num_proj = 6 - state_size = num_units + num_proj - batch_size = 3 - input_size = 2 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - cell = rnn_cell_impl.LSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - state_is_tuple=False) - output, state = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [output, state], { - x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), - m.name: 0.1 * np.ones((batch_size, state_size)) - }) - self.assertEqual(len(res), 2) - # The numbers in results were not calculated, this is mostly just a - # smoke test. - self.assertEqual(res[0].shape, (batch_size, num_proj)) - self.assertEqual(res[1].shape, (batch_size, state_size)) - # Different inputs so different outputs and states - for i in range(1, batch_size): - self.assertTrue( - float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) - self.assertTrue( - float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) - - def testLSTMCellVariables(self): - with self.cached_session(): - num_units = 8 - num_proj = 6 - state_size = num_units + num_proj - batch_size = 3 - input_size = 2 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - cell = rnn_cell_impl.LSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - state_is_tuple=False) - cell(x, m) # Execute to create variables - variables = variables_lib.global_variables() - self.assertEquals(variables[0].op.name, "root/lstm_cell/kernel") - self.assertEquals(variables[1].op.name, "root/lstm_cell/bias") - self.assertEquals(variables[2].op.name, - "root/lstm_cell/projection/kernel") - - def testLSTMCellLayerNorm(self): - with self.cached_session() as sess: - num_units = 2 - num_proj = 3 - batch_size = 1 - input_size = 4 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - c = array_ops.zeros([batch_size, num_units]) - h = array_ops.zeros([batch_size, num_proj]) - state = rnn_cell_impl.LSTMStateTuple(c, h) - cell = contrib_rnn_cell.LayerNormLSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - layer_norm=True, - norm_gain=1.0, - norm_shift=0.0) - g, out_m = cell(x, state) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], { - x.name: np.ones((batch_size, input_size)), - c.name: 0.1 * np.ones((batch_size, num_units)), - h.name: 0.1 * np.ones((batch_size, num_proj)) - }) - self.assertEqual(len(res), 2) - # The numbers in results were not calculated, this is mostly just a - # smoke test. - self.assertEqual(res[0].shape, (batch_size, num_proj)) - self.assertEqual(res[1][0].shape, (batch_size, num_units)) - self.assertEqual(res[1][1].shape, (batch_size, num_proj)) - # Different inputs so different outputs and states - for i in range(1, batch_size): - self.assertTrue( - float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) - self.assertTrue( - float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testWrapperCheckpointing(self): - for wrapper_type in [ - rnn_cell_impl.DropoutWrapper, - rnn_cell_impl.ResidualWrapper, - lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: - cell = rnn_cell_impl.BasicRNNCell(1) - wrapper = wrapper_type(cell) - wrapper(array_ops.ones([1, 1]), - state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) - self.evaluate([v.initializer for v in cell.variables]) - checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(cell._bias.assign([40.])) - save_path = checkpoint.save(prefix) - self.evaluate(cell._bias.assign([0.])) - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - self.assertAllEqual([40.], self.evaluate(cell._bias)) - - def testOutputProjectionWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - cell = contrib_rnn.OutputProjectionWrapper(rnn_cell_impl.GRUCell(3), 2) - g, new_m = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 3)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.231907, 0.231907]]) - - def testInputProjectionWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 3]) - cell = contrib_rnn.InputProjectionWrapper( - rnn_cell_impl.GRUCell(3), num_proj=3) - g, new_m = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 3)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) - - def testResidualWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - base_cell = rnn_cell_impl.GRUCell(3) - g, m_new = base_cell(x, m) - variable_scope.get_variable_scope().reuse_variables() - wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) - (name, dep), = wrapper_object._checkpoint_dependencies - wrapper_object.get_config() # Should not throw an error - self.assertIs(dep, base_cell) - self.assertEqual("cell", name) - - g_res, m_new_res = wrapper_object(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, g_res, m_new, m_new_res], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.1, 0.1]]) - }) - # Residual connections - self.assertAllClose(res[1], res[0] + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res[2], res[3]) - - def testResidualWrapperWithSlice(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 5]) - m = array_ops.zeros([1, 3]) - base_cell = rnn_cell_impl.GRUCell(3) - g, m_new = base_cell(x, m) - variable_scope.get_variable_scope().reuse_variables() - - def residual_with_slice_fn(inp, out): - inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) - return inp_sliced + out - - g_res, m_new_res = rnn_cell_impl.ResidualWrapper( - base_cell, residual_with_slice_fn)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res_g, res_g_res, res_m_new, res_m_new_res = sess.run( - [g, g_res, m_new, m_new_res], { - x: np.array([[1., 1., 1., 1., 1.]]), - m: np.array([[0.1, 0.1, 0.1]]) - }) - # Residual connections - self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res_m_new, res_m_new_res) - - def testDeviceWrapper(self): - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - wrapped = rnn_cell_impl.GRUCell(3) - cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") - (name, dep), = cell._checkpoint_dependencies - cell.get_config() # Should not throw an error - self.assertIs(dep, wrapped) - self.assertEqual("cell", name) - - outputs, _ = cell(x, m) - self.assertTrue("cpu:14159" in outputs.device.lower()) - - def _retrieve_cpu_gpu_stats(self, run_metadata): - cpu_stats = None - gpu_stats = None - step_stats = run_metadata.step_stats - for ds in step_stats.dev_stats: - if "cpu:0" in ds.device[-5:].lower(): - cpu_stats = ds.node_stats - if "gpu:0" == ds.device[-5:].lower(): - gpu_stats = ds.node_stats - return cpu_stats, gpu_stats - - def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): - if not test.is_gpu_available(): - # Can't perform this test w/o a GPU - return - - gpu_dev = test.gpu_device_name() - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 1, 3]) - cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev) - with ops.device("/cpu:0"): - outputs, _ = rnn.dynamic_rnn( - cell=cell, inputs=x, dtype=dtypes.float32) - run_metadata = config_pb2.RunMetadata() - opts = config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE) - - sess.run([variables_lib.global_variables_initializer()]) - _ = sess.run(outputs, options=opts, run_metadata=run_metadata) - - cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) - self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) - self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) - - def testEmbeddingWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 1], dtype=dtypes.int32) - m = array_ops.zeros([1, 2]) - embedding_cell = contrib_rnn.EmbeddingWrapper( - rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) - self.assertEqual(embedding_cell.output_size, 2) - g, new_m = embedding_cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 2)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.17139, 0.17139]]) - - def testEmbeddingWrapperWithDynamicRnn(self): - with self.cached_session() as sess: - with variable_scope.variable_scope("root"): - inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) - input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) - embedding_cell = contrib_rnn.EmbeddingWrapper( - rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), - embedding_classes=1, - embedding_size=2) - outputs, _ = rnn.dynamic_rnn( - cell=embedding_cell, - inputs=inputs, - sequence_length=input_lengths, - dtype=dtypes.float32) - sess.run([variables_lib.global_variables_initializer()]) - # This will fail if output's dtype is inferred from input's. - sess.run(outputs) - - def testMultiRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 4]) - multi_rnn_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=False) - _, ml = multi_rnn_cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run(ml, { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1, 0.1]]) - }) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) - self.assertEqual(len(multi_rnn_cell.weights), 2 * 4) - self.assertTrue( - [x.dtype == dtypes.float32 for x in multi_rnn_cell.weights]) - - def testMultiRNNCellWithStateTuple(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m_bad = array_ops.zeros([1, 4]) - m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])) - - # Test incorrectness of state - with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): - rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=True)(x, m_bad) - - _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=True)(x, m_good) - - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - ml, { - x.name: np.array([[1., 1.]]), - m_good[0].name: np.array([[0.1, 0.1]]), - m_good[1].name: np.array([[0.1, 0.1]]) - }) - - # The numbers in results were not calculated, this is just a - # smoke test. However, these numbers should match those of - # the test testMultiRNNCell. - self.assertAllClose(res[0], [[0.175991, 0.175991]]) - self.assertAllClose(res[1], [[0.13248, 0.13248]]) - - -class DropoutWrapperTest(test.TestCase, parameterized.TestCase): - - def _testDropoutWrapper(self, - batch_size=None, - time_steps=None, - parallel_iterations=None, - wrapper_type=None, - **kwargs): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - if batch_size is None and time_steps is None: - # 2 time steps, batch size 1, depth 3 - batch_size = 1 - time_steps = 2 - x = constant_op.constant( - [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) - m = rnn_cell_impl.LSTMStateTuple( - *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32 - )] * 2) - else: - x = constant_op.constant( - np.random.randn(time_steps, batch_size, 3).astype(np.float32)) - m = rnn_cell_impl.LSTMStateTuple(*[ - constant_op. - constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) - ] * 2) - outputs, final_state = rnn.dynamic_rnn( - cell=wrapper_type( - rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs), - time_major=True, - parallel_iterations=parallel_iterations, - inputs=x, - initial_state=m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([outputs, final_state]) - self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) - self.assertEqual(res[1].c.shape, (batch_size, 3)) - self.assertEqual(res[1].h.shape, (batch_size, 3)) - return res - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperProperties(self, wrapper_type): - cell = rnn_cell_impl.BasicRNNCell(10) - wrapper = wrapper_type(cell) - # Github issue 15810 - self.assertEqual(wrapper.wrapped_cell, cell) - self.assertEqual(wrapper.state_size, 10) - self.assertEqual(wrapper.output_size, 10) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperZeroState(self, wrapper_type): - class _Cell(rnn_cell_impl.BasicRNNCell): - - def zero_state(self, batch_size=None, dtype=None): - return "wrapped_cell_zero_state" - wrapper = wrapper_type(_Cell(10)) - self.assertEqual(wrapper.zero_state(10, dtypes.float32), - "wrapped_cell_zero_state") - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepAllConstantInput(self, wrapper_type): - keep = array_ops.ones([]) - res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, - wrapper_type=wrapper_type) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(true_full_output, res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepAll(self, wrapper_type): - keep = variable_scope.get_variable("all", initializer=1.0) - res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, - wrapper_type=wrapper_type) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(true_full_output, res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperWithSeed(self, wrapper_type): - keep_some = 0.5 - random_seed.set_random_seed(2) - ## Use parallel_iterations = 1 in both calls to - ## _testDropoutWrapper to ensure the (per-time step) dropout is - ## consistent across both calls. Otherwise the seed may not end - ## up being munged consistently across both graphs. - res_standard_1 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - seed=10, - parallel_iterations=1, - wrapper_type=wrapper_type) - # Clear away the graph and the test session (which keeps variables around) - ops.reset_default_graph() - self._ClearCachedSession() - random_seed.set_random_seed(2) - res_standard_2 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - seed=10, - parallel_iterations=1, - wrapper_type=wrapper_type) - self.assertAllClose(res_standard_1[0], res_standard_2[0]) - self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) - self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepNoOutput(self, wrapper_type): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_none, - state_keep_prob=keep_all, - wrapper_type=wrapper_type) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(np.zeros(res[0].shape), res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self, wrapper_type): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - # Even though we dropout state, by default DropoutWrapper never - # drops out the memory ("c") term of an LSTMStateTuple. - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_all, - state_keep_prob=keep_none, - wrapper_type=wrapper_type) - true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - self.assertAllClose(true_full_output[0], res[0][0]) - # Second output is modified by zero input state - self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) - # h state has been set to zero - self.assertAllClose(np.zeros(res[1].h.shape), res[1].h) - # c state of an LSTMStateTuple is NEVER modified. - self.assertAllClose(true_c_state, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepNoInput(self, wrapper_type): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - # All outputs are different because inputs are zeroed out - res = self._testDropoutWrapper( - input_keep_prob=keep_none, - output_keep_prob=keep_all, - state_keep_prob=keep_all, - wrapper_type=wrapper_type) - self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) - self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) - self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperRecurrentOutput(self, wrapper_type): - keep_some = 0.8 - keep_all = variable_scope.get_variable("all", initializer=1.0) - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_some, - state_keep_prob=keep_all, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7) - # Ensure the same dropout pattern for all time steps - output_mask = np.abs(res[0]) > 1e-6 - for m in output_mask[1:]: - self.assertAllClose(output_mask[0], m) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperRecurrentStateInputAndOutput(self, wrapper_type): - keep_some = 0.9 - res = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7) - - # Smoke test for the state/input masks. - output_mask = np.abs(res[0]) > 1e-6 - for time_step in output_mask: - # Ensure the same dropout output pattern for all time steps - self.assertAllClose(output_mask[0], time_step) - for batch_entry in time_step: - # Assert all batch entries get the same mask - self.assertAllClose(batch_entry, time_step[0]) - - # For state, ensure all batch entries have the same mask - state_c_mask = np.abs(res[1].c) > 1e-6 - state_h_mask = np.abs(res[1].h) > 1e-6 - for batch_entry in state_c_mask: - self.assertAllClose(batch_entry, state_c_mask[0]) - for batch_entry in state_h_mask: - self.assertAllClose(batch_entry, state_h_mask[0]) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperRecurrentStateInputAndOutputWithSeed( - self, wrapper_type): - keep_some = 0.9 - random_seed.set_random_seed(2347) - np.random.seed(23487) - res0 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7, - seed=-234987) - ops.reset_default_graph() - self._ClearCachedSession() - random_seed.set_random_seed(2347) - np.random.seed(23487) - res1 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7, - seed=-234987) - - output_mask = np.abs(res0[0]) > 1e-6 - for time_step in output_mask: - # Ensure the same dropout output pattern for all time steps - self.assertAllClose(output_mask[0], time_step) - for batch_entry in time_step: - # Assert all batch entries get the same mask - self.assertAllClose(batch_entry, time_step[0]) - - # For state, ensure all batch entries have the same mask - state_c_mask = np.abs(res0[1].c) > 1e-6 - state_h_mask = np.abs(res0[1].h) > 1e-6 - for batch_entry in state_c_mask: - self.assertAllClose(batch_entry, state_c_mask[0]) - for batch_entry in state_h_mask: - self.assertAllClose(batch_entry, state_h_mask[0]) - - # Ensure seeded calculation is identical. - self.assertAllClose(res0[0], res1[0]) - self.assertAllClose(res0[1].c, res1[1].c) - self.assertAllClose(res0[1].h, res1[1].h) - - def testDropoutWrapperKerasStyle(self): - """Tests if DropoutWrapperV2 cell is instantiated in keras style scope.""" - wrapped_cell_v2 = rnn_cell_impl.DropoutWrapperV2( - rnn_cell_impl.BasicRNNCell(1)) - self.assertTrue(wrapped_cell_v2._keras_style) - - wrapped_cell = rnn_cell_impl.DropoutWrapper(rnn_cell_impl.BasicRNNCell(1)) - self.assertFalse(wrapped_cell._keras_style) - - def testDropoutWrapperV2VariableNames(self): - """Tests that variables names do not depend on wrapper in RNN layer.""" - - def _rnn_input(apply_wrapper): - """Creates a RNN layer with/without wrapper and returns built rnn cell.""" - with base_layer.keras_style_scope(): - base_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) - if apply_wrapper: - rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell) - else: - rnn_cell = base_cell - rnn_layer = keras_layers.RNN(rnn_cell) - inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) - _ = rnn_layer(inputs) - return base_cell._cells[0] - - rnn_1 = _rnn_input(True) - ops.reset_default_graph() - rnn_2 = _rnn_input(False) - - self.assertLen(rnn_1.weights, expected_len=2) - self.assertCountEqual([v.name for v in rnn_1.weights], - [v.name for v in rnn_2.weights]) - - def testDropoutWrapperV2Caller(self): - """Tests that DropoutWrapperV2 is using the LayerRNNCell's caller.""" - - with base_layer.keras_style_scope(): - base_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) - rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell) - inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32) - state = ops.convert_to_tensor([[1]], dtype=dtypes.float32) - _ = rnn_cell(inputs, [state, state]) - weights = base_cell._cells[0].weights - self.assertLen(weights, expected_len=2) - self.assertTrue(all(["dropout_wrapper" in v.name for v in weights])) - - def testDropoutWrapperV2Build(self): - cell = rnn_cell_impl.LSTMCell(10) - wrapper = rnn_cell_impl.DropoutWrapperV2(cell) - wrapper.build((1,)) - self.assertTrue(cell.built) - - -def basic_rnn_cell(inputs, state, num_units, scope=None): - if state is None: - if inputs is not None: - batch_size = inputs.get_shape()[0] - dtype = inputs.dtype - else: - batch_size = 0 - dtype = dtypes.float32 - init_output = array_ops.zeros( - array_ops.stack([batch_size, num_units]), dtype=dtype) - init_state = array_ops.zeros( - array_ops.stack([batch_size, num_units]), dtype=dtype) - init_output.set_shape([batch_size, num_units]) - init_state.set_shape([batch_size, num_units]) - return init_output, init_state - else: - with variable_scope.variable_scope(scope, "basic_rnn_cell", - [inputs, state]): - output = math_ops.tanh( - Linear([inputs, state], num_units, True)([inputs, state])) - return output, output - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index d7ee7fb8faacb0876218a983d68f007e1905c11e..921b4baae43d86f5ac2a86df0828c9691d9dbb2a 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -22,6 +22,7 @@ import itertools import numpy as np +from tensorflow.contrib.rnn.python.ops import core_rnn_cell as legacy_rnn_cell from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -29,6 +30,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.keras import initializers from tensorflow.python.keras import layers as keras_layers @@ -53,6 +55,314 @@ from tensorflow.python.util import nest class RNNCellTest(test.TestCase): + def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size, + out_size): + cell = cell_class(out_size, dtype=dtype) + in_shape = tensor_shape.TensorShape((batch_size, in_size)) + cell.build(in_shape) + state_output = cell.get_initial_state( + inputs=None, batch_size=batch_size, dtype=dtype) + cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) + self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) + + def testCellsBuild(self): + f32 = dtypes.float32 + f64 = dtypes.float64 + self._assert_cell_builds(contrib_rnn_cell.IndRNNCell, f32, 5, 7, 3) + self._assert_cell_builds(contrib_rnn_cell.IndRNNCell, f64, 5, 7, 3) + self._assert_cell_builds(contrib_rnn_cell.IndyGRUCell, f32, 5, 7, 3) + self._assert_cell_builds(contrib_rnn_cell.IndyGRUCell, f64, 5, 7, 3) + self._assert_cell_builds(contrib_rnn_cell.IndyLSTMCell, f32, 5, 7, 3) + self._assert_cell_builds(contrib_rnn_cell.IndyLSTMCell, f64, 5, 7, 3) + + def testIndRNNCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.IndRNNCell(2) + g, _ = cell(x, m) + self.assertEqual([ + "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME + ], [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[0].shape, (1, 2)) + + def testIndyGRUCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.185265, 0.17704]]) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyGRUCell with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.155127, 0.157328]]) + + def testIndyLSTMCell(self): + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + cell = rnn_cell_impl.MultiRNNCell( + [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) + self.assertEqual(cell.dtype, None) + self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) + self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error + g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_state_0, out_state_1], { + x.name: np.array([[1., 1.]]), + state_0[0].name: 0.1 * np.ones([1, 2]), + state_0[1].name: 0.1 * np.ones([1, 2]), + state_1[0].name: 0.1 * np.ones([1, 2]), + state_1[1].name: 0.1 * np.ones([1, 2]), + }) + self.assertEqual(len(res), 3) + global_variables = variables.global_variables() + self.assertEqual(expected_variable_names, + [v.name for v in global_variables]) + # Only check the range of outputs as this is just a smoke test. + self.assertAllInRange(res[0], -1.0, 1.0) + self.assertAllInRange(res[1], -1.0, 1.0) + self.assertAllInRange(res[2], -1.0, 1.0) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_state], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + }) + self.assertEqual(len(res), 2) + + def testLSTMCellLayerNorm(self): + with self.cached_session() as sess: + num_units = 2 + num_proj = 3 + batch_size = 1 + input_size = 4 + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([batch_size, input_size]) + c = array_ops.zeros([batch_size, num_units]) + h = array_ops.zeros([batch_size, num_proj]) + state = rnn_cell_impl.LSTMStateTuple(c, h) + cell = contrib_rnn_cell.LayerNormLSTMCell( + num_units=num_units, + num_proj=num_proj, + forget_bias=1.0, + layer_norm=True, + norm_gain=1.0, + norm_shift=0.0) + g, out_m = cell(x, state) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_m], { + x.name: np.ones((batch_size, input_size)), + c.name: 0.1 * np.ones((batch_size, num_units)), + h.name: 0.1 * np.ones((batch_size, num_proj)) + }) + self.assertEqual(len(res), 2) + # The numbers in results were not calculated, this is mostly just a + # smoke test. + self.assertEqual(res[0].shape, (batch_size, num_proj)) + self.assertEqual(res[1][0].shape, (batch_size, num_units)) + self.assertEqual(res[1][1].shape, (batch_size, num_proj)) + # Different inputs so different outputs and states + for i in range(1, batch_size): + self.assertTrue( + float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) + self.assertTrue( + float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + + def testOutputProjectionWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 3]) + cell = legacy_rnn_cell.OutputProjectionWrapper( + rnn_cell_impl.GRUCell(3), 2) + g, new_m = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.231907, 0.231907]]) + + def testInputProjectionWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 3]) + cell = legacy_rnn_cell.InputProjectionWrapper( + rnn_cell_impl.GRUCell(3), num_proj=3) + g, new_m = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) + + def testEmbeddingWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 1], dtype=dtypes.int32) + m = array_ops.zeros([1, 2]) + embedding_cell = legacy_rnn_cell.EmbeddingWrapper( + rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) + self.assertEqual(embedding_cell.output_size, 2) + g, new_m = embedding_cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 2)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.17139, 0.17139]]) + + def testEmbeddingWrapperWithDynamicRnn(self): + with self.cached_session() as sess: + with variable_scope.variable_scope("root"): + inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) + input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) + embedding_cell = legacy_rnn_cell.EmbeddingWrapper( + rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), + embedding_classes=1, + embedding_size=2) + outputs, _ = rnn.dynamic_rnn( + cell=embedding_cell, + inputs=inputs, + sequence_length=input_lengths, + dtype=dtypes.float32) + sess.run([variables.global_variables_initializer()]) + # This will fail if output's dtype is inferred from input's. + sess.run(outputs) + + def testSRUCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.509682, 0.509682]]) + + def testSRUCellKerasRNN(self): + """Tests that SRUCell works with keras RNN layer.""" + cell = contrib_rnn_cell.SRUCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs_keras = rnn_layer(seq_input) + with self.cached_session() as sess: + sess.run([variables.global_variables_initializer()]) + self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) + + def testSRUCellBiasType(self): + """Tests that the bias' dtype is properly set.""" + cell = contrib_rnn_cell.SRUCell(10) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.float32_ref) + + cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.int32_ref) + + cell_input = ops.convert_to_tensor( + np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) + cell_state = ops.convert_to_tensor( + np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) + cell = contrib_rnn_cell.SRUCell(10) + cell(cell_input, [cell_state]) + self.assertEqual(cell._bias.dtype, dtypes.float16_ref) + + def testSRUCellWithDiffSize(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) + def testCoupledInputForgetGateLSTMCell(self): with self.cached_session() as sess: num_units = 2 @@ -1265,6 +1575,61 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(outputs), batch) self.assertEqual(len(state), batch) + def testNTMCell(self): + expected_output = np.array( + [[-0.04973561, -0.00020032, -0.09586009, -0.05049511], + [-0.02199885, 0.02302885, -0.05558189, -0.02051288], + [-0.01399924, 0.02543444, -0.06975862, -0.03782758], + [-0.02238393, 0.0135776, -0.09102941, -0.05594013]], + dtype=np.float32) + expected_read_vector_list = np.array( + [[1e-6, 1e-6, 1e-6, 1e-6], [1e-6, 1e-6, 1e-6, 1e-6], + [1e-6, 1e-6, 1e-6, 1e-6], [1e-6, 1e-6, 1e-6, 1e-6]], + dtype=np.float32) + expected_w_list = np.array( + [[[0.15837428, 0.21354634, 0.22115856, 0.21117255, 0.19574821], + [0.15826838, 0.2150458, 0.2228198, 0.20747298, 0.19639312], + [0.15750293, 0.21550071, 0.22280747, 0.20737495, 0.19681393], + [0.15763053, 0.21473582, 0.22187267, 0.20920397, 0.19655706]], + [[0.21703579, 0.19425659, 0.22143759, 0.18024713, 0.18702294], + [0.2164267, 0.19451937, 0.22112325, 0.18051708, 0.18741359], + [0.21567065, 0.1947548, 0.22107735, 0.18058982, 0.18790732], + [0.2163743, 0.194361, 0.22131558, 0.18042919, 0.1875199]]], + dtype=np.float32) + expected_M_0 = np.array( + [[-0.00553495, -0.01089884, 0.00683121, -0.00273276], + [-0.00495392, -0.00975483, 0.00611433, -0.00244583], + [-0.00564722, -0.0111199, 0.00696973, -0.0027882], + [-0.00459658, -0.00905126, 0.00567345, -0.00226937], + [-0.00476941, -0.00939155, 0.00588669, -0.00235472]], + dtype=np.float32) + + with session.Session() as sess: + with variable_scope.variable_scope("root"): + seed = 1234 + random_seed.set_random_seed(seed) + batch_size = 4 + inputs = random_ops.random_uniform((batch_size, 4), + 0.0, + 1.0, + seed=seed + 1) + cell = contrib_rnn_cell.NTMCell( + controller=rnn_cell_impl.LSTMCell(num_units=4), + memory_size=5, + memory_vector_dim=4, + read_head_num=1, + write_head_num=1) + output, state = cell(inputs, cell.zero_state(batch_size, + dtypes.float32)) + sess.run([variables.global_variables_initializer()]) + res, read_vector_list, w_list, M = sess.run( + [output, state.read_vector_list, state.w_list, state.M]) + # Smoke test + self.assertAllClose(res, expected_output) + self.assertAllClose(read_vector_list[0], expected_read_vector_list) + self.assertAllClose(w_list, expected_w_list) + self.assertAllClose(M[0], expected_M_0) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index b043026bc556a8879b15b432829baf8136250c0e..ed3c7609368819295e142a2a4c5a4e5f66c2ee36 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -691,9 +691,10 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: - max_seq_len = math_ops.to_int64(time_len) + max_seq_len = math_ops.cast(time_len, dtypes.int64) else: - max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length)) + max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length), + dtypes.int64) _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm( seq_len_max=max_seq_len, diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py index 0266b72dcb15e4aba01a9a31b4be75c5b84d44da..41b1698321e20f4360d75fa2db79f9bd8a806cea 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn.py +++ b/tensorflow/contrib/rnn/python/ops/rnn.py @@ -131,7 +131,8 @@ def stack_bidirectional_dynamic_rnn(cells_fw, sequence_length=None, parallel_iterations=None, time_major=False, - scope=None): + scope=None, + swap_memory=False): """Creates a dynamic bidirectional recurrent neural network. Stacks several bidirectional rnn layers. The combined forward and backward @@ -171,6 +172,10 @@ def stack_bidirectional_dynamic_rnn(cells_fw, data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to None. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs + which would typically not fit on a single GPU, with very minimal (or no) + performance penalty. Returns: A tuple (outputs, output_state_fw, output_state_bw) where: @@ -230,6 +235,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, sequence_length=sequence_length, parallel_iterations=parallel_iterations, dtype=dtype, + swap_memory=swap_memory, time_major=time_major) # Concat the outputs to create the new input. prev_layer = array_ops.concat(outputs, 2) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 482e547a16be85804beec88a91fa03b053d09b27..9ada8e244ba5657c3bfc91d17aa69c3c5f49a0df 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -33,6 +33,7 @@ from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -3153,7 +3154,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): r"""Independently Gated Recurrent Unit cell. Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell, - yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and + yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal matrices, i.e. a Hadamard product with a single vector: @@ -3164,12 +3165,10 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j + [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$ - where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU + where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU node sees only its own state, as opposed to seeing all states in the same layer. - TODO(gonnet): Write a paper describing this and add a reference here. - Args: num_units: int, The number of units in the GRU cell. activation: Nonlinearity to use. Default: `tanh`. @@ -3254,7 +3253,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): self.built = True def call(self, inputs, state): - """Gated recurrent unit (GRU) with nunits cells.""" + """Recurrently independent Gated Recurrent Unit (GRU) with nunits cells.""" gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + ( gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u) @@ -3278,10 +3277,9 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): r"""Basic IndyLSTM recurrent network cell. Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to - BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\) - matrices in - https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate - replaced by diagonal matrices, i.e. a Hadamard product with a single vector: + BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\) + matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a + Hadamard product with a single vector: $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$ @@ -3289,8 +3287,8 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): $$c_t = f_t \circ c_{t-1} + i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$ - where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM - node sees only its own state \(h\) and \(c\), as opposed to seeing all + where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM + node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all states in the same layer. We add forget_bias (default: 1) to the biases of the forget gate in order to @@ -3298,11 +3296,6 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - - For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` - that follows. - - TODO(gonnet): Write a paper describing this and add a reference here. """ def __init__(self, @@ -3417,6 +3410,354 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): return new_h, new_state +NTMControllerState = collections.namedtuple( + "NTMControllerState", + ("controller_state", "read_vector_list", "w_list", "M", "time")) + + +class NTMCell(rnn_cell_impl.LayerRNNCell): + """Neural Turing Machine Cell with RNN controller. + + Implementation based on: + https://arxiv.org/abs/1807.08518 + Mark Collier, Joeran Beel + + which is in turn based on the source code of: + https://github.com/snowkylin/ntm + + and of course the original NTM paper: + Neural Turing Machines + https://arxiv.org/abs/1410.5401 + A Graves, G Wayne, I Danihelka + """ + + def __init__(self, + controller, + memory_size, + memory_vector_dim, + read_head_num, + write_head_num, + shift_range=1, + output_dim=None, + clip_value=20, + dtype=dtypes.float32, + name=None): + """Initialize the NTM Cell. + + Args: + controller: an RNNCell, the RNN controller. + memory_size: int, The number of memory locations in the NTM memory + matrix + memory_vector_dim: int, The dimensionality of each location in the NTM + memory matrix + read_head_num: int, The number of read heads from the controller into + memory + write_head_num: int, The number of write heads from the controller into + memory + shift_range: int, The number of places to the left/right it is possible + to iterate the previous address to in a single step + output_dim: int, The number of dimensions to make a linear projection of + the NTM controller outputs to. If None, no linear projection is + applied + clip_value: float, The maximum absolute value the controller parameters + are clipped to + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + """ + super(NTMCell, self).__init__(dtype=dtype, name=name) + + rnn_cell_impl.assert_like_rnncell("NTM RNN controller cell", controller) + + self.controller = controller + self.memory_size = memory_size + self.memory_vector_dim = memory_vector_dim + self.read_head_num = read_head_num + self.write_head_num = write_head_num + self.clip_value = clip_value + + self.output_dim = output_dim + self.shift_range = shift_range + + self.num_parameters_per_head = ( + self.memory_vector_dim + 2 * self.shift_range + 4) + self.num_heads = self.read_head_num + self.write_head_num + self.total_parameter_num = ( + self.num_parameters_per_head * self.num_heads + + self.memory_vector_dim * 2 * self.write_head_num) + + @property + def state_size(self): + return NTMControllerState( + controller_state=self.controller.state_size, + read_vector_list=[ + self.memory_vector_dim for _ in range(self.read_head_num) + ], + w_list=[ + self.memory_size + for _ in range(self.read_head_num + self.write_head_num) + ], + M=tensor_shape.TensorShape([self.memory_size * self.memory_vector_dim]), + time=tensor_shape.TensorShape([])) + + @property + def output_size(self): + return self.output_dim + + def build(self, inputs_shape): + if self.output_dim is None: + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % + inputs_shape) + else: + self.output_dim = inputs_shape[1].value + + def _create_linear_initializer(input_size, dtype=dtypes.float32): + stddev = 1.0 / math.sqrt(input_size) + return init_ops.truncated_normal_initializer(stddev=stddev, dtype=dtype) + + self._params_kernel = self.add_variable( + "parameters_kernel", + shape=[self.controller.output_size, self.total_parameter_num], + initializer=_create_linear_initializer(self.controller.output_size)) + + self._params_bias = self.add_variable( + "parameters_bias", + shape=[self.total_parameter_num], + initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + + self._output_kernel = self.add_variable( + "output_kernel", + shape=[ + self.controller.output_size + + self.memory_vector_dim * self.read_head_num, self.output_dim + ], + initializer=_create_linear_initializer(self.controller.output_size + + self.memory_vector_dim * + self.read_head_num)) + + self._output_bias = self.add_variable( + "output_bias", + shape=[self.output_dim], + initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + + self._init_read_vectors = [ + self.add_variable( + "initial_read_vector_%d" % i, + shape=[1, self.memory_vector_dim], + initializer=initializers.glorot_uniform()) + for i in range(self.read_head_num) + ] + + self._init_address_weights = [ + self.add_variable( + "initial_address_weights_%d" % i, + shape=[1, self.memory_size], + initializer=initializers.glorot_uniform()) + for i in range(self.read_head_num + self.write_head_num) + ] + + self._M = self.add_variable( + "memory", + shape=[self.memory_size, self.memory_vector_dim], + initializer=init_ops.constant_initializer(1e-6, dtype=self.dtype)) + + self.built = True + + def call(self, x, prev_state): + # Addressing Mechanisms (Sec 3.3) + + def _prev_read_vector_list_initial_value(): + return [ + self._expand( + math_ops.tanh( + array_ops.squeeze( + math_ops.matmul( + array_ops.ones([1, 1]), self._init_read_vectors[i]))), + dim=0, + N=x.shape[0].value or array_ops.shape(x)[0]) + for i in range(self.read_head_num) + ] + + prev_read_vector_list = control_flow_ops.cond( + math_ops.equal(prev_state.time, + 0), _prev_read_vector_list_initial_value, lambda: + prev_state.read_vector_list) + if self.read_head_num == 1: + prev_read_vector_list = [prev_read_vector_list] + + controller_input = array_ops.concat([x] + prev_read_vector_list, axis=1) + controller_output, controller_state = self.controller( + controller_input, prev_state.controller_state) + + parameters = math_ops.matmul(controller_output, self._params_kernel) + parameters = nn_ops.bias_add(parameters, self._params_bias) + parameters = clip_ops.clip_by_value(parameters, -self.clip_value, + self.clip_value) + head_parameter_list = array_ops.split( + parameters[:, :self.num_parameters_per_head * self.num_heads], + self.num_heads, + axis=1) + erase_add_list = array_ops.split( + parameters[:, self.num_parameters_per_head * self.num_heads:], + 2 * self.write_head_num, + axis=1) + + def _prev_w_list_initial_value(): + return [ + self._expand( + nn_ops.softmax( + array_ops.squeeze( + math_ops.matmul( + array_ops.ones([1, 1]), + self._init_address_weights[i]))), + dim=0, + N=x.shape[0].value or array_ops.shape(x)[0]) + for i in range(self.read_head_num + self.write_head_num) + ] + + prev_w_list = control_flow_ops.cond( + math_ops.equal(prev_state.time, 0), + _prev_w_list_initial_value, lambda: prev_state.w_list) + if (self.read_head_num + self.write_head_num) == 1: + prev_w_list = [prev_w_list] + + prev_M = control_flow_ops.cond( + math_ops.equal(prev_state.time, 0), lambda: self._expand( + self._M, dim=0, N=x.shape[0].value or array_ops.shape(x)[0]), + lambda: prev_state.M) + + w_list = [] + for i, head_parameter in enumerate(head_parameter_list): + k = math_ops.tanh(head_parameter[:, 0:self.memory_vector_dim]) + beta = nn_ops.softplus(head_parameter[:, self.memory_vector_dim]) + g = math_ops.sigmoid(head_parameter[:, self.memory_vector_dim + 1]) + s = nn_ops.softmax(head_parameter[:, self.memory_vector_dim + + 2:(self.memory_vector_dim + 2 + + (self.shift_range * 2 + 1))]) + gamma = nn_ops.softplus(head_parameter[:, -1]) + 1 + w = self._addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i]) + w_list.append(w) + + # Reading (Sec 3.1) + + read_w_list = w_list[:self.read_head_num] + read_vector_list = [] + for i in range(self.read_head_num): + read_vector = math_ops.reduce_sum( + array_ops.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1) + read_vector_list.append(read_vector) + + # Writing (Sec 3.2) + + write_w_list = w_list[self.read_head_num:] + M = prev_M + for i in range(self.write_head_num): + w = array_ops.expand_dims(write_w_list[i], axis=2) + erase_vector = array_ops.expand_dims( + math_ops.sigmoid(erase_add_list[i * 2]), axis=1) + add_vector = array_ops.expand_dims( + math_ops.tanh(erase_add_list[i * 2 + 1]), axis=1) + erase_M = array_ops.ones_like(M) - math_ops.matmul(w, erase_vector) + M = M * erase_M + math_ops.matmul(w, add_vector) + + output = math_ops.matmul( + array_ops.concat([controller_output] + read_vector_list, axis=1), + self._output_kernel) + output = nn_ops.bias_add(output, self._output_bias) + output = clip_ops.clip_by_value(output, -self.clip_value, self.clip_value) + + return output, NTMControllerState( + controller_state=controller_state, + read_vector_list=read_vector_list, + w_list=w_list, + M=M, + time=prev_state.time + 1) + + def _expand(self, x, dim, N): + return array_ops.concat([array_ops.expand_dims(x, dim) for _ in range(N)], + axis=dim) + + def _addressing(self, k, beta, g, s, gamma, prev_M, prev_w): + # Sec 3.3.1 Focusing by Content + + k = array_ops.expand_dims(k, axis=2) + inner_product = math_ops.matmul(prev_M, k) + k_norm = math_ops.sqrt( + math_ops.reduce_sum(math_ops.square(k), axis=1, keepdims=True)) + M_norm = math_ops.sqrt( + math_ops.reduce_sum(math_ops.square(prev_M), axis=2, keepdims=True)) + norm_product = M_norm * k_norm + + # eq (6) + K = array_ops.squeeze(inner_product / (norm_product + 1e-8)) + + K_amplified = math_ops.exp(array_ops.expand_dims(beta, axis=1) * K) + + # eq (5) + w_c = K_amplified / math_ops.reduce_sum(K_amplified, axis=1, keepdims=True) + + # Sec 3.3.2 Focusing by Location + + g = array_ops.expand_dims(g, axis=1) + + # eq (7) + w_g = g * w_c + (1 - g) * prev_w + + s = array_ops.concat([ + s[:, :self.shift_range + 1], + array_ops.zeros([ + s.shape[0].value or array_ops.shape(s)[0], self.memory_size - + (self.shift_range * 2 + 1) + ]), s[:, -self.shift_range:] + ], + axis=1) + t = array_ops.concat( + [array_ops.reverse(s, axis=[1]), + array_ops.reverse(s, axis=[1])], + axis=1) + s_matrix = array_ops.stack([ + t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1] + for i in range(self.memory_size) + ], + axis=1) + + # eq (8) + w_ = math_ops.reduce_sum( + array_ops.expand_dims(w_g, axis=1) * s_matrix, axis=2) + w_sharpen = math_ops.pow(w_, array_ops.expand_dims(gamma, axis=1)) + + # eq (9) + w = w_sharpen / math_ops.reduce_sum(w_sharpen, axis=1, keepdims=True) + + return w + + def zero_state(self, batch_size, dtype): + read_vector_list = [ + array_ops.zeros([batch_size, self.memory_vector_dim]) + for _ in range(self.read_head_num) + ] + + w_list = [ + array_ops.zeros([batch_size, self.memory_size]) + for _ in range(self.read_head_num + self.write_head_num) + ] + + controller_init_state = self.controller.zero_state(batch_size, dtype) + + M = array_ops.zeros([batch_size, self.memory_size, self.memory_vector_dim]) + + return NTMControllerState( + controller_state=controller_init_state, + read_vector_list=read_vector_list, + w_list=w_list, + M=M, + time=0) + + class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): """MinimalRNN cell. @@ -3429,7 +3770,7 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): Propagation in Recurrent Neural Networks." ICML, 2018. A MinimalRNN cell first projects the input to the hidden space. The new - hidden state is then calcuated as a weighted sum of the projected input and + hidden state is then calculated as a weighted sum of the projected input and the previous hidden state, using a single update gate. """ @@ -3543,7 +3884,7 @@ class CFNCell(rnn_cell_impl.LayerRNNCell): "A recurrent neural network without chaos." ICLR, 2017. A CFN cell first projects the input to the hidden space. The hidden state - goes through a contractive mapping. The new hidden state is then calcuated + goes through a contractive mapping. The new hidden state is then calculated as a linear combination of the projected input and the contracted previous hidden state, using decoupled input and forget gates. """ diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index 3fc6bfbb4d03a39906d4441e48b2788423caa234..d8ab9eba7049e468b373a1641f92dc781aa22558 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -61,10 +61,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): self._server = server def tearDown(self): - # TODO(ebrevdo): Figure out why this sometimes times out. - # self._service.ExitLoop() - # self._service_thread.join() - # self._server.stop() + self._server.stop(grace=None) super(RpcOpTest, self).tearDown() diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 0392ed9eee79391c60318faf68d8dfd6eb64a994..a61e9579b84a60d74b73e45a6100a2c772d9cff8 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -22,5 +22,5 @@ from tensorflow.python.keras import saving # TODO(kathywu): Remove all contrib callers, switch to tf.keras. -save_keras_model = saving.export +save_keras_model = saving.export_saved_model load_keras_model = saving.load_from_saved_model diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 89176180ae0dd963bccc34aa2d0fc52be839dd3f..f42a2953ef96a863bf6b0c33e763413da569bb41 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -74,9 +74,6 @@ tf_custom_op_library( "kernels/beam_search_ops_gpu.cu.cc", "kernels/beam_search_ops.h", ], - deps = [ - "//tensorflow/core/kernels:eigen_helpers", - ], ) tf_gen_op_wrapper_py( @@ -96,7 +93,6 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", ], ) @@ -139,6 +135,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "basic_decoder_v2_test", + size = "medium", + srcs = ["python/kernel_tests/basic_decoder_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "beam_search_ops_test", size = "medium", @@ -173,6 +190,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "decoder_v2_test", + size = "medium", + srcs = ["python/kernel_tests/decoder_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "beam_search_decoder_test", size = "medium", @@ -213,3 +251,19 @@ cuda_py_test( "//tensorflow/python:variables", ], ) + +cuda_py_test( + name = "attention_wrapper_v2_test", + size = "medium", + srcs = ["python/kernel_tests/attention_wrapper_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], + shard_count = 4, +) diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h index 34da8c82cdab9b6f82af328c49a365ae1cb951ed..c0b3091fb8d98589f26818b93140f1b58eb73794 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -18,7 +18,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index bc28d492fe1a25afe0d0783539aa9e759e7b703f..be2aa4782c3cbc2ecce23b57d332e9bf0cec18bc 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -91,16 +91,11 @@ struct GatherTree { beams.device(d) = beams.constant(end_token); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); - // clang-format off - GatherTreeOpKernel - <<>>( - batch_size, max_time, beam_width, - step_ids.data(), - parent_ids.data(), - max_sequence_length.data(), - end_token, - beams.data()); - // clang-format on + TF_CHECK_OK(CudaLaunchKernel( + GatherTreeOpKernel, config.block_count, config.thread_per_block, 0, + d.stream(), batch_size, max_time, beam_width, step_ids.data(), + parent_ids.data(), max_sequence_length.data(), end_token, + beams.data())); } }; diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index d815f81f847ad79ddcc6c6ecf5c050598e185d8d..98e54db4584037a0cb6aea13bc6846f38007ecba 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -13,11 +13,9 @@ # limitations under the License. # ============================================================================== """Tests for contrib.seq2seq.python.ops.attention_wrapper.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import collections import functools @@ -30,6 +28,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -66,6 +65,7 @@ def get_result_summary(x): return x +@test_util.run_v1_only class AttentionWrapperTest(test.TestCase): def assertAllCloseOrEqual(self, x, y, **kwargs): @@ -358,7 +358,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00597103), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.6)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -387,7 +387,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -454,7 +454,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333333333)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -696,7 +696,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0025896581), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.6)) + shape=(5, 3), dtype=dtype('int32'), mean=1.73333333)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -707,12 +707,12 @@ class AttentionWrapperTest(test.TestCase): shape=(5, 6), dtype=dtype('float32'), mean=-0.00069823361), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + shape=(5, 8), dtype=dtype('float32'), mean=0.029914695), attention_state=ResultSummary( - shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + shape=(5, 8), dtype=dtype('float32'), mean=0.029914695), alignment_history=()) expected_final_alignment_history = ResultSummary( - shape=(3, 5, 8), dtype=dtype('float32'), mean=0.04865776002407074) + shape=(3, 5, 8), dtype=dtype('float32'), mean=0.0465225502849) self._testWithAttention( create_attention_mechanism, @@ -921,9 +921,9 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966), + shape=(5, 3, 20), dtype=dtype('float32'), mean=0.115853324533), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667)) + shape=(5, 3), dtype=dtype('int32'), mean=8.6)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -931,7 +931,7 @@ class AttentionWrapperTest(test.TestCase): h=ResultSummary( shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)), attention=ResultSummary( - shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207), + shape=(5, 20), dtype=dtype('float32'), mean=0.11462739855), time=3, alignments=(ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66a464dc2183f272215921d26f89ce282bbdf07b --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -0,0 +1,745 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.ops.attention_wrapper.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py +from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.keras import initializers +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionMechanismTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(AttentionMechanismTest, self).setUp() + self.batch = 10 + self.timestep = 5 + self.memory_size = 6 + self.units = 8 + + self.memory = np.random.randn(self.batch, self.timestep, + self.memory_size).astype(np.float32) + self.query = np.random.randn(self.batch, self.units).astype(np.float32) + self.state = np.random.randn(self.batch, self.timestep).astype(np.float32) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_attention_shape_inference(self, attention_cls): + attention = attention_cls(self.units, self.memory) + attention_score = attention([self.query, self.state]) + self.assertLen(attention_score, 2) + self.assertEqual(attention_score[0].shape, (self.batch, self.timestep)) + self.assertEqual(attention_score[1].shape, (self.batch, self.timestep)) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_get_config(self, attention_cls): + attention = attention_cls(self.units, self.memory) + config = attention.get_config() + + attention_from_config = attention_cls.from_config(config) + config_from_clone = attention_from_config.get_config() + + self.assertDictEqual(config, config_from_clone) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_layer_output(self, attention_cls): + attention = attention_cls(self.units, self.memory) + score = attention([self.query, self.state]) + self.evaluate(variables.variables_initializer(attention.variables)) + + score_val = self.evaluate(score) + self.assertLen(score_val, 2) + self.assertEqual(score_val[0].shape, (self.batch, self.timestep)) + self.assertEqual(score_val[1].shape, (self.batch, self.timestep)) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_passing_memory_from_call(self, attention_cls): + attention = attention_cls(self.units, self.memory) + weights_before_query = attention.get_weights() + ref_score = attention([self.query, self.state]) + + self.evaluate(variables.global_variables_initializer()) + ref_score_val = self.evaluate(ref_score) + + all_weights = attention.get_weights() + config = attention.get_config() + # Simulate the twice invocation of calls here. + attention_from_config = attention_cls.from_config(config) + attention_from_config.build(self.memory.shape) + attention_from_config.set_weights(weights_before_query) + attention_from_config(self.memory, setup_memory=True) + attention_from_config.build([self.query.shape, self.state.shape]) + attention_from_config.set_weights(all_weights) + score = attention_from_config([self.query, self.state]) + + score_val = self.evaluate(score) + self.assertAllClose(ref_score_val, score_val) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_save_load_layer(self, attention_cls): + vocab = 20 + embedding_dim = 6 + inputs = keras.layers.Input(shape=[self.timestep]) + encoder_input = keras.layers.Embedding( + vocab, embedding_dim, mask_zero=True)( + inputs) + encoder_output = keras.layers.LSTM( + self.memory_size, return_sequences=True)( + encoder_input) + + attention = attention_cls(self.units, encoder_output) + query = keras.layers.Input(shape=[self.units]) + state = keras.layers.Input(shape=[self.timestep]) + + score = attention([query, state]) + + x = np.random.randint(vocab, size=(self.batch, self.timestep)) + x_test = np.random.randint(vocab, size=(self.batch, self.timestep)) + y = np.random.randn(self.batch, self.timestep) + model = keras.models.Model([inputs, query, state], score) + model.compile("rmsprop", "mse") + model.fit([x, self.query, self.state], (y, y)) + y_ref = model.predict_on_batch([x_test, self.query, self.state]) + + config = model.get_config() + weights = model.get_weights() + loaded_model = keras.models.Model.from_config( + config, custom_objects={attention_cls.__name__: attention_cls}) + loaded_model.set_weights(weights) + + y = loaded_model.predict_on_batch([x_test, self.query, self.state]) + + self.assertAllClose(y_ref, y) + + # TODO(scottzhu): Add tests for model.compile(run_eagerly=True) + + +class ResultSummary( + collections.namedtuple("ResultSummary", ("shape", "dtype", "mean"))): + pass + + +def get_result_summary(x): + if isinstance(x, np.ndarray): + return ResultSummary(x.shape, x.dtype, x.mean()) + return x + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): + + def assertAllCloseOrEqual(self, x, y, **kwargs): + if isinstance(x, np.ndarray) or isinstance(x, float): + return super(AttentionWrapperV2Test, self).assertAllClose( + x, y, atol=1e-3, **kwargs) + else: + self.assertAllEqual(x, y, **kwargs) + + def setUp(self): + super(AttentionWrapperV2Test, self).setUp() + self.batch = 64 + self.units = 128 + self.encoder_timestep = 10 + self.encoder_dim = 256 + self.decoder_timestep = 12 + self.encoder_outputs = np.random.randn(self.batch, self.encoder_timestep, + self.encoder_dim) + self.encoder_sequence_length = np.random.randint( + self.encoder_timestep, size=(self.batch,)).astype(np.int32) + self.decoder_inputs = np.random.randn(self.batch, self.decoder_timestep, + self.units) + self.decoder_sequence_length = np.random.randint( + self.decoder_timestep, size=(self.batch,)).astype(np.int32) + + def _testWithAttention(self, + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=3, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_size=6, + attention_layer=None, + create_query_layer=False, + create_memory_layer=True, + create_attention_kwargs=None): + attention_layer_sizes = ([attention_layer_size] + if attention_layer_size is not None else None) + attention_layers = ([attention_layer] + if attention_layer is not None else None) + self._testWithMaybeMultiAttention( + is_multi=False, + create_attention_mechanisms=[create_attention_mechanism], + expected_final_output=expected_final_output, + expected_final_state=expected_final_state, + attention_mechanism_depths=[attention_mechanism_depth], + alignment_history=alignment_history, + expected_final_alignment_history=expected_final_alignment_history, + attention_layer_sizes=attention_layer_sizes, + attention_layers=attention_layers, + create_query_layer=create_query_layer, + create_memory_layer=create_memory_layer, + create_attention_kwargs=create_attention_kwargs) + + def _testWithMaybeMultiAttention(self, + is_multi, + create_attention_mechanisms, + expected_final_output, + expected_final_state, + attention_mechanism_depths, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_sizes=None, + attention_layers=None, + create_query_layer=False, + create_memory_layer=True, + create_attention_kwargs=None): + # Allow is_multi to be True with a single mechanism to enable test for + # passing in a single mechanism in a list. + assert len(create_attention_mechanisms) == 1 or is_multi + encoder_sequence_length = [3, 2, 3, 1, 1] + decoder_sequence_length = [2, 0, 1, 2, 3] + batch_size = 5 + encoder_max_time = 8 + decoder_max_time = 4 + input_depth = 7 + encoder_output_depth = 10 + cell_depth = 9 + create_attention_kwargs = create_attention_kwargs or {} + + if attention_layer_sizes is not None: + # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. + attention_depth = sum(attention_layer_size or encoder_output_depth + for attention_layer_size in attention_layer_sizes) + elif attention_layers is not None: + # Compute sum of attention_layers output depth. + attention_depth = sum( + attention_layer.compute_output_shape( + [batch_size, cell_depth + encoder_output_depth]).dims[-1].value + for attention_layer in attention_layers) + else: + attention_depth = encoder_output_depth * len(create_attention_mechanisms) + + decoder_inputs = np.random.randn(batch_size, decoder_max_time, + input_depth).astype(np.float32) + encoder_outputs = np.random.randn(batch_size, encoder_max_time, + encoder_output_depth).astype(np.float32) + + attention_mechanisms = [] + for creator, depth in zip(create_attention_mechanisms, + attention_mechanism_depths): + # Create a memory layer with deterministic initializer to avoid randomness + # in the test between graph and eager. + if create_query_layer: + create_attention_kwargs["query_layer"] = keras.layers.Dense( + depth, kernel_initializer="ones", use_bias=False) + if create_memory_layer: + create_attention_kwargs["memory_layer"] = keras.layers.Dense( + depth, kernel_initializer="ones", use_bias=False) + + attention_mechanisms.append( + creator( + units=depth, + memory=encoder_outputs, + memory_sequence_length=encoder_sequence_length, + **create_attention_kwargs)) + + with self.cached_session(use_gpu=True): + attention_layer_size = attention_layer_sizes + attention_layer = attention_layers + if not is_multi: + if attention_layer_size is not None: + attention_layer_size = attention_layer_size[0] + if attention_layer is not None: + attention_layer = attention_layer[0] + cell = keras.layers.LSTMCell(cell_depth, + recurrent_activation="sigmoid", + kernel_initializer="ones", + recurrent_initializer="ones") + cell = wrapper.AttentionWrapper( + cell, + attention_mechanisms if is_multi else attention_mechanisms[0], + attention_layer_size=attention_layer_size, + alignment_history=alignment_history, + attention_layer=attention_layer) + if cell._attention_layers is not None: + for layer in cell._attention_layers: + if getattr(layer, "kernel_initializer") is None: + layer.kernel_initializer = initializers.glorot_uniform(seed=1337) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + initial_state = cell.get_initial_state( + dtype=dtypes.float32, batch_size=batch_size) + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=initial_state, + sequence_length=decoder_sequence_length) + + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + + expected_time = ( + expected_final_state.time if context.executing_eagerly() else None) + self.assertEqual((batch_size, expected_time, attention_depth), + tuple(final_outputs.rnn_output.get_shape().as_list())) + self.assertEqual((batch_size, expected_time), + tuple(final_outputs.sample_id.get_shape().as_list())) + + self.assertEqual((batch_size, attention_depth), + tuple(final_state.attention.get_shape().as_list())) + self.assertEqual((batch_size, cell_depth), + tuple(final_state.cell_state[0].get_shape().as_list())) + self.assertEqual((batch_size, cell_depth), + tuple(final_state.cell_state[1].get_shape().as_list())) + + if alignment_history: + if is_multi: + state_alignment_history = [] + for history_array in final_state.alignment_history: + history = history_array.stack() + self.assertEqual((expected_time, batch_size, encoder_max_time), + tuple(history.get_shape().as_list())) + state_alignment_history.append(history) + state_alignment_history = tuple(state_alignment_history) + else: + state_alignment_history = final_state.alignment_history.stack() + self.assertEqual((expected_time, batch_size, encoder_max_time), + tuple(state_alignment_history.get_shape().as_list())) + nest.assert_same_structure(cell.state_size, + cell.zero_state(batch_size, dtypes.float32)) + # Remove the history from final_state for purposes of the + # remainder of the tests. + final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access + else: + state_alignment_history = () + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "final_outputs": final_outputs, + "final_state": final_state, + "state_alignment_history": state_alignment_history, + }) + + final_output_info = nest.map_structure(get_result_summary, + eval_result["final_outputs"]) + final_state_info = nest.map_structure(get_result_summary, + eval_result["final_state"]) + print("final_output_info: ", final_output_info) + print("final_state_info: ", final_state_info) + + nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, + final_output_info) + nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, + final_state_info) + if alignment_history: # by default, the wrapper emits attention as output + final_alignment_history_info = nest.map_structure( + get_result_summary, eval_result["state_alignment_history"]) + print("final_alignment_history_info: ", final_alignment_history_info) + nest.map_structure( + self.assertAllCloseOrEqual, + # outputs are batch major but the stacked TensorArray is time major + expected_final_alignment_history, + final_alignment_history_info) + + # TODO(b/126893309): reenable np.float16 once the bug is fixed. + @parameterized.parameters([np.float32, np.float64]) + def testBahdanauNormalizedDType(self, dtype): + encoder_outputs = self.encoder_outputs.astype(dtype) + decoder_inputs = self.decoder_inputs.astype(dtype) + attention_mechanism = wrapper.BahdanauAttentionV2( + units=self.units, + memory=encoder_outputs, + memory_sequence_length=self.encoder_sequence_length, + normalize=True, + dtype=dtype) + cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), + sequence_length=self.decoder_sequence_length) + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + + # TODO(b/126893309): reenable np.float16 once the bug is fixed. + @parameterized.parameters([np.float32, np.float64]) + def testLuongScaledDType(self, dtype): + # Test case for GitHub issue 18099 + encoder_outputs = self.encoder_outputs.astype(dtype) + decoder_inputs = self.decoder_inputs.astype(dtype) + attention_mechanism = wrapper.LuongAttentionV2( + units=self.units, + memory=encoder_outputs, + memory_sequence_length=self.encoder_sequence_length, + scale=True, + dtype=dtype, + ) + cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), + sequence_length=self.decoder_sequence_length) + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + + def testBahdanauNotNormalized(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=0.051747426), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype(np.int32), mean=3.33333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.44189346), + ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.65429491)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype(np.float32), mean=0.073610783), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + create_query_layer=True, + expected_final_alignment_history=expected_final_alignment_history, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauNormalized(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.6)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testLuongNotNormalized(self): + create_attention_mechanism = wrapper.LuongAttentionV2 + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9) + + def testLuongScaled(self): + create_attention_mechanism = wrapper.LuongAttentionV2 + create_attention_kwargs = {"scale": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + create_attention_kwargs=create_attention_kwargs) + + def testNotUseAttentionLayer(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)], + attention=ResultSummary( + shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_layer_size=None, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauMonotonicNotNormalized(self): + create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauMonotonicNormalized(self): + create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones", + "normalize": True} + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testLuongMonotonicNotNormalized(self): + create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.133333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history) + + def testLuongMonotonicScaled(self): + create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 + create_attention_kwargs = {"scale": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)], + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_attention_kwargs=create_attention_kwargs) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index abcf71c61b6e6df9462bf06323b8b11d5cc0d9a8..599abf5a361fa6e2067cd18725a9a471add8ddeb 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -13,31 +13,30 @@ # limitations under the License. # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.basic_decoder.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np -from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell -from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test -# pylint: enable=g-import-not-at-top +@test_util.run_v1_only class BasicDecoderTest(test.TestCase): def _testStepWithTrainingHelper(self, use_output_layer): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2341ebb77ab6ecad1e979bc8bed0080128a804da --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py @@ -0,0 +1,670 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +class BasicDecoderTest(keras_parameterized.TestCase): + """Unit test for basic_decoder.BasicDecoderV2.""" + + @parameterized.named_parameters( + ("use_output_layer", True), + ("without_output_layer", False)) + def testStepWithTrainingHelperOutputLayer(self, use_output_layer): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + output_layer_depth = 3 + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampler = sampler_py.TrainingSampler(time_major=False) + if use_output_layer: + output_layer = layers_core.Dense(output_layer_depth, use_bias=False) + expected_output_depth = output_layer_depth + else: + output_layer = None + expected_output_depth = cell_depth + initial_state = cell.zero_state(dtype=dtypes.float32, + batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler, + output_layer=output_layer) + + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(input_t, + initial_state=initial_state, + sequence_length=sequence_length) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(expected_output_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, expected_output_depth), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + if use_output_layer: + # The output layer was accessed + self.assertEqual(len(output_layer.variables), 1) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + self.assertEqual(output_dtype.sample_id, + eval_result["step_outputs"].sample_id.dtype) + self.assertAllEqual( + np.argmax(eval_result["step_outputs"].rnn_output, -1), + eval_result["step_outputs"].sample_id) + + def DISABLED_testStepWithGreedyEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.cached_session(use_gpu=True): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + embeddings_t = constant_op.constant(embeddings) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.GreedyEmbeddingSampler() + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + embeddings_t, + start_tokens=start_tokens, + end_token=end_token, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + expected_sample_ids = np.argmax( + eval_result["step_outputs"].rnn_output, -1) + expected_step_finished = (expected_sample_ids == end_token) + expected_step_next_inputs = embeddings[expected_sample_ids] + self.assertAllEqual([False, False, False, False, False], + eval_result["first_finished"]) + self.assertAllEqual(expected_step_finished, eval_result["step_finished"]) + self.assertEqual(output_dtype.sample_id, + eval_result["step_outputs"].sample_id.dtype) + self.assertAllEqual(expected_sample_ids, + eval_result["step_outputs"].sample_id) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithSampleEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + np.random.seed(0) + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.cached_session(use_gpu=True): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + embeddings_t = constant_op.constant(embeddings) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.SampleEmbeddingSampler(seed=0) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(embeddings_t, + start_tokens=start_tokens, + end_token=end_token, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = embeddings[sample_ids] + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithScheduledEmbeddingTrainingHelper(self): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + vocabulary_size = 10 + + with self.cached_session(use_gpu=True): + inputs = np.random.randn( + batch_size, max_time, input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + embeddings = np.random.randn( + vocabulary_size, input_depth).astype(np.float32) + half = constant_op.constant(0.5) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.ScheduledEmbeddingTrainingSampler( + sampling_probability=half, + time_major=False) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + input_t, sequence_length=sequence_length, embedding=embeddings, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(vocabulary_size, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, vocabulary_size), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[1].get_shape()) + self.assertEqual((batch_size, input_depth), + step_next_inputs.get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + batch_where_not_sampling = np.where(sample_ids == -1) + batch_where_sampling = np.where(sample_ids > -1) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_sampling], + embeddings[sample_ids[batch_where_sampling]]) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_not_sampling], + np.squeeze(inputs[batch_where_not_sampling, 1], axis=0)) + + def _testStepWithScheduledOutputTrainingHelper( + self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = input_depth + if use_auxiliary_inputs: + auxiliary_input_depth = 4 + auxiliary_inputs = np.random.randn( + batch_size, max_time, auxiliary_input_depth).astype(np.float32) + else: + auxiliary_inputs = None + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampling_probability = constant_op.constant(sampling_probability) + + if use_next_inputs_fn: + def next_inputs_fn(outputs): + # Use deterministic function for test. + samples = math_ops.argmax(outputs, axis=1) + return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32) + else: + next_inputs_fn = None + + sampler = sampler_py.ScheduledOutputTrainingSampler( + sampling_probability=sampling_probability, + time_major=False, + next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(input_t, + sequence_length=sequence_length, + initial_state=initial_state, + auxiliary_inputs=auxiliary_inputs) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + + if use_next_inputs_fn: + output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) + + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + + fetches = { + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + } + if use_next_inputs_fn: + fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn + + eval_result = self.evaluate(fetches) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + batch_where_not_sampling = np.where(np.logical_not(sample_ids)) + batch_where_sampling = np.where(sample_ids) + + auxiliary_inputs_to_concat = ( + auxiliary_inputs[:, 1] if use_auxiliary_inputs else + np.array([]).reshape(batch_size, 0).astype(np.float32)) + + expected_next_sampling_inputs = np.concatenate( + (eval_result["output_after_next_inputs_fn"][batch_where_sampling] + if use_next_inputs_fn else + eval_result["step_outputs"].rnn_output[batch_where_sampling], + auxiliary_inputs_to_concat[batch_where_sampling]), + axis=-1) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_sampling], + expected_next_sampling_inputs) + + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_not_sampling], + np.concatenate( + (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), + auxiliary_inputs_to_concat[batch_where_not_sampling]), + axis=-1)) + + def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs( + self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=False, + use_auxiliary_inputs=False) + + def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=True, + use_auxiliary_inputs=False) + + def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=False, + use_auxiliary_inputs=True) + + def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs( + self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=True, + use_auxiliary_inputs=True) + + def testStepWithScheduledOutputTrainingHelperWithNoSampling(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.0, use_next_inputs_fn=True, + use_auxiliary_inputs=True) + + def testStepWithInferenceHelperCategorical(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size, dtype=np.int32) * start_token, + vocabulary_size) + + # The sample function samples categorically from the logits. + sample_fn = lambda x: sampler_py.categorical_sample(logits=x) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = ( + lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) + end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token) + + with self.cached_session(use_gpu=True): + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.InferenceSampler( + sample_fn, sample_shape=(), sample_dtype=dtypes.int32, end_fn=end_fn, + next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + start_inputs, initial_state=initial_state) + + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = np.zeros((batch_size, vocabulary_size)) + expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0 + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithInferenceHelperMultilabel(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size, dtype=np.int32) * start_token, + vocabulary_size) + + # The sample function samples independent bernoullis from the logits. + sample_fn = ( + lambda x: sampler_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = math_ops.to_float + end_fn = lambda sample_ids: sample_ids[:, end_token] + + with self.cached_session(use_gpu=True): + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.InferenceSampler( + sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool, + end_fn=end_fn, next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + start_inputs, initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, cell_depth), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = sample_ids[:, end_token] + expected_step_next_inputs = sample_ids.astype(np.float32) + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 5e28e651c666b1c448f778fc9c02d637ce817bae..8c84cd13588b624d8a50ca2dc3e4432cf6422473 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -184,14 +187,23 @@ class TestArrayShapeChecks(test.TestCase): shape=dynamic_shape) batch_size = array_ops.constant(batch_size) - check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access - with self.cached_session() as sess: - if is_valid: - sess.run(check_op) + def _test_body(): + # pylint: disable=protected-access + if context.executing_eagerly(): + beam_search_decoder._check_batch_beam(t, batch_size, beam_width) else: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(check_op) + with self.cached_session(): + check_op = beam_search_decoder._check_batch_beam( + t, batch_size, beam_width) + self.evaluate(check_op) + # pylint: enable=protected-access + + if is_valid: + _test_body() + else: + with self.assertRaises(errors.InvalidArgumentError): + _test_body() def test_array_shape_dynamic_checks(self): self._test_array_shape_dynamic_checks( @@ -460,6 +472,7 @@ class TestLargeBeamStep(test.TestCase): self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]]) +@test_util.run_v1_only class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention, @@ -530,11 +543,10 @@ class BeamSearchDecoderTest(test.TestCase): return (shape[1], shape[0]) + shape[2:] return shape - self.assertTrue( - isinstance(final_outputs, - beam_search_decoder.FinalBeamSearchDecoderOutput)) - self.assertTrue( - isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( @@ -574,5 +586,119 @@ class BeamSearchDecoderTest(test.TestCase): with_alignment_history=True) +@test_util.run_all_in_graph_and_eager_modes +class BeamSearchDecoderV2Test(test.TestCase): + + def _testDynamicDecodeRNN(self, time_major, has_attention, + with_alignment_history=False): + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) + batch_size = 5 + decoder_max_time = 4 + input_depth = 7 + cell_depth = 9 + attention_depth = 6 + vocab_size = 20 + end_token = vocab_size - 1 + start_token = 0 + embedding_dim = 50 + max_out = max(decoder_sequence_length) + output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) + beam_width = 3 + + with self.cached_session(): + batch_size_tensor = constant_op.constant(batch_size) + embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) + cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) + coverage_penalty_weight = 0.0 + if has_attention: + coverage_penalty_weight = 0.2 + inputs = array_ops.placeholder_with_default( + np.random.randn(batch_size, decoder_max_time, input_depth).astype( + np.float32), + shape=(None, None, input_depth)) + tiled_inputs = beam_search_decoder.tile_batch( + inputs, multiplier=beam_width) + tiled_sequence_length = beam_search_decoder.tile_batch( + encoder_sequence_length, multiplier=beam_width) + attention_mechanism = attention_wrapper.BahdanauAttention( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) + cell = attention_wrapper.AttentionWrapper( + cell=cell, + attention_mechanism=attention_mechanism, + attention_layer_size=attention_depth, + alignment_history=with_alignment_history) + cell_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone(cell_state=initial_state) + bsd = beam_search_decoder.BeamSearchDecoderV2( + cell=cell, + beam_width=beam_width, + output_layer=output_layer, + length_penalty_weight=0.0, + coverage_penalty_weight=coverage_penalty_weight, + output_time_major=time_major, + maximum_iterations=max_out) + + final_outputs, final_state, final_sequence_lengths = bsd( + embedding, + start_tokens=array_ops.fill([batch_size_tensor], start_token), + end_token=end_token, + initial_state=cell_state) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) + + beam_search_decoder_output = final_outputs.beam_search_decoder_output + expected_seq_length = 3 if context.executing_eagerly() else None + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(beam_search_decoder_output.scores.get_shape().as_list())) + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(final_outputs.predicted_ids.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + eval_results = self.evaluate({ + 'final_outputs': final_outputs, + 'final_sequence_lengths': final_sequence_lengths + }) + + max_sequence_length = np.max(eval_results['final_sequence_lengths']) + + # A smoke test + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), + eval_results['final_outputs'].beam_search_decoder_output.scores.shape) + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), eval_results[ + 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) + + def testDynamicDecodeRNNBatchMajorNoAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=False) + + def testDynamicDecodeRNNBatchMajorYesAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=True) + + def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): + self._testDynamicDecodeRNN( + time_major=False, + has_attention=True, + with_alignment_history=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index b41734d214e98cd24be0c98ee67f7cb5e58b7a61..5506aa8b8ee259fbacc80bc310cd954bdb66482b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -49,8 +49,8 @@ class GatherTreeTest(test.TestCase): parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=end_token) - with self.session(use_gpu=True): - self.assertAllEqual(expected_result, beams.eval()) + with self.cached_session(use_gpu=True): + self.assertAllEqual(expected_result, self.evaluate(beams)) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) @@ -62,15 +62,14 @@ class GatherTreeTest(test.TestCase): [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] with ops.device("/cpu:0"): - beams = beam_search_ops.gather_tree( - step_ids=step_ids, - parent_ids=parent_ids, - max_sequence_lengths=max_sequence_lengths, - end_token=end_token) - with self.cached_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): - _ = beams.eval() + beams = beam_search_ops.gather_tree( + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) + self.evaluate(beams) def testBadParentValuesOnGPU(self): # Only want to run this test on CUDA devices, as gather_tree is not @@ -93,8 +92,7 @@ class GatherTreeTest(test.TestCase): parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=end_token) - with self.session(use_gpu=True): - self.assertAllEqual(expected_result, beams.eval()) + self.assertAllEqual(expected_result, self.evaluate(beams)) def testGatherTreeBatch(self): batch_size = 10 @@ -103,7 +101,7 @@ class GatherTreeTest(test.TestCase): max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] end_token = 5 - with self.session(use_gpu=True): + with self.cached_session(use_gpu=True): step_ids = np.random.randint( 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) parent_ids = np.random.randint( @@ -116,7 +114,7 @@ class GatherTreeTest(test.TestCase): end_token=end_token) self.assertEqual((max_time, batch_size, beam_width), beams.shape) - beams_value = beams.eval() + beams_value = self.evaluate(beams) for b in range(batch_size): # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 4c25489fade320f2f2218354343021a71af01baf..4a420221e27775c2844daaba6b6d2b3b2ce38828 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -13,26 +13,25 @@ # limitations under the License. # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.decoder.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np +from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py -from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import test -# pylint: enable=g-import-not-at-top +@test_util.run_v1_only class DynamicDecodeRNNTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bba2b32e940aa4d5984821ebd3845d7f272549 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py @@ -0,0 +1,169 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +class DecodeV2RNNTest(keras_parameterized.TestCase, test.TestCase): + """Tests for DecoderV2.""" + + def _testDecodeRNN(self, time_major, maximum_iterations=None): + + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + max_out = max(sequence_length) + + with self.cached_session(use_gpu=True): + if time_major: + inputs = np.random.randn(max_time, batch_size, + input_depth).astype(np.float32) + else: + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampler = sampler_py.TrainingSampler(time_major=time_major) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler, + output_time_major=time_major, + maximum_iterations=maximum_iterations) + + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + (final_outputs, unused_final_state, final_sequence_length) = my_decoder( + input_t, initial_state=initial_state, sequence_length=sequence_length) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + if not context.executing_eagerly(): + self.assertEqual((batch_size,), + tuple(final_sequence_length.get_shape().as_list())) + self.assertEqual( + _t((batch_size, None, cell_depth)), + tuple(final_outputs.rnn_output.get_shape().as_list())) + self.assertEqual( + _t((batch_size, None)), + tuple(final_outputs.sample_id.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + final_outputs = self.evaluate(final_outputs) + final_sequence_length = self.evaluate(final_sequence_length) + + # Mostly a smoke test + time_steps = max_out + expected_length = sequence_length + if maximum_iterations is not None: + time_steps = min(max_out, maximum_iterations) + expected_length = [min(x, maximum_iterations) for x in expected_length] + if context.executing_eagerly() and maximum_iterations != 0: + # Only check the shape of output when maximum_iterations > 0, see + # b/123431432 for more details. + self.assertEqual( + _t((batch_size, time_steps, cell_depth)), + final_outputs.rnn_output.shape) + self.assertEqual( + _t((batch_size, time_steps)), final_outputs.sample_id.shape) + self.assertItemsEqual(expected_length, final_sequence_length) + + def testDynamicDecodeRNNBatchMajor(self): + self._testDecodeRNN(time_major=False) + + def testDynamicDecodeRNNTimeMajor(self): + self._testDecodeRNN(time_major=True) + + def testDynamicDecodeRNNZeroMaxIters(self): + self._testDecodeRNN(time_major=True, maximum_iterations=0) + + def testDynamicDecodeRNNOneMaxIter(self): + self._testDecodeRNN(time_major=True, maximum_iterations=1) + + def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + self, use_sequence_length): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + max_out = max(sequence_length) + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + inputs = constant_op.constant(inputs) + + cell = rnn_cell.LSTMCell(cell_depth) + zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, sampler=sampler, impute_finished=use_sequence_length) + + final_decoder_outputs, final_decoder_state, _ = my_decoder( + inputs, initial_state=zero_state, sequence_length=sequence_length) + + final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length if use_sequence_length else None, + initial_state=zero_state) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "final_decoder_outputs": final_decoder_outputs, + "final_decoder_state": final_decoder_state, + "final_rnn_outputs": final_rnn_outputs, + "final_rnn_state": final_rnn_state + }) + + # Decoder only runs out to max_out; ensure values are identical + # to dynamic_rnn, which also zeros out outputs and passes along state. + self.assertAllClose(eval_result["final_decoder_outputs"].rnn_output, + eval_result["final_rnn_outputs"][:, 0:max_out, :]) + if use_sequence_length: + self.assertAllClose(eval_result["final_decoder_state"], + eval_result["final_rnn_state"]) + + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNWithSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + use_sequence_length=True) + + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNNoSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + use_sequence_length=False) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 41b2a53ca5b178be9b04446c81d832575e5ed75b..7eb544a921c595c667083b783757f4b719be5aa0 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes class LossTest(test.TestCase): - def setUp(self): + def config_default_values(self): self.batch_size = 2 self.sequence_length = 3 self.number_of_classes = 5 @@ -56,7 +56,8 @@ class LossTest(test.TestCase): self.expected_loss = 1.60944 def testSequenceLoss(self): - with self.test_session(use_gpu=True): + self.config_default_values() + with self.cached_session(use_gpu=True): average_loss_per_example = loss.sequence_loss( self.logits, self.targets, self.weights, average_across_timesteps=True, @@ -90,7 +91,8 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testSequenceLossClass(self): - with self.test_session(use_gpu=True): + self.config_default_values() + with self.cached_session(use_gpu=True): seq_loss = loss.SequenceLoss(average_across_timesteps=True, average_across_batch=True, sum_over_timesteps=False, @@ -132,7 +134,8 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testSumReduction(self): - with self.test_session(use_gpu=True): + self.config_default_values() + with self.cached_session(use_gpu=True): seq_loss = loss.SequenceLoss(average_across_timesteps=False, average_across_batch=False, sum_over_timesteps=True, @@ -174,6 +177,7 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testWeightedSumReduction(self): + self.config_default_values() weights = [ constant_op.constant(1.0, shape=[self.batch_size]) for _ in range(self.sequence_length) @@ -181,7 +185,7 @@ class LossTest(test.TestCase): # Make the last element in the sequence to have zero weights. weights[-1] = constant_op.constant(0.0, shape=[self.batch_size]) self.weights = array_ops.stack(weights, axis=1) - with self.test_session(use_gpu=True): + with self.cached_session(use_gpu=True): seq_loss = loss.SequenceLoss(average_across_timesteps=False, average_across_batch=False, sum_over_timesteps=True, @@ -225,12 +229,13 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testZeroWeights(self): + self.config_default_values() weights = [ constant_op.constant(0.0, shape=[self.batch_size]) for _ in range(self.sequence_length) ] weights = array_ops.stack(weights, axis=1) - with self.test_session(use_gpu=True): + with self.cached_session(use_gpu=True): average_loss_per_example = loss.sequence_loss( self.logits, self.targets, weights, average_across_timesteps=True, diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 31c62d5849ac3bdb35cbd00f03b298cb5952162c..577a3efbd7da58e8931c6668af2c2f1be91e7298 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -25,10 +25,13 @@ import math import numpy as np from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras import initializers +from tensorflow.python.keras import layers +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -218,25 +221,46 @@ class _BaseAttentionMechanism(AttentionMechanism): return self.initial_alignments(batch_size, dtype) -class _BaseAttentionMechanismV2(AttentionMechanism, Layer): +class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): """A base AttentionMechanism class providing common functionality. Common functionality includes: 1. Storing the query and memory layers. 2. Preprocessing and storing the memory. - Note that this layer only support Keras functional API since it takes multiple - input tensors, which is not available in sequential model. + Note that this layer takes memory as its init parameter, which is an + anti-pattern of Keras API, we have to keep the memory as init parameter for + performance and dependency reason. Under the hood, during `__init__()`, it + will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let + keras to keep track of the memory tensor as the input of this layer. Once + the `__init__()` is done, then user can query the attention by + `score = att_obj([query, state])`, and use it as a normal keras layer. + + Special attention is needed when adding using this class as the base layer for + new attention: + 1. Build() could be invoked at least twice. So please make sure weights are + not duplicated. + 2. Layer.get_weights() might return different set of weights if the instance + has `query_layer`. The query_layer weights is not initialized until the + memory is configured. + + Also note that this layer does not work with Keras model when + `model.compile(run_eagerly=True)` due to the fact that this layer is stateful. + The support for that will be added in a future version. """ def __init__(self, + memory, probability_fn, query_layer=None, memory_layer=None, + memory_sequence_length=None, **kwargs): """Construct base AttentionMechanism class. Args: + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. probability_fn: A `callable`. Converts the score and previous alignments to probabilities. Its signature should be: `probabilities = probability_fn(score, state)`. @@ -247,15 +271,18 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): depth must match the depth of `query_layer`. If `memory_layer` is not provided, the shape of `memory` must match that of `query_layer`. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. **kwargs: Dictionary that contains other common arguments for layer creation. """ if (query_layer is not None - and not isinstance(query_layer, layers_base.Layer)): + and not isinstance(query_layer, layers.Layer)): raise TypeError( "query_layer is not a Layer: %s" % type(query_layer).__name__) if (memory_layer is not None - and not isinstance(memory_layer, layers_base.Layer)): + and not isinstance(memory_layer, layers.Layer)): raise TypeError( "memory_layer is not a Layer: %s" % type(memory_layer).__name__) self.query_layer = query_layer @@ -273,18 +300,127 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): self.batch_size = None self._memory_initialized = False self._check_inner_dims_defined = True + self.supports_masking = True + self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) + + if memory is not None: + # Setup the memory by self.__call__() with memory and memory_seq_length. + # This will make the attention follow the keras convention which takes + # all the tensor inputs via __call__(). + if memory_sequence_length is None: + inputs = memory + else: + inputs = [memory, memory_sequence_length] + + self.values = super(_BaseAttentionMechanismV2, self).__call__( + inputs, setup_memory=True) def build(self, input_shape): - if self.query_layer is not None: - self.query_layer.build(input_shape) - if self.memory_layer is not None: - self.memory_layer.build(input_shape) - # dtype of the layer is known at this moment, create the score_mask_value if - # needed. - self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) - self.built = True + if not self._memory_initialized: + # This is for setting up the memory, which contains memory and optional + # memory_sequence_length. Build the memory_layer with memory shape. + if self.memory_layer is not None and not self.memory_layer.built: + if isinstance(input_shape, list): + self.memory_layer.build(input_shape[0]) + else: + self.memory_layer.build(input_shape) + else: + # The input_shape should be query.shape and state.shape. Use the query + # to init the query layer. + if self.query_layer is not None and not self.query_layer.built: + self.query_layer.build(input_shape[0]) + + def __call__(self, inputs, **kwargs): + """Preprocess the inputs before calling `base_layer.__call__()`. + + Note that there are situation here, one for setup memory, and one with + actual query and state. + 1. When the memory has not been configured, we just pass all the param to + base_layer.__call__(), which will then invoke self.call() with proper + inputs, which allows this class to setup memory. + 2. When the memory has already been setup, the input should contain query + and state, and optionally processed memory. If the processed memory is + not included in the input, we will have to append it to the inputs and + give it to the base_layer.__call__(). The processed memory is the output + of first invocation of self.__call__(). If we don't add it here, then from + keras perspective, the graph is disconnected since the output from + previous call is never used. - def _setup_memory(self, memory, memory_mask=None): + Args: + inputs: the inputs tensors. + **kwargs: dict, other keyeword arguments for the `__call__()` + """ + if self._memory_initialized: + if len(inputs) not in (2, 3): + raise ValueError("Expect the inputs to have 2 or 3 tensors, got %d" % + len(inputs)) + if len(inputs) == 2: + # We append the calculated memory here so that the graph will be + # connected. + inputs.append(self.values) + return super(_BaseAttentionMechanismV2, self).__call__(inputs, **kwargs) + + def call(self, inputs, mask=None, setup_memory=False, **kwargs): + """Setup the memory or query the attention. + + There are two case here, one for setup memory, and the second is query the + attention score. `setup_memory` is the flag to indicate which mode it is. + The input list will be treated differently based on that flag. + + Args: + inputs: a list of tensor that could either be `query` and `state`, or + `memory` and `memory_sequence_length`. + `query` is the tensor of dtype matching `memory` and shape + `[batch_size, query_depth]`. + `state` is the tensor of dtype matching `memory` and shape + `[batch_size, alignments_size]`. (`alignments_size` is memory's + `max_time`). + `memory` is the memory to query; usually the output of an RNN encoder. + The tensor should be shaped `[batch_size, max_time, ...]`. + `memory_sequence_length` (optional) is the sequence lengths for the + batch entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + mask: optional bool tensor with shape `[batch, max_time]` for the mask of + memory. If it is not None, the corresponding item of the memory should + be filtered out during calculation. + setup_memory: boolean, whether the input is for setting up memory, or + query attention. + **kwargs: Dict, other keyword arguments for the call method. + Returns: + Either processed memory or attention score, based on `setup_memory`. + """ + if setup_memory: + if isinstance(inputs, list): + if len(inputs) not in (1, 2): + raise ValueError("Expect inputs to have 1 or 2 tensors, got %d" % + len(inputs)) + memory = inputs[0] + memory_sequence_length = inputs[1] if len(inputs) == 2 else None + memory_mask = mask + else: + memory, memory_sequence_length = inputs, None + memory_mask = mask + self._setup_memory(memory, memory_sequence_length, memory_mask) + # We force the self.built to false here since only memory is initialized, + # but the real query/state has not been call() yet. The layer should be + # build and call again. + self.built = False + # Return the processed memory in order to create the Keras connectivity + # data for it. + return self.values + else: + if not self._memory_initialized: + raise ValueError("Cannot query the attention before the setup of " + "memory") + if len(inputs) not in (2, 3): + raise ValueError("Expect the inputs to have query, state, and optional " + "processed memory, got %d items" % len(inputs)) + # Ignore the rest of the inputs and only care about the query and state + query, state = inputs[0], inputs[1] + return self._calculate_attention(query, state) + + def _setup_memory(self, memory, memory_sequence_length=None, + memory_mask=None): """Pre-process the memory before actually query the memory. This should only be called once at the first invocation of call(). @@ -292,17 +428,30 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): Args: memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. - memory_mask: The boolean tensor with shape `[batch_size, max_time]`. For - any value equal to False, the corresponding value in memory should be - ignored. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros for + values past the respective sequence lengths. + memory_mask: (Optional) The boolean tensor with shape `[batch_size, + max_time]`. For any value equal to False, the corresponding value in + memory should be ignored. """ if self._memory_initialized: raise ValueError("The memory for the attention has already been setup.") + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask cannot be " + "used at same time for attention.") with ops.name_scope( self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): self.values = _prepare_memory( - memory, memory_mask=memory_mask, + memory, + memory_sequence_length=memory_sequence_length, + memory_mask=memory_mask, check_inner_dims_defined=self._check_inner_dims_defined) + # Mark the value as check since the memory and memory mask might not + # passed from __call__(), which does not have proper keras metadata. + # TODO(omalleyt): Remove this hack once the mask the has proper keras + # history. + base_layer_utils.mark_checked(self.values) if self.memory_layer is not None: self.keys = self.memory_layer(self.values) else: @@ -310,39 +459,28 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): self.batch_size = ( tensor_shape.dimension_value(self.keys.shape[0]) or array_ops.shape(self.keys)[0]) - self.alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) - or array_ops.shape(self.keys)[1]) + self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) + or array_ops.shape(self.keys)[1]) if memory_mask is not None: - self.probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda - self.probability_fn(_maybe_mask_score( - score, self.score_mask_value, memory_mask=memory_mask), prev)) + unwrapped_probability_fn = self.probability_fn + def _mask_probability_fn(score, prev): + return unwrapped_probability_fn( + _maybe_mask_score( + score, + memory_mask=memory_mask, + memory_sequence_length=memory_sequence_length, + score_mask_value=self.score_mask_value), prev) + self.probability_fn = _mask_probability_fn self._memory_initialized = True - def call(self, inputs, mask=None, **kwargs): - """Base method to calculate the attention score. - - Args: - inputs: a list of tensor that contains `query`, `state`, and `memory`. - `query` is the tensor of dtype matching `memory` and shape - `[batch_size, query_depth]`. - `state` is the tensor of dtype matching `memory` and shape - `[batch_size, alignments_size]`. (`alignments_size` is memory's - `max_time`). - `memory` is the memory to query; usually the output of an RNN encoder. - This tensor should be shaped `[batch_size, max_time, feature]`. - mask: optional bool tensor with shape `[batch, max_time]` for the mask of - memory. If it is not None, the corresponding item of the memory should - be filtered out during calculation. - **kwargs: Dict, other keyword arguments for the call method. - """ - query, state, memory, memory_mask = self._process_inputs(inputs, mask) - if not self._memory_initialized: - self._setup_memory(memory, memory_mask=memory_mask) - return self.calculate_attention(query, state) - - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): raise NotImplementedError( - "calculate_attention need to be implemented by subclasses.") + "_calculate_attention need to be implemented by subclasses.") + + def compute_mask(self, inputs, mask=None): + # There real input of the attention is query and state, and the memory layer + # mask shouldn't be pass down. Returning None for all output mask here. + return None, None def get_config(self): config = {} @@ -359,16 +497,12 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): "class_name": self.memory_layer.__class__.__name__, "config": self.memory_layer.get_config(), } + # memory is a required init parameter and its a tensor. It cannot be + # serialized to config, so we put a placeholder for it. + config["memory"] = None base_config = super(_BaseAttentionMechanismV2, self).get_config() return dict(list(base_config.items()) + list(config.items())) - def _process_inputs(self, inputs, mask): - if len(inputs) != 3: - raise ValueError( - "Expect to have 3 inputs for attention, got %d" % len(inputs)) - query, state, memory = inputs - return query, state, memory, mask - def _process_probability_fn(self, func_name): """Helper method to retrieve the probably function by string input.""" valid_probability_fns = { @@ -398,6 +532,8 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): """ # Reconstruct the query and memory layer for parent class. from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + # Instead of updating the input, create a copy and use that. + config = config.copy() query_layer_config = config.pop("query_layer", None) if query_layer_config: query_layer = deserialize_layer(query_layer_config, @@ -410,6 +546,50 @@ class _BaseAttentionMechanismV2(AttentionMechanism, Layer): config["memory_layer"] = memory_layer return config + @property + def alignments_size(self): + return self._alignments_size + + @property + def state_size(self): + return self._alignments_size + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return a tensor of all zeros. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return _zero_state_tensors(max_time, batch_size, dtype) + + def initial_state(self, batch_size, dtype): + """Creates the initial state values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return the same output as initial_alignments. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A structure of all-zero tensors with shapes as described by `state_size`. + """ + return self.initial_alignments(batch_size, dtype) + def _luong_score(query, keys, scale): """Implements Luong-style (multiplicative) scoring function. @@ -579,6 +759,8 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, scale=False, probability_fn="softmax", dtype=None, @@ -588,6 +770,11 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): Args: units: The depth of the attention mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. probability_fn: (optional) string, the name of function to convert the attention score to probabilities. The default is `softmax` which is @@ -606,27 +793,31 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): wrapped_probability_fn = lambda score, _: probability_fn(score) if dtype is None: dtype = dtypes.float32 + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.scale = scale + self.scale_weight = None super(LuongAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, query_layer=None, - memory_layer=layers_core.Dense( - units, name="memory_layer", use_bias=False, dtype=dtype), + memory_layer=memory_layer, probability_fn=wrapped_probability_fn, name=name, dtype=dtype, **kwargs) - self.units = units - self.scale = scale def build(self, input_shape): super(LuongAttentionV2, self).build(input_shape) - if self.scale: + if self.scale and self.scale_weight is None: self.scale_weight = self.add_weight( "attention_g", initializer=init_ops.ones_initializer, shape=()) - else: - self.scale_weight = None self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -640,6 +831,7 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): alignments: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). + next_state: Same as the alignments. """ score = _luong_score(query, self.keys, self.scale_weight) alignments = self.probability_fn(score, state) @@ -839,8 +1031,11 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, normalize=False, probability_fn="softmax", + kernel_initializer="glorot_uniform", dtype=None, name="BahdanauAttention", **kwargs): @@ -848,12 +1043,19 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): Args: units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. probability_fn: (optional) string, the name of function to convert the attention score to probabilities. The default is `softmax` which is `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within this module. Any other value will result into validation error. Default to use `softmax`. + kernel_initializer: (optional), the name of the initializer for the + attention kernel. dtype: The data type for the query and memory layers of the attention mechanism. name: Name to use when creating ops. @@ -865,35 +1067,47 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): wrapped_probability_fn = lambda score, _: probability_fn(score) if dtype is None: dtype = dtypes.float32 + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.normalize = normalize + self.kernel_initializer = initializers.get(kernel_initializer) + self.attention_v = None + self.attention_g = None + self.attention_b = None super(BahdanauAttentionV2, self).__init__( - query_layer=layers_core.Dense( - units, name="query_layer", use_bias=False, dtype=dtype), - memory_layer=layers_core.Dense( - units, name="memory_layer", use_bias=False, dtype=dtype), + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=query_layer, + memory_layer=memory_layer, probability_fn=wrapped_probability_fn, name=name, dtype=dtype, **kwargs) - self.units = units - self.normalize = normalize def build(self, input_shape): super(BahdanauAttentionV2, self).build(input_shape) - self.attention_v = self.add_weight( - "attention_v", [self.units], dtype=self.dtype) - if self.normalize: + if self.attention_v is None: + self.attention_v = self.add_weight( + "attention_v", [self.units], + dtype=self.dtype, + initializer=self.kernel_initializer) + if self.normalize and self.attention_g is None and self.attention_b is None: self.attention_g = self.add_weight( "attention_g", initializer=init_ops.constant_initializer( math.sqrt((1. / self.units))), shape=()) self.attention_b = self.add_weight( "attention_b", shape=[self.units], initializer=init_ops.zeros_initializer()) - else: - self.attention_g = None - self.attention_b = None self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -907,6 +1121,7 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): alignments: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). + next_state: same as alignments. """ processed_query = self.query_layer(query) if self.query_layer else query score = _bahdanau_score(processed_query, self.keys, self.attention_v, @@ -921,6 +1136,7 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): "units": self.units, "normalize": self.normalize, "probability_fn": self.probability_fn_name, + "kernel_initializer": initializers.serialize(self.kernel_initializer) } base_config = super(BahdanauAttentionV2, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1280,11 +1496,14 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, normalize=False, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0., mode="parallel", + kernel_initializer="glorot_uniform", dtype=None, name="BahdanauMonotonicAttention", **kwargs): @@ -1292,6 +1511,11 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): Args: units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring for `_monotonic_probability_fn` for more information. @@ -1302,6 +1526,8 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for `tf.contrib.seq2seq.monotonic_attention` for more information. + kernel_initializer: (optional), the name of the initializer for the + attention kernel. dtype: The data type for the query and memory layers of the attention mechanism. name: Name to use when creating ops. @@ -1314,34 +1540,47 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): wrapped_probability_fn = functools.partial( _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) - super(BahdanauMonotonicAttentionV2, self).__init__( - query_layer=layers_core.Dense( - units, name="query_layer", use_bias=False, dtype=dtype), - memory_layer=layers_core.Dense( - units, name="memory_layer", use_bias=False, dtype=dtype), - probability_fn=wrapped_probability_fn, - name=name, - dtype=dtype, - **kwargs) + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) self.units = units self.normalize = normalize self.sigmoid_noise = sigmoid_noise self.sigmoid_noise_seed = sigmoid_noise_seed self.score_bias_init = score_bias_init self.mode = mode + self.kernel_initializer = initializers.get(kernel_initializer) + self.attention_v = None + self.attention_score_bias = None + self.attention_g = None + self.attention_b = None + super(BahdanauMonotonicAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) def build(self, input_shape): super(BahdanauMonotonicAttentionV2, self).build(input_shape) - self.attention_v = self.add_weight( - "attention_v", [self.units], dtype=self.dtype) - self.attention_score_bias = self.add_weight( - "attention_score_bias", shape=(), dtype=self.dtype, - initializer=init_ops.constant_initializer( - self.score_bias_init, dtype=self.dtype)) - if not self.normalize: - self.attention_g = None - self.attention_b = None - else: + if self.attention_v is None: + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype, + initializer=self.kernel_initializer) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), dtype=self.dtype, + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + if self.normalize and self.attention_g is None and self.attention_b is None: self.attention_g = self.add_weight( "attention_g", dtype=self.dtype, initializer=init_ops.constant_initializer( @@ -1352,7 +1591,7 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): initializer=init_ops.zeros_initializer()) self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -1384,6 +1623,7 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): "sigmoid_noise_seed": self.sigmoid_noise_seed, "score_bias_init": self.score_bias_init, "mode": self.mode, + "kernel_initializer": initializers.serialize(self.kernel_initializer), } base_config = super(BahdanauMonotonicAttentionV2, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1517,6 +1757,8 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, scale=False, sigmoid_noise=0., sigmoid_noise_seed=None, @@ -1529,6 +1771,11 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): Args: units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring for `_monotonic_probability_fn` for more information. @@ -1551,35 +1798,41 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): wrapped_probability_fn = functools.partial( _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) - super(LuongMonotonicAttentionV2, self).__init__( - query_layer=None, - memory_layer=layers_core.Dense( - units, name="memory_layer", use_bias=False, dtype=dtype), - probability_fn=wrapped_probability_fn, - name=name, - dtype=dtype, - **kwargs) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) self.units = units self.scale = scale self.sigmoid_noise = sigmoid_noise self.sigmoid_noise_seed = sigmoid_noise_seed self.score_bias_init = score_bias_init self.mode = mode + self.attention_g = None + self.attention_score_bias = None + super(LuongMonotonicAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) def build(self, input_shape): super(LuongMonotonicAttentionV2, self).build(input_shape) - if self.scale: + if self.scale and self.attention_g is None: self.attention_g = self.add_weight( "attention_g", initializer=init_ops.ones_initializer, shape=()) - else: - self.attention_g = None - self.attention_score_bias = self.add_weight( - "attention_score_bias", shape=(), - initializer=init_ops.constant_initializer( - self.score_bias_init, dtype=self.dtype)) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -1593,6 +1846,7 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): alignments: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). + next_state: Same as alignments """ score = _luong_score(query, self.keys, self.attention_g) score += self.attention_score_bias @@ -1603,7 +1857,7 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): def get_config(self): config = { "units": self.units, - "normalize": self.normalize, + "scale": self.scale, "sigmoid_noise": self.sigmoid_noise, "sigmoid_noise_seed": self.sigmoid_noise_seed, "score_bias_init": self.score_bias_init, @@ -1666,7 +1920,15 @@ class AttentionWrapperState( def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): - return tensor_util.with_same_shape(old, new) + if not context.executing_eagerly(): + return tensor_util.with_same_shape(old, new) + else: + if old.shape.as_list() != new.shape.as_list(): + raise ValueError("The shape of the AttentionWrapperState is " + "expected to be same as the one to clone. " + "self.shape: %s, input.shape: %s" % + (old.shape, new.shape)) + return new return new return nest.map_structure( @@ -1710,41 +1972,26 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, "but saw shape: %s" % (m.name, m.get_shape())) nest.map_structure(_check_dims, memory) if memory_sequence_length is None and memory_mask is None: - seq_len_mask = None - seq_len_batch_size = None + return memory elif memory_sequence_length is not None: seq_len_mask = array_ops.sequence_mask( memory_sequence_length, maxlen=array_ops.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_sequence_length.shape[0]) - or array_ops.shape(memory_sequence_length)[0]) else: # For memory_mask is not None - seq_len_mask = memory_mask - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_mask.shape[0]) - or array_ops.shape(memory_mask)[0]) + seq_len_mask = math_ops.cast( + memory_mask, dtype=nest.flatten(memory)[0].dtype) def _maybe_mask(m, seq_len_mask): """Mask the memory based on the memory mask.""" rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - m_batch_size = tensor_shape.dimension_value( - m.shape[0]) or array_ops.shape(m)[0] - if seq_len_batch_size is not None: - message = ("memory_sequence_length and memory tensor batch sizes do not " - "match.") - with ops.control_dependencies([ - check_ops.assert_equal( - seq_len_batch_size, m_batch_size, message=message)]): - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask - else: - return m + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) @@ -1790,8 +2037,14 @@ def hardmax(logits, name=None): def _compute_attention(attention_mechanism, cell_output, attention_state, attention_layer): """Computes the attention and alignments for a given attention_mechanism.""" - alignments, next_attention_state = attention_mechanism( - cell_output, state=attention_state) + if isinstance(attention_mechanism, _BaseAttentionMechanismV2): + alignments, next_attention_state = attention_mechanism( + [cell_output, attention_state]) + else: + # For other class, assume they are following _BaseAttentionMechanism, which + # takes query and state as separate parameter. + alignments, next_attention_state = attention_mechanism( + cell_output, state=attention_state) # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) @@ -1804,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state, # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. - context = math_ops.matmul(expanded_alignments, attention_mechanism.values) - context = array_ops.squeeze(context, [1]) + context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values) + context_ = array_ops.squeeze(context_, [1]) if attention_layer is not None: - attention = attention_layer(array_ops.concat([cell_output, context], 1)) + attention = attention_layer(array_ops.concat([cell_output, context_], 1)) else: - attention = context + attention = context_ return attention, alignments, next_attention_state @@ -2094,7 +2347,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: - cell_state = self._cell.zero_state(batch_size, dtype) + cell_state = self._cell.get_initial_state(batch_size=batch_size, + dtype=dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 7eb95e5a70de985dca0d4b565ba03bdf454b6161..16dfa7ed8268d761dee49ec0146efabcaaef1393 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -23,8 +23,10 @@ import collections from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import layers from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest @@ -146,3 +148,102 @@ class BasicDecoder(decoder.Decoder): sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished) + + +class BasicDecoderV2(decoder.BaseDecoder): + """Basic sampling decoder.""" + + def __init__(self, cell, sampler, output_layer=None, **kwargs): + """Initialize BasicDecoder. + + Args: + cell: An `RNNCell` instance. + sampler: A `Sampler` instance. + output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., + `tf.layers.Dense`. Optional layer to apply to the RNN output prior to + storing the result or sampling. + **kwargs: Other keyward arguments for layer creation. + + Raises: + TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. + """ + rnn_cell_impl.assert_like_rnncell("cell", cell) + if not isinstance(sampler, sampler_py.Sampler): + raise TypeError("sampler must be a Sampler, received: %s" % (sampler,)) + if (output_layer is not None and + not isinstance(output_layer, layers.Layer)): + raise TypeError( + "output_layer must be a Layer, received: %s" % (output_layer,)) + self.cell = cell + self.sampler = sampler + self.output_layer = output_layer + super(BasicDecoderV2, self).__init__(**kwargs) + + def initialize(self, inputs, initial_state=None, **kwargs): + """Initialize the decoder.""" + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + self._cell_dtype = nest.flatten(initial_state)[0].dtype + return self.sampler.initialize(inputs, **kwargs) + (initial_state,) + + @property + def batch_size(self): + return self.sampler.batch_size + + def _rnn_output_size(self): + size = tensor_shape.TensorShape(self.cell.output_size) + if self.output_layer is None: + return size + else: + # To use layer's compute_output_shape, we need to convert the + # RNNCell's output_size entries into shapes with an unknown + # batch size. We then pass this through the layer's + # compute_output_shape and read off all but the first (batch) + # dimensions to get the output size of the rnn with the layer + # applied to the top. + output_shape_with_unknown_batch = nest.map_structure( + lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) + layer_output_shape = self.output_layer.compute_output_shape( + output_shape_with_unknown_batch) + return nest.map_structure(lambda s: s[1:], layer_output_shape) + + @property + def output_size(self): + # Return the cell output and the id + return BasicDecoderOutput( + rnn_output=self._rnn_output_size(), + sample_id=self.sampler.sample_ids_shape) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and the sample_ids_dtype from the helper. + dtype = self._cell_dtype + return BasicDecoderOutput( + nest.map_structure(lambda _: dtype, self._rnn_output_size()), + self.sampler.sample_ids_dtype) + + def step(self, time, inputs, state): + """Perform a decoding step. + + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + cell_outputs, cell_state = self.cell(inputs, state) + if self.output_layer is not None: + cell_outputs = self.output_layer(cell_outputs) + sample_ids = self.sampler.sample( + time=time, outputs=cell_outputs, state=cell_state) + (finished, next_inputs, next_state) = self.sampler.next_inputs( + time=time, + outputs=cell_outputs, + state=cell_state, + sample_ids=sample_ids) + outputs = BasicDecoderOutput(cell_outputs, sample_ids) + return (outputs, next_state, next_inputs, finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 8f8f057702951094758b277ce060955f3dc6e99d..44b7b2c09203c860cb05889c8556ac088c18f226 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -24,11 +24,12 @@ import numpy as np from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.layers import base as layers_base +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -148,8 +149,8 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - max_sequence_lengths = math_ops.to_int32( - math_ops.reduce_max(sequence_length, axis=1)) + max_sequence_lengths = math_ops.cast( + math_ops.reduce_max(sequence_length, axis=1), dtypes.int32) sorted_beam_ids = beam_search_ops.gather_tree( step_ids=beam_ids, parent_ids=parent_ids, @@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length): return ordered -def _check_maybe(t): +def _check_ndims(t): if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t) + def _check_static_batch_beam_maybe(shape, batch_size, beam_width): """Raises an exception if dimensions are known statically and can not be reshaped to [batch_size, beam_size, -1]. @@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width): return False return True + def _check_batch_beam(t, batch_size, beam_width): """Returns an Assert operation checking that the elements of the stacked TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, @@ -215,7 +218,7 @@ def _check_batch_beam(t, batch_size, beam_width): "incompatible with the dynamic shape of %s elements. " "Consider setting reorder_tensor_arrays to False to disable " "TensorArray reordering during the beam search." - % (t.name)) + % (t if context.executing_eagerly() else t.name)) rank = t.shape.ndims shape = array_ops.shape(t) if rank == 2: @@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width): return control_flow_ops.Assert(condition, [error_message]) +class BeamSearchDecoderMixin(object): + """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder. -class BeamSearchDecoder(decoder.Decoder): - """BeamSearch sampling decoder. - - **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in - `AttentionWrapper`, then you must ensure that: - - - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - - The `batch_size` argument passed to the `zero_state` method of this - wrapper is equal to `true_batch_size * beam_width`. - - The initial state created with `zero_state` above contains a - `cell_state` value containing properly tiled final state from the - encoder. - - An example: - - ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( - encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( - encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( - sequence_length, multiplier=beam_width) - attention_mechanism = MyFavoriteAttentionMechanism( - num_units=attention_depth, - memory=tiled_inputs, - memory_sequence_length=tiled_sequence_length) - attention_cell = AttentionWrapper(cell, attention_mechanism, ...) - decoder_initial_state = attention_cell.zero_state( - dtype, batch_size=true_batch_size * beam_width) - decoder_initial_state = decoder_initial_state.clone( - cell_state=tiled_encoder_final_state) - ``` - - Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use - when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages - the translation to cover all inputs. + It is expected to be used a base class for concrete BeamSearchDecoder. Since + this is a mixin class, it is expected to be used together with other class as + base. """ def __init__(self, cell, - embedding, - start_tokens, - end_token, - initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, coverage_penalty_weight=0.0, - reorder_tensor_arrays=True): - """Initialize the BeamSearchDecoder. + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderMixin. Args: cell: An `RNNCell` instance. - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. - end_token: `int32` scalar, the token that marks end of decoding. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. @@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder): Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. + **kwargs: Dict, other keyword arguments for parent class. Raises: TypeError: if `cell` is not an instance of `RNNCell`, - or `output_layer` is not an instance of `tf.layers.Layer`. - ValueError: If `start_tokens` is not a vector or - `end_token` is not a scalar. + or `output_layer` is not an instance of `tf.keras.layers.Layer`. """ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and - not isinstance(output_layer, layers_base.Layer)): + not isinstance(output_layer, layers.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - - self._start_tokens = ops.convert_to_tensor( - start_tokens, dtype=dtypes.int32, name="start_tokens") - if self._start_tokens.get_shape().ndims != 1: - raise ValueError("start_tokens must be a vector") - self._end_token = ops.convert_to_tensor( - end_token, dtype=dtypes.int32, name="end_token") - if self._end_token.get_shape().ndims != 0: - raise ValueError("end_token must be a scalar") - - self._batch_size = array_ops.size(start_tokens) + self._start_tokens = None + self._end_token = None + self._batch_size = None self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight - self._initial_cell_state = nest.map_structure( - self._maybe_split_batch_beams, initial_state, self._cell.state_size) - self._start_tokens = array_ops.tile( - array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) - self._start_inputs = self._embedding_fn(self._start_tokens) - - self._finished = array_ops.one_hot( - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=False, - off_value=True, - dtype=dtypes.bool) + super(BeamSearchDecoderMixin, self).__init__(**kwargs) @property def batch_size(self): return self._batch_size def _rnn_output_size(self): + """Get the output shape from the RNN layer.""" size = self._cell.output_size if self._output_layer is None: return size @@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder): predicted_ids=tensor_shape.TensorShape([self._beam_width]), parent_ids=tensor_shape.TensorShape([self._beam_width])) - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_cell_state)[0].dtype - return BeamSearchDecoderOutput( - scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), - predicted_ids=dtypes.int32, - parent_ids=dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, start_inputs, initial_state)`. - """ - finished, start_inputs = self._finished, self._start_inputs - - dtype = nest.flatten(self._initial_cell_state)[0].dtype - log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=ops.convert_to_tensor(0.0, dtype=dtype), - off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), - dtype=dtype) - init_attention_probs = get_attention_probs( - self._initial_cell_state, self._coverage_penalty_weight) - if init_attention_probs is None: - init_attention_probs = () - - initial_state = BeamSearchDecoderState( - cell_state=self._initial_cell_state, - log_probs=log_probs, - finished=finished, - lengths=array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.int64), - accumulated_attention_probs=init_attention_probs) - - return (finished, start_inputs, initial_state) - def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. @@ -456,8 +351,8 @@ class BeamSearchDecoder(decoder.Decoder): """ del sequence_lengths # Get max_sequence_length across all beams for each batch. - max_sequence_lengths = math_ops.to_int32( - math_ops.reduce_max(final_state.lengths, axis=1)) + max_sequence_lengths = math_ops.cast( + math_ops.reduce_max(final_state.lengths, axis=1), dtypes.int32) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, @@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) else: @@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: @@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder): if not isinstance(t, tensor_array_ops.TensorArray): return t # pylint: disable=protected-access - if (not t._infer_shape or not t._element_shape - or t._element_shape[0].ndims is None - or t._element_shape[0].ndims < 1): + # This is a bad hack due to the implementation detail of eager/graph TA. + # TODO(b/124374427): Update this to use public property of TensorArray. + if context.executing_eagerly(): + element_shape = t._element_shape + else: + element_shape = t._element_shape[0] + if (not t._infer_shape + or not t._element_shape + or element_shape.ndims is None + or element_shape.ndims < 1): shape = ( - t._element_shape[0] if t._infer_shape and t._element_shape + element_shape if t._infer_shape and t._element_shape else tensor_shape.TensorShape(None)) tf_logging.warn("The TensorArray %s in the cell state is not amenable to " "sorting based on the beam search result. For a " @@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder): "defined and have at least a rank of 1, but saw shape: %s" % (t.handle.name, shape)) return t - shape = t._element_shape[0] # pylint: enable=protected-access if not _check_static_batch_beam_maybe( - shape, tensor_util.constant_value(self._batch_size), self._beam_width): + element_shape, tensor_util.constant_value(self._batch_size), + self._beam_width): return t t = t.stack() with ops.control_dependencies( @@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder): return (beam_search_output, beam_search_state, next_inputs, finished) +class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoder to cover all inputs. + """ + + def __init__(self, + cell, + embedding, + start_tokens, + end_token, + initial_state, + beam_width, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True): + """Initialize the BeamSearchDecoder. + + Args: + cell: An `RNNCell` instance. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + beam_width: Python integer, the number of beams. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + ValueError: If `start_tokens` is not a vector or + `end_token` is not a scalar. + """ + super(BeamSearchDecoder, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays) + + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, start_inputs, initial_state)`. + """ + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + +class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoding to cover all inputs. + """ + + def __init__(self, + cell, + beam_width, + embedding_fn=None, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderV2. + + Args: + cell: An `RNNCell` instance. + beam_width: Python integer, the number of beams. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids). + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + **kwargs: Dict, other keyword arguments for initialization. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + """ + super(BeamSearchDecoderV2, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays, + **kwargs) + + if embedding_fn is None or callable(embedding_fn): + self._embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + + def initialize(self, + embedding, + start_tokens, + end_token, + initial_state): + """Initialize the decoder. + + Args: + embedding: A tensor from the embedding layer output, which is the + `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + Returns: + `(finished, start_inputs, initial_state)`. + Raises: + ValueError: If `start_tokens` is not a vector or `end_token` is not a + scalar. + """ + if embedding is not None and self._embedding_fn is not None: + raise ValueError( + "embedding and embedding_fn cannot be provided at same time") + elif embedding is not None: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs): + init_kwargs = kwargs + init_kwargs["start_tokens"] = start_tokens + init_kwargs["end_token"] = end_token + init_kwargs["initial_state"] = initial_state + return decoder.dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=embeddning, + decoder_init_kwargs=init_kwargs) + + def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight, coverage_penalty_weight): @@ -730,7 +985,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) - add_mask = math_ops.to_int64(not_finished) + add_mask = math_ops.cast(not_finished, dtypes.int64) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) @@ -741,7 +996,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, attention_probs = get_attention_probs( next_cell_state, coverage_penalty_weight) if attention_probs is not None: - attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2) + attention_probs *= array_ops.expand_dims( + math_ops.cast(not_finished, dtypes.float32), 2) accumulated_attention_probs = ( beam_state.accumulated_attention_probs + attention_probs) @@ -775,15 +1031,17 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, gather_shape=[-1], name="next_beam_probs") # Note: just doing the following - # math_ops.to_int32(word_indices % vocab_size, + # math_ops.cast( + # word_indices % vocab_size, + # dtypes.int32, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. raw_next_word_ids = math_ops.mod( word_indices, vocab_size, name="next_beam_word_ids") - next_word_ids = math_ops.to_int32(raw_next_word_ids) - next_beam_ids = math_ops.to_int32( - word_indices / vocab_size, name="next_beam_parent_ids") + next_word_ids = math_ops.cast(raw_next_word_ids, dtypes.int32) + next_beam_ids = math_ops.cast( + word_indices / vocab_size, dtypes.int32, name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = _tensor_gather_helper( @@ -802,7 +1060,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # 2. Beams that are now finished (EOS predicted) have their length # increased by 1. # 3. Beams that are not yet finished have their length increased by 1. - lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) + lengths_to_add = math_ops.cast( + math_ops.logical_not(previously_finished), dtypes.int64) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, @@ -949,7 +1208,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, coverage_penalty = math_ops.reduce_sum( math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2) # Apply coverage penalty to finished predictions. - coverage_penalty *= math_ops.to_float(finished) + coverage_penalty *= math_ops.cast(finished, dtypes.float32) weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1] weighted_coverage_penalty = array_ops.expand_dims( @@ -1002,8 +1261,9 @@ def _length_penalty(sequence_lengths, penalty_factor): static_penalty = tensor_util.constant_value(penalty_factor) if static_penalty is not None and static_penalty == 0: return 1.0 - return math_ops.div((5. + math_ops.to_float(sequence_lengths)) - **penalty_factor, (5. + 1.)**penalty_factor) + return math_ops.div( + (5. + math_ops.cast(sequence_lengths, dtypes.float32))**penalty_factor, + (5. + 1.)**penalty_factor) def _mask_probs(probs, eos_token, finished): @@ -1068,7 +1328,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, """ if isinstance(gather_from, tensor_array_ops.TensorArray): return gather_from - _check_maybe(gather_from) + _check_ndims(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( gather_indices=gather_indices, diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index f58268eff525a4b592c79acb32207e1a3f62bdc7..33f7bac8159401175ce57c0463fff1398c1dd9bb 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -135,6 +136,127 @@ class Decoder(object): return False +class BaseDecoder(layers.Layer): + """An RNN Decoder that is based on a Keras layer. + + Concepts used by this interface: + - `inputs`: (structure of) tensors and TensorArrays that is passed as input to + the RNNCell composing the decoder, at each time step. + - `state`: (structure of) tensors and TensorArrays that is passed to the + RNNCell instance as the state. + - `memory`: (sturecute of) tensors that is usually the full output of the + encoder, which will be used for the attention wrapper for the RNNCell. + - `finished`: boolean tensor telling whether each sequence in the batch is + finished. + - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each + time step. + """ + + def __init__(self, + output_time_major=False, + impute_finished=False, + maximum_iterations=None, + parallel_iterations=32, + swap_memory=False, + **kwargs): + self.output_time_major = output_time_major + self.impute_finished = impute_finished + self.maximum_iterations = maximum_iterations + self.parallel_iterations = parallel_iterations + self.swap_memory = swap_memory + super(BaseDecoder, self).__init__(**kwargs) + + def call(self, inputs, initial_state=None, **kwargs): + init_kwargs = kwargs + init_kwargs["initial_state"] = initial_state + return dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=inputs, + decoder_init_kwargs=init_kwargs) + + @property + def batch_size(self): + """The batch size of input values.""" + raise NotImplementedError + + @property + def output_size(self): + """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s].""" + raise NotImplementedError + + @property + def output_dtype(self): + """A (possibly nested tuple of...) dtype[s].""" + raise NotImplementedError + + def initialize(self, inputs, initial_state=None, **kwargs): + """Called before any decoding iterations. + + This methods must compute initial input values and initial state. + + Args: + inputs: (structure of) tensors that contains the input for the decoder. In + the normal case, its a tensor with shape [batch, timestep, embedding]. + initial_state: (structure of) tensors that contains the initial state for + the RNNCell. + **kwargs: Other arguments that are passed in from layer.call() method. It + could contains item like input sequence_length, or masking for input. + + Returns: + `(finished, initial_inputs, initial_state)`: initial values of + 'finished' flags, inputs and state. + """ + raise NotImplementedError + + def step(self, time, inputs, state): + """Called per step of decoding (but only once for dynamic decoding). + + Args: + time: Scalar `int32` tensor. Current step number. + inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time + step. + state: RNNCell state (possibly nested tuple of) tensor[s] from previous + time step. + + Returns: + `(outputs, next_state, next_inputs, finished)`: `outputs` is an object + containing the decoder output, `next_state` is a (structure of) state + tensors and TensorArrays, `next_inputs` is the tensor that should be used + as input for the next step, `finished` is a boolean tensor telling whether + the sequence is complete, for each sequence in the batch. + """ + raise NotImplementedError + + def finalize(self, outputs, final_state, sequence_lengths): + raise NotImplementedError + + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `dynamic_decode` function keeps track + of which batch entries are already finished, and performs a logical OR to + insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `dynamic_decode` will mix up the finished state across these entries because + it does not track the reshuffle across time steps. In this case, it is + up to the decoder to declare that it will keep track of its own finished + state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + + # TODO(scottzhu): Add build/get_config/from_config and other layer methods. + + def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" def _create(s, d): @@ -149,7 +271,8 @@ def dynamic_decode(decoder, maximum_iterations=None, parallel_iterations=32, swap_memory=False, - scope=None): + scope=None, + **kwargs): """Perform dynamic decoding with `decoder`. Calls initialize() once and step() repeatedly on the Decoder object. @@ -171,6 +294,9 @@ def dynamic_decode(decoder, parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. scope: Optional variable scope to use. + **kwargs: dict, other keyword arguments for dynamic_decode. It might contain + arguments for `BaseDecoder` to initialize, which takes all tensor inputs + during call(). Returns: `(final_outputs, final_state, final_sequence_lengths)`. @@ -179,7 +305,7 @@ def dynamic_decode(decoder, TypeError: if `decoder` is not an instance of `Decoder`. ValueError: if `maximum_iterations` is provided but is not a scalar. """ - if not isinstance(decoder, Decoder): + if not isinstance(decoder, (Decoder, BaseDecoder)): raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) @@ -204,7 +330,14 @@ def dynamic_decode(decoder, if maximum_iterations.get_shape().ndims != 0: raise ValueError("maximum_iterations must be a scalar") - initial_finished, initial_inputs, initial_state = decoder.initialize() + if isinstance(decoder, Decoder): + initial_finished, initial_inputs, initial_state = decoder.initialize() + else: + # For BaseDecoder that takes tensor inputs during call. + decoder_init_input = kwargs.pop("decoder_init_input", None) + decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {}) + initial_finished, initial_inputs, initial_state = decoder.initialize( + decoder_init_input, **decoder_init_kwargs) zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, @@ -222,7 +355,7 @@ def dynamic_decode(decoder, def _shape(batch_size, from_shape): if (not isinstance(from_shape, tensor_shape.TensorShape) or from_shape.ndims == 0): - return tensor_shape.TensorShape(None) + return None else: batch_size = tensor_util.constant_value( ops.convert_to_tensor( diff --git a/tensorflow/contrib/seq2seq/python/ops/sampler.py b/tensorflow/contrib/seq2seq/python/ops/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3e48b3bc61c0ff94ae0a1794767c7ff6914969 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/sampler.py @@ -0,0 +1,765 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A library of sampler for use with SamplingDecoders.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + +__all__ = [ + "Sampler", + "TrainingSampler", + "GreedyEmbeddingSampler", + "SampleEmbeddingSampler", + "CustomSampler", + "ScheduledEmbeddingTrainingSampler", + "ScheduledOutputTrainingSampler", + "InferenceSampler", +] + +_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access + + +@six.add_metaclass(abc.ABCMeta) +class Sampler(object): + """Interface for implementing sampling in seq2seq decoders. + + Sampler instances are used by `BasicDecoder`. The normal usage of a sampler is + like below: + sampler = Sampler(init_args) + (initial_finished, initial_inputs) = sampler.initialize(input_tensors) + for time_step in range(time): + cell_output, cell_state = cell.call(cell_input, previous_state) + sample_ids = sampler.sample(time_step, cell_output, cell_state) + (finished, next_inputs, next_state) = sampler.next_inputs( + time_step,cell_output, cell_state) + + Note that all the tensor input should not be feed to Sampler as __init__() + parameters, instead, they should be feed by decoders via initialize(). + """ + + @abc.abstractmethod + def initialize(self, inputs, **kwargs): + """initialize the sampler with the input tensors. + + This method suppose to be only invoke once before the calling other methods + of the Sampler. + + Args: + inputs: A (structure of) input tensors, it could be a nested tuple or a + single tensor. + **kwargs: Other kwargs for initialization. It could contain tensors like + mask for inputs, or non tensor parameter. + + Returns: + `(initial_finished, initial_inputs)`. + """ + pass + + @abc.abstractmethod + def sample(self, time, outputs, state): + """Returns `sample_ids`.""" + pass + + @abc.abstractmethod + def next_inputs(self, time, outputs, state, sample_ids): + """Returns `(finished, next_inputs, next_state)`.""" + pass + + @abc.abstractproperty + def batch_size(self): + """Batch size of tensor returned by `sample`. + + Returns a scalar int32 tensor. The return value might not available before + the invocation of initialize(), in this case, ValueError is raised. + """ + raise NotImplementedError("batch_size has not been implemented") + + @abc.abstractproperty + def sample_ids_shape(self): + """Shape of tensor returned by `sample`, excluding the batch dimension. + + Returns a `TensorShape`. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_shape has not been implemented") + + @abc.abstractproperty + def sample_ids_dtype(self): + """DType of tensor returned by `sample`. + + Returns a DType. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_dtype has not been implemented") + + +class CustomSampler(Sampler): + """Base abstract class that allows the user to customize sampling.""" + + def __init__(self, + initialize_fn, + sample_fn, + next_inputs_fn, + sample_ids_shape=None, + sample_ids_dtype=None): + """Initializer. + + Args: + initialize_fn: callable that returns `(finished, next_inputs)` for the + first iteration. + sample_fn: callable that takes `(time, outputs, state)` and emits tensor + `sample_ids`. + next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` + and emits `(finished, next_inputs, next_state)`. + sample_ids_shape: Either a list of integers, or a 1-D Tensor of type + `int32`, the shape of each value in the `sample_ids` batch. Defaults to + a scalar. + sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32. + """ + self._initialize_fn = initialize_fn + self._sample_fn = sample_fn + self._next_inputs_fn = next_inputs_fn + self._batch_size = None + self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) + self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return self._sample_ids_shape + + @property + def sample_ids_dtype(self): + return self._sample_ids_dtype + + def initialize(self, inputs, **kwargs): + (finished, next_inputs) = self._initialize_fn(inputs, **kwargs) + if self._batch_size is None: + self._batch_size = array_ops.size(finished) + return (finished, next_inputs) + + def sample(self, time, outputs, state): + return self._sample_fn(time=time, outputs=outputs, state=state) + + def next_inputs(self, time, outputs, state, sample_ids): + return self._next_inputs_fn( + time=time, outputs=outputs, state=state, sample_ids=sample_ids) + + +class TrainingSampler(Sampler): + """A Sampler for use during training. + + Only reads inputs. + + Returned sample_ids are the argmax of the RNN output logits. + """ + + def __init__(self, time_major=False): + """Initializer. + + Args: + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + + Raises: + ValueError: if `sequence_length` is not a 1D tensor. + """ + self.time_major = time_major + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + + def initialize(self, inputs, sequence_length=None): + """Initialize the TrainSampler. + + Args: + inputs: A (structure of) input tensors. + sequence_length: An int32 vector tensor. + + Returns: + (finished, next_inputs), a tuple of two items. The first item is a boolean + vector to indicate whether the item in the batch has finished. The + second item is the first slide of input data based on the timestep + dimension (usually the second dim of the input). + """ + self.inputs = ops.convert_to_tensor(inputs, name="inputs") + if not self.time_major: + inputs = nest.map_structure(_transpose_batch_time, inputs) + + self.input_tas = nest.map_structure(_unstack_ta, inputs) + if sequence_length is None: + raise ValueError("sequence_length is required for TrainingSampler") + self.sequence_length = ops.convert_to_tensor( + sequence_length, name="sequence_length") + if self.sequence_length.get_shape().ndims != 1: + raise ValueError( + "Expected sequence_length to be a vector, but received shape: %s" % + self._sequence_length.get_shape()) + + self.zero_inputs = nest.map_structure( + lambda inp: array_ops.zeros_like(inp[0, :]), inputs) + + self._batch_size = array_ops.size(self.sequence_length) + + finished = math_ops.equal(0, self.sequence_length) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + lambda: self.zero_inputs, + lambda: nest.map_structure(lambda inp: inp.read(0), self.input_tas)) + return (finished, next_inputs) + + def sample(self, time, outputs, state): + del state + sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids): + del sample_ids + next_time = time + 1 + finished = (next_time >= self.sequence_length) + all_finished = math_ops.reduce_all(finished) + + def read_from_ta(inp): + return inp.read(next_time) + + next_inputs = control_flow_ops.cond( + all_finished, + lambda: self.zero_inputs, + lambda: nest.map_structure(read_from_ta, self.input_tas)) + return (finished, next_inputs, state) + + +class ScheduledEmbeddingTrainingSampler(TrainingSampler): + """A training sampler that adds scheduled sampling. + + Returns -1s for sample_ids where no sampling took place; valid sample id + values elsewhere. + """ + + def __init__(self, + sampling_probability, + embedding_fn=None, + time_major=False, + seed=None, + scheduling_seed=None): + """Initializer. + + Args: + sampling_probability: A `float32` 0-D or 1-D tensor: the probability of + sampling categorically from the output ids instead of reading directly + from the inputs. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + scheduling_seed: The schedule decision rule sampling seed. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + if callable(embedding_fn) or embedding_fn is None: + self.embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be callable, got %s" + % type(embedding_fn)) + self.sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self.sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self.sampling_probability.get_shape())) + self.seed = seed + self.scheduling_seed = scheduling_seed + super(ScheduledEmbeddingTrainingSampler, + self).__init__(time_major=time_major) + + def initialize(self, inputs, sequence_length=None, embedding=None): + if self.embedding_fn is None: + if embedding is None: + raise ValueError("embedding is required as a keyword argument for " + "ScheduledEmbeddingTrainingSampler") + self.embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + return super(ScheduledEmbeddingTrainingSampler, self).initialize( + inputs, sequence_length=sequence_length) + + def sample(self, time, outputs, state): + del state + # Return -1s where we did not sample, and sample_ids elsewhere + select_sample = bernoulli_sample( + probs=self.sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self.scheduling_seed) + return array_ops.where(select_sample, + categorical_sample(logits=outputs, seed=self.seed), + gen_array_ops.fill([self.batch_size], -1)) + + def next_inputs(self, time, outputs, state, sample_ids): + (finished, base_next_inputs, state) = ( + super(ScheduledEmbeddingTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids)) + + def maybe_sample(): + """Perform scheduled sampling.""" + where_sampling = math_ops.cast( + array_ops.where(sample_ids > -1), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(sample_ids <= -1), dtypes.int32) + sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) + inputs_not_sampling = array_ops.gather_nd(base_next_inputs, + where_not_sampling) + sampled_next_inputs = self.embedding_fn(sample_ids_sampling) + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd( + indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + + array_ops.scatter_nd( + indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond(all_finished, lambda: base_next_inputs, + maybe_sample) + return (finished, next_inputs, state) + + +class ScheduledOutputTrainingSampler(TrainingSampler): + """A training sampler that adds scheduled sampling directly to outputs. + + Returns False for sample_ids where no sampling took place; True elsewhere. + """ + + def __init__(self, + sampling_probability, + time_major=False, + seed=None, + next_inputs_fn=None): + """Initializer. + + Args: + sampling_probability: A `float32` scalar tensor: the probability of + sampling from the outputs instead of reading directly from the inputs. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + next_inputs_fn: (Optional) callable to apply to the RNN outputs to create + the next input when sampling. If `None` (default), the RNN outputs will + be used as the next inputs. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + self.sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self.sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self._sampling_probability.get_shape())) + + self.seed = seed + self.next_inputs_fn = next_inputs_fn + + super(ScheduledOutputTrainingSampler, self).__init__(time_major=time_major) + + def initialize(self, inputs, sequence_length=None, auxiliary_inputs=None): + if auxiliary_inputs is None: + maybe_concatenated_inputs = inputs + else: + inputs = ops.convert_to_tensor(inputs) + auxiliary_inputs = ops.convert_to_tensor(auxiliary_inputs) + maybe_concatenated_inputs = nest.map_structure( + lambda x, y: array_ops.concat((x, y), -1), inputs, auxiliary_inputs) + if not self.time_major: + auxiliary_inputs = nest.map_structure(_transpose_batch_time, + auxiliary_inputs) + if auxiliary_inputs is not None: + self._auxiliary_input_tas = nest.map_structure(_unstack_ta, + auxiliary_inputs) + else: + self._auxiliary_input_tas = None + + return super(ScheduledOutputTrainingSampler, self).initialize( + maybe_concatenated_inputs, sequence_length=sequence_length) + + def sample(self, time, outputs, state): + del state + return bernoulli_sample( + probs=self.sampling_probability, + sample_shape=self.batch_size, + seed=self.seed) + + def next_inputs(self, time, outputs, state, sample_ids): + (finished, base_next_inputs, state) = ( + super(ScheduledOutputTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids)) + sample_ids = math_ops.cast(sample_ids, dtypes.bool) + + def maybe_sample(): + """Perform scheduled sampling.""" + + def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): + """Concatenate outputs with auxiliary inputs, if they exist.""" + if self._auxiliary_input_tas is None: + return outputs_ + + next_time = time + 1 + auxiliary_inputs = nest.map_structure(lambda ta: ta.read(next_time), + self._auxiliary_input_tas) + if indices is not None: + auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) + return nest.map_structure(lambda x, y: array_ops.concat((x, y), -1), + outputs_, auxiliary_inputs) + + if self.next_inputs_fn is None: + return array_ops.where(sample_ids, + maybe_concatenate_auxiliary_inputs(outputs), + base_next_inputs) + + where_sampling = math_ops.cast(array_ops.where(sample_ids), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) + outputs_sampling = array_ops.gather_nd(outputs, where_sampling) + inputs_not_sampling = array_ops.gather_nd(base_next_inputs, + where_not_sampling) + sampled_next_inputs = maybe_concatenate_auxiliary_inputs( + self.next_inputs_fn(outputs_sampling), where_sampling) + + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd( + indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + + array_ops.scatter_nd( + indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids)) + next_inputs = control_flow_ops.cond( + math_ops.logical_or(all_finished, no_samples), lambda: base_next_inputs, + maybe_sample) + return (finished, next_inputs, state) + + +class GreedyEmbeddingSampler(Sampler): + """A sampler for use during inference. + + Uses the argmax of the output (treated as logits) and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding_fn=None): + """Initializer. + + Args: + embedding_fn: A optional callable that takes a vector tensor of `ids` + (argmax ids), or the `params` argument for `embedding_lookup`. The + returned tensor will be passed to the decoder input. Default to use + `embedding_ops.embedding_lookup`. + """ + if embedding_fn is None or callable(embedding_fn): + self.embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + + def initialize(self, embedding, start_tokens=None, end_token=None): + """Initialize the GreedyEmbeddingSampler. + + Args: + embedding: tensor that contains embedding states matrix. It will be used + to generate generate outputs with start_tokens and end_tokens. The + embedding will be ignored if the embedding_fn has been provided at + __init__(). + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + + Returns: + Tuple of two items: `(finished, self.start_inputs)`. + Raises: + ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a + scalar. + """ + if self.embedding_fn is None: + self.embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self.start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + self.end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self.start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._batch_size = array_ops.size(start_tokens) + if self.end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + self.start_inputs = self.embedding_fn(self.start_tokens) + + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self.start_inputs) + + def sample(self, time, outputs, state): + """sample for GreedyEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, use argmax to get the most probable id + if not isinstance(outputs, ops.Tensor): + raise TypeError( + "Expected outputs to be a single Tensor, got: %s" % type(outputs)) + sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids): + """next_inputs_fn for GreedyEmbeddingHelper.""" + del time, outputs # unused by next_inputs_fn + finished = math_ops.equal(sample_ids, self.end_token) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + # If we're finished, the next_inputs value doesn't matter + lambda: self.start_inputs, + lambda: self.embedding_fn(sample_ids)) + return (finished, next_inputs, state) + + +class SampleEmbeddingSampler(GreedyEmbeddingSampler): + """A sampler for use during inference. + + Uses sampling (from a distribution) instead of argmax and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding_fn=None, softmax_temperature=None, seed=None): + """Initializer. + + Args: + embedding_fn: (Optional) A callable that takes a vector tensor of `ids` + (argmax ids), or the `params` argument for `embedding_lookup`. The + returned tensor will be passed to the decoder input. + softmax_temperature: (Optional) `float32` scalar, value to divide the + logits by before computing the softmax. Larger values (above 1.0) result + in more random samples, while smaller values push the sampling + distribution towards the argmax. Must be strictly greater than 0. + Defaults to 1.0. + seed: (Optional) The sampling seed. + + Raises: + ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a + scalar. + """ + super(SampleEmbeddingSampler, self).__init__(embedding_fn) + self.softmax_temperature = softmax_temperature + self.seed = seed + + def sample(self, time, outputs, state): + """sample for SampleEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, we sample instead of argmax (greedy). + if not isinstance(outputs, ops.Tensor): + raise TypeError( + "Expected outputs to be a single Tensor, got: %s" % type(outputs)) + if self.softmax_temperature is None: + logits = outputs + else: + logits = outputs / self.softmax_temperature + + return categorical_sample(logits=logits, seed=self.seed) + + +class InferenceSampler(Sampler): + """A helper to use during inference with a custom sampling function.""" + + def __init__(self, + sample_fn, + sample_shape, + sample_dtype, + end_fn, + next_inputs_fn=None): + """Initializer. + + Args: + sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`. + sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`, + the shape of the each sample in the batch returned by `sample_fn`. + sample_dtype: the dtype of the sample returned by `sample_fn`. + end_fn: A callable that takes `sample_ids` and emits a `bool` vector + shaped `[batch_size]` indicating whether each sample is an end token. + next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns + the next batch of inputs. If not provided, `sample_ids` is used as the + next batch of inputs. + """ + self.sample_fn = sample_fn + self.sample_shape = tensor_shape.TensorShape(sample_shape) + self.sample_dtype = sample_dtype + self.end_fn = end_fn + self.next_inputs_fn = next_inputs_fn + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return self.sample_shape + + @property + def sample_ids_dtype(self): + return self.sample_dtype + + def initialize(self, start_inputs): + self.start_inputs = ops.convert_to_tensor(start_inputs, name="start_inputs") + self._batch_size = array_ops.shape(start_inputs)[0] + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self.start_inputs) + + def sample(self, time, outputs, state): + del time, state # unused by sample + return self.sample_fn(outputs) + + def next_inputs(self, time, outputs, state, sample_ids): + del time, outputs # unused by next_inputs + if self.next_inputs_fn is None: + next_inputs = sample_ids + else: + next_inputs = self.next_inputs_fn(sample_ids) + finished = self.end_fn(sample_ids) + return (finished, next_inputs, state) + + +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast( + math_ops.equal(array_ops.rank(sample_shape), 0), dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad( + array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, + logits=None, + dtype=dtypes.int32, + sample_shape=(), + seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform(new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + + return _call_sampler(_sample_n, sample_shape) + + +def _unstack_ta(inp): + return tensor_array_ops.TensorArray( + dtype=inp.dtype, + size=array_ops.shape(inp)[0], + element_shape=inp.get_shape()[1:]).unstack(inp) diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index f2bb458848fab5603128903868b52f29785efc92..7b54aafeb2cfb5f2a99a93b97d14fbc5bf6e8f9c 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -11,7 +11,7 @@ import tensorflow.contrib.slim as slim ## Why TF-Slim? -TF-Slim is a library that makes building, training and evaluation neural +TF-Slim is a library that makes defining, training and evaluating neural networks simple: * Allows the user to define models much more compactly by eliminating @@ -78,7 +78,7 @@ provides convenience wrappers for variable creation and manipulation. ## Defining Models Models can be succinctly defined using TF-Slim by combining its variables, -layers and scopes. Each of these elements are defined below. +layers and scopes. Each of these elements is defined below. ### Variables @@ -160,15 +160,15 @@ slim.add_model_variable(my_model_variable) ### Layers -While the set of TensorFlow operations is quite extensive, developers of -neural networks typically think of models in terms of higher level concepts -like "layers", "losses", "metrics", and "networks". A layer, -such as a Convolutional Layer, a Fully Connected Layer or a BatchNorm Layer -are more abstract than a single TensorFlow operation and typically involve -several operations. Furthermore, a layer usually (but not always) has -variables (tunable parameters) associated with it, unlike more primitive -operations. For example, a Convolutional Layer in a neural network -is composed of several low level operations: +While the set of TensorFlow operations is quite extensive, developers of neural +networks typically think of models in terms of higher level concepts like +"layers", "losses", "metrics", and "networks". A layer, such as a Convolutional +Layer, a Fully Connected Layer or a BatchNorm Layer is more abstract than a +single TensorFlow operation and typically involve several operations. +Furthermore, a layer usually (but not always) has variables (tunable parameters) +associated with it, unlike more primitive operations. For example, a +Convolutional Layer in a neural network is composed of several low level +operations: 1. Creating the weight and bias variables 2. Convolving the weights with the input from the previous layer @@ -455,9 +455,8 @@ loss = slim.losses.softmax_cross_entropy(predictions, labels) ``` In this example, we start by creating the model (using TF-Slim's VGG -implementation), and add the standard classification loss. Now, lets turn -to the case where we have a multi-task model that produces multiple outputs: - +implementation), and add the standard classification loss. Now, let's turn to +the case where we have a multi-task model that produces multiple outputs: ```python # Load the images and labels. @@ -555,8 +554,8 @@ that we'll save a model checkpoint every 10 minutes. ### Working Example: Training the VGG16 Model -To illustrate this, lets -examine the following sample of training the VGG network: +To illustrate this, let's examine the following sample of training the VGG +network: ```python import tensorflow as tf @@ -738,7 +737,7 @@ slim.learning.train(train_op, log_dir, init_fn=init_fn) Once we've trained a model (or even while the model is busy training) we'd like to see how well the model performs in practice. This is accomplished by picking -a set of evaluation metrics, which will grade the models performance, and the +a set of evaluation metrics, which will grade the model's performance, and the evaluation code which actually loads the data, performs inference, compares the results to the ground truth and records the evaluation scores. This step may be performed once or repeated periodically. diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py index 99ad48763031cc2f98009449cea050fd90d01eb5..36d544d565b0cab05febcb91456e99dc95cb2307 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py @@ -252,8 +252,9 @@ def parallel_read(data_sources, common_queue = data_flow_ops.FIFOQueue( capacity=capacity, dtypes=dtypes, name='common_queue') - summary.scalar('fraction_of_%d_full' % capacity, - math_ops.to_float(common_queue.size()) * (1. / capacity)) + summary.scalar( + 'fraction_of_%d_full' % capacity, + math_ops.cast(common_queue.size(), tf_dtypes.float32) * (1. / capacity)) return ParallelReader( reader_class, diff --git a/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py b/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py index 62bd20036126b41040ca4329c7f13ea7671a8045..d3c5ab7d7c2c5ced79808ecc59e0b0218c461062 100644 --- a/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py +++ b/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops @@ -86,6 +87,7 @@ def prefetch_queue(tensors, enqueue_op = queue.enqueue(tensors) queue_runner.add_queue_runner( queue_runner.QueueRunner(queue, [enqueue_op] * num_threads)) - summary.scalar("fraction_of_%d_full" % capacity, - math_ops.to_float(queue.size()) * (1. / capacity)) + summary.scalar( + "fraction_of_%d_full" % capacity, + math_ops.cast(queue.size(), _dtypes.float32) * (1. / capacity)) return queue diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 1b2b6acacca838f95cb758ae88f79263993ca69e..c3193171a0ccbdfcd1c0f563d80baf24b465e5a6 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import image_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -329,7 +329,7 @@ class SparseTensor(ItemHandler): shape = indices.dense_shape indices_shape = array_ops.shape(indices.indices) rank = indices_shape[1] - ids = math_ops.to_int64(indices.values) + ids = math_ops.cast(indices.values, dtypes.int64) indices_columns_to_preserve = array_ops.slice( indices.indices, [0, 0], array_ops.stack([-1, rank - 1])) new_indices = array_ops.concat( @@ -396,8 +396,8 @@ class Image(ItemHandler): image_format = keys_to_tensors[self._format_key] if self._repeated: - return functional_ops.map_fn(lambda x: self._decode(x, image_format), - image_buffer, dtype=self._dtype) + return map_fn.map_fn(lambda x: self._decode(x, image_format), + image_buffer, dtype=self._dtype) else: return self._decode(image_buffer, image_format) diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index 8bbdf96384683c68648367c6433eeb89c64c22bf..e9595d1b324dbd3d570d2407a6620c5295b15548 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -115,9 +115,9 @@ py_library( py_test( name = "inception_v1_test", - size = "large", + size = "medium", srcs = ["inception_v1_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v1", @@ -135,9 +135,9 @@ py_test( py_test( name = "inception_v2_test", - size = "large", + size = "medium", srcs = ["inception_v2_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v2", @@ -155,9 +155,9 @@ py_test( py_test( name = "inception_v3_test", - size = "large", + size = "medium", srcs = ["inception_v3_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v3", @@ -233,8 +233,9 @@ py_library( py_test( name = "resnet_v1_test", - size = "large", + size = "medium", srcs = ["resnet_v1_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":resnet_utils", @@ -268,8 +269,9 @@ py_library( py_test( name = "resnet_v2_test", - size = "large", + size = "medium", srcs = ["resnet_v2_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":resnet_utils", diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py index 8ff44fe4b5f21e6d174451c416b7e4107cebcde3..1cc54b15514157de1e48890feca398c906b16ec8 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py @@ -54,18 +54,19 @@ def create_test_input(batch_size, height, width, channels): return array_ops.placeholder(dtypes.float32, (batch_size, height, width, channels)) else: - return math_ops.to_float( + return math_ops.cast( np.tile( np.reshape( np.reshape(np.arange(height), [height, 1]) + np.reshape( np.arange(width), [1, width]), [1, height, width, 1]), - [batch_size, 1, 1, channels])) + [batch_size, 1, 1, channels]), dtypes.float32) class ResnetUtilsTest(test.TestCase): def testSubsampleThreeByThree(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(9)), [1, 3, 3, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(9), dtypes.float32), + [1, 3, 3, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 6, 8]), [1, 2, 2, 1]) @@ -73,7 +74,8 @@ class ResnetUtilsTest(test.TestCase): self.assertAllClose(x.eval(), expected.eval()) def testSubsampleFourByFour(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(16)), [1, 4, 4, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(16), dtypes.float32), + [1, 4, 4, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 8, 10]), [1, 2, 2, 1]) @@ -95,19 +97,20 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 26], [28, 48, 66, 37], - [43, 66, 84, 46], [26, 37, 46, 22]]) + y1_expected = math_ops.cast([[14, 28, 43, 26], [28, 48, 66, 37], + [43, 66, 84, 46], [26, 37, 46, 22]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43], [43, 84]]) + y2_expected = math_ops.cast([[14, 43], [43, 84]], dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') y3_expected = y2_expected y4 = layers.conv2d(x, 1, [3, 3], stride=2, scope='Conv') - y4_expected = math_ops.to_float([[48, 37], [37, 22]]) + y4_expected = math_ops.cast([[48, 37], [37, 22]], dtypes.float32) y4_expected = array_ops.reshape(y4_expected, [1, n2, n2, 1]) with self.cached_session() as sess: @@ -132,14 +135,19 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 58, 34], [28, 48, 66, 84, 46], - [43, 66, 84, 102, 55], - [58, 84, 102, 120, 64], - [34, 46, 55, 64, 30]]) + y1_expected = math_ops.cast([[14, 28, 43, 58, 34], + [28, 48, 66, 84, 46], + [43, 66, 84, 102, 55], + [58, 84, 102, 120, 64], + [34, 46, 55, 64, 30]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43, 34], [43, 84, 55], [34, 55, 30]]) + y2_expected = math_ops.cast([[14, 43, 34], + [43, 84, 55], + [34, 55, 30]], + dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py index 055ecff1c32f76e0788fe141f410d6e6aac86cf5..31bdea9fbcd39ac486b9e13c0b7fd24e723f7fe9 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py @@ -54,18 +54,20 @@ def create_test_input(batch_size, height, width, channels): return array_ops.placeholder(dtypes.float32, (batch_size, height, width, channels)) else: - return math_ops.to_float( + return math_ops.cast( np.tile( np.reshape( np.reshape(np.arange(height), [height, 1]) + np.reshape( np.arange(width), [1, width]), [1, height, width, 1]), - [batch_size, 1, 1, channels])) + [batch_size, 1, 1, channels]), + dtypes.float32) class ResnetUtilsTest(test.TestCase): def testSubsampleThreeByThree(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(9)), [1, 3, 3, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(9), dtypes.float32), + [1, 3, 3, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 6, 8]), [1, 2, 2, 1]) @@ -73,7 +75,8 @@ class ResnetUtilsTest(test.TestCase): self.assertAllClose(x.eval(), expected.eval()) def testSubsampleFourByFour(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(16)), [1, 4, 4, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(16), dtypes.float32), + [1, 4, 4, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 8, 10]), [1, 2, 2, 1]) @@ -95,19 +98,22 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 26], [28, 48, 66, 37], - [43, 66, 84, 46], [26, 37, 46, 22]]) + y1_expected = math_ops.cast([[14, 28, 43, 26], + [28, 48, 66, 37], + [43, 66, 84, 46], + [26, 37, 46, 22]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43], [43, 84]]) + y2_expected = math_ops.cast([[14, 43], [43, 84]], dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') y3_expected = y2_expected y4 = layers.conv2d(x, 1, [3, 3], stride=2, scope='Conv') - y4_expected = math_ops.to_float([[48, 37], [37, 22]]) + y4_expected = math_ops.cast([[48, 37], [37, 22]], dtypes.float32) y4_expected = array_ops.reshape(y4_expected, [1, n2, n2, 1]) with self.cached_session() as sess: @@ -132,17 +138,19 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 58, 34], - [28, 48, 66, 84, 46], - [43, 66, 84, 102, 55], - [58, 84, 102, 120, 64], - [34, 46, 55, 64, 30]]) + y1_expected = math_ops.cast([[14, 28, 43, 58, 34], + [28, 48, 66, 84, 46], + [43, 66, 84, 102, 55], + [58, 84, 102, 120, 64], + [34, 46, 55, 64, 30]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43, 34], - [43, 84, 55], - [34, 55, 30]]) + y2_expected = math_ops.cast([[14, 43, 34], + [43, 84, 55], + [34, 55, 30]], + dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index f88b03ec4c2b1f250091594ea12d7d1862029fa2..7dd52df6b68caea6111813837ba1e872acbeccdb 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -4,17 +4,14 @@ exports_files([ "LICENSE", ]) -load( - "//tensorflow:tensorflow.bzl", - "py_test", - "tf_gen_op_wrapper_py", -) +load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "summary_ops_test", srcs = ["summary_ops_test.py"], srcs_version = "PY2AND3", deps = [ + ":summary", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", @@ -22,7 +19,6 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", @@ -35,6 +31,7 @@ py_test( srcs = ["summary_ops_graph_test.py"], srcs_version = "PY2AND3", deps = [ + ":summary", ":summary_test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -43,7 +40,6 @@ py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python:variables", "@six_archive//:six", diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index 807741e05f92f6b666c175269742dc1af50c0054..8e13f7f56b23e47f046120b285b1519c6371ddab 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -22,6 +22,7 @@ import time import six +from tensorflow.contrib.summary import summary as summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -32,7 +33,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 10e4556dacbc17ec02c2bd698389b04d517d7076..27bfdeb3601f4fdb9897feee509b06d5e8f9b873 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -25,6 +25,7 @@ import sqlite3 import numpy as np import six +from tensorflow.contrib.summary import summary as summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -36,7 +37,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 398ac314f4b520610ec100273b37c33bc4b5b43a..583bbf97c57cf263f65bc3b0a56b32cc2dce5482 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -537,8 +537,9 @@ py_library( py_test( name = "random_forest_test", - size = "large", + size = "medium", srcs = ["client/random_forest_test.py"], + shard_count = 6, srcs_version = "PY2AND3", tags = [ "noasan", diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d..0b4125f00f9261bb9cbfaf1eedadd08189ab8be0 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib import losses from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -35,7 +36,7 @@ FEATURE_IMPORTANCE_NAME = 'global_feature_importance' def _top_k_generator(k): def _top_k(probabilities, targets): - targets = math_ops.to_int32(targets) + targets = math_ops.cast(targets, dtypes.int32) if targets.get_shape().ndims > 1: targets = array_ops.squeeze(targets, axis=[1]) return metrics.mean(nn.in_top_k(probabilities, targets, k)) @@ -48,18 +49,19 @@ def _accuracy(predictions, targets, weights=None): def _r2(probabilities, targets, weights=None): - targets = math_ops.to_float(targets) + targets = math_ops.cast(targets, dtypes.float32) y_mean = math_ops.reduce_mean(targets, 0) - squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) + squares_total = math_ops.reduce_sum( + math_ops.squared_difference(targets, y_mean), 0) squares_residuals = math_ops.reduce_sum( - math_ops.square(targets - probabilities), 0) + math_ops.squared_difference(targets, probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) return metrics.mean(score, weights=weights) def _squeeze_and_onehot(targets, depth): targets = array_ops.squeeze(targets, axis=[1]) - return array_ops.one_hot(math_ops.to_int32(targets), depth) + return array_ops.one_hot(math_ops.cast(targets, dtypes.int32), depth) def _sigmoid_entropy(probabilities, targets, weights=None): @@ -74,7 +76,7 @@ def _sigmoid_entropy(probabilities, targets, weights=None): def _softmax_entropy(probabilities, targets, weights=None): return metrics.mean( losses.sparse_softmax_cross_entropy(probabilities, - math_ops.to_int32(targets)), + math_ops.cast(targets, dtypes.int32)), weights=weights) diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py index a427a02b7cd597b7090155b5e0d300c7d71208c8..926e4dda916e3dd30015aea2d2a7b13df53cdb52 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py +++ b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py @@ -22,6 +22,7 @@ import collections from tensorflow.contrib import layers from tensorflow.contrib.framework.python.ops import variables as framework_variables +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -110,14 +111,15 @@ class HybridModel(object): """The loss to minimize while training.""" if self.is_regression: - diff = self.training_inference_graph(data) - math_ops.to_float(labels) + diff = self.training_inference_graph(data) - math_ops.cast( + labels, dtypes.float32) mean_squared_error = math_ops.reduce_mean(diff * diff) root_mean_squared_error = math_ops.sqrt(mean_squared_error, name="loss") loss = root_mean_squared_error else: loss = math_ops.reduce_mean( nn_ops.sparse_softmax_cross_entropy_with_logits( - labels=array_ops.squeeze(math_ops.to_int32(labels)), + labels=array_ops.squeeze(math_ops.cast(labels, dtypes.int32)), logits=self.training_inference_graph(data)), name="loss") if self.regularizer: diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index b9aad36f3d25b9fb7b8b525be54fb7a39394b373..76b1d2b4da269cda71f5b49878f2933d7d9b5776 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -304,7 +304,7 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, + auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index fe2c91c1047fe56710b1a86b2fa3206caf6ff3bc..0243f106814511c1b53a5aacb830b845214a00a3 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -307,7 +307,7 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids_tensor, &num_targets, &data_set, + auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); @@ -317,7 +317,7 @@ class ProcessInputOp : public OpKernel { static_cast(end), &ready_to_split); }; - auto update_collated = [this, &target, &num_targets, fertile_stats_resource, + auto update_collated = [&target, &num_targets, fertile_stats_resource, tree_resource, &leaf_examples, &set_lock, &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index e04eb60f9b27cfd8b6b4e1502594d4d310ae55cc..774da472f1543f938d1b607ebdef008f7b540211 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -18,10 +18,10 @@ #include #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/contrib/tensor_forest/python/ops/data_ops.py b/tensorflow/contrib/tensor_forest/python/ops/data_ops.py index f878e5989cf2b43be960d34a45c4014d412f1c67..5c1fe23981d50ed067ae0bcf587b89d04c515629 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/data_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/data_ops.py @@ -44,7 +44,7 @@ def CastToFloat(tensor): if tensor.dtype == dtypes.string: return tensor_forest_ops.reinterpret_string_to_float(tensor) elif tensor.dtype.is_integer: - return math_ops.to_float(tensor) + return math_ops.cast(tensor, dtypes.float32) else: return tensor @@ -195,7 +195,7 @@ def ParseLabelTensorOrDict(labels): A 2-D tensor for labels/outputs. """ if isinstance(labels, dict): - return math_ops.to_float( + return math_ops.cast( array_ops.concat( [ sparse_ops.sparse_tensor_to_dense( @@ -203,10 +203,12 @@ def ParseLabelTensorOrDict(labels): labels, sparse_tensor.SparseTensor) else labels[k] for k in sorted(labels.keys()) ], - 1)) + 1), + dtypes.float32) else: if isinstance(labels, sparse_tensor.SparseTensor): - return math_ops.to_float(sparse_ops.sparse_tensor_to_dense( - labels, default_value=-1)) + return math_ops.cast( + sparse_ops.sparse_tensor_to_dense(labels, default_value=-1), + dtypes.float32) else: - return math_ops.to_float(labels) + return math_ops.cast(labels, dtypes.float32) diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 290c16fe3966791ea78986539750caf938a37322..d36d0eb0c46b0d68bea4b6fc29a20dc8876ac539 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _model_ops = loader.load_op_library( @@ -103,9 +103,9 @@ class TreeVariable(tracking.TrackableResource): self._container = container self._init_op = None super(TreeVariable, self).__init__() - self._resource_handle = self.create_resource() + self._resource_handle = self._create_resource() - def create_resource(self): + def _create_resource(self): if context.executing_eagerly(): # TODO(allenl): This will leak memory due to kernel caching by the # shared_name attribute value (but is better than the alternative of @@ -117,7 +117,7 @@ class TreeVariable(tracking.TrackableResource): return gen_model_ops.decision_tree_resource_handle_op( self._container, shared_name=shared_name, name=self._name) - def initialize(self): + def _initialize(self): return gen_model_ops.create_tree_variable( self.resource_handle, self._tree_config, @@ -126,7 +126,7 @@ class TreeVariable(tracking.TrackableResource): @property def initializer(self): if self._init_op is None: - self._init_op = self.initialize() + self._init_op = self._initialize() return self._init_op def is_initialized(self): diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 9184198cd4c8fd2a7609714d094d5ef2b6868658..7ac68fed20c3c9dfeaff05013e3fc686eea8cc2e 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _stats_ops = loader.load_op_library( @@ -98,9 +98,9 @@ class FertileStatsVariable(tracking.TrackableResource): self._container = container self._init_op = None super(FertileStatsVariable, self).__init__() - self._resource_handle = self.create_resource() + self._resource_handle = self._create_resource() - def create_resource(self): + def _create_resource(self): if context.executing_eagerly(): # TODO(allenl): This will leak memory due to kernel caching by the # shared_name attribute value (but is better than the alternative of @@ -112,7 +112,7 @@ class FertileStatsVariable(tracking.TrackableResource): return gen_stats_ops.fertile_stats_resource_handle_op( self._container, shared_name=shared_name, name=self._name) - def initialize(self): + def _initialize(self): return gen_stats_ops.create_fertile_stats_variable( self.resource_handle, self._stats_config, @@ -121,7 +121,7 @@ class FertileStatsVariable(tracking.TrackableResource): @property def initializer(self): if self._init_op is None: - self._init_op = self.initialize() + self._init_op = self._initialize() return self._init_op def is_initialized(self): diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 6f62cd11a9733949c350e35b6b0c436dd097cc33..65a3574e75238bac283fb551a8f0110fa1358bf2 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tensor_forest.python.ops import data_ops from tensorflow.contrib.tensor_forest.python.ops import model_ops from tensorflow.contrib.tensor_forest.python.ops import stats_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -540,7 +541,8 @@ class RandomForestGraphs(object): for i in range(self.params.num_trees): with ops.device(self.variables.device_dummies[i].device): sizes.append(self.trees[i].size()) - return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes))) + return math_ops.reduce_mean( + math_ops.cast(array_ops.stack(sizes), dtypes.float32)) # pylint: disable=unused-argument def training_loss(self, features, labels, name='training_loss'): @@ -603,7 +605,7 @@ class RandomTreeGraphs(object): The last op in the random tree training graph. """ # TODO(gilberth): Use this. - unused_epoch = math_ops.to_int32(get_epoch_variable()) + unused_epoch = math_ops.cast(get_epoch_variable(), dtypes.int32) if input_weights is None: input_weights = [] diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 3f1090676865a5cddc61810c385284f0db0fbbbb..91b6d2614a8963c21e35c385411dc4c9956e3146 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -11,602 +11,54 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", - "tf_copts", "tf_cuda_library", - "tf_custom_op_library", "tf_custom_op_library_additional_deps", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", ) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -exports_files(glob([ - "test/testdata/*", -])) - -tf_cuda_cc_test( - name = "tensorrt_test_cc", - size = "small", - srcs = ["tensorrt_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - "//tensorflow/core:gpu_init", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:tensorrt", - ]), -) - -tf_custom_op_library( - name = "python/ops/_trt_ops.so", - srcs = [ - "ops/get_serialized_resource_op.cc", - "ops/trt_engine_op.cc", - ], - deps = [ - ":trt_shape_function", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - tf_cuda_library( name = "trt_shape_function", srcs = ["shape_fn/trt_shfn.cc"], hdrs = ["shape_fn/trt_shfn.h"], visibility = ["//visibility:public"], deps = [ - ":trt_logging", - ":trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_logging", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), ) -cc_library( - name = "trt_op_kernels", - srcs = [ - "kernels/get_serialized_resource_op.cc", - "kernels/trt_engine_op.cc", - ], - hdrs = [ - "kernels/trt_engine_op.h", - ], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":test_utils", - ":trt_allocator", - ":trt_conversion", - ":trt_logging", - ":trt_plugins", - ":trt_resources", - ":utils", - "@com_google_absl//absl/memory", - "//tensorflow/core:gpu_headers_lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core/grappler/costs:graph_properties", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd): fix this by merging header file in cc file. - alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs -) - -tf_cuda_cc_test( - name = "get_serialized_resource_op_test", - size = "small", - srcs = ["kernels/get_serialized_resource_op_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":get_serialized_resource_op_op_lib", - ":trt_op_kernels", - ":trt_resources", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_testutil", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "trt_engine_op", - "get_serialized_resource_op", - ], -) - -tf_cuda_library( - name = "trt_logging", - srcs = ["log/trt_logger.cc"], - hdrs = ["log/trt_logger.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - -tf_gen_op_wrapper_py( - name = "trt_ops", - deps = [ - ":get_serialized_resource_op_op_lib", - ":trt_engine_op_op_lib", - ":trt_logging", - ":trt_shape_function", - ], -) - -tf_custom_op_py_library( - name = "trt_ops_loader", - srcs = ["python/ops/trt_ops.py"], - dso = [ - ":python/ops/_trt_ops.so", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), - kernels = [ - ":trt_op_kernels", - ":trt_engine_op_op_lib", - ":get_serialized_resource_op_op_lib", - ":trt_shape_function", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", - ], -) - py_library( name = "init_py", srcs = [ "__init__.py", "python/__init__.py", + "python/trt_convert.py", ], srcs_version = "PY2AND3", deps = [ - ":tf_trt_integration_test_base", - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:errors", + "//tensorflow/python/compiler/tensorrt:init_py", ], ) -py_library( - name = "trt_ops_py", - srcs_version = "PY2AND3", - deps = [ - ":trt_ops", - ":trt_ops_loader", - ], -) +# The following rules forward the libraries that were moved in order to not +# break other internal targets. -py_library( - name = "trt_convert_py", - srcs = ["python/trt_convert.py"], - srcs_version = "PY2AND3", - deps = [ - ":wrap_conversion", - "//tensorflow/python:graph_util", - "//tensorflow/python:session", - "//tensorflow/python:tf_optimizer", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", - ], -) - -# TODO(aaroey): this wrapper has been causing troubles of double linking, so -# either get rid of it, or split to make it contain minimum dependencies. -tf_py_wrap_cc( - name = "wrap_conversion", - srcs = ["trt_conversion.i"], - copts = tf_copts(), - swig_includes = [ - "//tensorflow/python:platform/base.i", - ], - deps = [ - ":test_utils", - ":trt_conversion", - ":trt_op_kernels", - "//third_party/python_runtime:headers", - ], -) - -tf_cuda_library( - name = "trt_resources", - srcs = [ - "resources/trt_int8_calibrator.cc", - "resources/trt_resource_manager.cc", - "resources/trt_resources.cc", - ], - hdrs = [ - "resources/trt_int8_calibrator.h", - "resources/trt_lru_cache.h", - "resources/trt_resource_manager.h", - "resources/trt_resources.h", - ], - deps = [ - ":trt_allocator", - ":trt_logging", - ":utils", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - -tf_cuda_library( - name = "trt_allocator", - srcs = ["resources/trt_allocator.cc"], - hdrs = ["resources/trt_allocator.h"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - -tf_cc_test( - name = "trt_allocator_test", - size = "small", - srcs = ["resources/trt_allocator_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":trt_allocator", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "trt_lru_cache_test", - size = "small", - srcs = ["resources/trt_lru_cache_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":trt_resources", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -# Library for the node-level conversion portion of TensorRT operation creation -tf_cuda_library( +alias( name = "trt_conversion", - srcs = [ - "convert/convert_graph.cc", - "convert/convert_nodes.cc", - "convert/trt_optimization_pass.cc", - ], - hdrs = [ - "convert/convert_graph.h", - "convert/convert_nodes.h", - "convert/trt_optimization_pass.h", - ], - deps = [ - ":segment", - ":test_utils", - ":trt_allocator", - ":trt_plugins", - ":trt_logging", - ":trt_resources", - ":utils", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", - "//tensorflow/core:framework_lite", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:devices", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_cuda_cc_test( - name = "convert_graph_test", - size = "medium", - srcs = ["convert/convert_graph_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_conversion", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:direct_session", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - -tf_cuda_cc_test( - name = "convert_nodes_test", - size = "medium", - srcs = ["convert/convert_nodes_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_logging", - ":trt_conversion", - ":trt_plugins", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:tensorrt", - ]), -) - -# Library for the segmenting portion of TensorRT operation creation -cc_library( - name = "segment", - srcs = ["segment/segment.cc"], - hdrs = [ - "segment/segment.h", - "segment/union_find.h", - ], - deps = [ - "//tensorflow/core:graph", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "@protobuf_archive//:protobuf_headers", - ], -) - -tf_cc_test( - name = "segment_test", - size = "small", - srcs = ["segment/segment_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":segment", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:lib", - "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], + actual = "//tensorflow/compiler/tf2tensorrt:trt_conversion", ) -# Library for the plugin factory -tf_cuda_library( - name = "trt_plugins", - srcs = [ - "plugin/trt_plugin.cc", - "plugin/trt_plugin_factory.cc", - "plugin/trt_plugin_utils.cc", - ], - hdrs = [ - "plugin/trt_plugin.h", - "plugin/trt_plugin_factory.h", - "plugin/trt_plugin_utils.h", - ], - deps = [ - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - -tf_cuda_cc_test( - name = "trt_plugin_factory_test", - size = "small", - srcs = ["plugin/trt_plugin_factory_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_plugins", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:tensorrt", - ]), -) - -py_library( - name = "tf_trt_integration_test_base", - srcs = ["test/tf_trt_integration_test_base.py"], - deps = [ - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], -) - -cuda_py_test( - name = "trt_convert_test", - srcs = ["python/trt_convert_test.py"], - additional_deps = [ - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:graph_util", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", - "//tensorflow/python/saved_model:tag_constants", - "//tensorflow/python/saved_model:utils", - "//tensorflow/python/tools:freeze_graph_lib", - "//tensorflow/python/tools:saved_model_utils", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], -) - -cuda_py_tests( - name = "tf_trt_integration_test", - srcs = [ - "test/base_test.py", - "test/batch_matmul_test.py", - "test/biasadd_matmul_test.py", - "test/binary_tensor_weight_broadcast_test.py", - "test/concatenation_test.py", - "test/const_broadcast_test.py", - "test/conv2d_test.py", - "test/dynamic_input_shapes_test.py", - "test/identity_output_test.py", - "test/int32_test.py", - "test/lru_cache_test.py", - "test/manual_test.py", - "test/memory_alignment_test.py", - "test/multi_connection_neighbor_engine_test.py", - "test/neighboring_engine_test.py", - "test/quantization_test.py", - "test/rank_two_test.py", - "test/reshape_transpose_test.py", - "test/topk_test.py", - "test/unary_test.py", - "test/vgg_block_nchw_test.py", - "test/vgg_block_test.py", - ], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], -) - -cuda_py_test( - name = "quantization_mnist_test", - srcs = ["test/quantization_mnist_test.py"], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/keras:keras", - "//tensorflow/python/estimator:estimator", - ], - data = [ - "test/testdata/checkpoint", - "test/testdata/model.ckpt-46900.data-00000-of-00001", - "test/testdata/model.ckpt-46900.index", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_pip", - "no_tap", # It is not able to download the mnist data. - "no_windows", - "nomac", - ], -) - -cc_library( - name = "utils", - srcs = ["convert/utils.cc"], - hdrs = ["convert/utils.h"], - copts = tf_copts(), - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], +alias( + name = "trt_op_kernels", + actual = "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", ) -cc_library( - name = "test_utils", - srcs = ["test/utils.cc"], - hdrs = ["test/utils.h"], - deps = [ - "//tensorflow/core:lib", - "@com_googlesource_code_re2//:re2", - ], +alias( + name = "trt_engine_op_op_lib", + actual = "//tensorflow/compiler/tf2tensorrt:trt_engine_op_op_lib", ) diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py index 140ad4828208ae4844a49bf664955b50cd9e51cd..fd551d70b4385b14b84b7b98a6d16b0c03733d38 100644 --- a/tensorflow/contrib/tensorrt/__init__.py +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -18,18 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import errors - -# pylint: disable=unused-import,wildcard-import,g-import-not-at-top -try: - from tensorflow.contrib.tensorrt.python import * -except errors.NotFoundError as e: - no_trt_message = ( - '**** Failed to initialize TensorRT. This is either because the TensorRT' - ' installation path is not in LD_LIBRARY_PATH, or because you do not have' - ' it installed. If not installed, please go to' - ' https://developer.nvidia.com/tensorrt to download and install' - ' TensorRT ****') - print(no_trt_message) - raise e -# pylint: enable=unused-import,wildcard-import,g-import-not-at-top +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensorrt.python import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h deleted file mode 100644 index 1f39f56f6392ba33af3d74fec12c326ed4451cb6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ - -#include - -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/grappler/clusters/cluster.h" -#include "tensorflow/core/grappler/costs/graph_properties.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/types.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT - -namespace tensorflow { -namespace tensorrt { -namespace convert { - -// Helper class for the segmenter to determine whether given TF node is -// supported by TRT. -class TrtCandidateSelector { - public: - TrtCandidateSelector(const grappler::GraphProperties& graph_properties, - int precision_mode); - - // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added - // to TRT subgraph and later converted into TRT engine. - Status IsTensorRTCandidate(const tensorflow::Node* node); - - private: - // The TF-TRT node converter used to verify whether individual node is - // supported. It will operate in validation-only mode. - TrtNodeValidator validator_; - - // GraphProperties of the graph whose nodes are to be validated by - // IsTensorRTCandidate(). - const grappler::GraphProperties& graph_properties_; - - // Quantization ops are only converted when using quantized precisions. - const int precision_mode_; -}; - -struct ConversionParams { - ConversionParams() - : input_graph_def(nullptr), - max_batch_size(1), - max_workspace_size_bytes(1 << 30), - output_graph_def(nullptr), - precision_mode(1), - minimum_segment_size(3), - graph_properties(nullptr), - cluster(nullptr), - is_dyn_op(false), - fixed_input_size(true), - use_calibration(true), - max_cached_engines(1) {} - const tensorflow::GraphDef* input_graph_def; - const std::vector* output_names; - size_t max_batch_size; - size_t max_workspace_size_bytes; - tensorflow::GraphDef* output_graph_def; - int precision_mode; - int minimum_segment_size; - const tensorflow::grappler::GraphProperties* graph_properties; - const tensorflow::grappler::Cluster* cluster; - bool is_dyn_op; // Whether to create engine on conversion or execution time - bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed - int max_cached_engines; // maximum number of cached engines - bool use_calibration; - std::vector cached_engine_batches; // list of cached engines -}; - -// This method extracts calibration information from the resource managers -// and puts them in to engine nodedefs. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, - bool is_dyn_op); - -// - max_batch_size: maximum batch size which can be used for inference for -// optimization targets inference run with max batch size. -// - max_workspace_size_bytes: The upper bound of memory allowance for engine -// building. -tensorflow::Status ConvertGraphDefToTensorRT( - const tensorflow::GraphDef& graph_def, - const std::vector& output_names, size_t max_batch_size, - size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = 1, int minimum_segment_size = 3, - bool is_dyn_op = false, int max_cached_engines = 1, - std::vector cached_engine_batches = {}, bool use_calibration = true); - -// Method to call from optimization pass -tensorflow::Status ConvertAfterShapes(ConversionParams& params); - -// Return compile time TensorRT library version information. -std::vector GetLinkedTensorRTVersion(); - -// Return runtime time TensorRT library version information. -std::vector GetLoadedTensorRTVersion(); - -// Helper method for the conversion, expose for testing. -std::pair GetDeviceAndAllocator( - const ConversionParams& params, const EngineInfo& engine); - -} // namespace convert -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 09990fb5e3f93d20ba51929dcc6f49e37e0ff0ac..0a2cf105baf5efb62d0c535c1f2d081973ec0ea3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -45,7 +45,7 @@ tf_custom_op_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:framework_lite", ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", @@ -64,7 +64,7 @@ tf_kernel_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:stream_executor_headers_lib", ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 11335d7da637c813b301b4d4657462f4aae0c190..b683c14c0d77ebac74ad4d9b479c5ed493a3900a 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -21,9 +21,10 @@ limitations under the License. #include #define EIGEN_USE_GPU -#include "tensorflow/core/framework/op_kernel.h" #include "cuda/include/cuda_runtime_api.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_launch_config.h" namespace tensorflow { namespace tensorrt { @@ -38,8 +39,8 @@ void IncrementKernel(const float* d_input, float inc, float* d_output, int threads_per_block = 256; int blocks_per_grid = (count + threads_per_block - 1) / threads_per_block; - VecInc<<>>(d_input, inc, - d_output, count); + TF_CHECK_OK(CudaLaunchKernel(VecInc, threads_per_block, blocks_per_grid, 0, + stream, d_input, inc, d_output, count)); } // Note: this kernel definition is not needed in the plugin_test rule, but it is diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 8d4c893af56689185da72398919e2241d451594b..7c9075142a02546ddd580e861ac87cb86badd739 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 189e9c939b9ffd4450f7ba95fe1abdbbc049b430..fb048d7b19da0f010ed918b147013b20d37ed0dd 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h deleted file mode 100644 index 7f0f05aa0a07f84368d79ff033c6d4b0837812d8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ - -#include -#include - -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_lru_cache.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { -struct TRTInt8Calibrator; -class TRTCalibrationResource; -class AsyncHelper; -// TODO(Sami): Remove this file? - -// This OP can construct TRTEngine on the fly and if construction of engine -// fails, executes equivalent subgraph as a TensorFlow function. -class TRTEngineOp : public AsyncOpKernel { - public: - explicit TRTEngineOp(OpKernelConstruction* context); - - void ComputeAsync(OpKernelContext* context, - AsyncOpKernel::DoneCallback done) override; - - private: - // TODO(samikama): context should go to a resource manager! - - // Execute calibration - void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); - - // Construct a function handle for executing native funcdef graph - Status ConstructFunctionHandle(OpKernelContext* ctx); - - // Execute replaced native segment as function Op. - void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); - - // Execute the tensorrt engine. Returns whether we need to retry by running - // the native segment. - bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context); - - // Allocate necessary resources for calibration - Status AllocateCalibrationResources(OpKernelContext* ctx, - TRTCalibrationResource** cr); - - // Get engine for the input shape - EngineContext* GetEngine(const std::vector& input_shapes, - OpKernelContext* ctx); - - // Return engine batch in cached_engne_batch_sizes_ which is closest to input - // batch. - bool GetCompatibleCachedEngine( - const std::vector& actual_input_shapes, - std::vector* engine_input_shapes); - - std::vector input_nodes_; - std::vector output_nodes_; - - // serialized protobuf segment or trt engine depending on static_engine_ flag. - string serialized_segment_; - - // Name of the function for TF native execution of the segment. - string funcdef_name_; - - // GraphDef representation of the segment. - GraphDef segment_graph_; - - // Engine Precision mode. - int precision_mode_; - - // Whether engine is constructed during the conversion or needs to be - // constructed from protobuf segment. - bool static_engine_; - - // Whether to calibrate INT8 engine. - bool calibration_mode_; - - // Batches of the cached engines - std::vector cached_engine_batches_; - - // Maximum number of cached engines - int max_cached_engines_; - - int64 workspace_size_; - mutex engine_mutex_; - FunctionLibraryRuntime::Handle native_func_; - - // The finalized calibrator for inference. - std::unique_ptr calibrator_; - - // If true, create calibration graph for INT8 mode. Otherwise, we are using - // user-provided quantization ranges. - bool use_calibration_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index e2cf253ca07244d99245e1ff83f2e5addc26f28f..c29665b9a82c2f6ec098d34bbb77d40a2d4e85f7 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -19,12 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.python.ops import trt_ops -from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph -from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph -from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_ops.py b/tensorflow/contrib/tensorrt/python/ops/trt_ops.py deleted file mode 100644 index 1fee06854ff0a8bd45249acd5a449fea45a312b6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/python/ops/trt_ops.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""Exposes the Python wrapper of TRTEngineOp.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import platform - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tensorrt.ops.gen_trt_ops import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _trt_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_trt_ops.so")) -else: - raise RuntimeError("Windows platforms are not supported") diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 49d72232aa0cfba3f5bf533de04f4d50e65275fd..8f4f1edae0bdeba007b67f18226683b39942df24 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -18,411 +18,34 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six as _six -# pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert -from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values -from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled -# pylint: enable=unused-import,line-too-long -from tensorflow.core.framework import graph_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.client import session -from tensorflow.python.framework import errors_impl as _impl -from tensorflow.python.framework import graph_util -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.grappler import tf_optimizer -from tensorflow.python.platform import tf_logging -from tensorflow.python.saved_model import builder -from tensorflow.python.saved_model import loader_impl -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.training import saver - - -def _to_bytes(s): - """Encode s if it is a sequence of chars.""" - if isinstance(s, _six.text_type): - return s.encode("utf-8", errors="surrogateescape") - return s - - -def _to_string(s): - """Decode s if it is a sequence of bytes.""" - if isinstance(s, _six.binary_type): - return s.decode("utf-8") - return s - - -class TrtPrecisionMode(object): - FP32 = "FP32" - FP16 = "FP16" - INT8 = "INT8" - - @staticmethod - def supported_precision_modes(): - return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] - - -def get_tensorrt_rewriter_config(rewriter_config=None, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=None, - use_calibration=True): - """Returns a RewriterConfig proto for TRT transformation. - - Args: - rewriter_config: a template RewriterConfig proto used to create a - TRT-enabled RewriterConfig. If None, it will use a default one. - max_batch_size: max size for the input batch - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the 'workspaceSize' - parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph to - be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT - network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - If the number of cached engines is already at max but none of them can - serve the input, the TRTEngineOp will fall back to run the TF function - based on which the TRTEngineOp is created. - cached_engine_batches: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be <= maximum_cached_engines, and the dynamic TRT op will - use this list to determine the batch sizes of the cached engines, instead - of making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. If - set to True, a calibration graph will be created to calibrate the missing - ranges. The calibration graph must be converted to an inference graph - using calib_graph_to_infer_graph() after running calibration. if set to - False, quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if there - is a mismatch between which tensors TRT quantizes and which tensors were - trained with fake quantization. - - Returns: - A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. - - Raises: - TypeError: if any of the parameters are of unexpected type. - ValueError: if any of the parameters are of unexpected value. - """ - if rewriter_config is not None and not isinstance( - rewriter_config, rewriter_config_pb2.RewriterConfig): - raise TypeError("rewriter_config should be a RewriterConfig proto.") - - rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() - if rewriter_config is None: - # Layout optimizer may add Const nodes followed by Reshape nodes, thus we - # need to run constant folding again. - rewriter_config_with_trt.optimizers.extend( - ["constfold", "layout", "constfold"]) - rewriter_config_with_trt.meta_optimizer_iterations = ( - rewriter_config_pb2.RewriterConfig.ONE) - else: - rewriter_config_with_trt.CopyFrom(rewriter_config) - - if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): - raise ValueError(("precision mode '{}' is not supported." - "It should be one of {}").format( - precision_mode, - TrtPrecisionMode.supported_precision_modes)) - - optimizer = rewriter_config_with_trt.custom_optimizers.add() - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size - optimizer.parameter_map["max_batch_size"].i = max_batch_size - optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op - optimizer.parameter_map[ - "max_workspace_size_bytes"].i = max_workspace_size_bytes - optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) - optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines - if cached_engine_batches: - if not isinstance(cached_engine_batches, list): - raise TypeError("cached_engine_batches should be a list.") - if len(cached_engine_batches) > maximum_cached_engines: - raise ValueError("cached_engine_batches should not contain more than " - "maximum_cached_engines items.") - optimizer.parameter_map["cached_engine_batches"].list.i.extend( - cached_engine_batches) - optimizer.parameter_map["use_calibration"].b = use_calibration - return rewriter_config_with_trt - - -def create_inference_graph(input_graph_def, - outputs, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=None, - use_calibration=True, - input_saved_model_dir=None, - input_saved_model_tags=None, - output_saved_model_dir=None, - session_config=None): - """Python wrapper for the TRT transformation. - - Args: - input_graph_def: a GraphDef object containing a model to be transformed. If - set to None, the graph will be read from the SavedModel loaded from - input_saved_model_dir. - outputs: list of tensors or node names for the model outputs. Only used when - input_graph_def is not None. - max_batch_size: max size for the input batch. - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the 'workspaceSize' - parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph to - be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT - network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - If the number of cached engines is already at max but none of them can - serve the input, the TRTEngineOp will fall back to run the TF function - based on which the TRTEngineOp is created. - cached_engine_batches: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be <= maximum_cached_engines, and the dynamic TRT op will - use this list to determine the batch sizes of the cached engines, instead - of making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. If - set to True, a calibration graph will be created to calibrate the missing - ranges. The calibration graph must be converted to an inference graph - using calib_graph_to_infer_graph() after running calibration. if set to - False, quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if there - is a mismatch between which tensors TRT quantizes and which tensors were - trained with fake quantization. - input_saved_model_dir: the directory to load the SavedModel which contains - the input graph to transforms. Used only when input_graph_def is None. - input_saved_model_tags: list of tags to load the SavedModel. - output_saved_model_dir: if not None, construct a SavedModel using the - returned GraphDef and save it to the specified directory. This option only - works when the input graph is loaded from a SavedModel, i.e. when - input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. It's also used as - a template to create a TRT-enabled ConfigProto for conversion. If not - specified, a default ConfigProto will be used. - - Returns: - A GraphDef transformed from input_graph_def (or the SavedModel graph def - loaded from input_saved_model_dir, if input_graph_def is not present), where - all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF - function is added for each of the subgraphs. - - If is_dynamic_op is True, each TRTEngineOp will contain a serialized - subgraph GraphDef, which will be converted to a TRT engine at execution time - and the TRT engine will be cached for future usage. A new TRT engine will be - created each time when none of the cached engines match the input shapes. If - it fails to execute the TRT engine or the number of cached engines reaches - maximum_cached_engines, the op will fall back to call the corresponding TF - function. - - If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT - engine created from the corresponding subgraph. No more engines will be - created on the fly, and the op will fall back to call the corresponding TF - function when it fails to execute the engine. - - Raises: - ValueError: if the combination of the parameters is invalid. - RuntimeError: if the TensorRT library version is incompatible. - """ - compiled_version = get_linked_tensorrt_version() - loaded_version = get_loaded_tensorrt_version() - version_mismatch = False - if loaded_version[0] < compiled_version[0]: - tf_logging.error( - "TensorRT version mismatch. Tensorflow was compiled against " + - "TensorRT %s but library loaded from environment is TensorRT %s" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version])) + - ". Please make sure that correct version of TensorRT " + - "is available in the system and added to ldconfig or LD_LIBRARY_PATH") - raise RuntimeError("Incompatible TensorRT library version") - for i in zip(loaded_version, compiled_version): - if i[0] != i[1]: - tf_logging.warn("TensorRT mismatch. Compiled against version " + - "%s, but loaded %s. Things may not work" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version]))) - version_mismatch = True - break - if not version_mismatch: - tf_logging.info("Running against TensorRT version %s" % ".".join( - [str(x) for x in loaded_version])) - - if session_config is None: - session_config = config_pb2.ConfigProto() - - if input_saved_model_tags is None: - input_saved_model_tags = [tag_constants.SERVING] - saved_model_loader = None - grappler_meta_graph_def = None - - if input_graph_def is None: - # Read from SavedModel and freeze the graph if necessary. - if input_saved_model_dir is None: - raise ValueError("input_graph_def and input_saved_model_dir cannot be " - "both None") - with ops.Graph().as_default(): - with session.Session(config=session_config) as sess: - saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir) - input_meta_graph_def = saved_model_loader.load(sess, - input_saved_model_tags) - output_node_names = set() - - def _gather_names(tensor_info): - """Get the node names from a TensorInfo.""" - return set( - [tensor_info[key].name.split(":")[0] for key in tensor_info]) - - # Get input and outputs from all SignatureDef. - for key in input_meta_graph_def.signature_def: - signature_def = input_meta_graph_def.signature_def[key] - output_node_names.update(_gather_names(signature_def.inputs)) - output_node_names.update(_gather_names(signature_def.outputs)) - - # Freeze the variables in the SavedModel graph and copy the frozen - # graph over. - frozen_graph_def = graph_util.convert_variables_to_constants( - sess, sess.graph.as_graph_def(add_shapes=True), - list(output_node_names)) - grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() - grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) - - # Copy the collections that are not variables. - for key in input_meta_graph_def.collection_def: - # TODO(laigd): currently we use the collection key to filter out - # collections that depend on variable ops, but this may miss some - # other user-defined collections. A better way would be to use - # CollectionDef::NodeList for the filtering. - if key not in [ - "variables", "local_variables", "model_variables", - "trainable_variables", "train_op", "table_initializer" - ]: - grappler_meta_graph_def.collection_def[key].CopyFrom( - input_meta_graph_def.collection_def[key]) - - # Copy other information. - grappler_meta_graph_def.meta_info_def.CopyFrom( - input_meta_graph_def.meta_info_def) - for key in input_meta_graph_def.signature_def: - grappler_meta_graph_def.signature_def[key].CopyFrom( - input_meta_graph_def.signature_def[key]) - # TODO(laigd): maybe add back AssetFileDef. - else: - if output_saved_model_dir is not None: - raise ValueError("output_saved_model_dir cannot be set when " - "input_graph_def is set") - # Create MetaGraphDef from input graph. - graph = ops.Graph() - with graph.as_default(): - importer.import_graph_def(input_graph_def, name="") - grappler_meta_graph_def = saver.export_meta_graph( - graph_def=graph.as_graph_def(add_shapes=True), graph=graph) - if outputs: - output_collection = meta_graph_pb2.CollectionDef() - output_list = output_collection.node_list.value - for i in outputs: - if isinstance(i, ops.Tensor): - output_list.append(_to_bytes(i.name)) - else: - output_list.append(_to_bytes(i)) - # TODO(laigd): use another key as the outputs are really not train_op. - grappler_meta_graph_def.collection_def["train_op"].CopyFrom( - output_collection) - - # Create TRT-enabled ConfigProto. - session_config_with_trt = config_pb2.ConfigProto() - session_config_with_trt.CopyFrom(session_config) - rewriter_config = None - if (session_config_with_trt.HasField("graph_options") and - session_config_with_trt.graph_options.HasField("rewrite_options")): - rewriter_config = session_config_with_trt.graph_options.rewrite_options - rewriter_config_with_trt = get_tensorrt_rewriter_config( - rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, - minimum_segment_size, is_dynamic_op, maximum_cached_engines, - cached_engine_batches, use_calibration) - session_config_with_trt.graph_options.rewrite_options.CopyFrom( - rewriter_config_with_trt) - - # Run Grappler. - transformed_graph_def = tf_optimizer.OptimizeGraph( - session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") - - # Optionally write the transformed graphdef as SavedModel. - if output_saved_model_dir is not None: - saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) - with ops.Graph().as_default(): - importer.import_graph_def(transformed_graph_def, name="") - # We don't use TRT here. - with session.Session(config=session_config) as sess: - saved_model_builder.add_meta_graph_and_variables( - sess, - input_saved_model_tags, - signature_def_map=grappler_meta_graph_def.signature_def) - # Ignore other meta graphs from the input SavedModel. - saved_model_builder.save() - - return transformed_graph_def - - -def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): - """Convert an existing calibration graph to inference graph. - - Args: - calibration_graph_def: the calibration GraphDef object with calibration data - is_dynamic_op: whether to create dynamic static engines from calibration - - Returns: - New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. - Raises: - RuntimeError: if the returned status message is malformed. - """ - - is_calib_graph = False - for n in calibration_graph_def.node: - if n.op == "TRTEngineOp": - is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s - if not is_calib_graph: - tf_logging.error( - "Not a calib graph. Doesn't seem to contain any calibration nodes.") - return None - graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str, is_dynamic_op) - status = _to_string(out[0]) - output_graph_def_string = out[1] - del graph_str # Save some memory - if len(status) < 2: - raise _impl.UnknownError(None, None, status) - if status[:2] != "OK": - msg = status.split(";") - if len(msg) == 1: - raise RuntimeError("Status message is malformed {}".format(status)) - # pylint: disable=protected-access - raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), - int(msg[0])) - # pylint: enable=protected-access - output_graph_def = graph_pb2.GraphDef() - output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string # Save some memory - return output_graph_def +from tensorflow.python.compiler.tensorrt import trt_convert + + +def create_inference_graph( + input_graph_def, + outputs, + max_batch_size=1, + max_workspace_size_bytes=trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, + precision_mode=trt_convert.TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=None, + input_saved_model_dir=None, + input_saved_model_tags=None, + output_saved_model_dir=None, + session_config=None): + return trt_convert.create_inference_graph( + input_graph_def=input_graph_def, + outputs=outputs, + max_batch_size=max_batch_size, + max_workspace_size_bytes=max_workspace_size_bytes, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=is_dynamic_op, + maximum_cached_engines=maximum_cached_engines, + cached_engine_batches=cached_engine_batches, + input_saved_model_dir=input_saved_model_dir, + input_saved_model_tags=input_saved_model_tags, + output_saved_model_dir=output_saved_model_dir, + session_config=session_config) diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py deleted file mode 100644 index 3ef18e3e150e0d421baa76cbda0b0daa929a7e91..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities to test TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.contrib.tensorrt.python import trt_convert -# pylint: disable=unused-import -from tensorflow.contrib.tensorrt.python.ops import trt_ops -# pylint: enable=unused-import -from tensorflow.core.framework import graph_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import graph_util -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.saved_model import builder -from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import signature_def_utils -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.saved_model import utils -from tensorflow.python.tools import saved_model_utils - - -class TrtConvertTest(test_util.TensorFlowTestCase): - """Class to test Tensorflow-TensorRT integration python API.""" - - def testGetTensorrtRewriterConfig(self): - """Test case for trt_convert.get_tensorrt_rewriter_config().""" - rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( - rewriter_config=None, - max_batch_size=128, - max_workspace_size_bytes=1234, - precision_mode="INT8", - minimum_segment_size=10, - is_dynamic_op=True, - maximum_cached_engines=2, - cached_engine_batches=[1, 128]) - self.assertEqual(["constfold", "layout", "constfold"], - rewriter_cfg.optimizers) - self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, - rewriter_cfg.meta_optimizer_iterations) - trt_optimizer = None - for optimizer in rewriter_cfg.custom_optimizers: - if optimizer.name == "TensorRTOptimizer": - self.assertTrue(trt_optimizer is None) - trt_optimizer = optimizer - self.assertTrue(trt_optimizer is not None) - for key in [ - "minimum_segment_size", "max_batch_size", "is_dynamic_op", - "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines", - "cached_engine_batches" - ]: - self.assertTrue(key in trt_optimizer.parameter_map) - self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) - self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) - self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) - self.assertEqual(1234, - trt_optimizer.parameter_map["max_workspace_size_bytes"].i) - self.assertEqual( - trt_convert._to_bytes("INT8"), - trt_optimizer.parameter_map["precision_mode"].s) - self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) - self.assertEqual( - [1, 128], trt_optimizer.parameter_map["cached_engine_batches"].list.i) - - def _GetConfigProto(self): - """Get ConfigProto for session creation.""" - config = config_pb2.ConfigProto( - gpu_options=config_pb2.GPUOptions(allow_growth=True)) - return config - - def _GetGraph(self): - """Get the graph for testing.""" - g = ops.Graph() - with g.as_default(): - with g.device("/GPU:0"): - inp = array_ops.placeholder( - dtype=dtypes.float32, shape=[None, 1, 1], name="input") - var = variables.VariableV1([[[1.0]]], dtype=dtypes.float32, name="v1") - add = inp + var.value() - mul = inp * add - add = mul + add - out = array_ops.identity(add, name="output") - return g, var, inp, out - - def _GetGraphDef(self): - """Get the graph def for testing.""" - g, var, _, _ = self._GetGraph() - with self.session(graph=g, config=self._GetConfigProto()) as sess: - sess.run(var.initializer) - graph_def = graph_util.convert_variables_to_constants( - sess, g.as_graph_def(add_shapes=True), ["output"]) - node_name_to_op = {node.name: node.op for node in graph_def.node} - self.assertEqual({ - "v1": "Const", - "v1/read": "Identity", - "input": "Placeholder", - "add": "Add", - "mul": "Mul", - "add_1": "Add", - "output": "Identity" - }, node_name_to_op) - return graph_def - - def _WriteInputSavedModel(self, input_saved_model_dir): - """Write the saved model as an input for testing.""" - g, var, inp, out = self._GetGraph() - signature_def = signature_def_utils.build_signature_def( - inputs={"myinput": utils.build_tensor_info(inp)}, - outputs={"myoutput": utils.build_tensor_info(out)}, - method_name=signature_constants.PREDICT_METHOD_NAME) - saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) - with self.session(graph=g, config=self._GetConfigProto()) as sess: - sess.run(var.initializer) - saved_model_builder.add_meta_graph_and_variables( - sess, [tag_constants.SERVING], - signature_def_map={"mypredict": signature_def}) - saved_model_builder.save() - - def _TestCreateInferenceGraph(self, - input_saved_model_dir=None, - output_saved_model_dir=None): - """General method to test trt_convert.create_inference_graph().""" - input_graph_def = None if input_saved_model_dir else self._GetGraphDef() - output_graph_def = trt_convert.create_inference_graph( - input_graph_def, ["output"], - input_saved_model_dir=input_saved_model_dir, - output_saved_model_dir=output_saved_model_dir, - session_config=self._GetConfigProto()) - graph_defs_to_verify = [output_graph_def] - if output_saved_model_dir is not None: - saved_model_graph_def = saved_model_utils.get_meta_graph_def( - output_saved_model_dir, tag_constants.SERVING).graph_def - self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef)) - graph_defs_to_verify.append(saved_model_graph_def) - - for graph_def in graph_defs_to_verify: - node_name_to_op = {node.name: node.op for node in graph_def.node} - self.assertEqual({ - "input": "Placeholder", - "TRTEngineOp_0": "TRTEngineOp", - "output": "Identity" - }, node_name_to_op) - - def testCreateInferenceGraph_BasicConversion(self): - """Test case for trt_convert.create_inference_graph().""" - if not trt_convert.is_tensorrt_enabled(): - return - - # Use GraphDef as input. - self._TestCreateInferenceGraph() - - # Use SavedModel as input. - tmp_dir = self.get_temp_dir() - input_saved_model_dir = os.path.join(tmp_dir, "in_dir1") - output_saved_model_dir = os.path.join(tmp_dir, "out_dir1") - self._WriteInputSavedModel(input_saved_model_dir) - self._TestCreateInferenceGraph(input_saved_model_dir, - output_saved_model_dir) - - def _TestRun(self, sess, batch_size, expect_engine_is_run): - trt_convert.clear_test_values("") - result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) - self.assertAllEqual([[[4.0]]] * batch_size, result) - execute_engine_test_value = ("done" if expect_engine_is_run else "") - execute_native_segment_test_value = ("" if expect_engine_is_run else "done") - self.assertEqual( - execute_engine_test_value, - trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine")) - self.assertEqual( - execute_native_segment_test_value, - trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment")) - - def testCreateInferenceGraph_MinimumSegmentSize(self): - if not trt_convert.is_tensorrt_enabled(): - return - output_graph_def = trt_convert.create_inference_graph( - self._GetGraphDef(), ["output"], - minimum_segment_size=5, - is_dynamic_op=False) - node_name_to_op = {node.name: node.op for node in output_graph_def.node} - self.assertEqual({ - "v1/read": "Const", - "input": "Placeholder", - "add": "Add", - "mul": "Mul", - "add_1": "Add", - "output": "Identity" - }, node_name_to_op) - - def testCreateInferenceGraph_DynamicOp(self): - if not trt_convert.is_tensorrt_enabled(): - return - trt_convert.enable_test_value() - - tmp_dir = self.get_temp_dir() - input_saved_model_dir = os.path.join(tmp_dir, "in_dir2") - output_saved_model_dir = os.path.join(tmp_dir, "out_dir2") - self._WriteInputSavedModel(input_saved_model_dir) - output_graph_def = trt_convert.create_inference_graph( - None, - None, - is_dynamic_op=True, - maximum_cached_engines=2, - input_saved_model_dir=input_saved_model_dir, - output_saved_model_dir=output_saved_model_dir, - session_config=self._GetConfigProto()) - - # Test the output GraphDef. - with ops.Graph().as_default(): - importer.import_graph_def(output_graph_def, name="") - with self.test_session(config=self._GetConfigProto()) as sess: - # Run with batch size 1, a new engine is created and cached. - self._TestRun(sess, 1, True) - # Run with batch size 2, a new engine is created and cached. - self._TestRun(sess, 2, True) - # Run with batch size 3, since the number of cached engines has reached - # the max, it should evict an old engine and create a new one. - self._TestRun(sess, 3, True) - - # Test the output SavedModel - with ops.Graph().as_default(): - with self.test_session(config=self._GetConfigProto()) as sess: - loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) - # Run with batch size 1, a new engine is created and cached. - self._TestRun(sess, 1, True) - # Run with batch size 2, a new engine is created and cached. - self._TestRun(sess, 2, True) - # Run with batch size 3, since the number of cached engines has reached - # the max, it should evict an old engine and create a new one. - self._TestRun(sess, 3, True) - - def testCreateInferenceGraph_StaticOp(self): - if not trt_convert.is_tensorrt_enabled(): - return - trt_convert.enable_test_value() - - tmp_dir = self.get_temp_dir() - input_saved_model_dir = os.path.join(tmp_dir, "in_dir3") - output_saved_model_dir = os.path.join(tmp_dir, "out_dir3") - self._WriteInputSavedModel(input_saved_model_dir) - output_graph_def = trt_convert.create_inference_graph( - None, - None, - max_batch_size=1, - is_dynamic_op=False, - maximum_cached_engines=2, # This is noop, added just for testing. - input_saved_model_dir=input_saved_model_dir, - output_saved_model_dir=output_saved_model_dir, - session_config=self._GetConfigProto()) - - # Test the output GraphDef. - with ops.Graph().as_default(): - importer.import_graph_def(output_graph_def, name="") - with self.test_session(config=self._GetConfigProto()) as sess: - # Run with batch size 1, the default engine embedded in the graphdef - # will be used. - self._TestRun(sess, 1, True) - # Run with batch size 2, which exceed the max_batch_size, it should fall - # back to TF function. - self._TestRun(sess, 2, False) - - # Test the output SavedModel - with ops.Graph().as_default(): - with self.test_session(config=self._GetConfigProto()) as sess: - loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) - # Run with batch size 1, the default engine embedded in the graphdef - # will be used. - self._TestRun(sess, 1, True) - # Run with batch size 2, which exceed the max_batch_size, it should fall - # back to TF function. - self._TestRun(sess, 2, False) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc deleted file mode 100644 index 9c3698e5d1cc5d6d8d31a8fcaf03d103f1e1915d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace tensorrt { - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::instance() { - static std::shared_ptr instance_(new TRTResourceManager); - return instance_; -} - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { - // mutex is held for lookup only. Most instantiations where mutex will be held - // longer will be during op creation and should be ok. - tensorflow::mutex_lock lock(map_mutex_); - auto s = managers_.find(op_name); - if (s == managers_.end()) { - auto it = managers_.emplace( - op_name, std::make_shared(op_name)); - VLOG(1) << "Returning a new manager " << op_name; - return it.first->second; - } - VLOG(1) << "Returning old manager " << op_name; - return s->second; -} - -} // namespace tensorrt -} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h deleted file mode 100644 index 19f39e6d3db1571573fb290dd2c30fd43ea604ef..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ -#include - -#include -#include -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTResourceManager { - TRTResourceManager() = default; - - public: - static std::shared_ptr instance(); - // returns a manager for given op, if it doesn't exists it creates one - std::shared_ptr getManager(const string& op_name); - - private: - std::unordered_map> - managers_; - tensorflow::mutex map_mutex_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b..5c60d6b589ed6a16276226726d989e949bcbf9d7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorrt/include/NvInfer.h" diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py deleted file mode 100644 index aad7b9f30728cbb3f4ec5fa730c5dbe46fe9fc3f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/manual_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Basic tests for TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ast -import os - -from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test -from tensorflow.core.framework import graph_pb2 -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test - - -class ManualTest(trt_test.TfTrtIntegrationTestBase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ManualTest, self).__init__(methodName) - self._params_map = None - - def _GetEnv(self): - """Get an environment variable specifying the manual test parameters. - - The value of the environment variable is the string representation of a dict - which should contain the following keys: - - 'graph_path': the file path to the serialized frozen graphdef - - 'input_names': TfTrtIntegrationTestParams.input_names - - 'input_dims': TfTrtIntegrationTestParams.input_dims - - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims - - 'output_name': the name of op to fetch - - 'expected_engines_to_run': ExpectedEnginesToRun() will return this - - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this - - 'max_batch_size': ConversionParams.max_batch_size - - Returns: - The value of the environment variable. - """ - return os.getenv('TRT_MANUAL_TEST_PARAMS', '') - - def _GetParamsMap(self): - """Parse the environment variable as a dict and return it.""" - if self._params_map is None: - self._params_map = ast.literal_eval(self._GetEnv()) - return self._params_map - - def GetParams(self): - """Testing conversion of manually provided frozen graph.""" - params_map = self._GetParamsMap() - gdef = graph_pb2.GraphDef() - with gfile.Open(params_map['graph_path'], 'rb') as f: - gdef.ParseFromString(f.read()) - return trt_test.TfTrtIntegrationTestParams( - gdef=gdef, - input_names=params_map['input_names'], - input_dims=[params_map['input_dims']], - output_names=params_map['output_names'], - expected_output_dims=[params_map['expected_output_dims']]) - - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(ManualTest, self).GetConversionParams(run_params) - params_map = self._GetParamsMap() - if 'max_batch_size' in params_map: - conversion_params = conversion_params._replace( - max_batch_size=params_map['max_batch_size']) - return conversion_params - - def ExpectedEnginesToBuild(self, run_params): - """Return the expected engines to build.""" - return self._GetParamsMap()['expected_engines_to_build'] - - def ExpectedEnginesToRun(self, run_params): - """Return the expected engines to run.""" - params_map = self._GetParamsMap() - if 'expected_engines_to_run' in params_map: - return params_map['expected_engines_to_run'] - return self.ExpectedEnginesToBuild(run_params) - - def ExpectedAbsoluteTolerance(self, run_params): - """The absolute tolerance to compare floating point results.""" - params_map = self._GetParamsMap() - if 'atol' in params_map: - return params_map['atol'] - return 1.e-3 - - def ExpectedRelativeTolerance(self, run_params): - """The relative tolerance to compare floating point results.""" - params_map = self._GetParamsMap() - if 'rtol' in params_map: - return params_map['rtol'] - return 1.e-3 - - def ShouldRunTest(self, run_params): - """Whether to run the test.""" - return len(self._GetEnv()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py deleted file mode 100644 index 090aa8bdb0487973e186631af3b4edac48096a5f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to test TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import numpy as np -import six as _six - -# normally we should do import tensorflow as tf and then -# tf.placeholder, tf.constant, tf.nn.conv2d etc but -# it looks like internal builds don't like it so -# importing every module individually - -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.core.protobuf import rewriter_config_pb2 as rwpb2 -from tensorflow.python.client import session as csess -from tensorflow.python.framework import constant_op as cop -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import math_ops as mops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops - - -def py2bytes(inp): - return inp - - -def py3bytes(inp): - return inp.encode("utf-8", errors="surrogateescape") - - -def py2string(inp): - return inp - - -def py3string(inp): - return inp.decode("utf-8") - - -if _six.PY2: - to_bytes = py2bytes - to_string = py2string -else: - to_bytes = py3bytes - to_string = py3string - - -def get_multi_engine_graph_def(mode="FP32"): - """Create a simple graph and return its graph_def.""" - dtype = dtypes.float32 - if mode.upper() == "FP16": - dtype = dtypes.float16 - else: - pass - - g = ops.Graph() - with g.as_default(): - x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) - with g.name_scope("Global_scope"): - with g.name_scope("first_scope"): - e = cop.constant( - np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) - conv = nn.conv2d( - input=x, - filter=e, - data_format="NCHW", - strides=[1, 1, 1, 1], - padding="VALID", - name="conv") - b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) - t = conv * b - - b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) - q = conv / b - edge = mops.sin(q) - edge1 = mops.cos(conv) - with g.name_scope("test_scope"): - de = edge + edge1 - t -= edge1 - q *= edge - t += q - t -= de - k = aops.squeeze(t, name="output") - print(k.dtype) - return g.as_graph_def() - - -def get_simple_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = aops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - e = cop.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtypes.float32) - conv = nn.conv2d( - input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") - b = cop.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) - t = nn.bias_add(conv, b, name="biasAdd") - relu = nn.relu(t, "relu") - idty = aops.identity(relu, "ID") - v = nn_ops.max_pool( - idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - aops.squeeze(v, name="output") - return g.as_graph_def() - - -def execute_graph(gdef, dumm_inp): - """Run given graphdef once.""" - print("executing") - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session(config=sessconfig, graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - - -# Use real data that is representative of the inference dataset -# for calibration. For this test script it is random data. -def execute_calibration(gdef, dumm_inp): - """Run given calibration graph multiple times.""" - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - # run over real calibration data here, we are mimicking a calibration set of - # 30 different batches. Use as much calibration data as you want - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val - - -def user(multi_engine, - run_graph=execute_graph, - run_calibration=execute_calibration): - """Example function that converts a graph to TFTRT graph.""" - if multi_engine: - inp_dims = (2, 3, 7, 5) - orig_graph = get_multi_engine_graph_def() - else: - inp_dims = (100, 24, 24, 2) - orig_graph = get_simple_graph_def() # use a frozen graph for inference - dummy_input = np.random.random_sample(inp_dims) - # Get optimized graph - trt_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=[]) - o1 = run_graph(orig_graph, dummy_input) - o2 = run_graph(trt_graph, dummy_input) - o3 = run_graph(trt_graph, dummy_input) - assert np.array_equal(o1, o2) - assert np.array_equal(o3, o2) # sanity check - fp16_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=[]) - int8_calib_gdef = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=[]) - o4 = run_graph(fp16_graph, dummy_input) - _ = run_calibration(int8_calib_gdef, dummy_input) - int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) - o5 = run_graph(int8_graph, dummy_input) - print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) - print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) - print("Pass") - - -def auto(multi_engine): - """Run the conversion as an optimization pass.""" - if multi_engine: - inp_dims = (2, 3, 7, 5) - orig_graph = get_multi_engine_graph_def() - else: - inp_dims = (100, 24, 24, 2) - orig_graph = get_simple_graph_def() # use a frozen graph for inference - dummy_input = np.random.random_sample(inp_dims) - opt_config = rwpb2.RewriterConfig() - opt_config.meta_optimizer_iterations = opt_config.ONE - opt_config.optimizers.extend(["constfold", "layout"]) - custom_op = opt_config.custom_optimizers.add() - custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") - custom_op.parameter_map["max_batch_size"].i = inp_dims[0] - custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 - print(custom_op) - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - graph_options = cpb2.GraphOptions(rewrite_options=opt_config) - sessconfig = cpb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) - print(sessconfig) - g = ops.Graph() - ops.reset_default_graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"], name="") - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session(config=sessconfig, graph=g) as sess: - val = sess.run(out, {inp: dummy_input}) - print(val.shape) - - -if "__main__" in __name__: - P = argparse.ArgumentParser( - prog="tftrt_test", - description="Example utilization of TensorFlow-TensorRT integration") - P.add_argument( - "--automatic", - "-a", - action="store_true", - help="Do TRT conversion automatically", - default=False) - P.add_argument( - "--multi-engine", - "-m", - action="store_true", - help="Use a graph that will result in 2 engines", - default=False) - flags, unparsed = P.parse_known_args() - if flags.automatic: - auto(flags.multi_engine) - else: - user(flags.multi_engine) diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc deleted file mode 100644 index 276308b3a0a6ce864969afb0179c6a3f00d6b70b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/utils.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/test/utils.h" - -#include -#include - -#include "re2/re2.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace tensorrt { -namespace test { - -// TODO(aaroey): make this class thread-safe. -class TestValueManager { - public: - static TestValueManager* singleton() { - static TestValueManager* manager = new TestValueManager(); - return manager; - } - - void Enable() { - VLOG(1) << "Enabling test value"; - enabled_ = true; - } - - void Add(const string& label, const string& value) { - if (TF_PREDICT_FALSE(enabled_)) { - QCHECK_NE("", value); - VLOG(1) << "Adding test value: " << label << " -> " << value; - values_.insert({label, value}); - } - } - - string Get(const string& label) { - if (TF_PREDICT_FALSE(enabled_)) { - VLOG(1) << "Getting test value by " << label; - auto itr = values_.find(label); - if (itr == values_.end()) return ""; - return itr->second; - } - return ""; - } - - void Clear(const string& pattern) { - if (TF_PREDICT_FALSE(enabled_)) { - VLOG(1) << "Clearing test values"; - if (pattern.empty()) { - values_.clear(); - return; - } - std::vector keys_to_clear; - for (const auto& kv : values_) { - if (RE2::FullMatch(kv.first, pattern)) { - keys_to_clear.push_back(kv.first); - } - } - for (const string& key : keys_to_clear) { - values_.erase(key); - } - } - } - - private: - TestValueManager() : enabled_(false) {} - - bool enabled_; - std::unordered_map values_; -}; - -void EnableTestValue() { TestValueManager::singleton()->Enable(); } - -void ClearTestValues(const string& pattern) { - TestValueManager::singleton()->Clear(pattern); -} - -void AddTestValue(const string& label, const string& value) { - TestValueManager::singleton()->Add(label, value); -} - -string GetTestValue(const string& label) { - return TestValueManager::singleton()->Get(label); -} - -} // namespace test -} // namespace tensorrt -} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i deleted file mode 100644 index 6ea15fb8eff13663625420288a37ba002d57fa47..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ /dev/null @@ -1,234 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Wrap trt_conversion */ -%{ -#define SWIG_FILE_WITH_INIT -%} -%include "std_pair.i" -%include "tensorflow/python/platform/base.i" - -%{ -PyObject* pair_helper(std::pair* in) { - PyObject *first(nullptr), *second(nullptr), *tuple(nullptr); - first = PyBytes_FromStringAndSize(in->first.data(), in->first.length()); - if (!first) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, "Pair conversion first argument failed"); - } - return NULL; - } - second = PyBytes_FromStringAndSize(in->second.data(), in->second.length()); - if (!second) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, - "Pair conversion second argument failed"); - } - return NULL; - } - tuple = Py_BuildValue("(OO)", first, second); - if (!tuple) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, - "Tuple creation from pair failed!"); - } - return NULL; - } - return tuple; -} - -struct version_struct{ - int vmajor; - int vminor; - int vpatch; -}; - -PyObject* version_helper(version_struct* in) { - PyObject *tuple(nullptr); - tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch); - if (!tuple) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, - "Tuple creation from version structure failed!"); - } - return NULL; - } - return tuple; -} -/* Define converters for vector */ -template<> -bool _PyObjAs(PyObject *pyobj, int* dest) { - *dest = PyLong_AsLong(pyobj); - return true; -} - -template<> -PyObject *_PyObjFrom(const int& src) { - return PyLong_FromLong(src); -} - -%} - -_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); - -%typemap(out) std::pair { - PyObject *tuple = pair_helper(&$1); - if (!tuple) SWIG_fail; - $result = tuple; -} - -%typemap(out) version_struct { - PyObject *tuple = version_helper(&$1); - if (!tuple) SWIG_fail; - $result = tuple; -} - -%{ -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/stat_summarizer.h" -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" -%} - -%ignoreall -%unignore tensorflow; -%unignore calib_convert; -%unignore get_linked_tensorrt_version; -%unignore get_loaded_tensorrt_version; -%unignore is_tensorrt_enabled; -%unignore enable_test_value; -%unignore clear_test_values; -%unignore add_test_value; -%unignore get_test_value; - -%{ - -std::pair calib_convert( - string graph_def_string, bool is_dyn_op - // unfortunately we can't use TF_Status here since it - // is in c/c_api and brings in a lot of other libraries - // which in turn declare ops. These ops are included - // statically in our library and cause an abort when - // module is loaded due to double registration - // until Tensorflow properly exposes these headers - // we have to work around this by returning a string - // and converting it to exception on python side. - //,TF_Status* out_status) { -) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - string out_status; - - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(graph_def_string)) { - out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; - return std::pair{out_status, ""}; - } - graph_def_string.resize(0); - tensorflow::GraphDef out_graph; - tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph( - graph_def, &out_graph, is_dyn_op); - if (!conversion_status.ok()) { - auto retCode = (int)conversion_status.code(); - char buff[2000]; - snprintf(buff, 2000, "%d;%s", retCode, - conversion_status.error_message().c_str()); - out_status = buff; - return std::pair{out_status, ""}; - } - string result; - if (!out_graph.SerializeToString(&result)) { - out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; - return std::pair{out_status, ""}; - } - out_status = "OK;All good!"; - return std::pair{out_status, result}; -#else - // Returns FAILED_PRECONDITION. - return std::pair{"9;TensorRT is not enabled!", ""}; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT -} - -version_struct get_linked_tensorrt_version() { - // Return the version at the link time. - version_struct s; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); - s.vmajor = lv[0]; - s.vminor = lv[1]; - s.vpatch = lv[2]; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT - return s; -} - -version_struct get_loaded_tensorrt_version() { - // Return the version from the loaded library. - version_struct s; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); - s.vmajor = lv[0]; - s.vminor = lv[1]; - s.vpatch = lv[2]; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT - return s; -} - -bool is_tensorrt_enabled() { - return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); -} - -void enable_test_value() { - tensorflow::tensorrt::test::EnableTestValue(); -} - -#if PY_MAJOR_VERSION < 3 -#define TRT_PY_TO_CPP_STRING PyString_AsString -#define TRT_CPP_TO_PY_STRING PyString_FromString -#else -#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8 -#define TRT_CPP_TO_PY_STRING PyUnicode_FromString -#endif - -void clear_test_values(PyObject* pattern) { - tensorflow::tensorrt::test::ClearTestValues( - string(TRT_PY_TO_CPP_STRING(pattern))); -} - -void add_test_value(PyObject* label, PyObject* value) { - tensorflow::tensorrt::test::AddTestValue( - string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value))); -} - -PyObject* get_test_value(PyObject* label) { - string value = tensorflow::tensorrt::test::GetTestValue( - string(TRT_PY_TO_CPP_STRING(label))); - return TRT_CPP_TO_PY_STRING(value.c_str()); -} - -%} - -std::pair calib_convert( - string graph_def_string, bool is_dyn_op); -version_struct get_linked_tensorrt_version(); -version_struct get_loaded_tensorrt_version(); -bool is_tensorrt_enabled(); -void enable_test_value(); -void clear_test_values(PyObject* pattern); -void add_test_value(PyObject* label, PyObject* value); -PyObject* get_test_value(PyObject* label); - -%unignoreall diff --git a/tensorflow/contrib/timeseries/BUILD b/tensorflow/contrib/timeseries/BUILD index f2b8786a527289fe20de86447355fbf552cd265e..18933227b3431fc56b91c6ab7376c975d3aa69a7 100644 --- a/tensorflow/contrib/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/BUILD @@ -23,10 +23,10 @@ py_library( name = "timeseries_pip", deps = [ ":timeseries", - "//tensorflow/contrib/timeseries/examples:known_anomaly", - "//tensorflow/contrib/timeseries/examples:lstm", - "//tensorflow/contrib/timeseries/examples:multivariate", - "//tensorflow/contrib/timeseries/examples:predict", + "//tensorflow/contrib/timeseries/examples:known_anomaly_main_lib", + "//tensorflow/contrib/timeseries/examples:lstm_main_lib", + "//tensorflow/contrib/timeseries/examples:multivariate_main_lib", + "//tensorflow/contrib/timeseries/examples:predict_main_lib", "//tensorflow/contrib/timeseries/python/timeseries:test_utils", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", ], diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index e10be88ece8ebba9635af955b3c3410f29e5503c..70c3a0720eed1971a90e0498d12f876abe4906d5 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -17,6 +17,14 @@ config_setting( py_binary( name = "predict", srcs = ["predict.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [":predict_main_lib"], +) + +py_library( + name = "predict_main_lib", + srcs = ["predict.py"], data = ["data/period_trend.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -40,7 +48,7 @@ py_test( "notsan", # b/67513579 ], deps = [ - ":predict", + ":predict_main_lib", "//tensorflow/python:client_testlib", ], ) @@ -48,6 +56,14 @@ py_test( py_binary( name = "known_anomaly", srcs = ["known_anomaly.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [":known_anomaly_main_lib"], +) + +py_library( + name = "known_anomaly_main_lib", + srcs = ["known_anomaly.py"], data = ["data/changepoints.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -66,7 +82,7 @@ py_test( srcs = ["known_anomaly_test.py"], srcs_version = "PY2AND3", deps = [ - ":known_anomaly", + ":known_anomaly_main_lib", "//tensorflow/python:client_testlib", ], ) @@ -74,6 +90,14 @@ py_test( py_binary( name = "multivariate", srcs = ["multivariate.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [":multivariate_main_lib"], +) + +py_library( + name = "multivariate_main_lib", + srcs = ["multivariate.py"], data = ["data/multivariate_level.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -94,7 +118,7 @@ py_test( ], srcs_version = "PY2AND3", deps = [ - ":multivariate", + ":multivariate_main_lib", "//tensorflow/python:client_testlib", ], ) @@ -102,6 +126,15 @@ py_test( py_binary( name = "lstm", srcs = ["lstm.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + visibility = ["//visibility:public"], + deps = [":lstm_main_lib"], +) + +py_library( + name = "lstm_main_lib", + srcs = ["lstm.py"], data = ["data/multivariate_periods.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -125,7 +158,7 @@ py_test( srcs_version = "PY2AND3", tags = ["notsan"], deps = [ - ":lstm", + ":lstm_main_lib", "//tensorflow/python:client_testlib", "//tensorflow/python/estimator:estimator_py", ], diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 2a22295197dc225cefbedf2736adeea5491a9fc2..449ec8b0a83d0bd6247970302630ad4e0a902a40 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -155,13 +155,16 @@ py_library( py_test( name = "head_test", - size = "large", + size = "medium", srcs = [ "head_test.py", ], - shard_count = 4, + shard_count = 10, srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], # b/63391119 + tags = [ + "no_pip_gpu", # b/63391119 + "notap", # b/124520733 + ], deps = [ ":estimators", ":feature_keys", @@ -169,7 +172,7 @@ py_test( ":input_pipeline", ":model", ":state_management", - "//tensorflow/contrib/timeseries/examples:lstm", + "//tensorflow/contrib/timeseries/examples:lstm_main_lib", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index a8d5e1a49dd4313f58f2f515bc3f292ecce5cbd4..3626701d24163ef52564b42d8a630bd9c5a788eb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -465,7 +465,8 @@ class ARModel(model.TimeSeriesModel): math_utils.normal_log_prob(targets, sigma, prediction)) else: assert self.loss == ARModel.SQUARED_LOSS, self.loss - loss_op = math_ops.reduce_sum(math_ops.square(prediction - targets)) + loss_op = math_ops.reduce_sum( + math_ops.squared_difference(prediction, targets)) loss_op /= math_ops.cast( math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype) return loss_op diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index d898f05405707672ac4d6b1c11bb5931dfe475e1..ee1cd3213efb0fff3a99536bdf1abd93c0c32a6e 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -18,22 +18,18 @@ package( "//learning/brain:__subpackages__", "//learning/deepmind:__subpackages__", "//medical/pathology:__subpackages__", + "//smartass/brain:__subpackages__", "//tensorflow:__subpackages__", "//vr/perception:__subpackages__", ], ) -cc_library( - name = "all_ops", +py_library( + name = "tpu_py", + srcs = ["python/ops/tpu_ops.py"], + srcs_version = "PY2AND3", deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", + "//tensorflow/python/tpu:tpu_py", ], ) @@ -42,25 +38,14 @@ py_library( srcs = ["python/tpu/async_checkpoint.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/tpu:async_checkpoint", ], ) py_library( name = "tpu_estimator", srcs = [ + "python/tpu/_tpu_estimator_embedding.py", "python/tpu/error_handling.py", "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", @@ -70,136 +55,24 @@ py_library( srcs_version = "PY2AND3", deps = [ ":async_checkpoint", + ":feature_column", ":functional", + ":tpu_embedding", ":tpu_lib", - ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:function", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:util", - "@six_archive//:six", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "cross_replica_ops", - "heartbeat_ops", - "host_compute_ops", - "infeed_ops", - "outfeed_ops", - "replication_ops", - "tpu_configuration_ops", - "tpu_embedding_ops", - "tpu_ordinal_selector_op", - "functional_ops", - ], - deps = [ - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ops.so", - srcs = [ - "ops/cross_replica_ops.cc", - "ops/heartbeat_ops.cc", - "ops/host_compute_ops.cc", - "ops/infeed_ops.cc", - "ops/outfeed_ops.cc", - "ops/replication_ops.cc", - "ops/tpu_configuration_ops.cc", - "ops/tpu_embedding_ops.cc", - ], - deps = [ - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib_proto_parsing", + "//tensorflow/python/tpu:tpu_estimator", ], ) -tf_gen_op_wrapper_py( - name = "tpu_ops", - hidden = [ - "SendTPUEmbeddingGradients", - "EnqueueTPUEmbeddingIntegerBatch", - "EnqueueTPUEmbeddingSparseBatch", - "EnqueueTPUEmbeddingSparseTensorBatch", - ], - deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ordinal_selector.so", - srcs = ["ops/tpu_ordinal_selector_op.cc"], -) - -tf_custom_op_py_library( - name = "tpu_ordinal_selector_py", - srcs = ["ops/gen_tpu_ordinal_selector_op.py"], - dso = [":python/ops/_tpu_ordinal_selector.so"], - kernels = [ - ":tpu_ordinal_selector_op_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":tpu_ordinal_selector_op", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ordinal_selector_op", - deps = [ - ":tpu_ordinal_selector_op_op_lib", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_functional_ops", - out = "python/tpu/gen_functional_ops.py", - hidden = [ - "TPUPartitionedCall", - ], - deps = [":functional_ops_op_lib"], -) - py_library( name = "functional", srcs = ["python/tpu/functional.py"], + srcs_version = "PY2AND3", visibility = [ "//visibility:public", ], deps = [ - ":gen_functional_ops", + "//tensorflow/python/tpu:functional", ], ) @@ -208,30 +81,7 @@ py_library( srcs = ["python/profiler/__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc", - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_proto_py", - "//tensorflow/contrib/tpu/profiler:trace_events_proto_py", - "//tensorflow/python:util", - ], -) - -tf_custom_op_py_library( - name = "tpu_py", - srcs = glob(["python/ops/*.py"]), - dso = [":python/ops/_tpu_ops.so"], - kernels = [ - ":all_ops", - ], - srcs_version = "PY2AND3", - deps = [ - ":profiler", - ":tpu_ops", - "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow/python/tpu/profiler", ], ) @@ -248,6 +98,7 @@ py_library( ":tpu_embedding", ":tpu_estimator", ":tpu_lib", + "//tensorflow/python/tpu", ], ) @@ -270,8 +121,8 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/distribute", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/core:protos_all_py", + "//tensorflow/core/protobuf/tpu:compilation_result_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -311,29 +162,12 @@ py_library( srcs_version = "PY2AND3", deps = [ ":datasets", + ":functional", ":profiler", ":tpu_py", - "//tensorflow/compiler/xla/experimental/xla_sharding", - "//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", - "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", - "//tensorflow/contrib/tpu/proto:topology_proto_py", - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", - "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_util", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/losses", + "//tensorflow/python/tpu:tpu_lib", ], ) @@ -344,125 +178,20 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", - ], -) - -tf_py_test( - name = "datasets_test", - size = "medium", - srcs = ["python/tpu/datasets_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":datasets", - ], - grpc_enabled = True, - shard_count = 4, - tags = ["no_oss"], -) - -tf_py_test( - name = "tpu_test", - size = "small", - srcs = ["python/tpu/tpu_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:layers", - ], - tags = ["no_windows"], # TODO: needs investigation on Windows -) - -tf_py_test( - name = "tpu_sharding_test", - size = "small", - srcs = ["python/tpu/tpu_sharding_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - ], -) - -tf_py_test( - name = "bfloat16_test", - size = "small", - srcs = ["python/tpu/bfloat16_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - ], -) - -tf_py_test( - name = "tpu_infeed_test", - size = "small", - srcs = ["python/tpu/tpu_infeed_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "tpu_config_test", - size = "small", - srcs = ["python/tpu/tpu_config_test.py"], - additional_deps = [ - ":tpu_estimator", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "tpu_estimator_signals_test", - size = "small", - srcs = ["python/tpu/tpu_estimator_signals_test.py"], - additional_deps = [ - ":tpu_estimator", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "topology_test", - size = "medium", - srcs = ["python/tpu/topology_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/tpu:datasets", ], ) py_library( name = "tpu_embedding", - srcs = ["python/tpu/tpu_embedding.py"], + srcs = [ + "python/tpu/tpu_embedding.py", + "python/tpu/tpu_embedding_gradient.py", + ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_ops", - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@six_archive//:six", + "//tensorflow/python/tpu:tpu_embedding", ], ) @@ -471,31 +200,6 @@ py_library( srcs = ["python/tpu/feature_column.py"], deps = [ ":tpu_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - ], -) - -tf_py_test( - name = "feature_column_test", - srcs = [ - "python/tpu/feature_column_test.py", - ], - additional_deps = [ - ":feature_column", - "//third_party/py/numpy", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/tpu:feature_column", ], - main = "python/tpu/feature_column_test.py", ) diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 541fbf33a302a4d850422885fdbbc438bd6b9b7b..e2ce77e118182bb07193cbac82e176d3b2057e17 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -2,35 +2,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_cc") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") - -tf_proto_library( - name = "tpu_profiler_proto", - srcs = ["tpu_profiler.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":op_profile_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -cc_library( - name = "dump_tpu_profile", - srcs = ["dump_tpu_profile.cc"], - hdrs = ["dump_tpu_profile.h"], - visibility = ["//visibility:public"], - deps = [ - ":op_profile_proto_cc", - ":tpu_profiler_proto_cc", - ":trace_events_proto_cc", - ":trace_events_to_json", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) cc_library( name = "version", @@ -43,71 +14,13 @@ tf_cc_binary( srcs = [ "capture_tpu_profile.cc", ], + tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ - ":dump_tpu_profile", - ":tpu_profiler_analysis_proto_cc", - ":tpu_profiler_proto_cc", ":version", - "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/profiler/rpc/client:capture_profile", ], ) - -tf_proto_library( - name = "trace_events_proto", - srcs = ["trace_events.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - -cc_library( - name = "trace_events_to_json", - srcs = ["trace_events_to_json.cc"], - hdrs = ["trace_events_to_json.h"], - deps = [ - ":trace_events_proto_cc", - "//tensorflow/core:lib", - "@jsoncpp_git//:jsoncpp", - ], -) - -tf_cc_test( - name = "trace_events_to_json_test", - srcs = ["trace_events_to_json_test.cc"], - deps = [ - ":trace_events_to_json", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@jsoncpp_git//:jsoncpp", - ], -) - -tf_proto_library( - name = "op_profile_proto", - srcs = ["op_profile.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "tpu_profiler_analysis_proto", - srcs = ["tpu_profiler_analysis.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":tpu_profiler_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -py_library( - name = "tpu_profiler_analysis_pb2_grpc", - srcs = ["tpu_profiler_analysis_pb2_grpc.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [":tpu_profiler_analysis_proto_py"], -) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 1c5ea2d997a58ca57ddc212ffd56aad525e961da..32858850cdb27c985ee16946fcc5d2146644ef64 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,235 +18,11 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpcpp/grpcpp.h" - -#include -#include -#include - -#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.grpc.pb.h" #include "tensorflow/contrib/tpu/profiler/version.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/util/command_line_flags.h" -namespace tensorflow { -namespace tpu { -namespace { - -using ::tensorflow::TPUProfileAnalysis; -using ::tensorflow::TPUProfiler; - -constexpr uint64 kMaxEvents = 1000000; - -string GetCurrentTimeStampAsString() { - char s[128]; - std::time_t t = std::time(nullptr); - CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0); - return s; -} - -Status ValidateHostPortPair(const string& host_port) { - uint32 port; - std::vector parts = str_util::Split(host_port, ':'); - // Must be host:port, port must be a number, host must not contain a '/', - // host also must not be empty. - if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || - parts[0].find("/") != string::npos || parts[0].empty()) { - return errors::InvalidArgument("Could not interpret \"", host_port, - "\" as a host-port pair."); - } - return Status::OK(); -} - -ProfileRequest PopulateProfileRequest(int duration_ms, - const string& repository_root, - const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request; - request.set_duration_ms(duration_ms); - request.set_max_events(kMaxEvents); - if (tensorflow::str_util::StartsWith(repository_root, "gs://")) { - // For backward compatibilities, only generate tracetable etc when the - // user provide a GCS path for model directory. - request.set_repository_root(repository_root); - request.set_session_id(session_id); - } - request.add_tools("op_profile"); - request.add_tools("input_pipeline"); - request.add_tools("memory_viewer"); - request.add_tools("overview_page"); - *request.mutable_opts() = opts; - return request; -} - -// Returns whether the returned trace is empty. -// Failure are handled by CHECK, i.e. abort() -bool Profile(const string& service_addr, const string& logdir, int duration_ms, - const string& repository_root, const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - ProfileResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); - - if (!response.encoded_trace().empty()) { - TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( - logdir, session_id, "", response, &std::cout)); - // Print this at the end so that it's not buried in irrelevant LOG messages. - std::cout - << "NOTE: using the trace duration " << duration_ms << "ms." - << std::endl - << "Set an appropriate duration (with --duration_ms) if you " - "don't see a full step in your trace or the captured trace is too " - "large." - << std::endl; - } - - return response.encoded_trace().empty(); -} - -// Start a new profiling session that include all the hosts included in -// hostnames, for the time interval of duration_ms. Possibly save the profiling -// result in the directory specified by repository_root and session_id. -bool NewSession(const string& service_addr, - const std::vector& hostnames, - int duration_ms, const string& repository_root, - const string& session_id, const ProfileOptions& opts) { - NewProfileSessionRequest new_session_request; - *new_session_request.mutable_request() = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - new_session_request.set_repository_root(repository_root); - new_session_request.set_session_id(session_id); - for (const auto& hostname : hostnames) { - new_session_request.add_hosts(hostname); - } - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - // TODO(jiesun): GRPC support following relevant naming scheme: - // 1. dns:///host:port - // 2. ipv4:host:port or ipv6:[host]:port - // We might need to change the prefix which depends on what TPU name resolver - // will give us. - std::unique_ptr stub = - TPUProfileAnalysis::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - NewProfileSessionResponse new_session_response; - TF_QCHECK_OK(FromGrpcStatus( - stub->NewSession(&context, new_session_request, &new_session_response))); - - std::cout << "Profile session succeed for host(s):" - << str_util::Join(hostnames, ",") << std::endl; - return new_session_response.empty_trace(); -} - -// Starts tracing on a single or multiple TPU hosts and saves the result in the -// given logdir. If no trace was collected, retries tracing for -// num_tracing_attempts. -void StartTracing(const tensorflow::string& service_addr, - const tensorflow::string& logdir, - const tensorflow::string& workers_list, - bool include_dataset_ops, int duration_ms, - int num_tracing_attempts) { - // Use the current timestamp as the run name. - tensorflow::string session_id = GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - io::JoinPath(logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(workers_list, ","); - - bool empty_trace = false; - int remaining_attempts = num_tracing_attempts; - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(include_dataset_ops); - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, - repository_root, session_id, opts); - } else { - tensorflow::string tpu_master = service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; - } - - if (empty_trace) { - std::cout << "No trace event is collected after " << num_tracing_attempts - << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." - << std::endl; - } -} - -MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { - MonitorRequest request; - request.set_duration_ms(duration_ms); - request.set_monitoring_level(monitoring_level); - return request; -} - -// Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). -void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, int num_queries) { - for (int query = 0; query < num_queries; ++query) { - MonitorRequest request = - PopulateMonitorRequest(duration_ms, monitoring_level); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - MonitorResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); - - std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1 - << "):\n\n" - << response.data() << std::flush; - } -} - -} // namespace -} // namespace tpu -} // namespace tensorflow - int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; @@ -300,8 +76,9 @@ int main(int argc, char** argv) { std::cout << usage.c_str() << std::endl; return 2; } - tensorflow::Status status = - tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); + tensorflow::Status status; + status = + tensorflow::profiler::client::ValidateHostPortPair(FLAGS_service_addr); if (!status.ok()) { std::cout << status.error_message() << std::endl; std::cout << usage.c_str() << std::endl; @@ -324,12 +101,17 @@ int main(int argc, char** argv) { << FLAGS_service_addr << " for " << duration_ms << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, - FLAGS_monitoring_level, num_queries); + tensorflow::profiler::client::StartMonitoring( + FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries); } else { - tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, - FLAGS_workers_list, FLAGS_include_dataset_ops, - duration_ms, num_tracing_attempts); + status = tensorflow::profiler::client::StartTracing( + FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, + FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts); + if (!status.ok() && status.code() != tensorflow::error::Code::UNAVAILABLE) { + std::cout << status.error_message() << std::endl; + std::cout << usage.c_str() << std::endl; + return 2; + } } return 0; } diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 6a6eba282a12d68cc3cd4e46a46a1b4190fb737b..8605bae5c128513186d8c03835dcf49d3e4b6fd9 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -1,389 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Operations for TPUs.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform - -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops import gen_tpu_ops - from tensorflow.contrib.tpu.ops.gen_tpu_ops import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ops.so")) - - def _create_default_group_assignment(): - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "cross_replica_sum should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - group_assignment = [list(range(num_shards))] - return group_assignment - - def all_to_all(x, - concat_dimension, - split_dimension, - split_count, - group_assignment=None, - name=None): - """Exchange data across TPU replicas. - - Args: - x: The local tensor. - concat_dimension: The dimension number to concatenate. - split_dimension: The dimension number to split. - split_count: The number of splits, this number must equal to the sub-group - size(group_assignment.get_shape()[1]) - group_assignment: Optional 2d int32 lists with shape [num_groups, - num_replicas_per_group]. `group_assignment[i]` represents the replica - ids in the ith subgroup. - name: Optional op name. - - Returns: - A `Tensor` which is concatenated by data from different replicas. - """ - if group_assignment is None: - group_assignment = _create_default_group_assignment() - return gen_tpu_ops.all_to_all( - x, - group_assignment, - concat_dimension=concat_dimension, - split_dimension=split_dimension, - split_count=split_count, - name=name) - - @ops.RegisterGradient("AllToAll") - def _all_to_all_grad(op, grad): - # The gradient of a all-to-all is also a all-to-all but the - # split_dimension and concat_dimension is swapped. - # The graident with respect to group_assignment is None. - return [ - gen_tpu_ops.all_to_all( - grad, - op.inputs[1], - concat_dimension=op.get_attr("split_dimension"), - split_dimension=op.get_attr("concat_dimension"), - split_count=op.get_attr("split_count")), None - ] - - def cross_replica_sum(x, group_assignment=None, name=None): - """Sum the input tensor across replicas according to group_assignment. - - Args: - x: The local tensor to the sum. - group_assignment: Optional 2d int32 lists with shape [num_groups, - num_replicas_per_group]. `group_assignment[i]` represents the replica - ids in the ith subgroup. - name: Optional op name. - - Returns: - A `Tensor` which is summed across replicas. - """ - if group_assignment is None: - group_assignment = _create_default_group_assignment() - - return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) - - def collective_permute(x, source_target_pairs, name=None): - """Permute the input tensor across replicas given source_target_pairs. - - For each source_target_pair , we send replica a's input to replica b. - Each replica id must only appear once in the source column. Also it must - only appear once in the target column. - For the replica id not in the target column, this op returns a zero tensor - with the same shape and dtype of the input x. - - For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing - source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: - `[0, A, B, C]`. - - Args: - x: The local tensor to be permuted. - source_target_pairs: 2d int lists with shape [num_pairs, 2]. - source_target_pairs[i][0] represents the source replica id and - source_target_pairs[i][1] represents the target replica id. - name: Optional op name. - - Returns: - A `Tensor` which is permuted. - """ - return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) - - @ops.RegisterGradient("CollectivePermute") - def _collective_permute_grad(op, grad): - # The gradient of a collective permute operation is also a collective - # permute, but with source/target pairs reversed. The gradient with respect - # to input argument `source_target_pairs` is `None`. - source_target_pairs = op.inputs[1][:, ::-1] - return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] - - @ops.RegisterGradient("CrossReplicaSum") - def _cross_replica_sum_grad(op, grad): - # The gradient of a cross replica sum is also a cross-replica sum. - # The gradient with respect to group_assignment is None. - return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] - - # This extra type checking exists to give a more helpful error message in - # the common case that uint8 and int64 values are infed. Remove when both - # types are supported. - - _SUPPORTED_INFEED_DTYPES = set([ - dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, - dtypes.complex64 - ]) - - def infeed_dequeue(dtype, shape, name=None): - """A placeholder op for a value that will be fed into the computation. - - Args: - dtype: A `tf.DType`. The type of elements in the tensor. - shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. - name: A name for the operation (optional). - - Returns: - A `Tensor` of type `dtype`. - A tensor that will be provided using the infeed mechanism. - - Raises: - TypeError: If 'dtype` is not a supported infeed type. - """ - if dtype not in _SUPPORTED_INFEED_DTYPES: - raise TypeError( - "{} is not a supported TPU infeed type. Supported types are: " - "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) - - return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) - - # pylint: disable=redefined-outer-name - def infeed_dequeue_tuple(dtypes, shapes, name=None): - """A placeholder op for values fed into the TPU simultaneously as a tuple. - - Args: - dtypes: A list of `tf.DType`s that has length `>= 1`. - The element types of each element in `outputs`. - shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). - The shapes of each tensor in `outputs`. - name: A name for the operation (optional). - - Returns: - A list of `Tensor` objects of type `dtypes`. - A list of tensors that will be provided using the infeed mechanism. - - Raises: - TypeError: If a type in 'dtypes` is not a supported infeed type. - """ - for dtype in dtypes: - if dtype not in _SUPPORTED_INFEED_DTYPES: - raise TypeError( - "{} is not a supported TPU infeed type. Supported types are: " - "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) - return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) - # pylint: enable=redefined-outer-name - - # pylint: disable=protected-access - def send_tpu_embedding_gradients(inputs, - config, - learning_rates=None, - name=None): - """A placeholder op for feeding per-sample gradients to the embedding layer. - - Args: - inputs: A TensorList of gradients with which to update embedding tables. - Contains one tensor per embedding table in the model. - config: Serialized TPUEmbeddingConfiguration proto. - learning_rates: A TensorList of float32 scalars, one for each embedding - table, containing the learning rates for each table when dynamic - learning rate is enabled through the OptimizationParameters in - TPUEmbeddingConfiguration. When the learning rate is constant, the list - should be empty (optional). - name: A name for the operation (optional). - - Returns: - A SendTPUEmbeddingGradients operation. - """ - if learning_rates is None: - learning_rates = [] - return gen_tpu_ops._send_tpu_embedding_gradients( - inputs=inputs, learning_rates=learning_rates, config=config, name=name) - - - send_tpu_embedding_gradients.__doc__ = ( - gen_tpu_ops._send_tpu_embedding_gradients.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_integer_batch(batch, - device_ordinal, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - batch: A list of 1D tensors, one for each embedding table, containing the - indices into the tables. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingIntegerBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_integer_batch( - batch=batch, - device_ordinal=device_ordinal, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_integer_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_sparse_batch(sample_indices, - embedding_indices, - aggregation_weights, - device_ordinal, - combiners=None, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - sample_indices: A list of rank 1 Tensors specifying the training example - and feature to which the corresponding embedding_indices and - aggregation_weights values belong. sample_indices[i] must equal b * nf + - f, where nf is the number of features from the corresponding table, f is - in [0, nf), and b is in [0, batch size). - embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. - aggregation_weights: A list of rank 1 Tensors containing per sample -- - i.e. per (training example, feature) -- aggregation weights. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - combiners: A list of string scalars, one for each embedding table that - specify how to normalize the embedding activations after weighted - summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is - invalid to have the sum of the weights be 0 for 'mean' or the sum of the - squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default - is to use 'sum' for all tables (optional). - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingSparseBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch( - sample_indices=sample_indices, - embedding_indices=embedding_indices, - aggregation_weights=aggregation_weights, - device_ordinal=device_ordinal, - combiners=combiners, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_sparse_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, - embedding_indices, - aggregation_weights, - table_ids, - device_ordinal, - combiners=None, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - sample_indices: A list of rank 1 Tensors specifying the training example - to which the corresponding embedding_indices and aggregation_weights - values - belong. It corresponds to sp_ids.indices[:,0] in - embedding_lookup_sparse(). - embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). - aggregation_weights: A list of rank 1 Tensors containing per training - example aggregation weights. It corresponds to sp_weights.values in - embedding_lookup_sparse(). - table_ids: A list of integers specifying the identifier of the embedding - table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to - lookup the corresponding input. The ith input is looked up using - table_ids[i]. The size of the table_ids list must be equal to that of - sample_indices, embedding_indices and aggregation_weights. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - combiners: A list of string scalars, one for each embedding table that - specify how to normalize the embedding activations after weighted - summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is - invalid to have the sum of the weights be 0 for 'mean' or the sum of the - squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default - is to use 'sum' for all tables (optional). - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingSparseTensorBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch( - sample_indices=sample_indices, - embedding_indices=embedding_indices, - aggregation_weights=aggregation_weights, - table_ids=table_ids, - device_ordinal=device_ordinal, - combiners=combiners, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__) - -else: - # We have already built the appropriate libraries into the binary via CMake - # if we have built contrib, so we don't need this - pass +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.ops.tpu_ops import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/training/mode_keys_test.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py similarity index 63% rename from tensorflow/python/training/mode_keys_test.py rename to tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py index c4435b7d4870ac1675a3f2f4d80def111dc85ae5..788e1fe0568cf2f406c379e4d928100ea51a37a3 100644 --- a/tensorflow/python/training/mode_keys_test.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `tf.train.ModeKeys.""" +"""Stub file to maintain backwards compatibility.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.platform import test -from tensorflow.python.training import mode_keys - - -class ModeKeysTest(test.TestCase): - - def testKeyEquality(self): - self.assertEqual(mode_keys.ModeKeys.PREDICT, 'predict') - self.assertEqual(mode_keys.ModeKeys.TRAIN, 'train') - self.assertEqual(mode_keys.ModeKeys.TEST, 'test') +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.ops.tpu_ordinal_selector_op import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/profiler/__init__.py b/tensorflow/contrib/tpu/python/profiler/__init__.py index 15ce6aceec299adacd7025f0021cf8b6f6ef765b..aeb061dbe114bc287946b50d08a86778c78c7b38 100644 --- a/tensorflow/contrib/tpu/python/profiler/__init__.py +++ b/tensorflow/contrib/tpu/python/profiler/__init__.py @@ -1,31 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Classes for TPU trace events.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.tpu.profiler.tpu_profiler_analysis_pb2 import * -from tensorflow.contrib.tpu.profiler.trace_events_pb2 import * +from tensorflow.python.tpu.profiler import * # pylint: enable=wildcard-import,unused-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent'] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/tpu/python/tpu/__init__.py b/tensorflow/contrib/tpu/python/tpu/__init__.py index 0dffd7064b19f353aed6afa3ad383564643a4a90..82d4f68c0221013706f70bcf54ae4c97cc7db1d3 100644 --- a/tensorflow/contrib/tpu/python/tpu/__init__.py +++ b/tensorflow/contrib/tpu/python/tpu/__init__.py @@ -1,20 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Ops related to Tensor Processing Units.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..41aa4d267812cabe775459723df7e01efaa83c93 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu._tpu_estimator_embedding import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index 1b09ce173a64ba3f93ec019c8fd65dc4710f0fcf..5eb8034e47474873ccef0b6123f2becd0668738c 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -1,212 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the 'License'); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Hook for asynchronous checkpointing. - -This hook dispatches checkpoint writing operations in a separate thread to -allow execution to continue on the main thread. -""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import threading -import time - -from tensorflow.core.util.event_pb2 import SessionLog -from tensorflow.python.framework import meta_graph -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import training_util -from tensorflow.python.training.session_run_hook import SessionRunArgs -from tensorflow.python.training.summary_io import SummaryWriterCache - - -class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): - """Saves checkpoints every N steps or seconds.""" - - def __init__(self, - checkpoint_dir, - save_secs=None, - save_steps=None, - saver=None, - checkpoint_basename="model.ckpt", - scaffold=None, - listeners=None): - """Initializes a `CheckpointSaverHook`. - - Args: - checkpoint_dir: `str`, base directory for the checkpoint files. - save_secs: `int`, save every N secs. - save_steps: `int`, save every N steps. - saver: `Saver` object, used for saving. - checkpoint_basename: `str`, base name for the checkpoint files. - scaffold: `Scaffold`, use to get saver object. - listeners: List of `CheckpointSaverListener` subclass instances. Used for - callbacks that run immediately before or after this hook saves the - checkpoint. - - Raises: - ValueError: One of `save_steps` or `save_secs` should be set. - ValueError: At most one of `saver` or `scaffold` should be set. - """ - logging.info("Create AsyncCheckpointSaverHook.") - if saver is not None and scaffold is not None: - raise ValueError("You cannot provide both saver and scaffold.") - self._saver = saver - self._save_thread = None - self._write_graph_thread = None - self._checkpoint_dir = checkpoint_dir - self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) - self._scaffold = scaffold - self._timer = basic_session_run_hooks.SecondOrStepTimer( - every_secs=save_secs, every_steps=save_steps) - self._listeners = listeners or [] - self._steps_per_run = 1 - self._summary_writer = None - self._global_step_tensor = None - - self._last_checkpoint_step = None - - def _set_steps_per_run(self, steps_per_run): - self._steps_per_run = steps_per_run - - def begin(self): - self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) - self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access - if self._global_step_tensor is None: - raise RuntimeError( - "Global step should be created to use CheckpointSaverHook.") - for l in self._listeners: - l.begin() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - - # We do write graph and saver_def at the first call of before_run. - # We cannot do this in begin, since we let other hooks to change graph and - # add variables in begin. Graph is finalized after all begin calls. - def _write_graph_fn(self): - training_util.write_graph( - ops.get_default_graph().as_graph_def(add_shapes=True), - self._checkpoint_dir, "graph.pbtxt") - self._write_graph_thread = threading.Thread(target=_write_graph_fn, - args=[self]) - self._write_graph_thread.start() - - saver_def = self._get_saver().saver_def if self._get_saver() else None - graph = ops.get_default_graph() - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) - self._summary_writer.add_graph(graph) - self._summary_writer.add_meta_graph(meta_graph_def) - # The checkpoint saved here is the state at step "global_step". - self._save(session, global_step) - self._timer.update_last_triggered_step(global_step) - - def before_run(self, run_context): # pylint: disable=unused-argument - return SessionRunArgs(self._global_step_tensor) - - def after_run(self, run_context, run_values): - global_step = run_context.session.run(self._global_step_tensor) - if self._timer.should_trigger_for_step(global_step): - self._timer.update_last_triggered_step(global_step) - logging.info("Triggering checkpoint. %s", global_step) - if self._save(run_context.session, global_step): - run_context.request_stop() - - def end(self, session): - if self._save_thread: - logging.info("Waiting for any pending checkpoints to finish.") - self._save_thread.join() - if self._write_graph_thread: - logging.info("Waiting for any pending write_graph to finish.") - self._write_graph_thread.join() - - last_step = session.run(self._global_step_tensor) - - if self._last_checkpoint_step != last_step: - self._save(session, last_step, asynchronous=False) - - for l in self._listeners: - l.end(session, last_step) - - def _save(self, session, step, asynchronous=True): - """Saves the latest checkpoint, returns should_stop.""" - - # Skip saving on step 0 - if step == 0: - return - - def _save_fn(): - """Run the saver process.""" - logging.info("Saving checkpoints for %d into %s.", step, self._save_path) - - start_time = time.time() - for l in self._listeners: - l.before_save(session, step) - - self._get_saver().save(session, self._save_path, global_step=step) - self._summary_writer.add_session_log( - SessionLog( - status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), - step) - - for l in self._listeners: - l.after_save(session, step) - - end_time = time.time() - logging.info("Checkpoint actual writing time: (%.3f sec)", - end_time - start_time) - logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - - if not asynchronous: - self._last_checkpoint_step = step - _save_fn() - return - - if self._save_thread is not None: - self._save_thread.join(timeout=0.1) - if self._save_thread.is_alive(): - logging.info("Saver thread still in progress, skipping checkpoint.") - return - - self._last_checkpoint_step = step - self._save_thread = threading.Thread(target=_save_fn) - self._save_thread.start() - - def _get_saver(self): - if self._saver is not None: - return self._saver - elif self._scaffold is not None: - return self._scaffold.saver - - # Get saver from the SAVERS collection if present. - collection_key = ops.GraphKeys.SAVERS - savers = ops.get_collection(collection_key) - if not savers: - raise RuntimeError( - "No items in collection {}. Please add a saver to the collection " - "or provide a saver or scaffold.".format(collection_key)) - elif len(savers) > 1: - raise RuntimeError( - "More than one item in collection {}. " - "Please indicate which one to use by passing it to the constructor." - .format(collection_key)) - - self._saver = savers[0] - return savers[0] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.async_checkpoint import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16.py b/tensorflow/contrib/tpu/python/tpu/bfloat16.py index fa74f651aa63c72d14eb78c8af479263810e9b7d..f3d392a8daec2a80f974d90051324a02be002afd 100644 --- a/tensorflow/contrib/tpu/python/tpu/bfloat16.py +++ b/tensorflow/contrib/tpu/python/tpu/bfloat16.py @@ -1,77 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper context for running models with bfloat16.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import tf_contextlib - - -def _get_custom_getter(): - """Returns a custom getter that this class's methods must be called under. - - All methods of this class must be called under a variable scope that was - passed this custom getter. Example: - - ```python - network = ConvNetBuilder(...) - with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): - network.conv(...) - # Call more methods of network here - ``` - - Currently, this custom getter only does anything if self.use_tf_layers is - True. In that case, it causes variables to be stored as dtype - self.variable_type, then casted to the requested dtype, instead of directly - storing the variable as the requested dtype. - """ - - def inner_custom_getter(getter, *args, **kwargs): - """Custom getter that forces variables to have type self.variable_type.""" - cast_to_bfloat16 = False - requested_dtype = kwargs['dtype'] - if requested_dtype == dtypes.bfloat16: - # Only change the variable dtype if doing so does not decrease variable - # precision. - kwargs['dtype'] = dtypes.float32 - cast_to_bfloat16 = True - var = getter(*args, **kwargs) - # This if statement is needed to guard the cast, because batch norm - # assigns directly to the return value of this custom getter. The cast - # makes the return value not a variable so it cannot be assigned. Batch - # norm variables are always in fp32 so this if statement is never - # triggered for them. - if cast_to_bfloat16: - var = math_ops.cast(var, dtypes.bfloat16) - return var - - return inner_custom_getter - - -@tf_contextlib.contextmanager -def bfloat16_scope(): - """Scope class for bfloat16 variables so that the model uses custom getter. - - This enables variables to be read as bfloat16 type when using get_variable. - """ - with variable_scope.variable_scope( - '', custom_getter=_get_custom_getter()) as varscope: - yield varscope +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.bfloat16 import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index bc0cd41d210ac6f8de1b20ebf744ee1e1dd04137..c20aac7e36aa31c5a9d88ca6fe02a8703f9ed5a3 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -1,191 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Library of Cloud TPU helper functions for data loading.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.experimental.ops import interleave_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import functional_ops - - -def _TextLineDataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = readers.TextLineDataset(filename, buffer_size=buffer_size) - return dataset - - -def _TFRecordDataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) - return dataset - - -_FILETYPE_MAP = { - 'tfrecord': _TFRecordDataset, - 'textline': _TextLineDataset, - 'text': _TextLineDataset, -} - - -def StreamingFilesDataset(files, - filetype=None, - file_reader_job=None, - worker_job=None, - num_epochs=None, - filename_shuffle_buffer_size=None, - num_parallel_reads=None, - batch_transfer_size=None, - sloppy=None): - """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM). - - Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read - files local to your GCE VM. In order to train using files stored on your local - VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset - helper to generate a dataset to feed your Cloud TPU with files from your GCE - VM. - - The resulting dataset may return an OutOfRangeError if there are no files - found as a result of the fileglob expansion. - - Note: StreamingFilesDataset assumes that the session is using a - TPUClusterResolver and has therefore a worker and a coordinator job. File - loading will be done on the coordinator job. - - Args: - files: A string glob to match files, or a `tf.data.Dataset` generating file - names. - filetype: A string (one of 'tfrecord', or 'textline') or a single-argument - TensorFlow function that when given a filename returns a dataset. - file_reader_job: An optional string that corresponds to the job that should - perform the file reads. - worker_job: An optional string that corresponds to the job that should - process the tensors (i.e. your GPU or TPU worker). - num_epochs: The number of epochs through the training set that should be - generated. By default, it will repeat infinitely. - filename_shuffle_buffer_size: An optional integer whose value controls the - shuffling of the file names. If you would like to read from the files in - the same order, set to 0 or False. - num_parallel_reads: An optional integer controlling the number of files to - read from concurrently. (Set to 1 for no parallelism.) - batch_transfer_size: An optional integer controlling the batching used to - amortize the remote function invocation overhead. Set to a very large - number to increase throughput. Set to a very small number to reduce memory - consumption. Set to False to skip batching. - sloppy: (Optional.) If `False`, read input data while maintaining a - deterministic order. (This may have significant performance impacts.) - sloppy defaults to: True. - Returns: - A `tf.data.Dataset` with an infinite stream of elements generated by a - parallel interleaving of the set of files matched (or generated) by `files` - with a type is the output of the dataset specified by `filetype`. - - Raises: - ValueError: if any argument is not of the expected type. - """ - if filetype is None: - filetype = 'tfrecord' - - if isinstance(filetype, str): - if filetype not in _FILETYPE_MAP: - raise ValueError('Unexpected filetype: %s' % filetype) - reader_fn = _FILETYPE_MAP[filetype] - elif callable(filetype): - reader_fn = filetype - else: - raise ValueError('filetype should be a string or a callable') - - file_reader_job = file_reader_job or 'coordinator' - - worker_job = worker_job or 'worker' - - if filename_shuffle_buffer_size is None: - filename_shuffle_buffer_size = 4096 - - num_parallel_reads = num_parallel_reads or 8 - - if batch_transfer_size is None: - batch_transfer_size = 256 - - if sloppy is None: - sloppy = True - - with ops.device('/job:%s' % file_reader_job): - if isinstance(files, str): - source_dataset = dataset_ops.Dataset.list_files(files) - elif isinstance(files, dataset_ops.DatasetV2): - source_dataset = files - else: - raise ValueError('files was not a string or a dataset: %s' % files) - - if filename_shuffle_buffer_size: - source_dataset = source_dataset.shuffle( - buffer_size=filename_shuffle_buffer_size) - - source_dataset = source_dataset.apply( - interleave_ops.parallel_interleave( - reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) - - source_dataset = source_dataset.repeat(num_epochs) - - if batch_transfer_size: - source_dataset = source_dataset.batch(batch_transfer_size) - - source_dataset = source_dataset.prefetch(1) - - source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) - source_handle = source_iterator.string_handle() - - @function.Defun(dtypes.string) - def LoadingFunc(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, source_dataset.output_types, source_dataset.output_shapes) - return remote_iterator.get_next() - - def MapFn(unused_input): - if isinstance(source_dataset.output_types, dtypes.DType): - output_types = [source_dataset.output_types] - elif isinstance(source_dataset.output_types, (list, tuple)): - output_types = source_dataset.output_types - else: - raise ValueError('source dataset has invalid output types') - remote_calls = functional_ops.remote_call( - args=[source_handle], - Tout=output_types, - f=LoadingFunc, - target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) - if len(remote_calls) == 1: - return remote_calls[0] - else: - return remote_calls - - with ops.device('/job:%s' % worker_job): - output_dataset = dataset_ops.Dataset.range(2).repeat().map( - MapFn, num_parallel_calls=4 if sloppy else None) - output_dataset = output_dataset.prefetch(1) - - if batch_transfer_size: - # Undo the batching used during the transfer. - output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1) - - return output_dataset +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.datasets import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 3313dc749c2c7606101b2dc96614df2d052dfed1..05dffef3a1efdae2ad7306ca5ad3bc7a9eac04cf 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -1,313 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Library of TPU helper functions.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.python.tpu.topology import Topology - - -SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]] - - -def _compute_task_and_cores_to_replicas(core_assignment, topology): - """Computes a nested dict which maps task and logical core to replicas.""" - task_and_cores_to_replicas = {} - for replica in xrange(core_assignment.shape[0]): - for logical_core in xrange(core_assignment.shape[1]): - coordinates = core_assignment[replica, logical_core, :] - task_id = topology.task_ordinal_at_coordinates(coordinates) - if task_id not in task_and_cores_to_replicas: - task_and_cores_to_replicas[task_id] = {} - if logical_core not in task_and_cores_to_replicas[task_id]: - task_and_cores_to_replicas[task_id][logical_core] = set() - - task_and_cores_to_replicas[task_id][logical_core].add(replica) - - task_to_sorted_replica_id = {} - - for task, core_to_replicas in task_and_cores_to_replicas.items(): - core_to_sorted_replicas = {} - for core, replicas in core_to_replicas.items(): - core_to_sorted_replicas[core] = sorted(replicas) - - task_to_sorted_replica_id[task] = core_to_sorted_replicas - return task_to_sorted_replica_id - - -class DeviceAssignment(object): - """Mapping from logical cores in a computation to the physical TPU topology. - - Prefer to use the `device_assignment()` helper to construct a - `DeviceAssignment`; it is easier if less flexible than constructing a - `DeviceAssignment` directly. - """ - - def __init__(self, topology, core_assignment): - """Constructs a `DeviceAssignment` object. - - Args: - topology: A `Topology` object that describes the physical TPU topology. - core_assignment: A logical to physical core mapping, represented as a - rank 3 numpy array. See the description of the `core_assignment` - property for more details. - - Raises: - ValueError: If `topology` is not `Topology` object. - ValueError: If `core_assignment` is not a rank 3 numpy array. - """ - if not isinstance(topology, Topology): - raise ValueError("topology must be a Topology object, got {}".format( - type(topology))) - core_assignment = np.asarray(core_assignment, dtype=np.int32) - - self._topology = topology - - if core_assignment.ndim != 3: - raise ValueError("core_assignment must be a rank 3 numpy array, " - "got shape {}".format(core_assignment.shape)) - - self._num_replicas = core_assignment.shape[0] - self._num_cores_per_replica = core_assignment.shape[1] - - if core_assignment.shape[-1] != topology.mesh_rank: - raise ValueError( - "minor dimension of core_assignment must have size equal to topology " - "rank ({}), got shape {}".format(topology.mesh_rank, - core_assignment.shape)) - - self._core_assignment = core_assignment - self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas( - self._core_assignment, topology) - - @property - def topology(self): - """A `Topology` that describes the TPU topology.""" - return self._topology - - @property - def num_cores_per_replica(self): - """The number of cores per replica.""" - return self._num_cores_per_replica - - @property - def num_replicas(self): - """The number of replicas of the computation.""" - return self._num_replicas - - @property - def core_assignment(self): - """The logical to physical core mapping. - - Returns: - An integer numpy array of rank 3, with shape - `[num_replicas, num_cores_per_replica, topology_rank]`. Maps - (replica, logical core) pairs to physical topology coordinates. - """ - return self._core_assignment - - def _coordinates(self, replica, logical_core): - """Returns the physical topology coordinates of a logical core.""" - return tuple(self.core_assignment[replica, logical_core, :]) - - def lookup_replicas(self, task_id, logical_core): - """Lookup replica ids by task number and logical core. - - Args: - task_id: TensorFlow task number. - logical_core: An integer, identifying a logical core. - Returns: - A sorted list of the replicas that are attached to that task and - logical_core. - Raises: - ValueError: If no replica exists in the task which contains the logical - core. - """ - try: - return self._task_and_cores_to_replicas[task_id][logical_core] - except KeyError: - raise ValueError( - "Can not find any replica in task: {} contains logical_core: {} ". - format(task_id, logical_core)) - - def tpu_ordinal(self, replica=0, logical_core=0): - """Returns the ordinal of the TPU device assigned to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.tpu_device_ordinal_at_coordinates(coordinates) - - def host_device(self, replica=0, logical_core=0, job=None): - """Returns the CPU device attached to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.cpu_device_name_at_coordinates(coordinates, job=job) - - def tpu_device(self, replica=0, logical_core=0, job=None): - """Returns the name of the TPU device assigned to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.tpu_device_name_at_coordinates(coordinates, job=job) - - -def device_assignment(topology, - computation_shape=None, - computation_stride=None, - num_replicas=1): - """Computes a device_assignment of a computation across a TPU topology. - - Attempts to choose a compact grid of cores for locality. - - Returns a `DeviceAssignment` that describes the cores in the topology assigned - to each core of each replica. - - `computation_shape` and `computation_stride` values should be powers of 2 for - optimal packing. - - Args: - topology: A `Topology` object that describes the TPU cluster topology. - To obtain a TPU topology, evaluate the `Tensor` returned by - `initialize_system` using `Session.run`. Either a serialized - `TopologyProto` or a `Topology` object may be passed. Note: you must - evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here. - computation_shape: A rank 1 int32 numpy array with size equal to the - topology rank, describing the shape of the computation's block of cores. - If None, the `computation_shape` is `[1] * topology_rank`. - computation_stride: A rank 1 int32 numpy array of size `topology_rank`, - describing the inter-core spacing of the `computation_shape` cores in the - TPU topology. If None, the `computation_stride` is `[1] * topology_rank`. - num_replicas: The number of computation replicas to run. The replicas will - be packed into the free spaces of the topology. - - Returns: - A DeviceAssignment object, which describes the mapping between the logical - cores in each computation replica and the physical cores in the TPU - topology. - - Raises: - ValueError: If `topology` is not a valid `Topology` object. - ValueError: If `computation_shape` or `computation_stride` are not 1D int32 - numpy arrays with shape [3] where all values are positive. - ValueError: If computation's replicas cannot fit into the TPU topology. - """ - # Deserialize the Topology proto, if it is a string. - if isinstance(topology, bytes): - topology = Topology(serialized=topology) - - if not isinstance(topology, Topology): - raise ValueError("`topology` is not a Topology object; got {}".format( - type(topology))) - - topology_rank = len(topology.mesh_shape) - mesh_shape = topology.mesh_shape - if computation_shape is None: - computation_shape = np.array([1] * topology_rank, dtype=np.int32) - else: - computation_shape = np.asarray(computation_shape, dtype=np.int32) - - if computation_stride is None: - computation_stride = np.array([1] * topology_rank, dtype=np.int32) - else: - computation_stride = np.asarray(computation_stride, dtype=np.int32) - - if computation_shape.shape != (topology_rank,): - raise ValueError("computation_shape must have shape [{}]; got {}".format( - topology_rank, computation_shape.shape)) - if computation_stride.shape != (topology_rank,): - raise ValueError("computation_stride must have shape [{}]; got {}".format( - topology_rank, computation_stride.shape)) - - if any(computation_shape < 1): - raise ValueError( - "computation_shape must be positive; got computation_shape={}".format( - computation_shape)) - if any(computation_stride < 1): - raise ValueError( - "computation_stride must be positive; got computation_stride={}".format( - computation_stride)) - - # Computes the physical size of one computation instance. - computation_footprint = computation_shape * computation_stride - if any(computation_footprint > mesh_shape): - raise ValueError( - "computation footprint {} does not fit in TPU topology shape {}".format( - computation_footprint, mesh_shape)) - - # Computes how many copies of the computation footprint fit in the mesh. - block_counts = mesh_shape // computation_footprint - - replica_counts = block_counts * computation_stride - max_replicas = np.prod(replica_counts) - if num_replicas > max_replicas: - raise ValueError( - "requested {} replicas but only {} replicas with shape {} and " - "computation_stride {} fit in a TPU mesh of shape {}".format( - num_replicas, max_replicas, computation_shape, computation_stride, - mesh_shape)) - - def ceil_of_ratio(n, m): - return (n + m - 1) // m - - replica_shape = [0] * topology_rank - if num_replicas > 0: - remaining_replicas = num_replicas - remaining_dims = topology_rank - - # Choose dimensions as close to an equal cube as possible, in order of - # increasing dimension size. By visiting dimensions in increasing size, we - # assign the most constrained dimension first, so we won't make infeasible - # choices. - # - # As a secondary sort order, visit the dimensions in reverse order. This - # means we try to use both cores on the same chip in preference to two cores - # on different chips. - for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))): - i = -ni - target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims))) - replica_shape[i] = min(target_size, x) - remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i]) - remaining_dims -= 1 - - assert remaining_replicas == 1 and remaining_dims == 0 - - # Assigns an offset to each replica such that no two replicas overlap. - replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32) - for replica in xrange(num_replicas): - # Chooses a replica number in each axis. - t = replica - pos = [] - for dim in replica_shape[::-1]: - pos.append(t % dim) - t //= dim - replica_pos = np.array(pos[::-1], dtype=np.int32) - - # Determines where that replica starts in each axis. - outer = replica_pos // computation_stride - inner = replica_pos % computation_stride - replica_offsets[replica, :] = outer * computation_footprint + inner - - # Computes a complete logical core -> physical core mapping for each replica. - indices = [ - np.arange(0, computation_shape[i] * computation_stride[i], - computation_stride[i]) for i in xrange(topology_rank) - ] - indices = np.concatenate( - [i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")], - axis=-1) - indices = indices.reshape((-1, topology_rank)) - assignment = indices + replica_offsets[:, np.newaxis, :] - return DeviceAssignment(topology, core_assignment=assignment) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.device_assignment import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py index 52e1ea42370d653d1de7c12eee4b456ec7ce921c..1b1328b4075d9a737e40693c13e33e0b7c1fbedf 100644 --- a/tensorflow/contrib/tpu/python/tpu/error_handling.py +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -1,132 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""ErrorRendezvous handler for collecting errors from multiple threads.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import sys -import threading -import time - -import six - -from tensorflow.python.framework import errors -from tensorflow.python.platform import tf_logging as logging - -_UNINTERESTING_ERRORS = (errors.CancelledError,) - - -class ErrorRendezvous(object): - """Resolve errors from multiple threads during TPU execution. - - TPU errors can occur on the infeed or outfeed threads as well as the main - training thread. - - Depending on which thread "wins" and receives the session error first, we may - end up showing users a confusing and non-actionable error message (session - cancelled) instead of a root cause (e.g. a bad filename). - - The rendezvous object provides a location to capture these errors until all - threads terminate. At that point we can choose the most informative error - to report. - """ - - def __init__(self, num_sources): - # string -> (message, traceback) - self._errors = {} - self._num_sources = num_sources - self._session_cancel_timer = None - - def record_error(self, source, exc_info, session=None): - """Report an exception from the given source. - - If a session is passed, a timer will be registered to close it after a few - seconds. This is necessary to ensure the main training loop does not hang - if an infeed/oufeed error occurs. We sleep a few seconds to allow a more - interesting error from another thread to propagate. - - Args: - source: string, source of the error - exc_info: Output from `sys.exc_info` (type, value, traceback) - session: Session to close after delay. - """ - _, value, _ = exc_info - self._errors[source] = exc_info - logging.info('Error recorded from %s: %s', source, value) - - if session is not None and self._session_cancel_timer is None: - - def _cancel_session(): - time.sleep(5) - try: - session.close() - except: # pylint: disable=bare-except - pass - - self._session_cancel_timer = threading.Thread(target=_cancel_session,) - self._session_cancel_timer.daemon = True - self._session_cancel_timer.start() - - def record_done(self, source): - """Mark execution source `source` as done. - - If an error was originally reported from `source` it is left intact. - - Args: - source: `str`, source being recorded - """ - logging.info('%s marked as finished', source) - if source not in self._errors: - self._errors[source] = None - - @contextlib.contextmanager - def catch_errors(self, source, session=None): - """Context manager to report any errors within a block.""" - try: - yield - except Exception: # pylint: disable=broad-except - self.record_error(source, sys.exc_info(), session) - - def raise_errors(self, timeout_sec=0): - """Wait for up to `timeout` seconds for all error sources to finish. - - Preferentially raise "interesting" errors (errors not in the - _UNINTERESTING_ERRORS) set. - - Args: - timeout_sec: Seconds to wait for other error sources. - """ - for _ in range(timeout_sec): - if len(self._errors) == self._num_sources: - break - time.sleep(1) - - kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] - - # First check for any interesting errors, then fall back on the session - # cancelled errors etc. - for k, (typ, value, traceback) in kept_errors: - if isinstance(value, _UNINTERESTING_ERRORS): - continue - else: - logging.warn('Reraising captured error') - six.reraise(typ, value, traceback) - - for k, (typ, value, traceback) in kept_errors: - logging.warn('Reraising captured error') - six.reraise(typ, value, traceback) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.error_handling import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py index 8edf131bc24fd003806263570b63ee8514c49896..ded75e975b10c4265370af260bf804687c9caebc 100644 --- a/tensorflow/contrib/tpu/python/tpu/feature_column.py +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -1,429 +1,30 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU Feature Column Library.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.feature_column import feature_column as fc -from tensorflow.python.feature_column import feature_column_lib as fc_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import variable_scope -# pylint: disable=protected-access - - -_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' -_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, - fc._VocabularyFileCategoricalColumn, - fc._VocabularyListCategoricalColumn, - fc._WeightedCategoricalColumn, - fc_lib.IdentityCategoricalColumn, - fc_lib.VocabularyFileCategoricalColumn, - fc_lib.VocabularyListCategoricalColumn, - fc_lib.WeightedCategoricalColumn) - - -def embedding_column(categorical_column, - dimension, - combiner='mean', - initializer=None): - """TPU embedding_column for `tf.feature_column.embedding_column`. - - Note that the interface for TPU embedding_column is different from the non-TPU - version. The following args available for the non-TPU version are NOT - supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. - - Args: - categorical_column: A categorical_column returned from - categorical_column_with_identity, weighted_categorical_column, - categorical_column_with_vocabulary_list or - categorical_column_with_vocabulary_file. - dimension: An integer specifying dimension of the embedding, must be > 0. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. For more information, see - `tf.feature_column.embedding_column`. - initializer: A variable initializer function to be used in embedding - variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean `0.0` and standard deviation - `1/sqrt(dimension)`. - - Returns: - A _TPUEmbeddingColumn. - - Raises: - ValueError: if `dimension` not > 0. - ValueError: if `initializer` is specified but not callable. - """ - if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): - raise TypeError( - 'categorical_column for tpu ' - ' embedding_column must be type %s, got %s.' % (' or '.join([ - cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS - ]), type(categorical_column))) - if (dimension is None) or (dimension < 1): - raise ValueError('Invalid dimension {}.'.format(dimension)) - - if (initializer is not None) and (not callable(initializer)): - raise ValueError('initializer must be callable if specified. ' - 'Embedding of column_name: {}'.format( - categorical_column.name)) - if initializer is None: - initializer = init_ops.truncated_normal_initializer( - mean=0.0, stddev=1 / math.sqrt(dimension)) - - embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access - - def _creator(weight_collections, scope): - embedding_column_layer = fc._EmbeddingColumnLayer( - embedding_shape=embedding_shape, - initializer=initializer, - weight_collections=weight_collections, - trainable=True, - name='embedding_column_layer') - return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable - - column = _TPUEmbeddingColumn( - categorical_column=categorical_column, - dimension=dimension, - combiner=combiner, - layer_creator=_creator, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True) - # For Embedding column, the initializer is hidden inside the creator Fn, which - # is not accessiable later. So, we attach it to a speicial field. Also note - # that non-TPU Embedding column and non-TPU shared Embedding column handle the - # initializer differently. See shared_embedding_columns for details. - column._tpu_initializer = initializer - return column - - -def shared_embedding_columns(categorical_columns, - dimension, - combiner='mean', - initializer=None, - shared_embedding_collection_name=None): - """List of dense columns that convert from sparse, categorical input.""" - for categorical_column in categorical_columns: - if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): - raise TypeError( - 'categorical_column for tpu ' - ' shared_embedding_columns must be type %s, got %s.' % (' or '.join([ - cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS - ]), type(categorical_column))) - columns = fc_lib.shared_embedding_columns( - categorical_columns, - dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True) - - # Use the initializer and shared_embedding_collection_name to create TPU - # version - initializer = columns[0].initializer - shared_embedding_collection_name = columns[0].shared_embedding_collection_name - tpu_columns = [] - - # Create the state (_SharedEmbeddingColumnLayer) here. - for categorical_column in categorical_columns: - column = _TPUSharedEmbeddingColumn( - categorical_column=categorical_column, - dimension=dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True) - tpu_columns.append(column) - - return tpu_columns - - -class _TPUBaseEmbeddingColumn(object): - """Base class for TPU Embedding Column.""" - - def __init__(self, categorical_column): - self._tpu_categorical_column = categorical_column - - def get_combiner(self): - """Returns the embedding combiner.""" - raise NotImplementedError('not implemented') - - def get_embedding_table_size(self): - """Returns the embedding table size, tuple of vocab size and dimension.""" - raise NotImplementedError('not implemented') - - def get_feature_key_name(self): - """Returns the feature key name in the features dict.""" - raise NotImplementedError('not impl') - - def get_weight_key_name(self): - """Return the key name for weights.""" - raise NotImplementedError('not impl') - - def get_embedding_var_name(self): - """Returns the embedding variable name. - - Feature key name and embedding variable name are usually one-to-one mapping. - But for shared embedding columns, it is many-to-one mapping. - """ - raise NotImplementedError('not impl') - - def get_initializer(self): - """Returns the initializer.""" - raise NotImplementedError('not impl') - - def is_categorical_column_weighted(self): - """Check if the categorical column of the embedding column is weighted.""" - raise NotImplementedError('not impl') - - -class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): - """Core Embedding Column.""" - - def __new__(cls, - categorical_column, - dimension, - combiner='mean', - layer_creator=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable - # are not supported on TPU. They are solely for matching the signature of - # __new__ of parent class fc._EmbeddingColumn. - return fc._EmbeddingColumn.__new__( - cls, - categorical_column, - dimension, - combiner=combiner, - layer_creator=layer_creator, - ckpt_to_load_from=ckpt_to_load_from, - tensor_name_in_ckpt=tensor_name_in_ckpt, - max_norm=max_norm, - trainable=trainable) - - def __init__(self, - categorical_column, - dimension, - combiner='mean', - layer_creator=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - _TPUBaseEmbeddingColumn.__init__(self, categorical_column) - self._key = None - - def get_combiner(self): - return self.combiner - - def get_embedding_table_size(self): - """Returns num_ids and width.""" - return (self.categorical_column._num_buckets, self.dimension) - - def get_feature_key_name(self): - """get_feature_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.categorical_column.name - return self.categorical_column.name - - def get_weight_key_name(self): - """get_weight_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.weight_feature_key - return None - - def get_embedding_var_name(self): - """get_embedding_var_name.""" - return self.categorical_column.name - - def get_initializer(self): - return self._tpu_initializer - - def is_categorical_column_weighted(self): - """Check if the categorical column of the embedding column is weighted.""" - if isinstance( - self.categorical_column, - ( - fc._WeightedCategoricalColumn, # pylint: disable=protected-access - fc_lib.WeightedCategoricalColumn)): - return True - return False - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - if tpu.under_tpu_inference_context(): - def host_computation(): - return fc._EmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - return tpu.outside_compilation(host_computation) - - if _is_running_on_cpu(): - return fc._EmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - - # TPU mode - # Get the embeddings from the LazyBuilder. - tensor = inputs.get(self.get_feature_key_name()) - - # Add to collection for _create_tpu_embedding_variables_and_ops - _record_variable_scope_and_name(self.get_embedding_var_name(), - 'embedding_weights') - - return tensor - - -class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, - fc._SharedEmbeddingColumn): - """Core Shared Embedding Column.""" - - def __new__(cls, - categorical_column, - dimension, - combiner='mean', - initializer=None, - shared_embedding_collection_name=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - return fc._SharedEmbeddingColumn.__new__( - cls, - categorical_column, - dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=ckpt_to_load_from, - tensor_name_in_ckpt=tensor_name_in_ckpt, - max_norm=max_norm, - trainable=trainable) - - def __init__(self, - categorical_column, - dimension, - combiner='mean', - initializer=None, - shared_embedding_collection_name=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - - _TPUBaseEmbeddingColumn.__init__(self, categorical_column) - self._key = None - - def get_combiner(self): - return self.combiner - - def get_embedding_table_size(self): - """Returns num_ids and width.""" - return (self.categorical_column._num_buckets, self.dimension) - - def get_feature_key_name(self): - """get_feature_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.categorical_column.name - return self.categorical_column.name - - def get_weight_key_name(self): - """get_weight_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.weight_feature_key - return None - - def get_embedding_var_name(self): - """get_embedding_var_name.""" - return self.shared_embedding_collection_name - - def get_initializer(self): - return self.initializer - - def is_categorical_column_weighted(self): - """Check if the categorical column of the embedding column is weighted.""" - if isinstance( - self.categorical_column, - ( - fc._WeightedCategoricalColumn, # pylint: disable=protected-access - fc_lib.WeightedCategoricalColumn)): - return True - return False - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - if tpu.under_tpu_inference_context(): - def host_computation(): - return fc._SharedEmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - return tpu.outside_compilation(host_computation) - - if _is_running_on_cpu(): - return fc._SharedEmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - - # TPU mode - # Get the embeddings from the LazyBuilder. - tensor = inputs.get(self.get_feature_key_name()) - - # Add to collection for _create_tpu_embedding_variables_and_ops - _record_variable_scope_and_name( - self.get_embedding_var_name(), - 'embedding_weights', - is_shared_embedding=True) - return tensor - - -def _record_variable_scope_and_name(embedding_var_name, - embedding_var_name_in_fc, - is_shared_embedding=False): - """Add embedding variable name and scope to collection.""" - g = ops.get_default_graph() - collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) - if not collection: - collection.append({}) - - var_def_dict = collection[0] - - captured_scope = None - - if is_shared_embedding and (embedding_var_name in var_def_dict): - if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: - raise ValueError( - 'For embedding var name {}, the shared embedding name is different, ' - 'got {}; expected {}'.format(embedding_var_name, - embedding_var_name_in_fc, - var_def_dict[embedding_var_name][1])) - else: - # scope contains var_scope_name. - captured_scope = variable_scope.get_variable_scope() - var_def_dict[embedding_var_name] = (captured_scope, - embedding_var_name_in_fc) - - -def _is_running_on_cpu(): - """Returns True if the current context is CPU model.""" - return tpu_function.get_tpu_context().number_of_shards is None +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.feature_column import * +# used by tests +from tensorflow.python.tpu.feature_column import _is_running_on_cpu +from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name +from tensorflow.python.tpu.feature_column import _TPU_FC_TO_SCOPE +from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn +from tensorflow.python.tpu.feature_column import _TPUEmbeddingColumn +from tensorflow.python.tpu.feature_column import _TPUSharedEmbeddingColumn +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py index 1ec9b5b33d007eb2eaa557438f32ea69053261c6..9a5759221ed9660200cc213df69961db56f8d490 100644 --- a/tensorflow/contrib/tpu/python/tpu/functional.py +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -1,25 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= -"""Functional operations.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.tpu.python.tpu import gen_functional_ops - - -TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access - +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.functional import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 37fe9af8c4b154a2e20a957f6ca5d97df3d413be..14a484b2c46ecf4231adbfdfda3b575edb7ef4a1 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -55,8 +55,6 @@ import numpy as np import six from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables from tensorflow.contrib.tpu.python.tpu import tpu @@ -64,6 +62,8 @@ from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result +from tensorflow.python import tf2 from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops @@ -94,6 +94,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(b/114775106): temporary shim to optionally initialize the TPU @@ -200,13 +201,22 @@ class TPUDistributionStrategy(object): removed in future once the model replication functionality is mature enough. If `False` (default behavior), the system automatically finds the best configuration, in terms of number of TPU cores, for the model - replication, typically using all avaiable TPU cores. If overwrites as + replication, typically using all available TPU cores. If overwrites as `True`, force the model replication using single core, i.e., no replication. Raises: Exception: No TPU Found on the given worker. """ - + if tf2.enabled(): + raise RuntimeError( + 'Keras support is now deprecated in support of TPU Strategy. ' + 'Please follow the distribution strategy guide on tensorflow.org ' + 'to migrate to the 2.0 supported version.') + else: + logging.warning( + 'Keras support is now deprecated in support of TPU Strategy. ' + 'Please follow the distribution strategy guide on tensorflow.org ' + 'to migrate to the 2.0 supported version.') if tpu_cluster_resolver is None: tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') @@ -298,10 +308,11 @@ def _cross_replica_concat(tensor, core_id, num_cores, name): '{}.'.format(input_dtype, name)) batch_size = tensor.shape[0] - mask = math_ops.to_float( - math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id)) + mask = math_ops.cast( + math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id), + dtypes.float32) mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims) - result = mask * math_ops.to_float(tensor) + result = mask * math_ops.cast(tensor, dtypes.float32) local_tensor_with_holes = array_ops.reshape(result, [-1] + result.shape.as_list()[2:]) concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes) @@ -725,9 +736,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager): self._dataset = dataset self._tpu_assignment = tpu_assignment - dummy_x_shape = dataset.output_shapes[0].as_list() + dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) + dummy_x_shape = dataset_output_shapes[0].as_list() dummy_x_shape[0] *= tpu_assignment.num_towers - dummy_y_shape = dataset.output_shapes[1].as_list() + dummy_y_shape = dataset_output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers self._iterator = dataset_ops.make_initializable_iterator(dataset) K.get_session().run(self._iterator.initializer) @@ -743,23 +755,26 @@ class TPUDatasetInfeedManager(TPUInfeedManager): # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. + dataset_output_types = dataset_ops.get_legacy_output_types(dataset) self._dummy_x = np.zeros( - dummy_x_shape, dtype=dataset.output_types[0].as_numpy_dtype) + dummy_x_shape, dtype=dataset_output_types[0].as_numpy_dtype) self._dummy_y = np.zeros( - dummy_y_shape, dtype=dataset.output_types[1].as_numpy_dtype) + dummy_y_shape, dtype=dataset_output_types[1].as_numpy_dtype) input_specs = [] - if isinstance(self._iterator.output_shapes, tuple): - assert isinstance(self._iterator.output_types, tuple) - assert len(self._iterator.output_shapes) == len( - self._iterator.output_types) - for i in range(len(self._iterator.output_shapes)): - spec = tensor_spec.TensorSpec(self._iterator.output_shapes[i], - self._iterator.output_types[i]) + iterator_output_shapes = dataset_ops.get_legacy_output_shapes( + self._iterator) + iterator_output_types = dataset_ops.get_legacy_output_types(self._iterator) + if isinstance(iterator_output_shapes, tuple): + assert isinstance(iterator_output_types, tuple) + assert len(iterator_output_shapes) == len(iterator_output_types) + for i in range(len(iterator_output_shapes)): + spec = tensor_spec.TensorSpec(iterator_output_shapes[i], + iterator_output_types[i]) input_specs.append(spec) - elif isinstance(self._iterator.output_shapes, tensor_shape.TensorShape): - spec = tensor_spec.TensorSpec(self._iterator.output_shapes, - self._iterator.output_types) + elif isinstance(iterator_output_shapes, tensor_shape.TensorShape): + spec = tensor_spec.TensorSpec(iterator_output_shapes, + iterator_output_types) input_specs.append(spec) # Pre-process the inputs and get_next_ops before caching. @@ -770,24 +785,26 @@ class TPUDatasetInfeedManager(TPUInfeedManager): def _verify_dataset_shape(self, dataset): """Verifies a dataset is of an appropriate shape for TPUs.""" + dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) + dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset) if not isinstance(dataset, dataset_ops.DatasetV2): raise ValueError('The function passed as the `x` parameter did not ' 'return a `tf.data.Dataset`.') - if not isinstance(dataset.output_classes, tuple): + if not isinstance(dataset_output_classes, tuple): raise ValueError('The dataset must return a tuple of tf.Tensors, ' - 'instead it returns: %s' % dataset.output_classes) - if len(dataset.output_classes) != 2: + 'instead it returns: %s' % dataset_output_classes) + if len(dataset_output_classes) != 2: raise ValueError('The dataset must return a 2-element tuple, got ' - '%s output classes instead.' % (dataset.output_classes,)) - for i, cls in enumerate(dataset.output_classes): + '%s output classes instead.' % (dataset_output_classes,)) + for i, cls in enumerate(dataset_output_classes): if cls != ops.Tensor: raise ValueError('The dataset returned a non-Tensor type (%s) at ' 'index %d.' % (cls, i)) - for i, shape in enumerate(dataset.output_shapes): + for i, shape in enumerate(dataset_output_shapes): if not shape: raise ValueError('The dataset returns a scalar tensor in ' 'tuple index %d. Did you forget to batch? ' - '(Output shapes: %s).' % (i, dataset.output_shapes)) + '(Output shapes: %s).' % (i, dataset_output_shapes)) for j, dim in enumerate(shape): if dim.value is None: if j == 0: @@ -800,7 +817,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): 'currently requires static shapes. The provided ' 'dataset only has a partially defined shape. ' '(Dimension %d of output tensor %d is not statically known ' - 'for output shapes: %s.%s)' % (j, i, dataset.output_shapes, hint)) + 'for output shapes: %s.%s)' % (j, i, dataset_output_shapes, hint)) @property def dummy_x(self): @@ -1028,29 +1045,29 @@ class TPUFunction(object): # the Momentum optimizer) when _make_train_function is invoked. with keras_tpu_variables.replicated_variable_for_optimizer( self._tpu_assignment.num_towers): - self._cloned_model._make_fit_function() + self._cloned_model._make_train_function() else: - self._cloned_model._make_fit_function() + self._cloned_model._make_train_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model._fit_function.outputs + for tensor in self._cloned_model.train_function.outputs ] return [ - self._cloned_model._fit_function.updates_op, + self._cloned_model.train_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - self._cloned_model._fit_function.outputs, + self._cloned_model.train_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - self._cloned_model._make_eval_function() + self._cloned_model._make_test_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model._eval_function.outputs + for tensor in self._cloned_model.test_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - self._cloned_model._eval_function.outputs, + self._cloned_model.test_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: @@ -1367,7 +1384,16 @@ class KerasTPUModel(models.Model): outputs=cpu_model.outputs, name=cpu_model.name, ) - + if tf2.enabled(): + raise RuntimeError( + 'Keras support is now deprecated in support of TPU Strategy. ' + 'Please follow the distribution strategy guide on tensorflow.org ' + 'to migrate to the 2.0 supported version.') + else: + logging.warning( + 'Keras support is now deprecated in support of TPU Strategy. ' + 'Please follow the distribution strategy guide on tensorflow.org ' + 'to migrate to the 2.0 supported version.') # Create a mapping from numpy arrays to infeed managers. # Note: uses a list of tuples instead of a map because numpy arrays are # not hashable. @@ -1380,8 +1406,6 @@ class KerasTPUModel(models.Model): self.predict_function = None self.test_function = None self.train_function = None - self._fit_function = None - self._eval_function = None self._stateful_metric_functions = [] cluster_resolver = strategy._tpu_cluster_resolver @@ -2022,21 +2046,6 @@ class KerasTPUModel(models.Model): self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) return self.test_function - def _make_fit_function(self): - if not self._fit_function: - self._fit_function = TPUFunction( - self, - model_fn_lib.ModeKeys.TRAIN, - tpu_assignment=self._tpu_assignment) - - return self._fit_function - - def _make_eval_function(self): - if not self._eval_function: - self._eval_function = TPUFunction( - self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) - return self._eval_function - def _make_predict_function(self): if not self.predict_function: self.predict_function = TPUFunction( @@ -2172,7 +2181,10 @@ Output shape: %(output_shape)s # pylint: enable=bad-continuation -@experimental +@deprecated( + '2019-02-20', 'Switch to tf.contrib.distribute.TPUStrategy. ' + 'https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy' +) def tpu_model(model, strategy=None): """Copy `model` along with weights to the TPU. diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index f5735cecc38b7033f21fc4d4105cfead233379fa..ed8f9525c9b91208d39805654b01837abdbf3a77 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -1,437 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the 'License'); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Operations for handling session logging and shutdown notifications.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - -import time -from google.protobuf import text_format - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.util import event_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training_util - -_WATCHDOG = None - - -class CoordinatorShutdownException(Exception): - """Raised when the coordinator needs to shutdown.""" - pass - - -def _clone_session(session, graph=None): - return session_lib.Session( - target=session.sess_str, - config=session._config, # pylint: disable=protected-access - graph=graph if graph else session.graph) - - -def _make_heartbeat_op(session, device, request_ph): - """Return a heartbeat op or None if heartbeats are not supported by device.""" - try: - # Test if we can connect in a isolated graph + session - with ops.Graph().as_default(): - with _clone_session(session) as temp_session: - with ops.device(device): - heartbeat_op = tpu_ops.worker_heartbeat('') - options = config_pb2.RunOptions(timeout_in_ms=5000) - temp_session.run(heartbeat_op, options=options) - except errors.InvalidArgumentError as _: - logging.warning('Error running heartbeat on %s', device) - return None - except errors.DeadlineExceededError as _: - logging.warning('Timeout connecting to %s when testing heartbeat', device) - return None - - # If we successfully connected and pinged the worker, go ahead and construct - # the operation. - with ops.device(device): - return tpu_ops.worker_heartbeat(request_ph) - - -class WorkerHeartbeatManager(object): - """Manages the status/heartbeat monitor for a set of workers.""" - - def __init__(self, session, devices, heartbeat_ops, request_placeholder): - """Construct a new WorkerHeartbeatManager. - - (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) - - Args: - session: `tf.Session`, session to use for heartbeat operations. - devices: `list[string]` Set of devices to connect to. - heartbeat_ops: `list[tf.Operation]` Heartbeat operations. - request_placeholder: `tf.Placeholder[String]` Placeholder used to specify - the WorkerHeartbeatRequest protocol buffer. - """ - self._session = session - self._devices = devices - self._ops = heartbeat_ops - self._request_placeholder = request_placeholder - - @staticmethod - def from_devices(session, devices): - """Construct a heartbeat manager for the given devices.""" - if not devices: - logging.error('Trying to create heartbeat manager with no devices?') - - logging.info('Creating heartbeat manager for %s', devices) - request_placeholder = array_ops.placeholder( - name='worker_heartbeat_request', dtype=dtypes.string) - - heartbeat_ops = [] - kept_devices = [] - for device in devices: - heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) - if heartbeat_op is not None: - kept_devices.append(device) - heartbeat_ops.append(heartbeat_op) - else: - logging.warning('Heartbeat support not available for %s', device) - - return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, - request_placeholder) - - def num_workers(self): - return len(self._devices) - - def configure(self, message): - """Configure heartbeat manager for all devices. - - Args: - message: `event_pb2.WorkerHeartbeatRequest` - Returns: `None` - """ - logging.info('Configuring worker heartbeat: %s', - text_format.MessageToString(message)) - self._session.run(self._ops, - {self._request_placeholder: message.SerializeToString()}) - - def ping(self, request=None, timeout_in_ms=5000): - """Ping all workers, returning the parsed status results.""" - if request is None: - request = event_pb2.WorkerHeartbeatRequest() - - options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) - results = self._session.run( - self._ops, - feed_dict={self._request_placeholder: request.SerializeToString()}, - options=options) - parsed_results = [ - event_pb2.WorkerHeartbeatResponse.FromString(res_pb) - for res_pb in results - ] - logging.debug('Ping results: %s', parsed_results) - return parsed_results - - def lame_workers(self): - """Ping all workers, returning manager containing lame workers (or None).""" - ping_results = self.ping() - lame_workers = [] - - for ping_response, device, op in zip(ping_results, self._devices, - self._ops): - if ping_response.health_status != event_pb2.OK: - lame_workers.append((device, op)) - - if not lame_workers: - return None - - bad_devices, bad_ops = zip(*lame_workers) - return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, - self._request_placeholder) - - def __repr__(self): - return 'HeartbeatManager(%s)' % ','.join(self._devices) - - def shutdown(self, timeout_ms=10000): - """Shutdown all workers after `shutdown_timeout_secs`.""" - logging.info('Shutting down %s.', self) - req = event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms)) - self.configure(req) - - # Wait for workers to shutdown. This isn't strictly required - # but it avoids triggering multiple checkpoints with the same lame worker. - logging.info('Waiting %dms for worker shutdown.', timeout_ms) - time.sleep(timeout_ms / 1000) - - -def all_worker_devices(session): - """Return a list of devices for each worker in the system.""" - devices = session.list_devices() - return [ - device.name - for device in devices - if ':CPU:' in device.name and 'coordinator' not in device.name - ] - - -class WatchdogManager(threading.Thread): - """Configures worker watchdog timer and handles periodic pings. - - Usage: - # Ping workers every minute, shutting down workers if they haven't received - # a ping after 1 hour. - watchdog_manager = WatchdogManager( - ping_interval=60, shutdown_timeout=3600 - ) - - # Use as a context manager, resetting watchdog on context exit: - with watchdog_manager: - session.run(...) - - # Or setup globally; watchdog will remain active until program exit. - watchdog_manager.configure_and_run() - """ - - def __init__(self, - session, - devices=None, - ping_interval=60, - shutdown_timeout=3600): - """Initialize a watchdog manager. - - Args: - session: Session connected to worker devices. A cloned session and graph - will be created for managing worker pings. - devices: Set of devices to monitor. If none, all workers will be - monitored. - ping_interval: Time, in seconds, between watchdog pings. - shutdown_timeout: Time, in seconds, before watchdog timeout. - """ - threading.Thread.__init__(self) - self.ping_interval = ping_interval - self.shutdown_timeout = shutdown_timeout - self.daemon = True - self._config = session._config # pylint: disable=protected-access - self._target = session.sess_str - self._running = False - self._devices = devices - - self._graph = None - self._session = None - self._worker_manager = None - - def _reset_manager(self): - """Reset the graph, session and worker manager.""" - self._graph = ops.Graph() - self._session = session_lib.Session( - target=self._target, - graph=self._graph, - config=self._config, - ) - - if self._devices is None: - self._devices = all_worker_devices(self._session) - - with self._graph.as_default(): - self._worker_manager = WorkerHeartbeatManager.from_devices( - self._session, self._devices) - - self._worker_manager.configure( - event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig( - timeout_ms=self.shutdown_timeout * 1000,), - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) - - def configure_and_run(self): - logging.info( - 'Enabling watchdog timer with %d second timeout ' - 'and %d second ping interval.', self.shutdown_timeout, - self.ping_interval) - self._reset_manager() - self._running = True - self.start() - - def stop(self): - logging.info('Stopping worker watchdog.') - self._worker_manager.configure( - event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,), - shutdown_mode=event_pb2.NOT_CONFIGURED)) - self._running = False - self.join() - - def __enter__(self): - self.configure_and_run() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - def run(self): - # Don't fetch logs or adjust timing: just ping the watchdog. - # - # If we hit an exception, reset our session as it is likely broken. - while self._running: - try: - self._worker_manager.ping(request=None) - time.sleep(self.ping_interval) - except errors.OpError as e: - # Catch any TF errors that occur so we don't stop sending heartbeats - logging.debug('Caught error while sending heartbeat: %s', e) - self._reset_manager() - - -def start_worker_watchdog(session, - devices=None, - ping_interval=60, - shutdown_timeout=3600): - """Start global worker watchdog to shutdown workers on coordinator exit.""" - global _WATCHDOG - if _WATCHDOG is None: - # Ensure we can send a few pings before we timeout! - ping_interval = min(shutdown_timeout / 10., ping_interval) - _WATCHDOG = WatchdogManager(session, devices, ping_interval, - shutdown_timeout) - _WATCHDOG.configure_and_run() - - -class GracefulShutdownHook(session_run_hook.SessionRunHook): - """Session hook that watches for shutdown events. - - If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a - SystemShutdown exception is raised to terminate the main session. If `saver` - is None the `SAVERS` collection will be read to find a saver. - - `on_shutdown_hooks` is an optional list of functions that should be called - after checkpointing. The function is called with (`run_context`, - `all_workers`, `lame_workers`). - - If `heartbeat_group` is not specified, it will default to all CPU workers - in the system. - """ - - def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): - self._saver = saver - self._checkpoint_prefix = checkpoint_prefix - self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] - - # Worker heartbeats are managed independently of the main training graph. - self._graph = ops.Graph() - self._workers = None - self._session = None - self._heartbeat_supported = False - - def after_create_session(self, training_session, coord): # pylint: disable=unused-argument - # N.B. We have to pull the global step here to avoid it being unavailable - # at checkpoint time; the graph has been frozen at that point. - if training_util.get_global_step() is None and self.saver() is not None: - raise ValueError( - 'Saver defined but no global step. Run `get_or_create_global_step()`' - ' in your model definition to allow checkpointing.') - - with self._graph.as_default(): - logging.info('Installing graceful shutdown hook.') - self._session = _clone_session(training_session, self._graph) - self._workers = WorkerHeartbeatManager.from_devices( - self._session, all_worker_devices(self._session)) - self._heartbeat_supported = self._workers.num_workers() > 0 - if self._heartbeat_supported: - self._workers.configure( - event_pb2.WorkerHeartbeatRequest( - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) - else: - logging.warn( - 'No workers support hearbeats. Failure handling will be disabled.') - - def saver(self): - if self._saver: - return self._saver - - savers = ops.get_collection(ops.GraphKeys.SAVERS) - if not savers: - return None - - if not isinstance(savers, list): - return savers - - if len(savers) > 1: - logging.error( - 'Multiple savers in the SAVERS collection. On-demand checkpointing ' - 'will be disabled. Pass an explicit `saver` to the constructor to ' - 'override this behavior.') - return None - - return savers[0] - - def after_run(self, run_context, run_values): - del run_values - - if not self._heartbeat_supported: - return - - lame_workers = self._workers.lame_workers() - if lame_workers: - logging.info('ShutdownHook: lame workers found: %s', lame_workers) - - if self.saver(): - logging.info('ShutdownHook: saving checkpoint to %s', - self._checkpoint_prefix) - self.saver().save( - run_context.session, - self._checkpoint_prefix, - global_step=training_util.get_global_step(), - write_state=True, - ) - else: - logging.info('ShutdownHook: no Saver defined.') - - for fn in self._on_shutdown_hooks: - fn(run_context, self._workers, lame_workers) - - -class RestartComputation(object): - """Restart the entire computation. - - This hook shuts down all workers and returns control to the top-level by - throwing a CoordinatorShutdownException. - """ - - def __init__(self, timeout_ms=10000): - self.timeout_ms = timeout_ms - - def __call__(self, run_context, all_workers, lame_workers): - del run_context, lame_workers - all_workers.shutdown(timeout_ms=self.timeout_ms) - - logging.info('Terminating coordinator.') - raise CoordinatorShutdownException() - - -class ShutdownLameWorkers(object): - """Shutdown lamed workers. - - Processing will continue normally (typically by waiting for the down - workers to be restarted). - """ - - def __init__(self, timeout_ms=10000): - self.timeout_in_ms = timeout_ms - - def __call__(self, run_context, all_workers, lame_workers): - lame_workers.shutdown(timeout_ms=self.timeout_in_ms) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.session_support import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index bf492e78a15acc92017663a286e8c8f0b2045339..73db253fd790f26679fb05bd6e7a5da6a99da1a7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -1,1147 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ======================================================================== -"""A utility to trace tensor values on TPU.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import os.path -import re -import sys - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_util -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging - -_TRACER_LOG_PREFIX = ' [>>>TT>>>]' -_DEVICE_TYPE_TPU = 'tpu' -_DEVICE_TYPE_CPU = 'cpu' -_TRACE_MODE_NAN_INF = 'nan-inf' -_TRACE_MODE_PART_TENSOR = 'part-tensor' -_TRACE_MODE_PART_TENSOR_SIZE = 3 -_TRACE_MODE_FULL_TENSOR = 'full-tensor' -_TRACE_MODE_NORM = 'norm' -_TRACE_MODE_MAX_ABS = 'max-abs' -_SUBMODE_BRIEF = 'brief' -_SUBMODE_DETAILED = 'detailed' -_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_REASON_UNSAFE_OP = 'not-traced-unsafe-op' -_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' -_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' -_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' -_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_REASON_SCALAR_GET_TRACED = 'traced-scalar' -_REASON_TENSOR_GET_TRACED = 'traced-tensor' -_REASON_USER_INCLUDED = 'traced-user-included' -_REASON_USER_EXCLUDED = 'not-traced-user-excluded' -_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' -_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' -_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' -_MARKER_SECTION_END = '!!!!!!! section-end:' -_SECTION_NAME_CONFIG = 'configuration' -_SECTION_NAME_REASON = 'reason' -_SECTION_NAME_OP_LIST = 'op-list' -_SECTION_NAME_TENSOR_LIST = 'tensor-list' -_SECTION_NAME_GRAPH = 'graph' -_FIELD_NAME_VERSION = 'version:' -_FIELD_NAME_DEVICE = 'device:' -_FIELD_NAME_TRACE_MODE = 'trace-mode:' -_FIELD_NAME_SUBMODE = 'submode:' -_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' -_FIELD_NAME_NUM_OPS = 'number-of-ops:' -_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' -_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' -_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' -_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") -_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') -_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') -_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') -_FLAG_NAME_ENABLE = 'enable' -_FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_SUBMODE = 'submode' -_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' -_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' -_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' -_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' -_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' -_FLAG_NAME_TRACE_FILE = 'trace_file_path' -_FLAG_NAME_REPORT_FILE = 'report_file_path' -_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' -_FLAG_NAME_OP_RANGE = 'op_range' -_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') -_OUTPUT_STREAM_ESCAPE = 'file://' -_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' -_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' -_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' - - -def tensor_tracepoint(tensor, checkpoint_name): - """Adds a checkpoint with the given checkpoint name for the given tensor. - - The tensor will be added to the list of tensors that will be traced by the - tensor tracer. - - Args: - tensor: the tensor object for which the tracing is requested. - checkpoint_name: a string name for the checkpoint. This name has to be a - unique name if used within model comparison. The tensors that have the same - checkpoint identifier is compared in model comparison. - Returns: - The provided tensor. - """ - - tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) - tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, - (tensor, checkpoint_name)) - return tensor - - -def keras_layer_tracepoint(layer, checkpoint_name): - """An interface for adding the tensor outputs of a keras layer. - - Encapsulates tensor_tracepoint. - - Args: - layer: A keras layer. - checkpoint_name: a string name for the checkpoint. This name has to be a - unique name if used within model comparison. The tensors that have the same - checkpoint identifier is compared in model comparison. - - Returns: - The provided layer. - """ - try: - outputs = layer.output - if tensor_util.is_tensor(outputs): - tensor_tracepoint(outputs, '%s' % (checkpoint_name)) - else: - idx = 0 - for output_tensor in outputs: - if tensor_util.is_tensor(outputs): - tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) - idx += 1 - except AttributeError: - pass - except RuntimeError: - pass - return layer - - -class TensorTracer(object): - """A software construct for tracing tensor values in a TF graph on TPU. - - This utility is disabled by default. It can be enabled by setting - the TENSOR_TRACER_FLAGS env variable as: - export TENSOR_TRACER_FLAGS="--enable=1" - If it is enabled, it will trace the output tensor values of - selected Ops in the graph. It has two outputs: (1) the traces and (2) - a report. The traces are dumped to a specified local file on the TPU - host. The report is printed to the log.info of the TPU job. - By passing options via the env variable, users can change: - (1) the trace mode (e.g., detecting NaN/Inf, printing partial or - full tensor values) - (2) which Ops to be traced (via op.name or op.type) - (3) output trace file path. - """ - - @staticmethod - def _match_next_flag(flags, pos): - """Returns the match for the next TensorTracer flag. - - Args: - flags: a string that contains the flags. - pos: where in flags to start the search. - - Returns: - A pair where the first element is the regular-expression - match found and the second element indicates if the match - has a value. - """ - - match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) - if match: - return match, True - match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) - if match: - return match, True - match = _FLAG_NO_QUOTE_PAT.match(flags, pos) - if match: - return match, True - match = _FLAG_NO_EQUAL_PAT.match(flags, pos) - if match: - # The flag is found but is not given a value. - return match, False - # The flag is not found. - return None, False - - @staticmethod - def validate_flag_names(): - """Validates if the TensorTrace flags passed are valid.""" - valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, - _FLAG_NAME_SUBMODE, - _FLAG_NAME_EXCLUDED_OPNAMES, - _FLAG_NAME_EXCLUDED_OPTYPES, - _FLAG_NAME_INCLUDED_OPNAMES, - _FLAG_NAME_INCLUDED_OPTYPES, - _FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE, - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, - _FLAG_NAME_OP_RANGE] - tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return - pos = 0 - while True: - match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos) - if not match: - break - flag_name = match.group(1) - if flag_name not in valid_flag_names: - raise ValueError( - 'The flag name "%s" passed via the environment variable "%s" ' - 'is invalid. Valid flag names are:' - '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) - pos = match.end() - - @staticmethod - def print_flag_values(): - """Prints all TensorTracer flags passed via environment variables.""" - - tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR - result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, - tensor_tracer_flags) - result += 'Individual flag value:\n' - pos = 0 - while True: - match, has_value = TensorTracer._match_next_flag( - tensor_tracer_flags, pos) - if not match: - break - flag_name = match.group(1) - if has_value: - flag_value = match.group(2) - else: - flag_value = None - result += ' %s: %s\n'%(flag_name, flag_value) - pos = match.end() - result += '\n' - return result - - @staticmethod - def get_flag_value(wanted_flag_name): - """Returns the value of a TensorTracer flags. - - Args: - wanted_flag_name: the name the the flag we are looking for. - - Returns: - A pair where the first element indicates if the flag is - found and the second element is the value of the flag. - - Raises: - RuntimeError: If supposedly deadcode is reached. - """ - - tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return False, None - pos = 0 - while True: - match, has_value = TensorTracer._match_next_flag( - tensor_tracer_flags, pos) - if not match: - return False, None - flag_name = match.group(1) - if has_value: - flag_value = match.group(2) - else: - flag_value = None - if flag_name == wanted_flag_name: - return True, flag_value - pos = match.end() - raise RuntimeError('Should not reach here.') - - @staticmethod - def flag_value_to_re_list(flag_name): - """Converts list of strings to compiled RE.""" - - re_list = [] - found, flag_value = TensorTracer.get_flag_value(flag_name) - if not found or not flag_value: - return re_list - list_of_values = flag_value.split() - for v in list_of_values: - r = re.compile(v) - re_list.append(r) - return re_list - - @staticmethod - def _is_flag_on(flag_name): - """Returns True if the given flag is on.""" - - found, flag_value = TensorTracer.get_flag_value(flag_name) - if not found: - return False - if flag_value is None: - return True - # Depends on the flag value. - flag_value = flag_value.lower() - enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] - return enabled - - @staticmethod - def is_enabled(): - """Returns True if TensorTracer is enabled.""" - - return TensorTracer._is_flag_on(_FLAG_NAME_ENABLE) - - @staticmethod - def use_test_undeclared_outputs_dir(): - """Decides the output directory of the report and trace files. - - Args: - None. - - Returns: - True if the output files should be written to the - test-undeclared-outputs-directory defined via an - env variable. - """ - - return TensorTracer._is_flag_on( - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) - - - @staticmethod - def check_device_type(device_type): - """Checks if the given device type is valid.""" - - if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: - raise ValueError('Invalid device_type "%s"'%device_type) - - @staticmethod - def check_trace_mode(trace_mode): - """Checks if the given trace mode is valid.""" - - valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, - _TRACE_MODE_MAX_ABS] - if trace_mode not in valid_trace_modes: - raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' - 'Valid trace modes are: %s'%(trace_mode, - valid_trace_modes)) - - @staticmethod - def check_submode(submode): - """Checks if the given submode is valid.""" - - if not submode: - return - valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] - if submode not in valid_submodes: - raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' - 'Valid submodes are: %s'%(submode, - valid_submodes)) - - @staticmethod - def unsafe_op(op): - """Returns True if this op is not safe to be traced.""" - - if control_flow_util.IsInCond(op): - return True - # Reasons for not including following op types: - # Assign: cause incorrect result with CPU tracing. - if op.type in ['Assign']: - return True - return False - - @staticmethod - def device_mismatch(device_type, op): - if device_type == _DEVICE_TYPE_TPU: - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr - # pylint: enable=protected-access - return False - - @staticmethod - def unsafe_scalar_trace(op): - """Return true if scalar output tensor from Op is not safe to be traced.""" - - # Tracing the following causes cycle in the graph on TPU. - if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', - 'Switch', 'Less', 'ReadVariableOp']: - return True - # Tracing the following will cause casting-issue - # with the norm tracing mode or other compilation issues on CPU. - if op.type in ['VarHandleOp', 'IteratorToStringHandle', - 'IteratorGetNext', 'OneShotIterator', - 'IteratorV2', 'MakeIterator', - 'BatchDatasetV2', 'MapDataset', - 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', - 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: - return True - return False - - @staticmethod - def less_interesting_op(op): - """Returns True if the given Op is not an interesting one to be traced.""" - - found, _ = TensorTracer.get_flag_value( - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) - if found: - # users force to include all ops. - return False - # Following ops are highly unlikey to cause bugs. - return op.type in ['Const', 'Identity', 'Cast', 'Shape'] - - @staticmethod - def reason(op_idx, details): - """Returns reason why the Op at op_idx is traced or not.""" - - return '%d %s'%(op_idx, details) - - @staticmethod - def topological_sort(g): - """Performs topological sort on the given graph. - - Args: - g: the graph. - - Returns: - A pair where the first element indicates if the topological - sort succeeded (True if there is no cycle found; False if a - cycle is found) and the second element is either the sorted - list of nodes or the cycle of nodes found. - """ - - def visit(op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops): - """Recursively visits all Ops in a graph. - - Args: - op: the current Op being visited. - cycle: a cycle of Ops found. - permanently_marked_ops: the set of Ops that were already visited. - temporarily_marked_ops: the set of Ops that we have visited during - the current descent. - sorted_ops: the list of Ops sorted in topological order. - """ - - if cycle: - return - if op in permanently_marked_ops: - return - if op in temporarily_marked_ops: - cycle = temporarily_marked_ops - return - temporarily_marked_ops.add(op) - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - for consumer_op in out_tensor.consumers(): - visit(consumer_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - # pylint: disable=protected-access - for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access - visit(ctrl_output_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - temporarily_marked_ops.remove(op) - permanently_marked_ops.add(op) - sorted_ops.insert(0, op) - - graph_cycle = set([]) - sorted_ops = [] - permanently_marked_ops = set([]) - temporarily_marked_ops = set([]) - unsorted_ops = g.get_operations() - for op in unsorted_ops: - visit(op, graph_cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - if graph_cycle: - return (False, graph_cycle) - else: - assert len(unsorted_ops) == len(sorted_ops) - return (True, sorted_ops) - - @staticmethod - def _make_op_and_tensor_maps(op_list): - """Creates various maps and lists from op_list. - - Args: - op_list: a list of Ops - - Returns: - opname_idx_map: a map from Op's name to its index in op_list. - tensor_list: a list of output tensors of the Ops in op_list. - tensorname_idx_map: a map from output tensor name to its index - in tensor_list. - """ - - opname_idx_map = {} - tensor_list = [] - tensorname_idx_map = {} - for op_id, op in enumerate(op_list): - if op.name in opname_idx_map: - raise ValueError('Duplicated Op name: %s'%op.name) - opname_idx_map[op.name] = op_id - for output_tensor in op.outputs: - if output_tensor.name not in tensorname_idx_map: - tensor_list.append(output_tensor) - tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 - return (opname_idx_map, tensor_list, tensorname_idx_map) - - def __init__(self): - """Initializes a TensorTracer. - - Sets the various member fields from the flags (if given) or the defaults. - """ - self._version = 'use-outside-compilation' - self._device_type = None - TensorTracer.validate_flag_names() - found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) - if not found or not self._trace_mode: - self._trace_mode = _TRACE_MODE_NAN_INF - TensorTracer.check_trace_mode(self._trace_mode) - found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE) - if not found or not self._submode: - self._submode = _SUBMODE_DETAILED - TensorTracer.check_submode(self._submode) - self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE - self._instrument_records = {} - self._set_trace_file_path() - self._set_report_file() - self._set_op_range() - self._set_excluded_opnames() - self._set_excluded_optypes() - self._set_included_opnames() - self._set_included_optypes() - self._num_replicas = None - self._replica_id = None - - def _add_replica_id_to_graph(self, num_replicas, result_tensor): - """Adds nodes for computing the replica ID to the graph.""" - - if not num_replicas: - self._replica_id = 'unknown' - return result_tensor - - self._num_replicas = num_replicas - - with ops.control_dependencies(None): - # Uses None as dependency to run outside of TPU graph rewrites. - self._replica_id = tpu_ops.tpu_replicated_input( - list(range(self._num_replicas)), - name='tt_replica_id') - use_replica_id = array_ops.identity(self._replica_id).op - with ops.control_dependencies([use_replica_id]): - # Adds a control dependency from the result_tensor to - # the replica_id to ensure that replica_id will be added to the graph. - return array_ops.identity(result_tensor) - - def _set_trace_file_path(self): - """Sets the path of the output trace file.""" - - found, self._trace_file_path = TensorTracer.get_flag_value( - _FLAG_NAME_TRACE_FILE) - if found and self._trace_file_path \ - and TensorTracer.use_test_undeclared_outputs_dir(): - if os.path.isabs(self._trace_file_path): - raise ValueError('If use_test_undeclared_outputs_dir is set,' - 'trace_file_path cannot be an absolute path (%s)' - %self._trace_file_path) - outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) - self._trace_file_path = os.path.join(outputs_dir, - self._trace_file_path) - - def _set_report_file(self): - """Sets the path of the output report file.""" - - found, self._report_file_path = TensorTracer.get_flag_value( - _FLAG_NAME_REPORT_FILE) - if found and self._report_file_path \ - and TensorTracer.use_test_undeclared_outputs_dir(): - if os.path.isabs(self._report_file_path): - raise ValueError('If use_test_undeclared_outputs_dir is set,' - 'report_file_path cannot be an absolute path (%s)' - %self._report_file_path) - outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) - self._report_file_path = os.path.join(outputs_dir, - self._report_file_path) - if not self._report_file_path: - self._report_file = None - return - try: - self._report_file = gfile.Open(self._report_file_path, 'w') - except IOError as e: - raise e - - def _close_report_file(self): - if self._report_file: - self._report_file.close() - - def _set_op_range(self): - """Sets the index range of the Ops that we will consider tracing.""" - - found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) - if not found or not op_range: - self._op_range = (-1, -1) # this means including all ops. - return - match = _OP_RANGE_PAT.match(op_range) - if not match: - self._op_range = (-1, -1) # this means including all ops. - return - self._op_range = (int(match.group(1)), int(match.group(2))) - - def _inside_op_range(self, idx): - """Return True if the given index is inside the selected range.""" - - if idx < self._op_range[0]: - return False - return self._op_range[1] < 0 or idx <= self._op_range[1] - - def _set_excluded_opnames(self): - self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_EXCLUDED_OPNAMES) - - def _set_excluded_optypes(self): - self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_EXCLUDED_OPTYPES) - - def _set_included_opnames(self): - self._included_opname_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_INCLUDED_OPNAMES) - - def _set_included_optypes(self): - self._included_optype_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_INCLUDED_OPTYPES) - - def _is_user_included_op(self, op): - for opname_re in self._included_opname_re_list: - if opname_re.match(op.name): - return True - for optype_re in self._included_optype_re_list: - if optype_re.match(op.type): - return True - return False - - def _is_user_excluded_op(self, op): - for opname_re in self._excluded_opname_re_list: - if opname_re.match(op.name): - return True - for optype_re in self._excluded_optype_re_list: - if optype_re.match(op.type): - return True - return False - - def _write_report(self, content): - """Writes the given content to the report.""" - - line = '%s %s'%(_TRACER_LOG_PREFIX, content) - if self._report_file: - self._report_file.write(line) - else: - logging.info(line) - - def _write_config_section(self): - """Writes the config section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) - self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) - self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) - self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) - self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode)) - self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) - - def _write_reason_section(self): - """Writes the reason section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) - for key in sorted(self._instrument_records): - self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - - def _write_op_list_section(self, op_list): - """Writes the Op-list section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) - self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) - for i in range(0, len(op_list)): - op = op_list[i] - line = '%d "%s" %s'%(i, op.name, op.type) - for out_tensor in op.outputs: - if out_tensor.name not in self._tensorname_idx_map: - raise ValueError( - 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) - line += ' %d'%self._tensorname_idx_map[out_tensor.name] - line += '\n' - self._write_report(line) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) - - def _write_tensor_list_section(self, tensor_list, opname_idx_map): - """Writes the tensor-list section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, - _SECTION_NAME_TENSOR_LIST)) - self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) - for i in range(0, len(tensor_list)): - tensor = tensor_list[i] - line = '%d "%s"'%(i, tensor.name) - for consumer_op in tensor.consumers(): - if consumer_op.name not in opname_idx_map: - raise ValueError( - 'consumer_op %s is not in opname_idx_map'%consumer_op.name) - line += ' %d'%opname_idx_map[consumer_op.name] - line += '\n' - self._write_report(line) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, - _SECTION_NAME_TENSOR_LIST)) - - def _write_graph_section(self, succeed, sorted_or_cycle): - """Writes the graph section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) - self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, - succeed)) - l = list(sorted_or_cycle) - for i in range(0, len(l)): - self._write_report('%d "%s"\n'%(i, l[i].name)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) - - def _preprocess_traced_tensor(self, tensor): - """Computes NAN/Norm/Max on TPUs before sending to CPU. - - Args: - tensor: The tensor to be traced. - Returns: - A tensor that should be input to the trace_function. - Raises: - RuntimeError: If the trace mode is invalid. - """ - - def _detect_nan_inf(tensor): - """Trace function for detecting any NaN/Inf in the tensor.""" - - if tensor.dtype.is_floating: - output_tensor = math_ops.reduce_any( - gen_math_ops.logical_or( - gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) - else: - output_tensor = constant_op.constant(False) - # The shape has to be 1. Set it if it does not have the information. - output_tensor = array_ops.reshape(output_tensor, [1]) - return output_tensor - - def _show_norm(tensor): - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = linalg_ops.norm(tensor) - # The shape has to be 1. Set it if it does not have the information. - output_tensor = array_ops.reshape(output_tensor, [1]) - return output_tensor - - def _show_max_abs(tensor): - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) - zero = constant_op.constant(0, dtypes.float32) - output_tensor = gen_math_ops.maximum(zero, output_tensor) - # The shape has to be 1. Set it if it does not have the information. - output_tensor = array_ops.reshape(output_tensor, [1]) - return output_tensor - - if self._trace_mode == _TRACE_MODE_NAN_INF: - return _detect_nan_inf(tensor) - if self._trace_mode == _TRACE_MODE_PART_TENSOR: - return tensor - if self._trace_mode == _TRACE_MODE_FULL_TENSOR: - return tensor - if self._trace_mode == _TRACE_MODE_NORM: - return _show_norm(tensor) - if self._trace_mode == _TRACE_MODE_MAX_ABS: - return _show_max_abs(tensor) - raise RuntimeError( - 'Tensor trace fun for %s is not yet implemented' % self._trace_mode) - - def _make_tensor_trace_fun(self, tensor_name): - """Makes the tensor tracing function called by outside compilation. - - Args: - tensor_name: name of the tensor being traced. - - Returns: - A function to be passed as the first argument to outside compilation. - - Raises: - RuntimeError: If the trace mode is invalid. - """ - - def _print_tensor(tensor_name, num_elements, tensor, output_tensor): - """Prints a tensor value to a file. - - Args: - tensor_name: name of the tensor being traced. - num_elements: number of elements to print (-1 means print all). - tensor: the tensor needs to be returned. - output_tensor: the tensor needs to be printed. - - Returns: - The same tensor passed via the "tensor" argument. - - Raises: - ValueError: If tensor_name is not already in - self._tensorname_idx_map. - """ - - if self._submode == _SUBMODE_BRIEF: - if tensor_name not in self._tensorname_idx_map: - raise ValueError( - 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) - msg = '%d'%self._tensorname_idx_map[tensor_name] - else: - msg = '"%s"'%tensor_name - - if self._trace_file_path: - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path - else: - output_stream = sys.stderr - print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), - '@', self._replica_id, - '\n', output_tensor, '\n', - summarize=num_elements, - output_stream=output_stream) - with ops.control_dependencies([print_op]): - return array_ops.identity(tensor).op - - - def _show_part_tensor(tensor): - """Trace function for printing part of the tensor.""" - - return _print_tensor(tensor_name, self._part_tensor_size, - tensor, tensor) - - def _show_full_tensor(tensor): - """Trace function for printing the entire tensor.""" - - return _print_tensor(tensor_name, -1, tensor, tensor) - - if self._trace_mode == _TRACE_MODE_PART_TENSOR: - return _show_part_tensor - # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, - # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are - # performed within TPUs and only their results are transferred to CPU. - # Simply, print the full tensor for these trace modes. - if self._trace_mode in [ - _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR, - _TRACE_MODE_MAX_ABS - ]: - return _show_full_tensor - - raise RuntimeError('Tensor trace fun for %s is not yet implemented' - %self._trace_mode) - - def _skip_op(self, op_id, op, user_included, user_excluded, - in_exec_path=True): - """Returns True if we should not trace Op.""" - - if user_included: - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_USER_INCLUDED) - return False - if user_excluded: - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_USER_EXCLUDED) - return True - if not in_exec_path: - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_NOT_EXECUTED) - return True - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_OUTSIDE_OP_RANGE) - return True - if TensorTracer.unsafe_op(op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_UNSAFE_OP) - return True - if TensorTracer.device_mismatch(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_DEVICE_MISMATCH) - return True - if TensorTracer.less_interesting_op(op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_LESS_INTERESTING_OP) - return True - return False - - def _skip_tensor(self, op_id, out_tensor, user_included, - user_excluded): - """Returns True if we should not trace out_tensor.""" - - # Skips a tensor if the tensor has a non-numeric type. - # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) - # because it also excludes tensors with dtypes, bool, and - # float32_ref, which we actually want to trace. - non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, - dtypes.string]) - if out_tensor.dtype in non_numeric_tensor_types: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_NON_NUMERIC_TENSOR) - return True - - if user_included: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_USER_INCLUDED) - return False - if user_excluded: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_USER_EXCLUDED) - return True - if not out_tensor.get_shape().is_fully_defined(): - # If trace mode is nan-inf, norm or max, then the tensor will be reduced - # to a scalar before the outside compilation call. - if self._trace_mode in [ - _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS - ]: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_TENSOR_GET_TRACED) - return False - else: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_DYNAMIC_SHAPE) - return True - rank = len(out_tensor.shape) - if rank < 1: - # scalar - if TensorTracer.unsafe_scalar_trace(out_tensor.op): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_UNSAFE_SCALAR) - return True - else: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_SCALAR_GET_TRACED) - return False - else: - # tensor - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_TENSOR_GET_TRACED) - return False - - def _filter_execution_path_operations(self, operations, fetches): - """Returns the set of ops in the execution path to compute given fetches.""" - # If no fetch provided, then return all operations. - if fetches is None: - return set(operations) - # Convert to list, if a single element is provided. - if not isinstance(fetches, (list, tuple)): - fetches = [fetches] - # If a tensor is given as fetch, convert it to op. - op_fetches = [] - for fetch in fetches: - if isinstance(fetch, ops.Operation): - op_fetches.append(fetch) - elif isinstance(fetch, ops.Tensor): - op_fetches.append(fetch.op) - else: - raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' - %fetch) - - execution_path_operations = set(op_fetches) - traverse_stack = list(op_fetches) - while True: - if not traverse_stack: - break - head_op = traverse_stack.pop() - input_ops = [tensor_input.op for tensor_input in head_op.inputs] - input_ops.extend(head_op.control_inputs) - - for input_op in input_ops: - if input_op not in execution_path_operations: - execution_path_operations.add(input_op) - traverse_stack.append(input_op) - return execution_path_operations - - def _pre_tracing(self, graph): - """Work needs to be done prior to TPU or CPU tracing.""" - - operations = graph.get_operations() - (opname_idx_map, tensor_list, self._tensorname_idx_map) = ( - TensorTracer._make_op_and_tensor_maps(operations)) - self._write_config_section() - self._write_op_list_section(operations) - self._write_tensor_list_section(tensor_list, opname_idx_map) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) - return (operations, succeed, sorted_or_cycle) - - def _post_tracing(self, succeed, sorted_or_cycle): - """Work needs to be done after TPU or CPU tracing.""" - - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) - self._close_report_file() - - def _get_checkpoints(self, graph): - """Returns the list of Ops that produce the tensors traced with API. - - Args: - graph: the graph of Ops. - - Returns: - A set of operation names which should be traced. - """ - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, - _TENSOR_TRACER_CHECKPOINT)) - checkpoint_operations = set() - tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) - for (tensor, checkpoint_name) in tensor_tracer_variables: - self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) - checkpoint_operations.add(tensor.op.name) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, - _TENSOR_TRACER_CHECKPOINT)) - return checkpoint_operations - - def trace_tpu(self, graph, result_tensor, num_replicas=None, fetches=None): - """Traces the tensors generated by TPU Ops in a TF graph. - - Args: - graph: the graph of Ops executed on the TPU. - result_tensor: a result tensor of evaluating the graph. - num_replicas: number of replicas used on the TPU. - fetches: the list of fetches given to session.run, used to determine the - ops in execution path. If None, the whole graph will be traced. - - Returns: - A tuple (result_tensor_copy, tracing_ops), where: - result_tensor_copy: an exact copy of result_tensor - tracing_ops: a list of tracing ops. If this list - is non empty, the caller of this function - should pose control dependencies upon these - Ops so that they will be executed when the - graph is evaluated. - """ - - def _cast_unsupported_dtypes(tensor): - """Casts tensor to a supported type.""" - - if tensor.dtype.__eq__(dtypes.int64): - # outside-compilation doesn't support int64 input yet. - return math_ops.cast(tensor, dtypes.int32) - if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( - dtypes.float16): - # Since host can't handle bf16, convert tensor to f32. - return math_ops.cast(tensor, dtypes.float32) - return tensor - - self._device_type = _DEVICE_TYPE_TPU - TensorTracer.check_device_type(self._device_type) - result_tensor_copy = self._add_replica_id_to_graph(num_replicas, - result_tensor) - (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) - # Filter out the operations that won't be executed. - # if fetches=None, then ops_in_exec_path = set(operations) - ops_in_exec_path = self._filter_execution_path_operations(operations, - fetches) - tracing_ops = [] - checkpoint_operations = self._get_checkpoints(graph) - - for op_id, op in enumerate(operations): - if checkpoint_operations and op.name not in checkpoint_operations: - continue - user_included = self._is_user_included_op(op) - user_excluded = self._is_user_excluded_op(op) - in_exec_path = op in ops_in_exec_path - if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): - continue - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - if self._skip_tensor(op_id, out_tensor, user_included, - user_excluded): - continue - # Create the list of consumers before calling _preprocess_traced_tensor. - # Otherwise, adding control input below, will introduce a cycle in the - # graph. - consumers = out_tensor.consumers() - tensor_name = out_tensor.name - processed_out_tensor = self._preprocess_traced_tensor(out_tensor) - processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) - trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(tensor_name), processed_out_tensor) - if consumers: - for consumer_op in consumers: - # pylint: disable=protected-access - consumer_op._add_control_input(trace_op) - # pylint: enable=protected-access - else: - # if there is no consumer, we will add the control dependence later - # when we add the control dependency to the output operations. - tracing_ops.append(trace_op) - self._post_tracing(succeed, sorted_or_cycle) - return (result_tensor_copy, tracing_ops) - - def trace_cpu(self, graph): - """Traces the tensors generated by CPU Ops in a TF graph. - - Args: - graph: the graph of Ops executed on the CPU. - - Returns: - tracing_calls: a map from keys to trace calls. - A key is constructed from an Op's name. - A trace call consists of a function and a tensor ( - the function will be invoked with the tensor). - """ - - self._device_type = _DEVICE_TYPE_CPU - TensorTracer.check_device_type(self._device_type) - self._num_replicas = 1 - self._replica_id = 0 - (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) - tracing_calls = {} - checkpoint_operations = self._get_checkpoints(graph) - - for op_id, op in enumerate(operations): - if checkpoint_operations and op.name not in checkpoint_operations: - continue - user_included = self._is_user_included_op(op) - user_excluded = self._is_user_excluded_op(op) - if self._skip_op(op_id, op, user_included, user_excluded): - continue - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - if self._skip_tensor(op_id, out_tensor, user_included, - user_excluded): - continue - processed_out_tensor = self._preprocess_traced_tensor(out_tensor) - trace_fun = self._make_tensor_trace_fun(out_tensor.name) - trace_call = (trace_fun, [processed_out_tensor]) - trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) - tracing_calls[trace_call_key] = trace_call - self._post_tracing(succeed, sorted_or_cycle) - return tracing_calls +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tensor_tracer import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index 6ae718cc2c9716587849aeee8abcd0a1de82a9ae..5bf805752cf51b0a0f4b7400b18b63aae93cf831 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -1,220 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Defines the `Topology` class, that describes a TPU fabric topology.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.proto import topology_pb2 - - -def _tpu_device_name(job, task, device): - """Returns the device name for the TPU `device` on `task` of `job`.""" - if job is None: - return "/task:%d/device:TPU:%d" % (task, device) - else: - return "/job:%s/task:%d/device:TPU:%d" % (job, task, device) - - -def _tpu_host_device_name(job, task): - """Returns the device name for the CPU device on `task` of `job`.""" - if job is None: - return "/task:%d/device:CPU:0" % task - else: - return "/job:%s/task:%d/device:CPU:0" % (job, task) - - -class Topology(object): - """Describes a set of TPU devices. - - Represents both the shape of the physical mesh, and the mapping between - TensorFlow TPU devices to physical mesh coordinates. - """ - - def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None): - """Builds a Topology object. - - If `serialized` is not `None`, the topology is parsed from `serialized` and - the other arguments are ignored. Otherwise, the topology is computed from - `mesh_shape` and `device_coordinates`. - - Args: - serialized: A serialized `TopologyProto`, or `None`. If not `None`, the - serialized proto is parsed to discover the topology. - mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`, - the shape of the TPU topology, in number of cores. Ignored if - `serialized` is not `None`. - device_coordinates: A rank 3 numpy array that describes the mapping from - TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored - if `serialized is not `None`. - - Raises: - ValueError: If `serialized` does not describe a well-formed topology. - ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence - of 3 positive integers. - ValueError: If `serialized` is `None` and `device_coordinates` is not a - rank 3 numpy int32 array that describes a valid coordinate mapping. - """ - - self._serialized = serialized - - if serialized: - self._parse_topology(serialized) - else: - self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) - self._device_coordinates = np.asarray(device_coordinates, np.int32) - if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): - raise ValueError("`mesh_shape` must be a sequence of 3 positive " - "entries; got {}".format(self._mesh_shape)) - - if (len(self._device_coordinates.shape) != 3 or - self._device_coordinates.shape[2] != len(self._mesh_shape)): - raise ValueError("`device_coordinates` must be a rank 3 int32 array " - "with minor dimension equal to the mesh shape rank") - - self._topology_tasks, self._topology_devices = self._invert_topology() - - def _parse_topology(self, serialized): - """Parses a serialized `TopologyProto` into `self`.""" - proto = topology_pb2.TopologyProto() - proto.ParseFromString(serialized) - - self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) - if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): - raise ValueError("`mesh_shape` must be a vector of size 3 with positive " - "entries; got {}".format(self._mesh_shape)) - - if proto.num_tasks < 0: - raise ValueError("`num_tasks` must be >= 0; got {}".format( - proto.num_tasks)) - if proto.num_tpu_devices_per_task < 0: - raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( - proto.num_tpu_devices_per_task)) - - expected_coordinates_size = ( - proto.num_tasks * proto.num_tpu_devices_per_task * len( - proto.mesh_shape)) - if len(proto.device_coordinates) != expected_coordinates_size: - raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " - "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " - "got shape {}".format(proto.num_tasks, - proto.num_tpu_devices_per_task, - proto.mesh_shape, - len(proto.device_coordinates))) - - coords = np.array(proto.device_coordinates, dtype=np.int32) - if any(coords < 0): - raise ValueError("`device_coordinates` must be >= 0") - coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, - len(proto.mesh_shape))) - self._device_coordinates = coords - - def _invert_topology(self): - """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" - tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32) - devices = np.full(list(self.mesh_shape), -1, dtype=np.int32) - for task in xrange(self.device_coordinates.shape[0]): - for device in xrange(self.device_coordinates.shape[1]): - x, y, z = self.device_coordinates[task, device, :] - tasks[x, y, z] = task - devices[x, y, z] = device - return tasks, devices - - @property - def mesh_shape(self): - """A rank 1 int32 array describing the shape of the TPU topology.""" - return self._mesh_shape - - @property - def mesh_rank(self): - """Returns the number of dimensions in the mesh.""" - return len(self._mesh_shape) - - @property - def device_coordinates(self): - """Describes the mapping from TPU devices to topology coordinates. - - Returns: - A rank 3 int32 array with shape `[tasks, devices, axis]`. - `tasks` is the number of tasks in the TPU cluster, `devices` is the number - of TPU devices per task, and `axis` is the number of axes in the TPU - cluster topology. Each entry gives the `axis`-th coordinate in the - topology of a task/device pair. TPU topologies are 3-dimensional, with - dimensions `(x, y, core number)`. - """ - return self._device_coordinates - - def task_ordinal_at_coordinates(self, device_coordinates): - """Returns the TensorFlow task number attached to `device_coordinates`. - - Args: - device_coordinates: An integer sequence describing a device's physical - coordinates in the TPU fabric. - - Returns: - Returns the TensorFlow task number that contains the TPU device with those - physical coordinates. - """ - return self._topology_tasks[tuple(device_coordinates)] - - def tpu_device_ordinal_at_coordinates(self, device_coordinates): - """Returns the TensorFlow device number at `device_coordinates`. - - Args: - device_coordinates: An integer sequence describing a device's physical - coordinates in the TPU fabric. - - Returns: - Returns the TensorFlow device number within the task corresponding to - attached to the device with those physical coordinates. - """ - return self._topology_devices[tuple(device_coordinates)] - - def cpu_device_name_at_coordinates(self, device_coordinates, job=None): - """Returns the CPU device attached to a logical core.""" - return _tpu_host_device_name( - job, self._topology_tasks[tuple(device_coordinates)]) - - def tpu_device_name_at_coordinates(self, device_coordinates, job=None): - """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, - self._topology_tasks[tuple(device_coordinates)], - self._topology_devices[tuple(device_coordinates)]) - - @property - def num_tasks(self): - """Returns the number of TensorFlow tasks in the TPU slice.""" - return self._device_coordinates.shape[0] - - @property - def num_tpus_per_task(self): - """Returns the number of TPU devices per task in the TPU slice.""" - return self._device_coordinates.shape[1] - - def serialized(self): - """Returns the serialized form of the topology.""" - if self._serialized is None: - proto = topology_pb2.TopologyProto() - proto.mesh_shape[:] = list(self._mesh_shape) - proto.num_tasks = self._device_coordinates.shape[0] - proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates.extend(list(self._device_coordinates.flatten())) - self._serialized = proto.SerializeToString() - - return self._serialized +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.topology import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index b04baebfe6150c7f10c16e237de884ab788f2642..5364b20f231ac7af8adf943c3d5e21921b7a06a9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -1,1392 +1,25 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== - -"""Library of TPU helper functions.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.compiler import xla -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function - -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.compat import compat as api_compat -from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import compat -from tensorflow.python.util import nest - - -# Operations that indicate some error in the users graph, e.g. a placeholder -# that's introduced outside of the infeed. -_BLACKLISTED_OPS = set([ - "Placeholder", -]) - -# XLA doesn't currently support reading of intermediate tensors, thus some ops -# are not supported. -_UNSUPPORTED_OPS = set([ - "AudioSummary", - "AudioSummaryV2", - "HistogramSummary", - "ImageSummary", - "MergeSummary", - "Print", - "ScalarSummary", - "TensorSummary", - "TensorSummaryV2", - ]) - -_MAX_WARNING_LINES = 5 - -_TPU_REPLICATE_ATTR = "_tpu_replicate" -_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" -_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" - - -def _tpu_system_device_name(job): - """Returns the device name for the TPU_SYSTEM device of `job`.""" - if job is None: - return "/device:TPU_SYSTEM:0" - else: - return "/job:%s/device:TPU_SYSTEM:0" % job - - -def initialize_system(embedding_config=None, job=None): - """Initializes a distributed TPU system for use with TensorFlow. - - Args: - embedding_config: If not None, a `TPUEmbeddingConfiguration` proto - describing the desired configuration of the hardware embedding lookup - tables. If embedding_config is None, no hardware embeddings can be used. - job: The job (the XXX in TensorFlow device specification /job:XXX) that - contains the TPU devices that will be initialized. If job=None it is - assumed there is only one job in the TensorFlow flock, and an error will - be returned if this assumption does not hold. - Returns: - A serialized `TopologyProto` that describes the TPU system. Note: - the topology must be evaluated using `Session.run` before it can be used. - """ - config_string = ("" if embedding_config is None else - embedding_config.SerializeToString()) - with ops.device(_tpu_system_device_name(job)): - return tpu_ops.configure_distributed_tpu(embedding_config=config_string) - - -def shutdown_system(job=None): - """Shuts down a running a distributed TPU system.""" - with ops.device(_tpu_system_device_name(job)): - shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() - return shutdown_distributed_tpu - - -def core(num): - """Returns the device name for a core in a replicated TPU computation. - - Args: - num: the virtual core number within each replica to which operators should - be assigned. - Returns: - A device name, suitable for passing to `tf.device()`. - """ - return "device:TPU_REPLICATED_CORE:{}".format(num) - - -class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside a TPU computation. - - The primary role of `TPUReplicateContext` is to mark operators inside a - tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ - is a unique name. - - We use a `ControlFlowContext` to perform the annotation since it integrates - with Tensorflow constructs like ResourceVariables. For example, if a - `ResourceVariable` is constructed inside a tpu.replicate() block, the - `ResourceVariable` implementation can use - `with ops.control_dependencies(None)` to build the variable's definition - outside the replicated computation. - """ - - def __init__(self, name, num_replicas, pivot): - """Builds a new TPUReplicateContext. - - Args: - name: a unique name for the context, used to populate the `_tpu_replicate` - attribute. - num_replicas: an integer that gives the number of replicas for the - computation. - pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any - inputs will have a control dependency on the pivot node. This ensures - that nodes are correctly included in any enclosing control flow - contexts. - """ - super(TPUReplicateContext, self).__init__() - self._num_replicas = num_replicas - self._outer_device_function_stack = None - self._oc_dev_fn_stack = None - self._outside_compilation_cluster = None - self._outside_compilation_counter = 0 - self._in_gradient_colocation = None - self._gradient_colocation_stack = [] - self._host_compute_core = [] - self._name = name - self._name_as_bytes = compat.as_bytes(name) - self._unsupported_ops = [] - self._pivot = pivot - self._replicated_vars = {} - - def get_replicated_var_handle(self, name, vars_): - """Returns a variable handle for replicated TPU variable 'var'. - - This is a method used by an experimental replicated variable implementation - and is not intended as a public API. - - Args: - name: The common name of the variable. - vars_: The replicated TPU variables. - - Returns: - The handle of the TPU replicated input node. - """ - handle = self._replicated_vars.get(name) - if handle is not None: - return handle - - # Builds a TPUReplicatedInput node for the variable, if one does not already - # exist. The TPUReplicatedInput node must belong to the enclosing - # control-flow scope of the TPUReplicateContext. - # TODO(phawkins): consider changing the contract of the TPU encapsulation - # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope - # instead. - - # pylint: disable=protected-access - graph = ops.get_default_graph() - saved_context = graph._get_control_flow_context() - graph._set_control_flow_context(self.outer_context) - handle = tpu_ops.tpu_replicated_input( - [v.handle for v in vars_], name=name + "/handle") - graph._set_control_flow_context(saved_context) - # pylint: enable=protected-access - self._replicated_vars[name] = handle - return handle - - def report_unsupported_operations(self): - if self._unsupported_ops: - op_str = "\n".join([" %s (%s)" % (op.type, op.name) - for op in self._unsupported_ops[:_MAX_WARNING_LINES]]) - logging.warning("%d unsupported operations found: \n%s", - len(self._unsupported_ops), op_str) - if len(self._unsupported_ops) > _MAX_WARNING_LINES: - logging.warning("... and %d more" % - (len(self._unsupported_ops) - _MAX_WARNING_LINES)) - - def EnterGradientColocation(self, op, gradient_uid): - if op is not None: - self._gradient_colocation_stack.append(op) - if not self._outside_compilation_cluster: - try: - outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR) - if self._in_gradient_colocation: - raise NotImplementedError( - "Cannot nest gradient colocation operations outside compilation" - ) - if gradient_uid == "__unsupported__": - raise NotImplementedError( - "No gradient_uid calling gradient within outside_compilation") - # When we take the gradient of an op X in an outside_compilation - # cluster C in a forward computation we would like to put the ops - # corresponding to the gradient of X into a new outside_compilation - # cluster C'. However, if we take the gradient of X twice, the second - # one should get yet another new outside_compilation cluster C''. - # - # The mechanism we adopt is to use a 'root_cluster' which is the - # cluster that X was in before we took gradients, and a 'gradient_uid' - # which is different for every invocation of gradients, and put the - # gradient of X in cluster 'root_cluster.gradient_uid'. - # - # When taking a gradient of a gradient, some ops will be colocated - # with Op in the forward pass (e.g., cluster root_cluster) and some in - # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). - # We need all of the grad-of-grad ops to be in the same cluster to - # avoid cyclic dependencies between clusters. We adopt a heuristic - # that puts any op clustered with root_cluster. in - # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. - self._in_gradient_colocation = op - parts = outside_attr.split(".") - cluster = parts[0] + "." + gradient_uid - self._EnterOutsideCompilationScope(cluster=cluster) - except ValueError: - # The attr was not present: do nothing. - pass - - def ExitGradientColocation(self, op, gradient_uid): - if op is not None: - if not self._gradient_colocation_stack: - raise errors.InternalError( - op.node_def, op, - "Badly nested gradient colocation: empty stack when popping Op " + - op.name) - last_op = self._gradient_colocation_stack.pop() - if op is last_op: - if op is self._in_gradient_colocation: - self._in_gradient_colocation = None - self._ExitOutsideCompilationScope() - else: - raise errors.InternalError( - op.node_def, op, "Badly nested gradient colocation, expected " + - last_op + ", got " + op.name) - - def _EnterOutsideCompilationScope(self, cluster=None): - - class FakeOp(object): - """A helper class to determine the current device. - - Supports only the type and device set/get methods needed to run the - graph's _apply_device_function method. - """ - - def __init__(self): - self._device = "" - - @property - def type(self): - return "FakeOp" - - @property - def device(self): - return self._device - - def _set_device(self, device): - if isinstance(device, pydev.DeviceSpec): - self._device = device.to_string() - else: - self._device = device - - if self._outside_compilation_cluster: - raise NotImplementedError("Cannot nest outside_compilation clusters") - if cluster: - self._outside_compilation_cluster = cluster - else: - self._outside_compilation_cluster = str(self._outside_compilation_counter) - self._outside_compilation_counter += 1 - graph = ops.get_default_graph() - fake_op = FakeOp() - graph._apply_device_functions(fake_op) # pylint: disable=protected-access - device = pydev.DeviceSpec.from_string(fake_op.device) - if (device.device_type == "TPU_REPLICATED_CORE" and - device.device_index is not None): - self._host_compute_core.append(self._outside_compilation_cluster + ":" + - str(device.device_index)) - self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access - graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access - - def _ExitOutsideCompilationScope(self): - if not self._outside_compilation_cluster: - raise NotImplementedError( - "Attempted to exit outside_compilation scope when not in scope") - self._outside_compilation_cluster = None - graph = ops.get_default_graph() - graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access - - def Enter(self): - if not self._outer_device_function_stack: - # Capture the device function stack at the time of first entry - # since that is the stack that will be used outside_compilation. - graph = ops.get_default_graph() - # pylint: disable=protected-access - self._outer_device_function_stack = graph._device_function_stack.copy() - # pylint: enable=protected-access - super(TPUReplicateContext, self).Enter() - - def HostComputeCore(self): - return self._host_compute_core - - def AddOp(self, op): - # pylint: disable=protected-access - if op.type in _BLACKLISTED_OPS: - logging.error("Operation of type %s (%s) is not supported on the TPU. " - "Execution will fail if this op is used in the graph. " % - (op.type, op.name)) - - if op.type in _UNSUPPORTED_OPS: - self._unsupported_ops.append(op) - - if any(x.dtype._is_ref_dtype for x in op.inputs): - raise NotImplementedError( - "Non-resource Variables are not supported inside TPU computations " - "(operator name: %s)" % op.name) - if _TPU_REPLICATE_ATTR in op.node_def.attr: - raise ValueError("TPU computations cannot be nested") - op._set_attr(_TPU_REPLICATE_ATTR, - attr_value_pb2.AttrValue(s=self._name_as_bytes)) - if self._outside_compilation_cluster: - op._set_attr( - _OUTSIDE_COMPILATION_ATTR, - attr_value_pb2.AttrValue( - s=compat.as_bytes(self._outside_compilation_cluster))) - if self._num_replicas > 1 or not self._outside_compilation_cluster: - # Prevent feeding or fetching anything that is being compiled, - # and any replicated outside_compilation Op. - op.graph.prevent_feeding(op) - op.graph.prevent_fetching(op) - - # Remove any control edges from outer control flow contexts. These may cause - # mismatched frame errors. - (internal_control_inputs, - external_control_inputs) = self._RemoveExternalControlEdges(op) - - if not op.inputs: - # Add a control edge from the control pivot to this op. - if not internal_control_inputs: - # pylint: disable=protected-access - op._add_control_input(self.GetControlPivot()) - # pylint: enable=protected-access - else: - for index in xrange(len(op.inputs)): - x = op.inputs[index] - real_x = self.AddValue(x) - if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - - if external_control_inputs: - # Use an identity to pull control inputs as data inputs. Note that we - # ignore ops which don't have outputs. TODO(phawkins): fix that. - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - # pylint: disable=protected-access - op._add_control_inputs(external_control_inputs) - # pylint: enable=protected-access - - # Mark op's outputs as seen by this context and any outer contexts. - output_names = [x.name for x in op.outputs] - context = self - while context is not None: - # pylint: disable=protected-access - context._values.update(output_names) - context = context._outer_context - # pylint: enable=protected-access - - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - """Add `val` to the current context and its outer context recursively.""" - if val.name in self._values: - # Use the real value if it comes from outer context. - result = self._external_values.get(val.name) - return val if result is None else result - - result = val - self._values.add(val.name) - if self._outer_context: - result = self._outer_context.AddValue(val) - self._values.add(result.name) - - self._external_values[val.name] = result - - return result - - def AddInnerOp(self, op): - self.AddOp(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - @property - def grad_state(self): - # Define the gradient loop state associated with the TPUReplicateContext to - # be None as the TPUReplicateContext does not get nested nor does the - # grad_state outside the TPUReplicateContext affect the graph inside so the - # grad_state should be as if this is the top-level gradient state. - return None - - @property - def back_prop(self): - """Forwards to the enclosing while context, if any.""" - if self.GetWhileContext(): - return self.GetWhileContext().back_prop - return False - - def GetControlPivot(self): - return self._pivot - - -def outside_compilation(computation, *args, **kwargs): - """Builds part of a computation outside any current TPU replicate scope. - - Args: - computation: A Python function that builds the computation to - place on the host. - *args: the positional arguments for the computation. - **kwargs: the keyword arguments for the computation. - - Returns: - The Tensors returned by computation. - """ - args = [] if args is None else args - graph = ops.get_default_graph() - - # If we are in a TPUReplicateContext, signal that we are now - # outside_compilation - initial_context = graph._get_control_flow_context() # pylint: disable=protected-access - context = initial_context - while context: - if isinstance(context, TPUReplicateContext): - context._EnterOutsideCompilationScope() # pylint: disable=protected-access - context = context.outer_context - - retval = computation(*args, **kwargs) - - # If we are in a TPUReplicateContext, signal that we are no longer - # outside_compilation - final_context = graph._get_control_flow_context() # pylint: disable=protected-access - if initial_context is not final_context: - raise NotImplementedError( - "Control-flow context cannot be different at start and end of an " - "outside_compilation scope") - context = initial_context - while context: - if isinstance(context, TPUReplicateContext): - context._ExitOutsideCompilationScope() # pylint: disable=protected-access - context = context.outer_context - - return retval - - -def replicate(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Builds a graph operator that runs a replicated TPU computation. - - Args: - computation: A Python function that builds the computation to replicate. - inputs: A list of lists of input tensors or `None` (equivalent to - `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. Each input can be a nested structure - containing values that are convertible to tensors. Note that passing an - N-dimension list of compatible values will result in a N-dimention list of - scalar tensors rather than a single Rank-N tensors. If you need different - behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to computation. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each replica of the computation uses - only one core, and there is either only one replica, or the number of - replicas is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of outputs, indexed by `[replica_num]` each output can be a nested - structure same as what computation() returns with a few exceptions. - - Exceptions include: - 1) None output: a NoOp would be returned which control-depends on - computation. - 2) Single value output: A tuple containing the value would be returned. - 3) Operation-only outputs: a NoOp would be returned which - control-depends on computation. - TODO(b/121383831): Investigate into removing these special cases. - - Raises: - ValueError: If all replicas do not have equal numbers of input tensors. - ValueError: If the number of inputs per replica does not match - the number of formal parameters to `computation`. - """ - return split_compile_and_replicate(computation, inputs, infeed_queue, - device_assignment, name)[1] - - -def split_compile_and_replicate(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None, - use_tpu=True): - """Builds graph operators that runs compilation and replicated computation. - - This is a lower level interface than replicate that returns a separate compile - and execute output tensor. In the generated graph the compile op feeds into - the execute op and no additional compilation is incurred when running the - compile op before the execute op. The compile op returns additional - information about the compilation but does not return the compiled program. - - Args: - computation: A Python function that builds the computation to replicate. - inputs: A list of lists of input tensors or `None` (equivalent to - `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. Each input can be a nested structure - containing values that are convertible to tensors. Note that passing an - N-dimension list of compatible values will result in a N-dimention list of - scalar tensors rather than a single Rank-N tensors. If you need different - behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to computation. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each replica of the computation uses - only one core, and there is either only one replica, or the number of - replicas is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU - backends. Currently, only supports a default placement (computation is - placed on GPU if one is available, and on CPU if not). - Returns: - A list of lists with the first list corresponding to the compile op and the - second a list of output tensors, indexed by `[replica_num][output_num]`. - Raises: - ValueError: If all replicas do not have equal numbers of input tensors. - ValueError: If the number of inputs per replica does not match - the number of formal parameters to `computation`. - """ - del name - inputs = [[]] if inputs is None else inputs - - metadata_kwargs = {} - if device_assignment is not None: - # Turn the Numpy array into a flattened list so we can pass it as an - # operator attribute. - metadata_kwargs = { - "topology": - device_assignment.topology.serialized(), - "device_assignment": - device_assignment.core_assignment.flatten().tolist() - } - # TODO(phawkins): remove this case after the forward compatibility window - # expires on 2018-10-5. - if api_compat.forward_compatible(2018, 10, 5): - metadata_kwargs["num_cores_per_replica"] = ( - device_assignment.num_cores_per_replica) - else: - metadata_kwargs["computation_shape"] = [ - device_assignment.num_cores_per_replica - ] - - if ((not isinstance(inputs, list)) or - any(not isinstance(inp, (list, tuple)) for inp in inputs)): - raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") - - num_replicas = len(inputs) - - # No replicas? Nothing to do. - if num_replicas == 0: - return [] - - # Checks all replicas have the same structure. - for i in xrange(1, num_replicas): - nest.assert_same_structure(inputs[0], inputs[i]) - - # Flatten inputs. - flat_inputs = [ - nest.flatten(per_replica_input) for per_replica_input in inputs - ] - # Converts inputs to Tensors. - flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] - - # Verifies that all replicas have matching numbers and types of inputs - flat_input_types = [x.dtype for x in flat_inputs[0]] - input_arity = len(inputs[0]) - flat_input_arity = len(flat_input_types) - for i in range(num_replicas): - if len(inputs[i]) != input_arity: - raise ValueError("Replicas must have the same number of inputs. " - "Replica 0 had {} inputs, replica {} had {} " - "inputs.".format(input_arity, i, len(inputs[i]))) - - types = [x.dtype for x in flat_inputs[i]] - if types != flat_input_types: - raise ValueError("Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - flat_input_types, i, types)) - - arg_error = xla.check_function_argument_count( - computation, input_arity, infeed_queue) - if arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied computation cannot be called with the specified inputs. " - "You specified %d inputs: %s, but the computation needs %s" % ( - input_arity, str([i.name for i in inputs[0]]), arg_error)) - else: - raise TypeError( - "Supplied computation cannot be called with the specified inputs. " - "You specified %d inputs: %s and %d additional inputs from infeed," - " but the computation needs %s" % (input_arity, str( - [i.name - for i in inputs[0]]), infeed_queue.number_of_tuple_elements, - arg_error)) - - graph = ops.get_default_graph() - - # Fan-in: Builds a TPUReplicatedInput node for each input. - computation_inputs = [] - for i in range(0, flat_input_arity): - replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] - computation_inputs.append( - tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - - cluster_name = graph.unique_name("cluster") - pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") - context = TPUReplicateContext( - name=cluster_name, num_replicas=num_replicas, pivot=pivot) - try: - context.Enter() - - metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) - - with tpu_function.tpu_shard_context( - num_replicas), ops.control_dependencies([metadata]): - - # Add identity ops so even unused inputs are "consumed" by the - # computation. This is to avoid orphaned TPUReplicatedInput nodes. - # TODO(phawkins): consider instead pruning unused TPUReplicatedInput - # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. - computation_inputs = [ - array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs) - ] - for i in computation_inputs: - # pylint: disable=protected-access - i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - - # Unflatten the computation inputs to match original input structure. - computation_inputs = nest.pack_sequence_as( - structure=inputs[0], flat_sequence=computation_inputs) - - # If there is an infeed queue, adds the dequeued values to the - # computation's inputs. - if infeed_queue is not None: - infeed_queue.set_number_of_shards(num_replicas) - for t in infeed_queue.generate_dequeue_op(): - computation_inputs.append(t) - - # Only resource variables work inside a TPU computation, so turn on - # resource variables for the computation. - # TODO(phawkins): consider removing this code. It will - # be less confusing to clients if they knowingly choose to use resource - # variables. - # Partitioned variables is not supported (b/112311320). - vscope = variable_scope.get_variable_scope() - saved_use_resource = vscope.use_resource - saved_custom_getter = vscope.custom_getter - - def custom_getter(getter, name, *args, **kwargs): - """Variables on TPU have a few restrictions.""" - partitioner = kwargs["partitioner"] - if partitioner is not None: - kwargs["partitioner"] = None - logging.warning( - "Partitioned variables are not supported on TPU. Got " - "`partitioner` that is {} for variable {}. " - "Setting `partitioner` to `None`." - .format(partitioner, name)) - if saved_custom_getter is None: - return getter(name, *args, **kwargs) - else: - return saved_custom_getter(getter, name, *args, **kwargs) - - vscope.set_use_resource(True) - vscope.set_custom_getter(custom_getter) - - outputs = computation(*computation_inputs) - - vscope.set_use_resource(saved_use_resource) - vscope.set_custom_getter(saved_custom_getter) - - outputs_is_flat = xla.is_flat(outputs) - if outputs_is_flat: - output_tensors, control_deps = _postprocess_flat_outputs(outputs) - else: - output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - - context.ExitResult(output_tensors) - finally: - context.report_unsupported_operations() - context.Exit() - host_compute_core = context.HostComputeCore() - - if host_compute_core: - attr_value = attr_value_pb2.AttrValue() - attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) - metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access - - with ops.control_dependencies([metadata]): - if use_tpu: - compile_status = tpu_ops.tpu_compilation_result() - op = compile_status.op - attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) - op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access - else: - compile_status = control_flow_ops.no_op(name="compilation_status") - - if not output_tensors: - # Returns a list of NoOps dependent on the replication Op, indexed by - # [replica_num]. - return [ - compile_status, - [ - control_flow_ops.group(control_deps, name="shard_%d" % i) - for i in range(num_replicas) - ] - ] - - # Fan-out: Builds a TPUReplicatedOutput node for each output. - replicated_outputs = [[] for i in xrange(num_replicas)] - for i, t in enumerate(output_tensors): - # Fan-out: Builds a TPUReplicatedOutput node for each output. - ys = tpu_ops.tpu_replicated_output( - t, num_replicas, name="output{}".format(i)) - - # Wraps the outputs in identity operators so the names of any possible - # `fetch` nodes are preserved by the replication rewrite. - with ops.control_dependencies(control_deps): - for replica in xrange(num_replicas): - replicated_outputs[replica].append( - array_ops.identity( - ys[replica], name="output_%d_shard_%d" % (i, replica))) - - if not outputs_is_flat: - replicated_outputs = [ - nest.pack_sequence_as(outputs, replica_outs) - for replica_outs in replicated_outputs - ] - - return [compile_status, replicated_outputs] - - -def _postprocess_flat_outputs(outputs): - """Validates non-flat outputs, add backs device assignments and other attrs. - - Args: - outputs: Output from `computation` inside `tpu.rewrite`. - - Returns: - Tensors and Operations extracted from outputs. - """ - # Following code segment is to preserve legacy behavior. Previously we only - # supported flat outputs and thus for consistency it was nice to convert even - # single element into a tuple. But now that we support arbitrary output - # structure, this is no longer necessary. - # TODO(b/121383831): Migrate all legacy use cases and delete this special - # case. - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, makes it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that fetching any return value of this function - # will trigger TPUExecute node. - outputs += (control_flow_ops.no_op(),) - try: - with ops.device(core(0)): - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - "convertible to Tensors. Got '%s'" % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU functions must return zero-or more Tensor values followed by " - "zero or more Operations.") - - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else core(0)): - o = array_ops.identity(t) - # pylint: disable=protected-access - o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - new_output_tensors.append(o) - return new_output_tensors, output_operations - - -def _postprocess_non_flat_outputs(outputs): - """Validates non-flat outputs, add backs device assignments and other attrs. - - Args: - outputs: Output from `computation` inside `tpu.rewrite`. - - Returns: - Tensors extracted from outputs and an empty list because Operations are not - allowed in non-flat outputs.. - """ - - # Flatten output items. - flat_outputs = nest.flatten(outputs) - - # Convert all non-Operation outputs to Tensors. - for i, o in enumerate(flat_outputs): - if isinstance(o, ops.Operation): - raise ValueError( - "tpu.rewrite does not support Operation as return value in non-flat " - "output structure. You can set returned Operations as control " - "dependencies of returned Tensors so Operations are triggered when " - 'Tensors are evaluated. Operation found: "%s"' % o.name) - - try: - o = ops.convert_to_tensor(o) - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - 'convertible to Tensors. Got error: "%s"' % str(e)) - - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - with ops.device(core(0)): - o = array_ops.identity(o) - # pylint: disable=protected-access - o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - flat_outputs[i] = array_ops.identity(o) - - # All flat_outputs are Tensors, and no Operations. - return flat_outputs, [] - - -def split_compile_and_shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` for parallel execution. - - `inputs` must be a list of Tensors or None (equivalent to an empty list), each - of which has a corresponding split axis (from `input_shard_axes`). Each input - is split into `num_shards` pieces along the corresponding axis, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - TODO(phawkins): consider adding support for broadcasting Tensors passed - as inputs. - - If `outputs_from_all_shards` is true, the outputs from all shards of - `computation` are concatenated back together along their `output_shards_axes`. - Otherwise, each output is taken from an arbitrary shard. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). Each - input tensor has a corresponding shard axes, given by `input_shard_axes`, - which must have size divisible by `num_shards`. - num_shards: The number of shards. - input_shard_axes: A list of dimensions along which to shard `inputs`, or - `None`. `None` means "shard all inputs along dimension 0". If not `None`, - there must be one dimension per input. - outputs_from_all_shards: Boolean or list of boolean. For each output, if - `True`, outputs from all shards are concatenated along the corresponding - `output_shard_axes` entry. Otherwise, each output is taken - from an arbitrary shard. If the argument is a boolean, the argument's - value is used for each output. - output_shard_axes: A list of dimensions along which to concatenate the - outputs of `computation`, or `None`. `None` means "concatenate all outputs - along dimension 0". If not `None`, there must be one dimension per output. - Ignored if `outputs_from_all_shards` is False. - infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs - of `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A tuple of (compile op, [output tensors]). - Raises: - ValueError: If num_shards <= 0 - ValueError: If len(input_shard_axes) != len(inputs) - ValueError: If len(output_shard_axes) != len(outputs from `computation`) - """ - - if num_shards <= 0: - raise ValueError("num_shards must be a positive integer.") - - inputs = [] if inputs is None else inputs - if not isinstance(inputs, list): - raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") - - # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - - if input_shard_axes is None: - input_shard_axes = [0] * len(inputs) - if len(inputs) != len(input_shard_axes): - raise ValueError("Length of input_shard_axes must be equal to the number " - "of inputs.") - - if inputs: - # Splits the `inputs` along the corresponding `input_shard_axes`, giving - # lists with layout [input][shard] - split_inputs = [ - array_ops.split(x, num_shards, axis=axis) - for (axis, x) in zip(input_shard_axes, inputs)] - - # Transposes the input lists to have layout [shard][input] - transposed_inputs = [list(i) for i in zip(*split_inputs)] - else: - transposed_inputs = [[]] * num_shards - - compile_op, outputs = split_compile_and_replicate( - computation, - transposed_inputs, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - - # There must be at least one shard since num_shards > 0. - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - if isinstance(outputs[0], ops.Operation): - # pylint: enable=indexing-exception - # There were no outputs from the computation and replicate returned a list - # of NoOps with control dependencies on the computation. Return the first - # one so it can be used as a control dependency or fetch node. - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - return compile_op, [outputs[0]] - # pylint: enable=indexing-exception - - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - num_outputs = len(outputs[0]) - # pylint: enable=indexing-exception - - if output_shard_axes is None: - output_shard_axes = [0] * num_outputs - if num_outputs != len(output_shard_axes): - raise ValueError("Length of output_shard_axes must be equal to the number " - "of outputs.") - - if isinstance(outputs_from_all_shards, bool): - outputs_from_all_shards = [outputs_from_all_shards] * num_outputs - - if num_outputs != len(outputs_from_all_shards): - raise ValueError("Length of outputs_from_all_shards must be equal to the " - "number of outputs.") - - results = [] - for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, - zip(*outputs)): - if all_shards: - # Concatenate all of the outputs together (use stack for scalars). - shape = x[0].shape - is_scalar = shape is not None and (shape.ndims == 0) - results.append((array_ops.stack(list(x)) if is_scalar - else array_ops.concat(list(x), axis=axis))) - else: - # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. - results.append(x[0]) - - return compile_op, results - - -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` for parallel execution. - - `inputs` must be a list of Tensors or None (equivalent to an empty list), each - of which has a corresponding split axis (from `input_shard_axes`). Each input - is split into `num_shards` pieces along the corresponding axis, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - TODO(phawkins): consider adding support for broadcasting Tensors passed - as inputs. - - If `outputs_from_all_shards` is true, the outputs from all shards of - `computation` are concatenated back together along their `output_shards_axes`. - Otherwise, each output is taken from an arbitrary shard. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). Each - input tensor has a corresponding shard axes, given by `input_shard_axes`, - which must have size divisible by `num_shards`. - num_shards: The number of shards. - input_shard_axes: A list of dimensions along which to shard `inputs`, or - `None`. `None` means "shard all inputs along dimension 0". If not `None`, - there must be one dimension per input. - outputs_from_all_shards: Boolean or list of boolean. For each output, if - `True`, outputs from all shards are concatenated along the corresponding - `output_shard_axes` entry. Otherwise, each output is taken - from an arbitrary shard. If the argument is a boolean, the argument's - value is used for each output. - output_shard_axes: A list of dimensions along which to concatenate the - outputs of `computation`, or `None`. `None` means "concatenate all outputs - along dimension 0". If not `None`, there must be one dimension per output. - Ignored if `outputs_from_all_shards` is False. - infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs - of `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - Raises: - ValueError: If num_shards <= 0 - ValueError: If len(input_shard_axes) != len(inputs) - ValueError: If len(output_shard_axes) != len(outputs from `computation`) - """ - return split_compile_and_shard( - computation, - inputs=inputs, - num_shards=num_shards, - input_shard_axes=input_shard_axes, - outputs_from_all_shards=outputs_from_all_shards, - output_shard_axes=output_shard_axes, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name)[1] - - -def batch_parallel(computation, - inputs=None, - num_shards=1, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` along the batch dimension for parallel execution. - - Convenience wrapper around shard(). - - `inputs` must be a list of Tensors or None (equivalent to an empty list). - Each input is split into `num_shards` pieces along the 0-th dimension, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - The outputs from all shards are concatenated back together along their 0-th - dimension. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). The - 0-th dimension of each Tensor must have size divisible by `num_shards`. - num_shards: The number of shards. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - Raises: - ValueError: If `num_shards <= 0` - """ - return shard( - computation, - inputs, - num_shards=num_shards, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - - -def rewrite(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Rewrites `computation` for execution on a TPU system. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. - - `computation` may return a list of operations and tensors. Tensors must - come before operations in the returned list. The return value of - `rewrite` is a list of tensors corresponding to the tensors from the - output of `computation`. - - All `Operation`s constructed during `computation` will be executed when - evaluating any of the returned output tensors, not just the ones returned. - inputs: A list of input tensors or `None` (equivalent to an empty list). - Each input can be a nested structure containing values that are - convertible to tensors. Note that passing an N-dimension list of - compatible values will result in a N-dimention list of scalar tensors - rather than a single Rank-N tensors. If you need different behavior, - convert part of inputs to tensors with `tf.convert_to_tensor`. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: if not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. May be omitted for a single-core computation, in which - case the core attached to task 0, TPU device 0 is used. - name: (Deprecated) Does nothing. - Returns: - Same data structure as if computation(*inputs) is called directly with some - exceptions for correctness. Exceptions include: - 1) None output: a NoOp would be returned which control-depends on - computation. - 2) Single value output: A tuple containing the value would be returned. - 3) Operation-only outputs: a NoOp would be returned which - control-depends on computation. - TODO(b/121383831): Investigate into removing these special cases. - """ - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - return replicate( - computation, - None if inputs is None else [inputs], - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name)[0] - # pylint: enable=indexing-exception - - # Operations that indicate some error in the user's inference graph. -_BLACKLISTED_INFERENCE_OPS = set([ - "ReadVariableOp", - "AssignVariableOp", - "AssignAddVariableOp", - "AssignSubVariableOp", - "VarHandleOp", - "Variable", - "VariableV2", -]) - - -def under_tpu_inference_context(): - """Check if it is currently under `tpu.rewrite_for_inference()`.""" - graph = ops.get_default_graph() - - context = graph._get_control_flow_context() # pylint: disable=protected-access - while context: - if isinstance(context, _TPUInferenceContext): - return True - context = context.outer_context - - return False - - -class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside a TPU inference computation. - - The primary role of `TPUReplicateContext` is to sanity check operators inside - a tpu.rewrite_for_inference() computation. - """ - - def __init__(self, name): - super(_TPUInferenceContext, self).__init__() - self._name = name - - def AddOp(self, op): - self._AddOpInternal(op) - - def _AddOpInternal(self, op): - # pylint: disable=protected-access - if op.type in _BLACKLISTED_INFERENCE_OPS: - raise NotImplementedError( - "Operation of type %s (%s) is not supported on the TPU for inference." - " Execution will fail if this op is used in the graph. Make sure your" - " variables are using variable_scope." % (op.type, op.name)) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - result = val - if self._outer_context: - result = self._outer_context.AddValue(val) - return result - - def AddInnerOp(self, op): - self._AddOpInternal(op) - - @property - def grad_state(self): - return None - - -@experimental -def validate_inference_rewrite_for_variables(graph): - """Validates whether rewrite_for_inference() 'worked' for variables. - - The rewrite_for_inference() method is supposed to append GuaranteeConstOps - after ReadVariableOps, but this mechanism works only if you are using - tf.get_variable() to create and access variables in your tpu computation. - This validation method can be called immediately after calling - tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added - to the graph. - - Typical usages: - tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) - - tpu.validate_inference_rewrite_for_variables(sess.graph) - - Args: - graph: The graph which needs to be validated. - Raises: - RuntimeError: if validation failed. - """ - if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): - raise RuntimeError( - "No GuaranteeConst ops found in the graph after running " - "tpu.rewrite_for_inference(...). Please check that you are using " - "tf.get_variable() to create and access variables in your tpu " - "computation.") - - -@experimental -def rewrite_for_inference(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Rewrites `computation` for inference on a TPU system. - - Other than 'rewriting' the computation to run on a TPU, if using variables - in your computation, it moves the ReadVariableOps outside the TPU - computation, and adds GuaranteeConst ops just after the ReadVariableOps. - This mechanism works only if you are using tf.get_variable() to create and - access variables in your tpu computation. You can validate whether this - worked, by calling validate_inference_rewrite_for_variables() method - immediately after this method to check whether GuaranteeConstOps where - added to the graph. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. If the function returns m outputs, rewrite will return a list of - m tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: if not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. May be omitted for a single-core computation, in which - case the core attached to task 0, TPU device 0 is used. - name: The name of the operator. - Returns: - A list of output tensors. - """ - - def guarantee_const_getter(getter, name, *args, **kwargs): - with ops.control_dependencies(None): - return array_ops.guarantee_const( - getter(name, *args, **kwargs), name=name + "/GuaranteeConst") - - def wrapped_computation(*args, **kwargs): - """Execute computation under `_TPUInferenceContext`.""" - context = _TPUInferenceContext( - name=ops.get_default_graph().unique_name("rewrite_for_inference")) - try: - context.Enter() - - vscope = variable_scope.get_variable_scope() - prev_custom_getter = vscope.custom_getter - prev_caching_device = vscope.caching_device - vscope.set_custom_getter(guarantee_const_getter) - vscope.set_caching_device(lambda op: op.device) - - result = computation(*args, **kwargs) - - vscope.set_custom_getter(prev_custom_getter) - vscope.set_caching_device(prev_caching_device) - finally: - context.Exit() - return result - - # pylint: disable=undefined-variable - return rewrite( - wrapped_computation, - inputs=inputs, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - # pylint: enable=undefined-variable +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu import * +# used by tests +from tensorflow.python.tpu.tpu import _TPU_REPLICATE_ATTR +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 9f8d14706845baa1ed45c84b2c15d372915a0eb4..c36aaa38c0e4823bfc438773e4aa5b5109794da4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -1,275 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""A RunConfig subclass with TPU support.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import json -import os - -from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.platform import tf_logging as logging - -# pylint: disable=protected-access -_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV -_SERVICE_KEY = run_config_lib._SERVICE_KEY -_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' -# pylint: enable=protected-access - - -class InputPipelineConfig(object): - r"""Please see the definition of these values in TPUConfig.""" - PER_SHARD_V1 = 1 - PER_HOST_V1 = 2 - PER_HOST_V2 = 3 - BROADCAST = 4 - - -class TPUConfig( - collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', - 'num_shards', - 'num_cores_per_replica', - 'per_host_input_for_training', - 'tpu_job_name', - 'initial_infeed_sleep_secs', - 'input_partition_dims', - ])): - r"""TPU related configuration required by `TPUEstimator`. - - Args: - iterations_per_loop: This is the number of train steps running in TPU - system before returning to CPU host for each `Session.run`. This means - global step is increased `iterations_per_loop` times in one `Session.run`. - It is recommended to be set as number of global steps for next checkpoint. - num_shards: (Deprecated, ignored by TPUEstimator). - The number of model replicas in the system. For non-model-parallelism - case, this number equals the total number of TPU cores. For - model-parallelism, the total number of TPU cores equals - num_cores_per_replica * num_shards. - num_cores_per_replica: Defaults to `None`, which disables model parallelism. - An integer which describes the number of TPU cores per model replica. This - is required by model-parallelism which enables partitioning - the model to multiple cores. Currently num_cores_per_replica must be - 1, 2, 4, or 8. - per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, - `input_fn` is invoked once on each host. With the per-core input pipeline - configuration, it is invoked once for each core. - With a global batch size `train_batch_size` in `TPUEstimator` constructor, - the batch size for each shard is `train_batch_size` // #hosts in the - `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is - `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only - invoked once on host 0 and the tensors are broadcasted to all other - replicas. The batch size equals to train_batch_size`. With the per-core - input pipeline configuration, the shard batch size is also - `train_batch_size` // #cores. - Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. - tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred - within TPUEstimator, however when using ClusterSpec propagation in more - esoteric cluster configurations, you may need to specify the job name as a - string. - initial_infeed_sleep_secs: The number of seconds the infeed thread should - wait before enqueueing the first batch. This helps avoid timeouts for - models that require a long compilation time. - input_partition_dims: A nested list to describe the partition dims - for all the tensors from input_fn(). The structure of - input_partition_dims must match the structure of `features` and - `labels` from input_fn(). The total number of partitions must match - `num_cores_per_replica`. For example, if input_fn() returns two tensors: - images with shape [N, H, W, C] and labels [N]. - input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 - pieces and feed into 4 TPU cores. labels tensor are directly broadcasted - to all the TPU cores since the partition dims is `None`. - Current limitations: This feature is only supported with the PER_HOST_V2 - input mode. - - Raises: - ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16. - """ - - def __new__(cls, - iterations_per_loop=2, - num_shards=None, - num_cores_per_replica=None, - per_host_input_for_training=True, - tpu_job_name=None, - initial_infeed_sleep_secs=None, - input_partition_dims=None): - - # Check iterations_per_loop. - util_lib.check_positive_integer(iterations_per_loop, - 'TPUConfig iterations_per_loop') - - # Check num_shards. - if num_shards is not None: - util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - - if input_partition_dims is not None: - if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: - raise ValueError( - 'input_partition_dims must be a list/tuple with one or two' - ' elements.') - - if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: - raise ValueError( - 'input_partition_dims is only supported in PER_HOST_V2 mode.') - - if num_cores_per_replica is None: - raise ValueError( - 'input_partition_dims requires setting num_cores_per_replica.') - - # Check num_cores_per_replica - if num_cores_per_replica is not None: - if num_cores_per_replica not in [1, 2, 4, 8, 16]: - raise ValueError( - 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format( - str(num_cores_per_replica))) - - # per_host_input_for_training may be True, False, or integer in [1..3]. - # Map legacy values (True, False) to numeric values. - if per_host_input_for_training is False: - per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1 - elif per_host_input_for_training is True: - per_host_input_for_training = InputPipelineConfig.PER_HOST_V1 - - # Check initial_infeed_sleep_secs. - if initial_infeed_sleep_secs: - util_lib.check_positive_integer(initial_infeed_sleep_secs, - 'TPUConfig initial_infeed_sleep_secs') - - tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config() - - return super(TPUConfig, cls).__new__( - cls, - iterations_per_loop=iterations_per_loop, - num_shards=num_shards, - num_cores_per_replica=num_cores_per_replica, - per_host_input_for_training=per_host_input_for_training, - tpu_job_name=tpu_job_name, - initial_infeed_sleep_secs=initial_infeed_sleep_secs, - input_partition_dims=input_partition_dims) - - -class RunConfig(run_config_lib.RunConfig): - """RunConfig with TPU support.""" - - def __init__(self, - tpu_config=None, - evaluation_master=None, - master=None, - cluster=None, - **kwargs): - """Constructs a RunConfig. - - Args: - tpu_config: the TPUConfig that specifies TPU-specific configuration. - evaluation_master: a string. The address of the master to use for eval. - Defaults to master if not set. - master: a string. The address of the master to use for training. - cluster: a ClusterResolver - **kwargs: keyword config parameters. - - Raises: - ValueError: if cluster is not None and the provided session_config has a - cluster_def already. - """ - super(RunConfig, self).__init__(**kwargs) - self._tpu_config = tpu_config or TPUConfig() - self._cluster = cluster - - # If user sets master and/or evaluation_master explicitly, including empty - # string '', take it. Otherwise, take the values set by parent class. - if master is not None: - if cluster is not None: - raise ValueError('Both master and cluster are set.') - self._master = master - else: - if cluster: - self._master = cluster.master() - - if evaluation_master is not None: - self._evaluation_master = evaluation_master - elif (not self._evaluation_master and - self.task_type != run_config_lib.TaskType.EVALUATOR): - # If the task type is EVALUATOR, it means some cluster manager sets the - # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG. - # - # Otherwise, it means user executes the code without external cluster - # manager. For that, we optimize the user experience by setting - # evaluation_master to master, unless user overwrites it. - self._evaluation_master = self._master - - # Set the ClusterSpec to use - if cluster: - self._cluster_spec = cluster.cluster_spec() - - # Merge the cluster_def into the ConfigProto. - if self._session_config is None: # pylint: disable=access-member-before-definition - self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) - if self._session_config.HasField('cluster_def'): - raise ValueError( - 'You cannot provide a ClusterResolver and ' - 'session_config.cluster_def.') - if self._cluster_spec: - self._session_config.cluster_def.CopyFrom( - self._cluster_spec.as_cluster_def()) - - def _maybe_overwrite_session_config_for_distributed_training(self): - # Overrides the parent class session_config overwrite for between-graph. TPU - # runs with in-graph, which should not have device filter. Doing nothing - # ("pass") basically disables it. - pass - - @property - def evaluation_master(self): - return self._evaluation_master - - @property - def master(self): - return self._master - - @property - def tpu_config(self): - return self._tpu_config - - @property - def cluster(self): - return self._cluster - - def replace(self, **kwargs): - if 'tpu_config' not in kwargs: - return super(RunConfig, self).replace(**kwargs) - - tpu_config = kwargs.pop('tpu_config') - new_instance = super(RunConfig, self).replace(**kwargs) - new_instance._tpu_config = tpu_config # pylint: disable=protected-access - return new_instance - - -def _get_tpu_job_name_from_tf_config(): - """Extracts the TPU job name from TF_CONFIG env variable.""" - # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster - # spec propagation. - tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) - tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME) - if tpu_job_name: - logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name) - return tpu_job_name +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_config import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 672462447944b777375331d49727c4d5366cf295..b77b010cba6bf32c3b6d170bc522eebfb6a04f77 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -1,725 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU system metadata and associated tooling.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from contextlib import contextmanager -import copy - -from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.platform import tf_logging as logging - - -_DEFAULT_JOB_NAME = 'tpu_worker' -_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' -_LOCAL_MASTERS = ('', 'local') -_NUM_CORES_TO_COMPUTATION_SHAPE = { - 1: [1, 1, 1], - 2: [1, 1, 2], - 4: [1, 2, 2], - 8: [2, 2, 2], - 16: [4, 2, 2], -} - - -class TPUContext(object): - """A context that holds the current configuration of the TPU computation.""" - - def __init__(self, - internal_ctx, - input_device=None, - invocation_index=None, - call_from_input_fn=True): - self._internal_ctx = internal_ctx - self._input_device = input_device - self._invocation_index = invocation_index - self._call_from_input_fn = call_from_input_fn - - def current_input_fn_deployment(self): - """The configuration of the current input_fn invocation. - - The configuration depends on `TPUConfig.per_host_input_for_training`. See - `TPUConfig` for details. - - Only set in params dict of input_fn - - Returns: - A tuple of - 1. Device spec string: String, is the current CPU host where the - input_fn is invoked. - 2. Current invocation index: Int, 0-based index of the input_fn - invocation. See next item for details. - 3. Total invocation count: Int, the total number of times to invoke the - input_fn on all CPU hosts. Each invocation will be passed with a new - `TPUContext` instance with current invocation index set properly. - 4. Total number of replicas consumed by current_invocation: Int, the - number of replicas fed by the data returned by current input_fn. For - example, for per_core input pipeline deployment - and non-model-parallelism, total invocation count is equal to - the number of cores in the system and num replicas consumed by - current invocation is 1. For per-host v2 input pipeline deployment, - total invocation count is equal to the number of hosts in the system - and num replicas consumed by current invocation is equal to number of - cores per host. - - Raises: - RuntimeError: If this method must not be called from input_fn. - """ - if not self._call_from_input_fn: - raise RuntimeError('This TPUContext instance must not be called from' - ' model_fn.') - - if self._internal_ctx.is_input_sharded_per_core(): - total_invocation_count = (self._internal_ctx.num_hosts - * self._internal_ctx.num_of_replicas_per_host) - replicas_consumed = 1 - elif self._internal_ctx.is_input_broadcast_with_iterators(): - total_invocation_count = 1 - replicas_consumed = self._internal_ctx.num_replicas - else: - total_invocation_count = self._internal_ctx.num_hosts - replicas_consumed = self._internal_ctx.num_of_replicas_per_host - return (self._input_device, self._invocation_index, - total_invocation_count, replicas_consumed) - - @property - def num_replicas(self): - """The total number of replicas. - - For non-model-parallelism, num_replicas should be the total num of TPU - cores in the system. - - Returns: - The number of replicas. - """ - return self._internal_ctx.num_replicas - - @property - def num_hosts(self): - """The number of hosts for the TPU system.""" - return self._internal_ctx.num_hosts - - @property - def current_host(self): - """The current host index for the TPU system.""" - return self._invocation_index - - @property - def num_of_replicas_per_host(self): - """The number of replicas for each host.""" - if self._internal_ctx.model_parallelism_enabled: - raise ValueError( - 'num_of_replicas_per_host is not supported for model_parallelism') - return self._internal_ctx.num_of_replicas_per_host - - @property - def device_assignment(self): - """Returns device_assignment object.""" - if self._call_from_input_fn: - raise RuntimeError('This TPUContext instance must not be called from' - ' input_fn.') - return self._internal_ctx.device_assignment - - def device_for_replica(self, replica_id): - """Returns the tuple of (CPU device and device ordinal) for replica. - - This should be used for full replicate for non-model-parallelism. - - Args: - replica_id: Int, the replica index. - - Returns: - A tuple of device spec for CPU device and int device ordinal. - """ - # Note that: For the non-model parallelism, the mapping could be - # a random permutation. The order should not matter in most cases - # as far as model is replicated to all cores in the system. - return self._internal_ctx.device_for_replica(replica_id) - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function. - - The place function takes host_id as the input and returns the TF device - for the correspoding host. - """ - - def _placement_function(host_id): - """Return the host device given host_id.""" - return self._internal_ctx.tpu_host_placement_function(host_id=host_id) - - return _placement_function - - -class _InternalTPUContext(object): - """A context holds immutable states of TPU computation. - - This immutable object holds TPUEstimator config, train/eval batch size, and - `TPUEstimator.use_tpu`, which is expected to be passed around. It also - provides utility functions, based on the current state, to determine other - information commonly required by TPU computation, such as TPU device names, - TPU hosts, shard batch size, etc. - - if eval_on_tpu is False, then execution of eval on TPU is disabled. - if eval_on_tpu is True, but use_tpu is False, a warning is issued, - and TPU execution is disabled for all modes. - - N.B. As `mode` is not immutable state in Estimator, but essential to - distinguish between TPU training and evaluation, a common usage for - _InternalTPUContext with `mode` is as follows: - ``` - with _ctx.with_mode(mode) as ctx: - if ctx.is_running_on_cpu(): - ... - ``` - """ - - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu=True): - self._config = config - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size - self._predict_batch_size = predict_batch_size - self._use_tpu = use_tpu - logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu) - if not use_tpu and eval_on_tpu: - logging.warning('eval_on_tpu ignored because use_tpu is False.') - - self._eval_on_tpu = eval_on_tpu - self._model_parallelism_enabled = ( - use_tpu and config.tpu_config.num_cores_per_replica) - self._mode = None - num_cores_per_replica = config.tpu_config.num_cores_per_replica - if num_cores_per_replica: - self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ - num_cores_per_replica] - else: - self._computation_shape = None - self._lazy_tpu_system_metadata_dict = {} # key by master address - self._lazy_device_assignment_dict = {} # key by master address - self._lazy_validation_dict = {} # key by ModeKeys - - def _assert_mode(self): - if self._mode is None: - raise RuntimeError( - '`mode` needs to be set via contextmanager `with_mode`.') - return self._mode - - @contextmanager - def with_mode(self, mode): - # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries, - # such as _lazy_tpu_system_metadata_dict between new copy and the original - # one. Note that all lazy states stored in properties _lazy_foo are sort of - # immutable as they should be same for the process lifetime. - new_ctx = copy.copy(self) - new_ctx._mode = mode # pylint: disable=protected-access - yield new_ctx - - @property - def mode(self): - return self._assert_mode() - - def _get_master_address(self): - mode = self._assert_mode() - config = self._config - master = ( - config.master - if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master) - return master - - def _get_tpu_system_metadata(self): - """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() - tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) - if tpu_system_metadata is not None: - return tpu_system_metadata - - cluster_def = None - if (self._config.session_config and - self._config.session_config.cluster_def.job): - cluster_def = self._config.session_config.cluster_def - - # pylint: disable=protected-access - tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata( - master, - cluster_def=cluster_def, - query_topology=self.model_parallelism_enabled)) - - self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata - return tpu_system_metadata - - def _get_device_assignment(self): - """Gets the (maybe cached) TPU device assignment.""" - master = self._get_master_address() - device_assignment = self._lazy_device_assignment_dict.get(master) - if device_assignment is not None: - return device_assignment - - tpu_system_metadata = self._get_tpu_system_metadata() - - device_assignment = tpu_device_assignment.device_assignment( - tpu_system_metadata.topology, - computation_shape=self._computation_shape, - num_replicas=self.num_replicas) - - logging.info('num_cores_per_replica: %s', - str(self._config.tpu_config.num_cores_per_replica)) - logging.info('computation_shape: %s', str(self._computation_shape)) - logging.info('num_replicas: %d', self.num_replicas) - logging.info('device_assignment.topology.device_coordinates: %s', - str(device_assignment.topology.device_coordinates)) - logging.info('device_assignment.core_assignment: %s', - str(device_assignment.core_assignment)) - - self._lazy_device_assignment_dict[master] = device_assignment - return device_assignment - - @property - def model_parallelism_enabled(self): - return self._model_parallelism_enabled - - @property - def input_partition_dims(self): - return self._config.tpu_config.input_partition_dims - - @property - def device_assignment(self): - return (self._get_device_assignment() - if self._model_parallelism_enabled else None) - - @property - def num_of_cores_per_host(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_of_cores_per_host - - @property - def num_cores(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_cores - - @property - def num_of_replicas_per_host(self): - """Return the number of replicas per host.""" - if self.model_parallelism_enabled: - return self.num_replicas // self.num_hosts - else: - return self.num_of_cores_per_host - - @property - def num_replicas(self): - num_cores_in_system = self.num_cores - - if self.model_parallelism_enabled: - num_cores_per_replica = self._config.tpu_config.num_cores_per_replica - if num_cores_per_replica > num_cores_in_system: - raise ValueError( - 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica, is larger than the total num of ' - 'TPU cores in the system. num_cores_per_replica: {}, num cores ' - 'in the system: {}'.format(num_cores_per_replica, - num_cores_in_system)) - - if num_cores_in_system % num_cores_per_replica != 0: - raise RuntimeError( - 'The num of cores in the system ({}) is not divisible by the num ' - 'of cores ({}) required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica. This should never happen!'.format( - num_cores_in_system, num_cores_per_replica)) - - return num_cores_in_system // num_cores_per_replica - else: - return num_cores_in_system - - @property - def num_hosts(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_hosts - - @property - def config(self): - return self._config - - def is_input_sharded_per_core(self): - """Return true if input_fn is invoked per-core (other than per-host).""" - mode = self._assert_mode() - return (mode == model_fn_lib.ModeKeys.TRAIN and - (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1)) - - def is_input_per_host_with_iterators(self): - """Return true if input_fn should be run in the per-host v2 config.""" - return (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_HOST_V2) - - def is_input_broadcast_with_iterators(self): - """Return true if input_fn should be run in the full_replicae config.""" - return (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.BROADCAST) - - def is_running_on_cpu(self, is_export_mode=False): - """Determines whether the input_fn and model_fn should be invoked on CPU. - - This API also validates user provided configuration, such as batch size, - according the lazy initialized TPU system metadata. - - Args: - is_export_mode: Indicates whether the current mode is for exporting the - model, when mode == PREDICT. Only with this bool, we could - tell whether user is calling the Estimator.predict or - Estimator.export_savedmodel, which are running on TPU and CPU - respectively. Parent class Estimator does not distinguish these two. - - Returns: - bool, whether current input_fn or model_fn should be running on CPU. - - Raises: - ValueError: any configuration is invalid. - """ - - is_running_on_cpu = self._is_running_on_cpu(is_export_mode) - if not is_running_on_cpu: - self._validate_tpu_configuration() - return is_running_on_cpu - - def _is_running_on_cpu(self, is_export_mode): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - mode = self._assert_mode() - - if not self._use_tpu: - return True - - if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu: - logging.info('_is_running_on_cpu: eval_on_tpu disabled') - return True - - if is_export_mode: - return True - - return False - - @property - def global_batch_size(self): - mode = self._assert_mode() - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - elif mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - elif mode == model_fn_lib.ModeKeys.PREDICT: - return self._predict_batch_size - else: - return None - - @property - def batch_size_for_input_fn(self): - """Returns the shard batch size for `input_fn`.""" - global_batch_size = self.global_batch_size - - if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): - return global_batch_size - - # On TPU - if self.is_input_sharded_per_core() or ( - self.is_input_per_host_with_iterators()): - return global_batch_size // self.num_replicas - else: - return global_batch_size // self.num_hosts - - @property - def batch_size_for_model_fn(self): - """Returns the shard batch size for `model_fn`.""" - global_batch_size = self.global_batch_size - - if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): - return global_batch_size - - # On TPU. always sharded per shard. - return global_batch_size // self.num_replicas - - @property - def master_job(self): - """Returns the job name to use to place TPU computations on. - - Returns: - A string containing the job name, or None if no job should be specified. - - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. - """ - run_config = self._config - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - mode = self._assert_mode() - master = ( - run_config.evaluation_master - if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - if master in _LOCAL_MASTERS: - return None - - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part ' - 'of your TPUConfig.') - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function.""" - - master = self.master_job - - def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name - """Return the host device given replica_id or host_id.""" - assert _sentinal is None - if replica_id is not None and host_id is not None: - raise RuntimeError( - 'replica_id and host_id can have only one non-None value.') - - if master is None: - return '/replica:0/task:0/device:CPU:0' - else: - if replica_id is not None: - if self.model_parallelism_enabled: - return self.device_assignment.host_device( - replica=replica_id, job=master) - else: - host_id = replica_id / self.num_of_cores_per_host - - return '/job:%s/task:%d/device:CPU:0' % (master, host_id) - - return _placement_function - - @property - def tpu_device_placement_function(self): - """Returns a TPU device placement Fn.""" - master = self.master_job - job_device = '' if master is None else ('/job:%s' % master) - - def _placement_function(i): - if self.model_parallelism_enabled: - return self.device_assignment.tpu_device(replica=i, job=master) - else: - num_of_cores_per_host = self.num_of_cores_per_host - host_id = i / num_of_cores_per_host - ordinal_id = i % num_of_cores_per_host - return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id) - - return _placement_function - - def tpu_ordinal_function(self, host_id): - """Returns the TPU ordinal fn.""" - - def _tpu_ordinal_function(shard_index_in_host): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - shard_index_in_host: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - if self.model_parallelism_enabled: - # We put both enqueue/dequeue ops at tpu.core(0) in each replica. - replica = self.device_assignment.lookup_replicas(host_id, - 0)[shard_index_in_host] - return self.device_assignment.tpu_ordinal(replica=replica) - else: - return shard_index_in_host % self.num_of_cores_per_host - - return _tpu_ordinal_function - - def _validate_tpu_configuration(self): - """Validates the configuration based on the TPU system metadata.""" - mode = self._assert_mode() - if self._lazy_validation_dict.get(mode): - return - - # All following information is obtained from TPU system metadata. - num_cores = self.num_cores - num_replicas = self.num_replicas - num_hosts = self.num_hosts - - if not num_cores: - tpu_system_metadata = self._get_tpu_system_metadata() - raise RuntimeError( - 'Cannot find any TPU cores in the system. Please double check ' - 'Tensorflow master address and TPU worker(s). Available devices ' - 'are {}.'.format(tpu_system_metadata.devices)) - - if self._config.tpu_config.num_shards: - user_provided_num_replicas = self._config.tpu_config.num_shards - if user_provided_num_replicas != num_replicas: - message = ( - 'TPUConfig.num_shards is not set correctly. According to TPU ' - 'system metadata for Tensorflow master ({}): num_replicas should ' - 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' - 'be the total num of TPU cores in the system. For ' - 'model-parallelism, the total number of TPU cores should be ' - 'num_cores_per_replica * num_replicas. Please set it ' - 'accordingly or leave it as `None`'.format( - self._get_master_address(), num_replicas, - user_provided_num_replicas)) - - raise ValueError(message) - - if self._config.tpu_config.num_cores_per_replica: - num_cores_per_replica = self._config.tpu_config.num_cores_per_replica - num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host - if num_cores_per_replica > num_cores_per_host: - raise ValueError( - 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica, is larger than the ' - 'num_cores_per_host. num_cores_per_replica: {}, ' - 'num_cores_per_host: {}'.format(num_cores_per_replica, - num_cores_per_host)) - - if mode == model_fn_lib.ModeKeys.TRAIN: - if (self._train_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'train batch size {} must be divisible by number of replicas {}' - .format(self._train_batch_size, num_replicas)) - - elif mode == model_fn_lib.ModeKeys.EVAL: - if self._eval_batch_size is None: - raise ValueError( - 'eval_batch_size in TPUEstimator constructor cannot be `None`' - 'if .evaluate is running on TPU.') - if (self._eval_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'eval batch size {} must be divisible by number of replicas {}' - .format(self._eval_batch_size, num_replicas)) - if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): - raise ValueError( - 'TPUEstimator.evaluate should be running on single TPU' - ' instead of a Pod.') - else: - assert mode == model_fn_lib.ModeKeys.PREDICT - if self._predict_batch_size is None: - raise ValueError( - 'predict_batch_size in TPUEstimator constructor should not be ' - '`None` if .predict is running on TPU.') - if (self._predict_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'predict batch size {} must be divisible by number of replicas {}' - .format(self._predict_batch_size, num_replicas)) - if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): - raise ValueError( - 'TPUEstimator.predict should be running on single TPU worker. ' - 'got {}.'.format(num_hosts)) - - # Record the state "validated" into lazy dictionary. - self._lazy_validation_dict[mode] = True - - def device_for_replica(self, replica_id): - """Returns the tuple of (CPU device and device ordinal) for replica. - - This should be used for full replicate for non-model-parallelism. - - Args: - replica_id: Int, the replica index. - - Returns: - A tuple of device spec for CPU device and int device ordinal. - """ - master = self.master_job - - if self.model_parallelism_enabled: - return (self.device_assignment.host_device( - replica=replica_id, job=master), - self.device_assignment.tpu_ordinal(replica=replica_id)) - - job_device = '' if master is None else ('/job:%s' % master) - - num_of_replicas_per_host = self.num_of_replicas_per_host - host_id = replica_id / num_of_replicas_per_host - ordinal_id = replica_id % num_of_replicas_per_host - - host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) - return (host_device, ordinal_id) - - -class _OneCoreTPUContext(_InternalTPUContext): - """Special _InternalTPUContext for one core usage.""" - - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu): - - super(_OneCoreTPUContext, self).__init__( - config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) - - def _get_tpu_system_metadata(self): - """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() - tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) - if tpu_system_metadata is not None: - return tpu_system_metadata - - tpu_system_metadata = ( - tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access - num_cores=1, - num_hosts=1, - num_of_cores_per_host=1, - topology=None, - devices=[])) - - self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata - return tpu_system_metadata - - -def _get_tpu_context(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu): - """Returns an instance of `_InternalTPUContext`.""" - - if (config.tpu_config.num_shards == 1 and - config.tpu_config.num_cores_per_replica is None): - logging.warning( - 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' - 'Please fix as soon as possible (leaving num_shards as None.)') - return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) - - return _InternalTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_context import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index 1a909a3ac6fae79070a7762b94bfa138f93a5fb5..cb38a8f1a6bee3c2adfbefc203c1d143303c3368 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -1,10 +1,10 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,1087 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TPU embedding APIs.""" +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy -import math -import re -import six - -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.ops import gen_tpu_ops -from tensorflow.contrib.tpu.proto import tpu_embedding_configuration_pb2 as elc -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables - -TRAINING = elc.TPUEmbeddingConfiguration.TRAINING -INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE - - -class TableConfig( - collections.namedtuple( - 'TableConfig', - ['vocabulary_size', 'dimension', 'initializer', 'combiner'])): - """Embedding table configuration.""" - - @experimental - def __new__(cls, - vocabulary_size, - dimension, - initializer=None, - combiner='mean'): - """Embedding table configuration. - - Args: - vocabulary_size: Number of vocabulary (/rows) in the table. - dimension: The embedding dimension. - initializer: A variable initializer function to be used in embedding - variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean `0.0` and standard deviation - `1/sqrt(dimension)`. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with - 'mean' the default. 'sqrtn' often achieves good accuracy, in particular - with bag-of-words columns. For more information, see - `tf.nn.embedding_lookup_sparse`. - - Returns: - `TableConfig`. - - Raises: - ValueError: if `vocabulary_size` is not positive integer. - ValueError: if `dimension` is not positive integer. - ValueError: if `initializer` is specified and is not callable. - ValueError: if `combiner` is not supported. - """ - if not isinstance(vocabulary_size, int) or vocabulary_size < 1: - raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) - - if not isinstance(dimension, int) or dimension < 1: - raise ValueError('Invalid dimension {}.'.format(dimension)) - - if (initializer is not None) and (not callable(initializer)): - raise ValueError('initializer must be callable if specified.') - if initializer is None: - initializer = init_ops.truncated_normal_initializer( - mean=0.0, stddev=1 / math.sqrt(dimension)) - - if combiner not in ('mean', 'sum', 'sqrtn'): - raise ValueError('Invalid combiner {}'.format(combiner)) - - return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension, - initializer, combiner) - - -# TODO(shizhiw): Factor `use_gradient_accumulation` and -# `pipeline_execution_with_tensor_core` out of `_OptimizationParameters`. -class _OptimizationParameters(object): - """Parameters common to all optimizations.""" - - def __init__(self, learning_rate, use_gradient_accumulation, - pipeline_execution_with_tensor_core): - self.learning_rate = learning_rate - self.use_gradient_accumulation = use_gradient_accumulation - self.pipeline_execution_with_tensor_core = ( - pipeline_execution_with_tensor_core) - - -class AdagradParameters(_OptimizationParameters): - """Optimization parameters for Adagrad.""" - - def __init__(self, learning_rate, initial_accumulator, - use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - """Optimization parameters for Adagrad. - - Args: - learning_rate: used for updating embedding table. - initial_accumulator: initial accumulator for Adagrad. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - super(AdagradParameters, self).__init__(learning_rate, - use_gradient_accumulation, - pipeline_execution_with_tensor_core) - self.initial_accumulator = initial_accumulator - - -class AdamParameters(_OptimizationParameters): - """Optimization parameters for Adam.""" - - def __init__(self, learning_rate, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - lazy_adam=True, - sum_inside_sqrt=True, - use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - """Optimization parameters for Adam. - - Args: - learning_rate: a floating point value. The learning rate. - beta1: A float value. - The exponential decay rate for the 1st moment estimates. - beta2: A float value. - The exponential decay rate for the 2nd moment estimates. - epsilon: A small constant for numerical stability. - lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. - Please see `optimization_parameters.proto` for details. - sum_inside_sqrt: This improves training speed. Please see - `optimization_parameters.proto` for details. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - super(AdamParameters, self).__init__(learning_rate, - use_gradient_accumulation, - pipeline_execution_with_tensor_core) - self.beta1 = beta1 - self.beta2 = beta2 - self.epsilon = epsilon - self.lazy_adam = lazy_adam - self.sum_inside_sqrt = sum_inside_sqrt - - -class StochasticGradientDescentParameters(_OptimizationParameters): - """Optimization parameters for stochastic gradient descent. - - Args: - learning_rate: a floating point value. The learning rate. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - - def __init__(self, learning_rate, use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - super(StochasticGradientDescentParameters, self).__init__( - learning_rate, use_gradient_accumulation, - pipeline_execution_with_tensor_core) - - -class TPUEmbedding(object): - """API for using TPU for embedding. - - Example: - ``` - table_config_user = tpu_embedding.TableConfig( - vocabulary_size=4, dimension=2, - initializer=initializer, combiner='mean') - table_to_config_dict = {'video': table_config_video, - 'user': table_config_user} - feature_to_table_dict = {'watched': 'video', - 'favorited': 'video', - 'friends': 'user'} - batch_size = 4 - num_hosts = 1 - optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) - mode = tpu_embedding.TRAINING - embedding = tpu_embedding.TPUEmbedding( - table_to_config_dict, feature_to_table_dict, - batch_size, num_hosts, mode, optimization_parameters) - - batch_size_per_core = embedding.batch_size_per_core - sparse_features_list = [] - for host in hosts: - with ops.device(host): - for _ in range(embedding.num_cores_per_host): - sparse_features = {} - sparse_features['watched'] = sparse_tensor.SparseTensor(...) - sparse_features['favorited'] = sparse_tensor.SparseTensor(...) - sparse_features['friends'] = sparse_tensor.SparseTensor(...) - sparse_features_list.append(sparse_features) - - enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) - - def computation(): - activations = embedding.get_activations() - loss = compute_loss(activations) - - base_optimizer = gradient_descent.GradientDescentOptimizer( - learning_rate=1) - cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( - base_optimizer) - - train_op = cross_shard_optimizer.minimize(loss) - # `train_op` and `send_gradients_op` must happen in order. - with ops.control_dependencies([train_op]): - send_gradients_op = embedding.generate_send_gradients_op() - with ops.control_dependencies([send_gradients_op]): - loss = array_ops.identity(loss) - - loss = tpu.shard(computation, - num_shards=embedding.num_cores) - - with self.test_session() as sess: - sess.run(tpu.initialize_system(embedding_config= - embedding.config_proto)) - sess.run(variables.global_variables_initializer()) - sess.run(embedding.init_ops) - sess.run(enqueue_ops) - loss_val = sess.run(loss) - ``` - """ - - # TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table - # name, consider `feature_to_config_dict` which maps to `FeatureConfig`. - # `FeatureConfig` could have fields other than table name. For example, it - # could have a field to indicate that the feature should not be used to - # update embedding table (cr/204852758, cr/204940540). Also, this can support - # different combiners for different features within the same table. - # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it - # to `FeatureConfig`? - - # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and - # `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively? - - # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate - # for-loops around construction of inputs. - - # `optimization_parameter` applies to all tables. If the need arises, - # we can add `optimization_parameters` to `TableConfig` to override this - # global setting. - @experimental - def __init__(self, - table_to_config_dict, - feature_to_table_dict, - batch_size, - mode, - master, - optimization_parameters=None): - """API for using TPU for embedding lookups. - - Args: - table_to_config_dict: A dictionary mapping from string of table name to - `TableConfig`. Table refers to an embedding table, e.g. `params` - argument to `tf.nn.embedding_lookup_sparse()`. - feature_to_table_dict: A dictionary mapping from string of feature name - to string of table name. Feature refers to ids to lookup in embedding - table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. - batch_size: An `int` representing the global batch size. - mode: `TRAINING` or `INFERENCE`. - master: A `string` representing the TensorFlow master to use. - optimization_parameters: `AdagradParameters`, `AdamParameters`, - `Stochasticgradientdescentparameters`. Must be set in training and must - be `None` in inference. - - Raises: - ValueError: if any input is invalid. - """ - _validate_table_to_config_dict(table_to_config_dict) - # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. - self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) - self._combiners = _create_combiners(self._table_to_config_dict) - - _validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict) - self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict) - self._table_to_features_dict = _create_table_to_features_dict( - self._feature_to_table_dict) - - self._batch_size = batch_size - - self._master = master - self._tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata(self._master)) # pylint: disable=protected-access - if self._tpu_system_metadata.num_cores == 0: - raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' - 'TPUs.'.format(self._master)) - self._num_hosts = self._tpu_system_metadata.num_hosts - self._hosts = [device.name for device in self._tpu_system_metadata.devices - if 'device:CPU:' in device.name] - self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host - self._num_cores = self._tpu_system_metadata.num_cores - - _validate_batch_size(self._batch_size, self._num_cores) - self._batch_size_per_core = self._batch_size // self._num_cores - - self._init_ops = [] - - # TODO(shizhiw): remove `mode`? - if mode == TRAINING: - _validate_optimization_parameters(optimization_parameters) - self._optimization_parameters = optimization_parameters - elif mode == INFERENCE: - if optimization_parameters is not None: - raise ValueError('`optimization_parameters` should be `None` ' - 'for inference mode.') - self._optimization_parameters = ( - StochasticGradientDescentParameters(1.)) - else: - raise ValueError('`mode` only supports {} and {}; got {}.' - .format(TRAINING, INFERENCE, mode)) - self._mode = mode - - # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` - # and create special handler for inference that inherits from - # StochasticGradientDescentHandler with more user-friendly error message - # on get_slot(). - self._optimizer_handler = _get_optimization_handler( - self._optimization_parameters) - - dummy_table_variables_init_op = self._create_dummy_table_variables() - self._init_ops.append(dummy_table_variables_init_op) - - self._config_proto = self._create_config_proto() - - self._create_variables_and_ops() - self._init_ops.extend(self._load_parameters_ops) - - @property - def hosts(self): - """A list of device names for CPU hosts. - - Returns: - A list of device names for CPU hosts. - """ - return copy.copy(self._hosts) - - # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and - # to be consistent with `tpu_embedding_configuration.proto`. - @property - def num_cores_per_host(self): - """Number of TPU cores on a CPU host. - - Returns: - Number of TPU cores on a CPU host. - """ - return self._num_cores_per_host - - @property - def num_cores(self): - """Total number of TPU cores on all hosts. - - Returns: - Total number of TPU cores on all hosts. - """ - return self._num_cores - - @property - def batch_size_per_core(self): - """Batch size for each TPU core. - - The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` - must have batch dimension equal to this. - - Returns: - Batch size for each TPU core. - """ - return self._batch_size_per_core - - @property - def config_proto(self): - """Create embedding config proto for `tpu.initialize_system()`. - - Returns: - an `TPUEmbeddingConfiguration` proto describing the desired - configuration of the hardware embedding lookup tables, which - is passed to `tpu.initialize_system()`. - """ - return self._config_proto - - @property - def init_ops(self): - """Initialization ops for TPU embedding. - - It must be called after all global variables have been initialized, - i.e. after `global_variables_initializer()`, as it loads embedding - tables into TPU. - - Returns: - A list of ops. - """ - return self._init_ops - - # TODO(shizhiw): get table variables the same way as getting slot variables. - @property - def table_to_table_variables_dict(self): - return copy.copy(self._table_to_table_variables_dict) - - def get_slot_names(self): - """Return a list of the names of slots created by `TPUEmbedding`.""" - return self._optimizer_handler.get_slot_names() - - def get_slot(self, table, name): - """Return a slot named `name` create for `table` by `TPUEmbedding`.""" - return self._optimizer_handler.get_slot(table, name) - - # TODO(shizhiw): expose load to user too? - @property - def retrieve_parameters_ops(self): - return self._retrieve_parameters_ops - - def _create_config_proto(self): - """Create `TPUEmbeddingConfiguration`.""" - config_proto = elc.TPUEmbeddingConfiguration() - for table in self._table_to_config_dict: - table_descriptor = config_proto.table_descriptor.add() - table_descriptor.name = table - - table_config = self._table_to_config_dict[table] - table_descriptor.vocabulary_size = table_config.vocabulary_size - table_descriptor.dimension = table_config.dimension - - features_for_table = self._table_to_features_dict[table] - table_descriptor.num_features = len(features_for_table) - - table_descriptor.optimization_parameters.learning_rate.constant = ( - self._optimization_parameters.learning_rate) - table_descriptor.optimization_parameters.use_gradient_accumulation = ( - self._optimization_parameters.use_gradient_accumulation) - self._optimizer_handler.set_optimization_parameters(table_descriptor) - - config_proto.mode = self._mode - config_proto.batch_size_per_tensor_core = self._batch_size_per_core - config_proto.num_hosts = self._num_hosts - config_proto.num_tensor_cores = self._num_cores - config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT - config_proto.pipeline_execution_with_tensor_core = ( - self._optimization_parameters.pipeline_execution_with_tensor_core) - - return config_proto - - def _create_variables_and_ops(self): - """Create embedding variables and return ops to load them into TPU.""" - self._load_parameters_ops = [] - self._retrieve_parameters_ops = [] - self._table_to_table_variables_dict = {} - for table in self._table_to_config_dict: - device_fn = _create_device_fn(self._hosts) - with ops.device(device_fn): - # TODO(shizhiw): allow user to specify variable name so that - # they could make the name consistent with CPU etc. - variable_name = table - table_variables = _create_partitioned_variables( - name=variable_name, - num_hosts=self._num_hosts, - vocabulary_size=self._table_to_config_dict[table].vocabulary_size, - embedding_dimension=self._table_to_config_dict[table].dimension, - initializer=self._table_to_config_dict[table].initializer, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - self._table_to_table_variables_dict[table] = table_variables - - self._optimizer_handler.create_variables_and_ops( - table, variable_name, self._num_hosts, - self._table_to_config_dict[table], table_variables, - self._load_parameters_ops, self._retrieve_parameters_ops) - - def _create_dummy_table_variables(self): - """Create dummy embedding table variables. - - The sole purpose of these dummy variables are to trigger gradient - calcuation wrt them so that the gradients wrt activation can be captured - and later sent to TPU embedding. - - Returns: - Initializer for these variables. - - Raises: - RuntimeError: if collection to store gradients already exists and is not - empty. - """ - self._dummy_table_variables = [] - # TODO(shizhiw): remove table id. - for table_id, table in enumerate(self._table_to_features_dict): - self._dummy_table_variables.append( - variable_scope.get_variable( - 'tpu_embedding_dummy_table_variable_%s' % table, - dtype=dtypes.float32, - shape=[1], - use_resource=True, - trainable=True, - # TODO(shizhiw): Remove these dummy variables as - # tensorflow optimizer creates slot variable for them which - # is undesirable. - # e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}. - # Explicitly specifying collections prevents this variable from - # being added to the GLOBAL_VARIABLES collection, so that Saver() - # ignores it. - collections=['tpu_embedding_dummy_table_variables'])) - - g = ops.get_default_graph() - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - if table_gradients: - raise RuntimeError( - 'tpu_embedding_gradients_table_%d is not empty.' % table_id) - table_gradients.extend([None] * len(self._table_to_features_dict[table])) - - return variables.variables_initializer( - self._dummy_table_variables, - name='tpu_embedding_dummy_table_variables_init') - - def generate_enqueue_ops(self, sparse_features_list): - """Generate enqueue ops. - - Args: - sparse_features_list: a list of dictionary mapping from string - of feature names to sparse tensor. Each dictionary is for one - TPU core. Dictionaries for the same core should be contiguous - on the list. - - Returns: - Ops to enqueue to TPU for embedding. - """ - self._validate_generate_enqueue_ops_sparse_features_list( - sparse_features_list) - return [ - self._generate_enqueue_op( - sparse_features, device_ordinal=i % self._num_cores_per_host) - for i, sparse_features in enumerate(sparse_features_list) - ] - - def _validate_generate_enqueue_ops_sparse_features_list( - self, sparse_features_list): - """Validate `sparse_features_list`.""" - if len(sparse_features_list) != self._num_cores: - raise ValueError('Length of `sparse_features_list` should match the ' - 'number of cores; ' - '`len(sparse_features_list)` is {}, ' - 'number of cores is {}.'.format( - len(sparse_features_list), self._num_cores)) - - feature_set = set(self._feature_to_table_dict.keys()) - contiguous_device = None - for i, sparse_features in enumerate(sparse_features_list): - used_feature_set = set(sparse_features.keys()) - - # Check features are valid. - missing_feature_set = feature_set - used_feature_set - if missing_feature_set: - raise ValueError('`sparse_features_list[{}]` misses a feature that is ' - 'in `feature_to_config_dict`: {}.'.format( - i, missing_feature_set)) - - extra_feature_set = used_feature_set - feature_set - if extra_feature_set: - raise ValueError('`sparse_features_list[{}]` has a feature that is not ' - 'in `feature_to_config_dict`: {}.'.format( - i, extra_feature_set)) - - device = None - device_feature = None - for feature, tensor in six.iteritems(sparse_features): - if not isinstance(tensor, sparse_tensor.SparseTensor): - raise ValueError('`sparse_features_list[{}]` has a feature that is ' - 'not mapped to `SparseTensor`. ' - '`feature`: {}, type: {}'.format( - i, feature, type(tensor))) - - # Check all features are on the same device. - if device is None: - device = tensor.op.device - device_feature = feature - else: - if device != tensor.op.device: - raise ValueError('Devices are different between features in ' - '`sparse_features_list[{}]`; ' - 'devices: {}, {}; features: {}, {}.'.format( - i, device, tensor.op.device, feature, - device_feature)) - - if i % self._num_cores_per_host: - if device != contiguous_device: - raise ValueError('We expect the `sparse_features` which are on the ' - 'same host to be contiguous in ' - '`sparse_features_list`, ' - '`sparse_features_list[{}]` is on device {}, ' - 'but is expected to be on device {}.'.format( - i, device, contiguous_device)) - else: - contiguous_device = device - - def _generate_enqueue_op(self, sparse_features, device_ordinal): - with ops.colocate_with(list(sparse_features.values())[0]): - sample_idcs, embedding_idcs, aggregation_weights = ( - self._format_for_tpu_embedding_sparse_batch(sparse_features)) - return tpu_ops.enqueue_tpu_embedding_sparse_batch( - sample_idcs, - embedding_idcs, - aggregation_weights, - combiners=self._combiners, - device_ordinal=device_ordinal) - - def _format_for_tpu_embedding_sparse_batch(self, sparse_features): - """Format sparse features for `enqueue_tpu_embedding_sparse_batch()`. - - Args: - sparse_features: a `Dict` of `SparseTensor`s for embedding. - - Returns: - Arguments for `enqueue_tpu_embedding_sparse_batch()`. - """ - - sample_idcs, embedding_idcs, aggregation_weights = list(), list(), list() - for table in self._table_to_features_dict: - sample_t, indices_t, weights_t = list(), list(), list() - - features = self._table_to_features_dict[table] - for i, feature in enumerate(features): - tensor = sparse_features[feature] - sample_indices = tensor.indices[:, 0] - embedding_indices = tensor.values - weights = array_ops.ones_like(embedding_indices) - sample_t.append(i * self._batch_size_per_core + sample_indices) - indices_t.append(embedding_indices) - weights_t.append(weights) - - sample_idcs.append( - math_ops.cast(array_ops.concat(sample_t, axis=0), dtype=dtypes.int32)) - embedding_idcs.append( - math_ops.cast( - array_ops.concat(indices_t, axis=0), dtype=dtypes.int32)) - aggregation_weights.append( - math_ops.cast( - array_ops.concat(weights_t, axis=0), dtype=dtypes.float32)) - - return sample_idcs, embedding_idcs, aggregation_weights - - def get_activations(self): - """Get activations for features. - - This should be called within `computation` that is passed to - `tpu.replicate` and friends. - - Returns: - A dictionary mapping from `String` of feature name to `Tensor` - of activation. - """ - recv_activations = tpu_ops.recv_tpu_embedding_activations( - num_outputs=len(self._table_to_config_dict), - config=self._config_proto.SerializeToString()) - - activations = collections.OrderedDict() - for table_id, table in enumerate(self._table_to_features_dict): - features = self._table_to_features_dict[table] - for lookup_id, feature in enumerate(features): - start_row = lookup_id * self._batch_size_per_core - end_row = start_row + self._batch_size_per_core - activations[feature] = gen_tpu_ops.tpu_embedding_activations( - self._dummy_table_variables[table_id], - recv_activations[table_id][start_row:end_row, :], - table_id=table_id, - lookup_id=lookup_id) - return activations - - # TODO(shizhiw): Make `gradient_multiplier` per feature. Setting it to 0 would - # have the effect of `tf.stop_gradients()`. - # TODO(shizhiw): Consider alternative ways to capture gradients wrt embedding - # layer outputs to remove `_dummy_table_variables`, - # `_embedding_activation_grad` and `tpu_embedding_gradients_table_%d'. - def generate_send_gradients_op(self, gradient_multipliers=None): - """Retrieve gradients from collections and send them to TPU embedding. - - Args: - gradient_multipliers: None, or dict mapping table names to gradient - multiplier Tensors. - - Returns: - SendTPUEmbeddingGradients Op. - - Raises: - ValueError: If required gradients have not been defined. - RuntimeError: If `mode` is not `TRAINING`. - """ - if self._mode != TRAINING: - raise RuntimeError('Only in training mode gradients need to ' - 'be sent to TPU embedding; got mode {}.' - .format(self._mode)) - - g = ops.get_default_graph() - gradients = list() - for table_id, table in enumerate(self._table_to_config_dict): - table_gradients = g.get_collection( - 'tpu_embedding_gradients_table_%d' % table_id) - if any(gradient is None for gradient in table_gradients): - raise ValueError( - 'Table {}/{} has undefined gradients: this is probably because the ' - 'model asked TPUEmbedding to compute activations that were not ' - 'used.'.format(table_id, table)) - concat_table_grads = array_ops.concat(table_gradients, axis=0) - if gradient_multipliers is not None: - concat_table_grads *= gradient_multipliers[table.name] - gradients.append(concat_table_grads) - - return tpu_ops.send_tpu_embedding_gradients( - inputs=gradients, config=self.config_proto.SerializeToString()) - - -def _validate_table_to_config_dict(table_to_config_dict): - """Validate `table_to_config_dict`.""" - for k, v in six.iteritems(table_to_config_dict): - if not isinstance(v, TableConfig): - raise ValueError('Value of `table_to_config_dict` must be of type ' - '`TableConfig`, got {} for {}.'.format(type(v), k)) - - -def _validate_feature_to_table_dict(table_to_config_dict, - feature_to_table_dict): - """Validate `feature_to_table_dict`.""" - used_table_set = set(feature_to_table_dict.values()) - table_set = set(table_to_config_dict.keys()) - - unused_table_set = table_set - used_table_set - if unused_table_set: - raise ValueError('`table_to_config_dict` specifies table that is not ' - 'used in `feature_to_table_dict`: {}.' - .format(unused_table_set)) - - extra_table_set = used_table_set - table_set - if extra_table_set: - raise ValueError('`feature_to_table_dict` refers to a table that is not ' - 'specified in `table_to_config_dict`: {}.' - .format(extra_table_set)) - - -def _validate_batch_size(batch_size, num_cores): - if batch_size % num_cores: - raise ValueError('`batch_size` is not a multiple of number of ' - 'cores. `batch_size`={}, `_num_cores`={}.'.format( - batch_size, num_cores)) - - -def _validate_optimization_parameters(optimization_parameters): - if not isinstance(optimization_parameters, _OptimizationParameters): - raise ValueError('`optimization_parameters` must inherit from ' - '`_OptimizationPramaters`. ' - '`type(optimization_parameters)`={}'.format( - type(optimization_parameters))) - - -class _OptimizerHandler(object): - """Interface class for handling optimizer specific logic.""" - - def __init__(self, optimization_parameters): - self._optimization_parameters = optimization_parameters - - def set_optimization_parameters(self, table_descriptor): - raise NotImplementedError() - - def create_variables_and_ops(self, table, variable_name): - raise NotImplementedError() - - def get_slot_names(self): - raise NotImplementedError() - - def get_slot(self, table, name): - raise NotImplementedError() - - -class _AdagradHandler(_OptimizerHandler): - """Handles Adagrad specific logic.""" - - def __init__(self, optimization_parameters): - super(_AdagradHandler, self).__init__(optimization_parameters) - self._table_to_accumulator_variables_dict = {} - - def set_optimization_parameters(self, table_descriptor): - table_descriptor.optimization_parameters.adagrad.SetInParent() - - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - optimizer_name = 'Adagrad' - accumulator_initializer = init_ops.constant_initializer( - self._optimization_parameters.initial_accumulator) - accumulator_variables = _create_partitioned_variables( - name='%s/%s' % (variable_name, optimizer_name), - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=accumulator_initializer) - - self._table_to_accumulator_variables_dict[table] = accumulator_variables - for host_id, table_variable, accumulator_variable in (zip( - range(num_hosts), table_variables, accumulator_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops.load_tpu_embedding_adagrad_parameters( - parameters=table_variable, - accumulators=accumulator_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table, retrieved_accumulator = ( - tpu_ops.retrieve_tpu_embedding_adagrad_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table), - state_ops.assign(accumulator_variable, retrieved_accumulator)) - - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return ['accumulator'] - - def get_slot(self, table, name): - if name not in self.get_slot_names(): - raise ValueError('Adagrad has {} as slot names; got {}.' - .format(self.get_slot_names(), name)) - return self._table_to_accumulator_variables_dict[table] - - -class _AdamHandler(_OptimizerHandler): - """Handles Adam specific logic.""" - - def __init__(self, optimization_parameters): - super(_AdamHandler, self).__init__(optimization_parameters) - self._table_to_m_variables_dict = {} - self._table_to_v_variables_dict = {} - - def set_optimization_parameters(self, table_descriptor): - table_descriptor.optimization_parameters.adam.beta1 = ( - self._optimization_parameters.beta1) - table_descriptor.optimization_parameters.adam.beta2 = ( - self._optimization_parameters.beta2) - table_descriptor.optimization_parameters.adam.epsilon = ( - self._optimization_parameters.epsilon) - table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( - not self._optimization_parameters.lazy_adam) - table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( - self._optimization_parameters.sum_inside_sqrt) - - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - optimizer_name = 'Adam' - m_initializer = init_ops.zeros_initializer() - m_variables = _create_partitioned_variables( - name='%s/%s/m' % (variable_name, optimizer_name), - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=m_initializer) - v_initializer = init_ops.zeros_initializer() - v_variables = _create_partitioned_variables( - name='%s/%s/v' % (variable_name, optimizer_name), - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=v_initializer) - - self._table_to_m_variables_dict[table] = m_variables - self._table_to_v_variables_dict[table] = v_variables - - for host_id, table_variable, m_variable, v_variable in (zip( - range(num_hosts), table_variables, - m_variables, v_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops.load_tpu_embedding_adam_parameters( - parameters=table_variable, - momenta=m_variable, - velocities=v_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table, retrieved_m, retrieved_v = ( - tpu_ops.retrieve_tpu_embedding_adam_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table), - state_ops.assign(m_variable, retrieved_m), - state_ops.assign(v_variable, retrieved_v)) - - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return ['m', 'v'] - - def get_slot(self, table, name): - if name == 'm': - return self._table_to_m_variables_dict[table] - elif name == 'v': - return self._table_to_v_variables_dict[table] - else: - raise ValueError('Adam has {} as slot names; got {}.' - .format(self.get_slot_names(), name)) - - -class _StochasticGradientDescentHandler(_OptimizerHandler): - """Handles stochastic gradient descent specific logic.""" - - def set_optimization_parameters(self, table_descriptor): - (table_descriptor.optimization_parameters.stochastic_gradient_descent - .SetInParent()) - - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - del table_config - - for host_id, table_variable in (zip( - range(num_hosts), table_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops - .load_tpu_embedding_stochastic_gradient_descent_parameters( - parameters=table_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table = ( - tpu_ops - .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table)) - - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return [] - - def get_slot(self, table, name): - raise ValueError('Stochastic gradient descent does not have slot variable.') - - -def _get_optimization_handler(optimization_parameters): - if isinstance(optimization_parameters, AdagradParameters): - return _AdagradHandler(optimization_parameters) - elif isinstance(optimization_parameters, AdamParameters): - return _AdamHandler(optimization_parameters) - elif isinstance(optimization_parameters, StochasticGradientDescentParameters): - return _StochasticGradientDescentHandler(optimization_parameters) - else: - return NotImplementedError() - - -def _create_ordered_dict(d): - """Create an OrderedDict from Dict.""" - return collections.OrderedDict((k, d[k]) for k in sorted(d)) - - -def _create_combiners(table_to_config_dict): - return [table_to_config_dict[t].combiner for t in table_to_config_dict] - - -def _create_table_to_features_dict(feature_to_table_dict): - """Create mapping from table to a list of its features.""" - table_to_features_dict_tmp = {} - for feature, table in six.iteritems(feature_to_table_dict): - if table in table_to_features_dict_tmp: - table_to_features_dict_tmp[table].append(feature) - else: - table_to_features_dict_tmp[table] = [feature] - - table_to_features_dict = collections.OrderedDict() - for table in sorted(table_to_features_dict_tmp): - table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) - return table_to_features_dict - - -def _create_device_fn(hosts): - """Create device_fn() to use with _create_partitioned_variables().""" - - def device_fn(op): - """Returns the `device` for `op`.""" - part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) - - if part_match: - idx = int(part_match.group(1)) - else: - raise RuntimeError('Internal Error: ' - 'Expected %s to contain /part_*.' % op.name) - - device = hosts[idx] - return device - - return device_fn - - -def _create_partitioned_variables(name, - num_hosts, - vocabulary_size, - embedding_dimension, - initializer, - collections=None): # pylint: disable=redefined-outer-name - """Creates ParitionedVariables based on `num_hosts` for `table`.""" - # TODO(shizhiw): automatically place embedding lookup elsewhere? - if vocabulary_size < num_hosts: - raise ValueError('`vocabulary_size`({}) is smaller than `num_hosts`({}). ' - 'As TPU embedding is not optimized for small tables, ' - 'please consider other ways for this embedding lookup.') - - return list(variable_scope.get_variable( - name, - shape=(vocabulary_size, embedding_dimension), - partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), - dtype=dtypes.float32, - initializer=initializer, - collections=collections, - trainable=False)) - - -@ops.RegisterGradient('TPUEmbeddingActivations') -def _embedding_activations_grad(activations_op, grad_wrt_activations): - """Saves the gradient of embedding activations ops in a graph collection.""" - g = ops.get_default_graph() - table_id = activations_op.get_attr('table_id') - lookup_id = activations_op.get_attr('lookup_id') - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - - if not table_gradients: - raise RuntimeError( - 'Gradients for TPUEmbedding have been generated in non-training mode. ' - 'This is not expected. Consider putting your Optimizer.minimize code ' - 'behind the training mode condition check. For Estimator, you can ' - 'do \n\n' - ' if mode == tf.estimator.ModeKeys.TRAIN:\n' - ' train_op = opt.minimize(loss)\n' - '\n') - - table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) - return [ - # RegisterGradient requires that value be returned for all inputs. Since - # the first argument (tpu_gradient_variable_{table_name}) has shape [1], - # we will return zeros(shape=[1]). The actual gradient w.r.t. the - # embedding activations (grad_wrt_activations) has the same shape as the - # activations returned by embedding_activations. - array_ops.zeros(arg.shape, dtype=dtypes.float32) - for arg in activations_op.inputs - ] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_embedding import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..308adc77e9ad2d912d0461512655b55faa53da60 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_embedding_gradient import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 27c8ccf6bd7e77743dabbd329cea01bc45697c52..893118412e1363ce50416e6ef36692bc23d04179 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1,3655 +1,33 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPUEstimator class.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy -import os -import signal -import sys -import threading -import time - -import numpy as np -import six -from six.moves import queue as Queue # pylint: disable=redefined-builtin -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.ops import gen_tpu_ordinal_selector_op -from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import error_handling -from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional -from tensorflow.contrib.tpu.python.tpu import session_support -from tensorflow.contrib.tpu.python.tpu import tensor_tracer -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_context -from tensorflow.contrib.tpu.python.tpu import tpu_feed -from tensorflow.contrib.tpu.python.tpu import training_loop -from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.contrib.training.python.training import hparam -from tensorflow.core.framework import variable_pb2 -from tensorflow.core.framework.summary_pb2 import Summary -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as tf_session -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest as data_nest -from tensorflow.python.estimator import estimator as estimator_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export_output as export_output_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as contrib_summary -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.summary import summary -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import evaluation -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training -from tensorflow.python.training import training_util -from tensorflow.python.util import function_utils -from tensorflow.python.util import nest -from tensorflow.python.util import tf_inspect - -_INITIAL_LOSS = 1e7 -_ZERO_LOSS = 0. -_TPU_ESTIMATOR = 'tpu_estimator' -_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' -_BATCH_SIZE_KEY = 'batch_size' -_CTX_KEY = 'context' -_USE_TPU_KEY = 'use_tpu' -_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' -_ONE_GIGABYTE = 1024 * 1024 * 1024 -_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' -_TPU_TRAIN_OP = '_tpu_train_op' -_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' -_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor' - -# Ideally _USE_TPU_KEY should be reserved as well. However there are already -# models that make use of this key, thus it can not be reserved now to prevent -# breakage. In the long run, we would like to mitigate this by migrating models -# off of using _USE_TPU_KEY. -_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] - -# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is -# only used for per-core based deployments. For per-host based pipelines, if a -# user returns a Dataset instance it will be automatically wrapped in a -# tf.while_loop (This can be disabled by returning features and labels -# explicitly). -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False - -ops.register_proto_function( - '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR), - proto_type=variable_pb2.VariableDef, - to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access - from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access - - -def _is_iterable(obj): - """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" - try: - iter(obj) - return True - except TypeError: - return False - - -class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext): - - def AddOp(self, op): - if op.type in [ - 'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary', - 'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2' - ]: - raise ValueError('Use tf.contrib.summary inside of host_calls.') - - -def _create_global_step(graph): - graph = graph or ops.get_default_graph() - if training.get_global_step(graph) is not None: - raise ValueError('"global_step" already exists.') - # Create in proper graph and base name_scope. - with graph.as_default() as g, g.name_scope(None): - return variable_scope.get_variable( - ops.GraphKeys.GLOBAL_STEP, - shape=[], - dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - trainable=False, - use_resource=True, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) - - -def _create_or_get_iterations_per_loop(): - """Creates or gets the iterations_per_loop variable. - - In TPUEstimator, the user provided computation, the model_fn, is wrapped - inside a tf.while_loop for peak performance. The iterations of the loop are - specified by this variable, which adjusts its value on the CPU after each TPU - program execution and before the next TPU execution. - - The purpose of using a variable, rather then a constant, is to allow - TPUEstimator adapt the TPU training iterations according to the final steps - specified by users. For example, if the user sets the iterations_per_loop as 4 - in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop - variable will have the following value before each TPU training. - - - 1-th TPU execution: iterations_per_loop = 4 - - 2-th TPU execution: iterations_per_loop = 4 - - 3-th TPU execution: iterations_per_loop = 2 - - As model_fn increases the global step once per train_op invocation, the global - step is 10 after all TPU executions, matching the steps=10 inputs passed in by - users. - - Returns: - A TF non-trainable resource variable. - - Raises: - RuntimeError: If multi iterations_per_loop variables were found. - """ - graph = ops.get_default_graph() - collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) - iter_vars = graph.get_collection(collection_name) - if len(iter_vars) == 1: - return iter_vars[0] - elif len(iter_vars) > 1: - raise RuntimeError('Multiple iterations_per_loop_var in collection.') - - with ops.colocate_with(training_util.get_global_step()): - with variable_scope.variable_scope( - _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], - use_resource=True) - - -def _sync_variables_ops(ctx): - """Create varriables synchronization ops. - - Gets the variables back from TPU nodes. This means the variables updated - by TPU will now be *synced* to host memory. - In BROADCAST mode, we skip this sync since the variables are ususally too - big to transmit via RPC. - - Args: - ctx: A `_InternalTPUContext` instance with mode. - - Returns: - A list of sync ops. - """ - - if not ctx.is_input_broadcast_with_iterators(): - return [ - array_ops.check_numerics(v.read_value(), - 'Gradient for %s is NaN' % v.name).op - for v in variables.trainable_variables() - ] - else: - return [control_flow_ops.no_op()] - - -def _increase_eval_step_op(iterations_per_loop): - """Returns an op to increase the eval step for TPU evaluation. - - Args: - iterations_per_loop: Tensor. The number of eval steps running in TPU system - before returning to CPU host for each `Session.run`. - - Returns: - An operation - """ - eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access - # Estimator evaluate increases 1 by default. So, we increase the difference. - return state_ops.assign_add( - eval_step, - math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype), - use_locking=True) - - -def _extract_key_names(tensor_or_dict): - if isinstance(tensor_or_dict, dict): - return sorted(tensor_or_dict.keys()) - return [] - - -class _SIGNAL(object): - """Signal used to control the thread of infeed/outfeed. - - All preserved signals must be negative numbers. Positive numbers are used to - indicate the number of iterations for next training/evaluation loop. - """ - NEXT_BATCH = -1 - STOP = -2 - - -class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - - See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and - `export_outputs`. - - For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where - `metric_fn` runs on CPU to generate metrics and `tensors` represents the - `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. - To be precise, TPU evaluation expects a slightly different signature from the - `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a - dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. - The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The - `tensors` usually specify the model logits, which are transferred back from - TPU system to CPU host. All tensors must have be batch-major, i.e., the batch - size is the first dimension. Once all tensors are available at CPU host from - all shards, they are concatenated (on CPU) and passed as positional arguments - to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - a dict. `metric_fn` takes the `tensors` and returns a dict from metric string - name to the result of calling a metric function, namely a `(metric_tensor, - update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the - `eval_metrics`. - - `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This - function should not capture any Tensors in `model_fn`. - - `host_call` is a tuple of a `function` and a list or dictionary of `tensors` - to pass to that function and returns a list of Tensors. `host_call` currently - works for train() and evaluate(). The Tensors returned by the function is - executed on the CPU on every step, so there is communication overhead when - sending tensors from TPU to CPU. To reduce the overhead, try reducing the - size of the tensors. The `tensors` are concatenated along their major (batch) - dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with `tf.contrib.summary.create_file_writer`. - """ - - def __new__(cls, - mode, - predictions=None, - loss=None, - train_op=None, - eval_metrics=None, - export_outputs=None, - scaffold_fn=None, - host_call=None, - training_hooks=None, - evaluation_hooks=None, - prediction_hooks=None): - """Creates a validated `TPUEstimatorSpec` instance.""" - host_calls = {} - if eval_metrics is not None: - host_calls['eval_metrics'] = eval_metrics - if host_call is not None: - host_calls['host_call'] = host_call - _OutfeedHostCall.validate(host_calls) - - training_hooks = tuple(training_hooks or []) - evaluation_hooks = tuple(evaluation_hooks or []) - prediction_hooks = tuple(prediction_hooks or []) - - for hook in training_hooks + evaluation_hooks + prediction_hooks: - if not isinstance(hook, session_run_hook.SessionRunHook): - raise TypeError('All hooks must be SessionRunHook instances, given: {}' - .format(hook)) - - return super(TPUEstimatorSpec, cls).__new__( - cls, - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metrics=eval_metrics, - export_outputs=export_outputs, - scaffold_fn=scaffold_fn, - host_call=host_call, - training_hooks=training_hooks, - evaluation_hooks=evaluation_hooks, - prediction_hooks=prediction_hooks) - - def as_estimator_spec(self): - """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" - host_calls = {} - if self.eval_metrics is not None: - host_calls['eval_metrics'] = self.eval_metrics - if self.host_call is not None: - host_calls['host_call'] = self.host_call - host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) - eval_metric_ops = None - if self.eval_metrics is not None: - eval_metric_ops = host_call_ret['eval_metrics'] - hooks = None - if self.host_call is not None: - hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - if tensor_tracer.TensorTracer.is_enabled(): - tt = tensor_tracer.TensorTracer() - tracing_calls = tt.trace_cpu(ops.get_default_graph()) - tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) - tracing_functions = tracing_call_ret.values() - if tracing_functions: - if hooks: - hooks.extend([_OutfeedHostCallHook(tracing_functions)]) - else: - hooks = [_OutfeedHostCallHook(tracing_functions)] - hooks = tuple(hooks or []) - scaffold = self.scaffold_fn() if self.scaffold_fn else None - return model_fn_lib.EstimatorSpec( - mode=self.mode, - predictions=self.predictions, - loss=self.loss, - train_op=self.train_op, - eval_metric_ops=eval_metric_ops, - export_outputs=self.export_outputs, - scaffold=scaffold, - training_hooks=self.training_hooks + hooks, - evaluation_hooks=self.evaluation_hooks + hooks, - prediction_hooks=self.prediction_hooks + hooks) - - -class _OpQueueContext(object): - """Manages work queue and thread for a infeed/outfeed thread.""" - - def __init__(self, name, target, args): - self._name = name - self._queue = Queue.Queue() - args = (self,) + args - self._thread = threading.Thread(name=name, target=target, args=args) - self._thread.daemon = True - self._thread.start() - - def stop(self): - self._queue.put(_SIGNAL.STOP) - - def send_next_batch_signal(self, iterations): - self._queue.put(iterations) - - def read_iteration_counts(self): - while True: - iterations = self._queue.get(block=True) - logging.debug('%s read iterations %s', self._name, iterations) - if iterations == _SIGNAL.STOP: - logging.info('%s received shutdown signal, stopping.', self._name) - return - yield iterations - - def join(self): - logging.info('Shutting down %s thread.', self._name) - self.stop() - self._thread.join() - - -class _OpSignalOnceQueueContext(_OpQueueContext): - """Manages work queue and thread for a infeed/outfeed thread. - - This subclass only signals once. - """ - - def __init__(self, name, target, args): - super(_OpSignalOnceQueueContext, self).__init__(name, target, args) - self._has_signaled = False - - def send_next_batch_signal(self, iterations): - if not self._has_signaled: - self._queue.put(iterations) - self._has_signaled = True - - -class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): - """A Session hook setting up the TPU initialization, infeed, and outfeed. - - This hook does two major things: - 1. initialize and shutdown TPU system. - 2. launch and join the threads for infeed enqueue and (optional) outfeed - dequeue. - """ - - def __init__(self, - ctx, - enqueue_ops, - dequeue_ops, - tpu_compile_op, - run_infeed_loop_on_coordinator=True, - rendezvous=None, - master=None, - session_config=None): - self._master_job = ctx.master_job - self._enqueue_ops = enqueue_ops - self._dequeue_ops = dequeue_ops - self._rendezvous = rendezvous - self._master = master - self._session_config = session_config - self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator - self._initial_infeed_sleep_secs = ( - ctx.config.tpu_config.initial_infeed_sleep_secs) - - self._feed_error = None - self._finished = False - self._should_initialize_tpu = True - self._tpu_compile_op = tpu_compile_op - - def begin(self): - logging.info('TPU job name %s', self._master_job) - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [] - if self._should_initialize_tpu: - self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] - else: - self._finalize_ops = [] - - summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() - self._init_ops.extend(summary_writer_init_ops) - # Get all the writer resources from the initializer, so we know what to - # flush. - for op in summary_writer_init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def _run_infeed(self, queue_ctx, session): - logging.info('Starting infeed thread controller.') - if self._initial_infeed_sleep_secs: - logging.info('Infeed thread sleeping for %d seconds.', - self._initial_infeed_sleep_secs) - time.sleep(self._initial_infeed_sleep_secs) - logging.info('Infeed thread starting after sleep') - - with self._rendezvous.catch_errors(source='infeed', session=session): - if self._run_infeed_loop_on_coordinator: - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(self._enqueue_ops) - else: - for _ in queue_ctx.read_iteration_counts(): - session.run(self._enqueue_ops) - logging.info('Infeed thread finished, shutting down.') - - def _run_outfeed(self, queue_ctx, session): - logging.info('Starting outfeed thread controller.') - with self._rendezvous.catch_errors(source='outfeed', session=session): - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) - session.run(self._dequeue_ops) - logging.info('Outfeed thread finished, shutting down.') - - def _create_infeed_controller(self, name, target, args): - return _OpQueueContext(name=name, target=target, args=args) - - def _assertCompilationSucceeded(self, result, coord): - proto = tpu_compilation_result.CompilationResultProto() - proto.ParseFromString(result) - if proto.status_error_message: - logging.error('Compilation failed: {}'.format(proto.status_error_message)) - coord.request_stop() - else: - logging.info('Compilation succeeded') - - def after_create_session(self, session, coord): - if self._should_initialize_tpu: - logging.info('Init TPU system') - start = time.time() - with ops.Graph().as_default(): - with tf_session.Session( - self._master, config=self._session_config) as sess: - sess.run(tpu.initialize_system(job=self._master_job)) - logging.info('Initialized TPU in %d seconds', time.time() - start) - - session.run(self._init_ops, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - - if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': - logging.info('Compiling user program: this may take a while...') - self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) - - self._infeed_controller = self._create_infeed_controller( - name='InfeedController', target=self._run_infeed, args=(session,)) - - self._outfeed_controller = _OpQueueContext( - name='OutfeedController', target=self._run_outfeed, args=(session,)) - - # Enable the worker watchdog to terminate workers on coordinator exit. - watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0')) - if watchdog_timeout > 0: - session_support.start_worker_watchdog(session, - shutdown_timeout=watchdog_timeout) - - def before_run(self, run_context): - self._feed_error = None - - iterations = run_context.session.run(self._iterations_per_loop_var) - - logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) - self._infeed_controller.send_next_batch_signal(iterations) - - logging.info('Dequeue next (%d) batch(es) of data from outfeed.', - iterations) - self._outfeed_controller.send_next_batch_signal(iterations) - - def end(self, session): - self._finished = True - logging.info('Stop infeed thread controller') - self._infeed_controller.join() - self._rendezvous.record_done('infeed') - - logging.info('Stop output thread controller') - self._outfeed_controller.join() - self._rendezvous.record_done('outfeed') - - logging.info('Shutdown TPU system.') - session.run(self._finalize_ops) - - -class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - - def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, - rendezvous=None, master=None, session_config=None): - super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, - enqueue_ops, - dequeue_ops, - tpu_compile_op=tpu_compile_op, - run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous, - master=master, - session_config=session_config) - - def _create_infeed_controller(self, name, target, args): - return _OpSignalOnceQueueContext(name=name, target=target, args=args) - - -class _TPUStopAtStepHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step. - - This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with - following differences for TPU training: - - 1. This hook sets the variable for iterations_per_loop, which is used by - `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed. - As the hook execution order is not guaranteed, the variable update is - handled in `after_create_session` and `after_run` as - `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`. - - 2. For each training loop (session.run), the global step could be increased - multiple times on TPU. The global step tensor value will be explicitly read - again in `after_run` to ensure the latest value is retrieved to avoid race - condition. - """ - - def __init__(self, iterations, num_steps=None, last_step=None): - """Initializes a `StopAtStepHook`. - - Args: - iterations: The number of iterations to run optimizer per training loop. - num_steps: Number of steps to execute. - last_step: Step after which to stop. - - Raises: - ValueError: If one of the arguments is invalid. - """ - if num_steps is None and last_step is None: - raise ValueError('One of num_steps or last_step must be specified.') - if num_steps is not None and last_step is not None: - raise ValueError('Only one of num_steps or last_step can be specified.') - self._num_steps = num_steps - self._last_step = last_step - self._iterations = iterations - - def _next_iterations(self, global_step, last_step): - gap = last_step - global_step - return min(gap, self._iterations) - - def begin(self): - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError('Global step should be created.') - - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - if self._last_step is None: - self._last_step = global_step + self._num_steps - - iterations = self._next_iterations(global_step, self._last_step) - - self._iterations_per_loop_var.load(iterations, session=session) - - def after_run(self, run_context, run_values): - # Global step cannot be retrieved via SessionRunArgs and before_run due to - # race condition. - global_step = run_context.session.run(self._global_step_tensor) - if global_step >= self._last_step: - run_context.request_stop() - else: - iterations = self._next_iterations(global_step, self._last_step) - self._iterations_per_loop_var.load( - iterations, session=run_context.session) - - -class _SetEvalIterationsHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step.""" - - def __init__(self, num_steps): - """Initializes a `_SetEvalIterationsHook`. - - Args: - num_steps: Number of steps to execute. - """ - self._num_steps = num_steps - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - self._iterations_per_loop_var.load(self._num_steps, session=session) - - -class _StoppingPredictHook(session_run_hook.SessionRunHook): - """Hook that requests stop according to the stopping signal in prediction.""" - - def __init__(self, scalar_stopping_signal): - self._scalar_stopping_signal = scalar_stopping_signal - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - # This is not necessary as we do not run infeed enqueue and outfeed dequeue - # in side threads for prediction model. But it makes the - # TPUInfeedOutfeedSessionHook prints nice message. - self._iterations_per_loop_var.load(1, session=session) - - def before_run(self, run_context): - return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) - - def after_run(self, run_context, run_values): - _ = run_context - scalar_stopping_signal = run_values.results - if _StopSignals.should_stop(scalar_stopping_signal): - # NOTE(xiejw): In prediction, stopping signals are inserted for each - # batch. And we append one more batch to signal the system it should stop. - # The data flow might look like - # - # batch 0: images, labels, stop = 0 (user provided) - # batch 1: images, labels, stop = 0 (user provided) - # ... - # batch 99: images, labels, stop = 0 (user provided) - # batch 100: images, labels, stop = 1 (TPUEstimator appended) - # - # where the final batch (id = 100) is appended by TPUEstimator, so we - # should drop it before returning the predictions to user. - # To achieve that, we throw the OutOfRangeError in after_run. Once - # Monitored Session sees this error in SessionRunHook.after_run, the - # "current" prediction, i.e., batch with id=100, will be discarded - # immediately - raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') - - -def generate_per_core_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, host_device, host_id): - """Generates infeed enqueue ops for per-core input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A fn returns enqueue_ops.""" - num_cores_per_host = ctx.num_of_cores_per_host - per_host_sharded_inputs = [] - for core_ordinal in range(num_cores_per_host): - with ops.name_scope('ordinal_%d' % (core_ordinal)): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, - input_device=host_device, - invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - if inputs.is_dataset: - raise TypeError( - '`input_fn` returning `Dataset` is not yet supported in ' - 'per-Core input pipeline deployment yet. Please set ' - 'TPUConfig.per_host_input_for_training to True or return ' - '`features` and `labels` from `input_fn`') - features, labels = inputs.features_and_labels() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels)) - per_host_sharded_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) - return per_host_enqueue_ops - - return enqueue_ops_fn, captured_infeed_queue - - -def generate_per_host_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - - dataset_initializer = None - - with ops.device(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device, invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - if batch_axis is not None: - raise TypeError('For mode PREDICT, batch_axis is not supported yet.') - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - dataset_initializer = inputs.dataset_initializer() - - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A Fn returning the TPU infeed enqueue ops. - - By providing as a Fn, it can be invoked inside the tf.while_loop such that - the input pipeline for multiple iterations can be executed by one - Session.run call. - - Returns: - list of dict of ops. - """ - with ops.device(device): - num_of_replicas_per_host = ctx.num_of_replicas_per_host - # Convert user input to features and labels. If the user returns a - # dataset, it is initialized and the features and labels extracted via - # `dataset.iterator.get_next()` - features, labels = inputs.features_and_labels() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure(features, labels) - unsharded_tensor_list = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_tensor_list], - tuple_shapes=[t.shape for t in unsharded_tensor_list], - shard_dimensions=batch_axis) - captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_number_of_shards(num_of_replicas_per_host) - per_host_enqueue_ops = ( - infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function_impl)) - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -def generate_per_host_v2_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - dataset_initializer = None - - with ops.device(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device, invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if not is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' - 'input pipeline configuration.') - - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True, - num_invocations_per_step=ctx.num_of_replicas_per_host) - - dataset_initializer = inputs.dataset_initializer() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """Generates the per_host enqueue ops.""" - control_deps = [] - per_host_sharded_inputs = [] - num_replicas_per_host = ctx.num_of_replicas_per_host - cached_signals = None - with ops.device(device): - if not inputs.is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for this mode.') - for _ in range(num_replicas_per_host): - # Use control dependencies to ensure a deterministic ordering. - with ops.control_dependencies(control_deps): - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - # All the replicas share the replica 0's stopping singal. - # This avoids inconsistent state among different model replcias. - if cached_signals: - signals['stopping'] = cached_signals['stopping'] - else: - cached_signals = signals - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - control_deps.extend(flattened_inputs) - per_host_sharded_inputs.append(flattened_inputs) - - if inputs_structure_recorder.flattened_input_dims: - input_partition_dims = inputs_structure_recorder.flattened_input_dims - if signals: - input_partition_dims += [None] * len(signals) - # pylint: disable=protected-access - infeed_queue = tpu_feed._PartitionedInfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0]), - host_id=host_id, - input_partition_dims=input_partition_dims, - device_assignment=ctx.device_assignment) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs) - else: - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl) - captured_infeed_queue.capture(infeed_queue) - - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, - num_hosts): - """Generates infeed enqueue ops for one input_fn on all the hosts.""" - captured_infeed_queue = _CapturedObject() - dataset_initializer = None - device_0 = ctx.tpu_host_placement_function(host_id=0) - with ops.device(device_0): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device_0, invocation_index=0) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - dataset_initializer = inputs.dataset_initializer() - num_replicas_per_host = ctx.num_of_replicas_per_host - - def tpu_ordinal_function_impl(replica_id): - if ctx.device_assignment: - return ctx.device_assignment.tpu_ordinal(replica=replica_id) - else: - return replica_id % num_replicas_per_host - - def device_function_impl(replica_id): - return ctx.tpu_host_placement_function(replica_id=replica_id) - - def enqueue_ops_fn(): - """Generates enqueue ops for all the hosts.""" - broadcasted_inputs = [] - flattened_inputs = None # Cache result from input_fn. - signals = None - for host_id in xrange(num_hosts): - with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): - for _ in xrange(ctx.num_of_replicas_per_host): - # Note: input_fn is only called once at host 0 for the first replica. - # The features and labels returned from that invocation are - # broadcasted to other replicas(including the replicas on other - # hosts). - if flattened_inputs is None: - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - broadcasted_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(broadcasted_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - enqueue_ops = infeed_queue.generate_enqueue_ops( - broadcasted_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl, - placement_function=device_function_impl) - - if signals is None: - return enqueue_ops - else: - return { - 'ops': enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -class _InputPipeline(object): - """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. - - `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from - call site. To be precise, based on the configuration in - `_InternalTPUContext`, it invokes `input_fn` for all cores (usually - multi-host TPU training) or for one host (usually for single-host TPU - evaluation), and sends all `features` and `labels` returned by `input_fn` to - TPU infeed. For per-core invocation, `features` and `labels` are piped to - infeed directly, one tuple for each core. For per-host invocation, `features` - and `labels` are split at host (with respect to `batch_axis`) and piped to all - cores accordingly. - - In addition, flatten/unflatten are handled by `_InputPipeline` also. Model - inputs returned by the `input_fn` can have one of the following forms: - 1. features - 2. (features, labels) - 3. ((arbitrarily nested structure of features), labels) - - Internally, form 1 is reformed to `(features, None)` as features and labels - are passed separately to underlying methods. For TPU training, TPUEstimator - may expect multiple `features` and `labels` tuples one for each core. - - TPUEstimator allows various different structures for inputs (namely `features` - and `labels`). Both `features` and `labels` can be any nested sturcture - supported by TF nest (namely, dict, tuples, namedtuples or any nested - structure of such of Tensors). `labels` could be `None` as well. - - These are flattened before they are passed to the infeed/outfeed library - as that expectes flattend lists. - """ - - class InputsStructureRecorder(object): - """The recorder to record inputs structure.""" - - def __init__(self, input_partition_dims=None): - # Holds the structure of inputs - self._feature_structure = {} - self._flattened_input_dims = None - - if input_partition_dims: - # This should have been validated in TPUConfig. - assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' - if len(input_partition_dims) == 2: - self._feature_dims, self._label_dims = input_partition_dims - else: - self._feature_dims = input_partition_dims[0] - self._label_dims = None - - assert self._feature_dims is not None, ('input_partition_dims[0] must ' - 'not be None') - else: - self._feature_dims = None - self._label_dims = None - - # Internal state. - self._initialized = False - - @property - def flattened_input_dims(self): - assert self._initialized, 'InputsStructureRecorder is not initialized.' - return self._flattened_input_dims - - def has_labels(self): - return 'labels' in self._feature_structure - - def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, - label_dims_names, label_names, has_labels): - """Flatten input dims with the same order as flattened input tensors.""" - flattened_input_dims = [] - if feature_dims_names: - # We need a fixed ordering for matching the tensors in features. - flattened_input_dims.extend( - [feature_dims[name] for name in feature_dims_names]) - else: - flattened_input_dims.append(feature_dims) - - if label_dims_names: - # We need a fixed ordering for matching the tensors in labels. - flattened_input_dims.extend( - [label_dims[name] for name in label_dims_names]) - else: - if label_names: - num_tensors_in_label = len(label_names) - else: - num_tensors_in_label = int(has_labels) - # Setting `None` in input_partition_dims[1] will apply `None` to - # all the tensors in labels, regardless of internal structure. - flattened_input_dims.extend([label_dims] * num_tensors_in_label) - - return flattened_input_dims - - def validate_and_record_structure(self, features, labels): - """Validates and records the structure of `features` and `labels`.""" - # Extract structure. - has_labels = labels is not None - feature_names = _extract_key_names(features) - label_names = _extract_key_names(labels) - - if not self._initialized: - # Record structure. - self._initialized = True - if self._feature_dims is not None: - feature_dims_names = _extract_key_names(self._feature_dims) - if feature_dims_names != feature_names: - raise ValueError( - 'TPUConfig.input_partition_dims[0] mismatched feature' - ' keys. Expected {}, got {}'.format(feature_names, - feature_dims_names)) - - label_dims_names = _extract_key_names(self._label_dims) - if self._label_dims is not None and label_dims_names != label_names: - raise ValueError( - 'TPUConfig.input_partition_dims[1] mismatched label' - ' keys. Expected {}, got {}'.format(label_names, - label_dims_names)) - - self._flattened_input_dims = self._flatten_input_dims( - self._feature_dims, feature_dims_names, self._label_dims, - label_dims_names, label_names, has_labels) - - def flatten_features_and_labels(self, features, labels, signals=None): - """Flattens the `features` and `labels` to a single tensor list.""" - self._feature_structure['features'] = features - if labels is not None: - self._feature_structure['labels'] = labels - if signals is not None: - self._feature_structure['signals'] = signals - return data_nest.flatten(self._feature_structure) - - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. - - Args: - flattened_inputs: Flattened inputs for each shard. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. - - Raises: - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. - """ - - unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, - flattened_inputs) - return _Inputs( - unflattened_inputs['features'], - unflattened_inputs.get('labels'), - signals=unflattened_inputs.get('signals')) - - def __init__(self, input_fn, batch_axis, ctx): - """Constructor. - - Args: - input_fn: input fn for train or eval. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. - ctx: A `_InternalTPUContext` instance with mode. - - Raises: - ValueError: If both `sharded_features` and `num_cores` are `None`. - """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( - ctx.input_partition_dims) - - self._sharded_per_core = ctx.is_input_sharded_per_core() - self._input_fn = input_fn - self._infeed_queue = None - self._ctx = ctx - self._batch_axis = batch_axis - - def generate_infeed_enqueue_ops_and_dequeue_fn(self): - """Generates infeed enqueue ops and dequeue_fn.""" - # While tf.while_loop is called, the body function, which invokes - # `enqueue_fn` passed in, is called to construct the graph. So, input_fn - # structure is recorded. - enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( - self._invoke_input_fn_and_record_structure()) - - self._validate_input_pipeline() - - def dequeue_fn(): - """dequeue_fn is used by TPU to retrieve the tensors.""" - # In the model-parallel case, both the host-side and device-side - # computations must agree on the core on which infeed takes place. We - # choose to perform infeed on logical core 0 of each replica. - values = self._infeed_queue.generate_dequeue_op(tpu_device=0) - # The unflatten process uses the structure information recorded above. - return self._inputs_structure_recorder.unflatten_features_and_labels( - values) - - return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) - - def _invoke_input_fn_and_record_structure(self): - """Deploys the input pipeline and record input structure.""" - enqueue_ops = [] - infeed_queues = [] - all_dataset_initializers = [] - num_hosts = self._ctx.num_hosts - tpu_host_placement_fn = self._ctx.tpu_host_placement_function - - run_infeed_loop_on_coordinator = True - - if self._sharded_per_core: - # Per-Core input pipeline deployment. - # Invoke input pipeline for each core and placed on the corresponding - # host. - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.device(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue = ( - generate_per_core_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder, - host_device, host_id)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - run_infeed_loop_on_coordinator = False - enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - # Infeed_queue_getter must be called after enqueue_ops_fn is called. - infeed_queues.append(captured_infeed_queue.get()) - - elif self._ctx.is_input_broadcast_with_iterators(): - # Only calls input_fn in host 0. - host_device = tpu_host_placement_fn(host_id=0) - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, - self._inputs_structure_recorder, - num_hosts)) - if dataset_initializer: - all_dataset_initializers.append(dataset_initializer) - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - else: - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.device(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - if self._ctx.is_input_per_host_with_iterators(): - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_per_host_v2_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, host_device, host_id)) - else: - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_per_host_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, self._batch_axis, - host_device, host_id)) - - # NOTE(xiejw): We dispatch here based on the return type of the - # users `input_fn`. - # - # 1. If input_fn returns a Dataset instance, we initialize the - # iterator outside of tf.while_loop, and call the iterator.get_next - # inside tf.while_loop. This should be always safe. - # - # 2. If input_fn returns (features, labels), it is too late to wrap - # them inside tf.while_loop, as resource initialization cannot be - # handled in TF control flow properly. In this case, we will use - # python loop to enqueue the data into TPU system. This may be - # slow compared to the previous case. - if dataset_initializer: - all_dataset_initializers.append(dataset_initializer) - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append( - wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - # infeed_queue is used to generate dequeue ops. The only thing it uses for - # dequeue is dtypes and types. So, any one can be used. Here, grab the - # first one. - self._infeed_queue = infeed_queues[0] - return enqueue_ops, [ - util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers) - ], run_infeed_loop_on_coordinator - - def _validate_input_pipeline(self): - """Validates the input pipeline. - - Perform some sanity checks to log user friendly information. We should - error out to give users better error message. But, if - _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - user code, so, log a warning. - - Raises: - RuntimeError: If the validation failed. - """ - if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): - err_msg = ('Input pipeline contains one or more QueueRunners. ' - 'It could be slow and not scalable. Please consider ' - 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/guide/datasets for ' - 'instructions.') - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) - else: - logging.warn(err_msg) - - -def call_computation(computation, - experimental_exported_model_uses_all_cores=True): - """Call computation. - - computation uses a single-core for TPU inference. If - `experimental_exported_model_uses_all_cores` is `True`, this function will - round-robin - computation among all TPU cores visible to the host; otherwise, it will use - a single core. - - Args: - computation: A Python function that takes no inputs and builds computation - graph. If `computation` returns m outputs, this function will return a - list of m Tensors. - experimental_exported_model_uses_all_cores: Whether to round-robin among all - cores visible to the host, or to use a single core. - - Returns: - A list of output tensors. - """ - if experimental_exported_model_uses_all_cores: - # Using `TPUPartitionedCall` makes it possible to target a different - # TPU core with every `Session.run()` call. Note that the entire inference - # graph executes on a single core, and that invocations of this graph - # will round-robin among the cores attached to a host. - @function.Defun() - def tpu_subgraph(): - return computation() - - return tpu_functional.TPUPartitionedCall( - args=tpu_subgraph.captured_inputs, - device_ordinal=gen_tpu_ordinal_selector_op.tpu_ordinal_selector(), - Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], - f=tpu_subgraph) - else: - return computation() - - -class _ModelFnWrapper(object): - """A `model_fn` wrapper. - - This makes calling model_fn on CPU and TPU easier and more consistent and - performs necessary check and mutation required by TPU training and evaluation. - - In addition, this wrapper manages converting the `model_fn` to a single TPU - train and eval step. - """ - - def __init__(self, model_fn, config, params, ctx): - self._model_fn = model_fn - self._config = config - self._params = params - self._ctx = ctx - - def call_without_tpu(self, features, labels, is_export_mode): - return self._call_model_fn(features, labels, is_export_mode=is_export_mode) - - def convert_to_single_tpu_train_step(self, dequeue_fn): - """Converts user provided model_fn` as a single train step on TPU. - - The user provided `model_fn` takes input tuple - (features, labels) and produces the EstimatorSpec with train_op and loss for - train `mode`. This usually represents a single train computation on CPU. - - For TPU training, a train (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input should be taken from TPU infeed rather - than input pipeline (input_fn) directly. To fit TPU loop and replicate - pattern, the original train computation should be reformed, which is the - returned `train_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn - representing the train step for TPU. - """ - - host_call = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_training_hooks = _CapturedObject() - - def train_step(loss): - """Training step function for use inside a while loop.""" - del loss # unused; required in function signature. - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - - estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels)) - loss, train_op = estimator_spec.loss, estimator_spec.train_op - - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - captured_scaffold_fn.capture(estimator_spec.scaffold_fn) - else: - captured_scaffold_fn.capture(None) - - captured_training_hooks.capture(estimator_spec.training_hooks) - - tracing_ops = [] - if tensor_tracer.TensorTracer.is_enabled(): - tt = tensor_tracer.TensorTracer() - loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, - self._ctx.num_replicas, - fetches=[loss, train_op]) - - # We must run train_op to update the variables prior to running the - # outfeed. - with ops.control_dependencies([train_op]+tracing_ops): - host_call_outfeed_ops = [] - if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access - and estimator_spec.host_call is not None): - host_call.record({'host_call': estimator_spec.host_call}) - host_call_outfeed_ops = host_call.create_enqueue_op() - with ops.control_dependencies(host_call_outfeed_ops): - return array_ops.identity(loss) - - return (train_step, host_call, captured_scaffold_fn, - captured_training_hooks) - - def convert_to_single_tpu_eval_step(self, dequeue_fn): - """Converts user provided model_fn` as a single eval step on TPU. - - Similar to training, the user provided `model_fn` takes input tuple - (features, labels) and produces the TPUEstimatorSpec with eval_metrics for - eval `mode`. This usually represents a single evaluation computation on CPU. - - For TPU evaluation, a eval (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input and output are slightly different. Input, - features and labels, should be taken from TPU infeed rather than input - pipeline (input_fn) directly. Output is managed in two stages. First, the - model outputs as the result of evaluation computation, usually model logits, - should be transferred from TPU system to CPU. Then, all model outputs are - concatenated first on CPU and sent to the metric_fn for metrics computation. - To fit TPU evaluation pattern, the original eval computation should be - reformed, which is the returned `eval_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn - representing the eval step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_eval_hooks = _CapturedObject() - - def eval_step(total_loss): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - - tpu_estimator_spec = self._call_model_fn(features, labels) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU evaluation must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - loss = tpu_estimator_spec.loss - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) - - to_record = {} - if tpu_estimator_spec.eval_metrics: - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics - if tpu_estimator_spec.host_call is not None: - # We assume that evaluate won't update global step, so we don't wrap - # this host_call. - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return math_ops.add(total_loss, loss) - - return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - - def convert_to_single_tpu_predict_step(self, dequeue_fn): - """Converts user provided model_fn` as a single predict step on TPU. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of predict_fn, host_calls, and captured scaffold_fn. The - predict_fn representing the predict step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_predict_hooks = _CapturedObject() - - def predict_step(unused_scalar_stopping_signal): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - stopping_signals = inputs.signals() - - assert stopping_signals is not None, ( - 'Internal Error: `signals` is missing.') - - tpu_estimator_spec = self._call_model_fn( - features, labels, is_export_mode=False) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU prediction must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) - - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) - to_record = {} - identity_fn = lambda **kwargs: kwargs - to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] - to_record['signals'] = [identity_fn, stopping_signals] - if tpu_estimator_spec.host_call is not None: - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return _StopSignals.as_scalar_stopping_signal(stopping_signals) - - return (predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks) - - def _verify_tpu_spec_predictions(self, predictions): - """Validates TPUEstimatorSpec.predictions dict.""" - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') - - for (key, tensor) in predictions.items(): - if tensor.shape.dims[0].value is None: - raise ValueError( - 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' - 'dynamic shape (should be static). Tensor: {}'.format(key, tensor)) - return predictions - - def _validate_model_features_and_labels(self, features, labels, - is_export_mode): - """Validates that the features and labels for the model function are valid. - - A valid features/labels object is the one with: - - Type: A tensor or any nested structure of tensors supported by TF nest, - namely nested dictionary, tuple, namedtuple, or sequence of tensors. - - Static shape if is_export_mode is False. - - Args: - features: the features that would be input to the model function. - labels: the labels that would be input to the model function. - is_export_mode: boolean value specifying if in export mode. - - Raises: - TypeError: If features/labels are not of the correct type. - ValueError: If features/labels have dynamic shape. - """ - - def validate(obj, obj_name): - """Helper validate function.""" - if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): - return - if isinstance(obj, ops.Tensor): - if not obj.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static shape.' - ' Tensor: {}'.format(obj_name, obj)) - else: - for tensor in data_nest.flatten(obj): - if not tensor.get_shape().is_fully_defined(): - raise ValueError( - ('The {} to the model returned by input_fn must have static ' - 'shape. Tensor: {}').format(obj_name, tensor)) - - validate(features, 'features') - if labels is not None: - validate(labels, 'labels') - - def _call_model_fn(self, features, labels, is_export_mode=False): - """Calls the model_fn with required parameters.""" - self._validate_model_features_and_labels(features, labels, is_export_mode) - model_fn_args = function_utils.fn_args(self._model_fn) - kwargs = {} - - # Makes deep copy with `config` and params` in case user mutates them. - config = copy.deepcopy(self._config) - params = copy.deepcopy(self._params) - - if 'labels' in model_fn_args: - kwargs['labels'] = labels - elif labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') - if 'mode' in model_fn_args: - kwargs['mode'] = self._ctx.mode - if 'config' in model_fn_args: - kwargs['config'] = config - if 'params' in model_fn_args: - kwargs['params'] = params - - if 'params' not in model_fn_args: - raise ValueError('model_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params[\'batch_size\']'.format(self._model_fn)) - - if is_export_mode: - batch_size_for_model_fn = None - else: - batch_size_for_model_fn = self._ctx.batch_size_for_model_fn - - if batch_size_for_model_fn is not None: - _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) - - running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) - _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) - - if not running_on_cpu: - user_context = tpu_context.TPUContext( - internal_ctx=self._ctx, call_from_input_fn=False) - _add_item_to_params(params, _CTX_KEY, user_context) - - estimator_spec = self._model_fn(features=features, **kwargs) - if (running_on_cpu and - isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access - # The estimator_spec will be passed to `Estimator` directly, which expects - # type `EstimatorSpec`. - return estimator_spec.as_estimator_spec() - else: - return estimator_spec - - def _verify_estimator_spec(self, estimator_spec): - """Validates the estimator_spec.""" - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - return estimator_spec - - err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' - if estimator_spec.training_chief_hooks: - raise ValueError( - err_msg.format('training_chief_hooks') + 'If you want' + - ' to pass training hooks, please pass via training_hooks.') - - if estimator_spec.scaffold: - logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' - 'Please use TPUEstimatorSpec.') - return estimator_spec - - -class _OutfeedHostCall(object): - """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" - - def __init__(self, ctx): - self._ctx = ctx - self._names = [] - # All of these are dictionaries of lists keyed on the name. - self._host_fns = {} - self._tensor_keys = collections.defaultdict(list) - self._tensors = collections.defaultdict(list) - self._tensor_dtypes = collections.defaultdict(list) - self._tensor_shapes = collections.defaultdict(list) - - @staticmethod - def validate(host_calls): - """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" - - for name, host_call in host_calls.items(): - if not isinstance(host_call, (tuple, list)): - raise ValueError('{} should be tuple or list'.format(name)) - if len(host_call) != 2: - raise ValueError('{} should have two elements.'.format(name)) - if not callable(host_call[0]): - raise TypeError('{}[0] should be callable.'.format(name)) - if not isinstance(host_call[1], (tuple, list, dict)): - raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) - - if isinstance(host_call[1], (tuple, list)): - fullargspec = tf_inspect.getfullargspec(host_call[0]) - fn_args = function_utils.fn_args(host_call[0]) - # wrapped_hostcall_with_global_step uses varargs, so we allow that. - if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' - 'method args of the function, which takes {}.'.format( - name, len(host_call[1]), len(fn_args))) - - @staticmethod - def create_cpu_hostcall(host_calls): - """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" - - _OutfeedHostCall.validate(host_calls) - ret = {} - for name, host_call in host_calls.items(): - host_fn, tensors = host_call - if isinstance(tensors, (tuple, list)): - ret[name] = host_fn(*tensors) - else: - # Must be dict. - try: - ret[name] = host_fn(**tensors) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise - return ret - - def record(self, host_calls): - """Records the host_call structure.""" - - for name, host_call in host_calls.items(): - host_fn, tensor_list_or_dict = host_call - self._names.append(name) - self._host_fns[name] = host_fn - - if isinstance(tensor_list_or_dict, dict): - for (key, tensor) in six.iteritems(tensor_list_or_dict): - self._tensor_keys[name].append(key) - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - else: - # List or tuple. - self._tensor_keys[name] = None - for tensor in tensor_list_or_dict: - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - - def create_enqueue_op(self): - """Create the op to enqueue the recorded host_calls. - - Returns: - A list of enqueue ops, which is empty if there are no host calls. - """ - if not self._names: - return [] - - tensors = [] - # TODO(jhseu): Consider deduping tensors. - for name in self._names: - tensors.extend(self._tensors[name]) - - with ops.device(tpu.core(0)): - return [tpu_ops.outfeed_enqueue_tuple(tensors)] - - def create_tpu_hostcall(self): - """Sends the tensors through outfeed and runs the host_fn on CPU. - - The tensors are concatenated along dimension 0 to form a global tensor - across all shards. The concatenated function is passed to the host_fn and - executed on the first host. - - Returns: - A dictionary mapping name to the return type of the host_call by that - name. - - Raises: - RuntimeError: If outfeed tensor is scalar. - """ - if not self._names: - return {} - - ret = {} - # For each i, dequeue_ops[i] is a list containing the tensors from all - # shards. This list is concatenated later. - dequeue_ops = [] - tensor_dtypes = [] - tensor_shapes = [] - for name in self._names: - for _ in self._tensors[name]: - dequeue_ops.append([]) - for dtype in self._tensor_dtypes[name]: - tensor_dtypes.append(dtype) - for shape in self._tensor_shapes[name]: - tensor_shapes.append(shape) - - # Outfeed ops execute on each replica's first logical core. Note: we must - # constraint it such that we have at most one outfeed dequeue and enqueue - # per replica. - for i in xrange(self._ctx.num_replicas): - host_device, ordinal_id = self._ctx.device_for_replica(i) - with ops.device(host_device): - outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, - shapes=tensor_shapes, - device_ordinal=ordinal_id) - for j, item in enumerate(outfeed_tensors): - dequeue_ops[j].append(item) - - # Deconstruct dequeue ops. - flat_dequeue_ops = [] - for l in dequeue_ops: - flat_dequeue_ops.extend(l) - - dequeue_ops_by_name = {} - pos = 0 - for name in self._names: - dequeue_ops_by_name[name] = dequeue_ops[pos:pos + - len(self._tensors[name])] - pos += len(self._tensors[name]) - - def _call_host_fn(fn, *args, **kw): - context = CatchInvalidHostcallFunctions() - context.Enter() - result = fn(*args, **kw) - context.Exit() - context.ExitResult(result) - return result - - # It is assumed evaluation always happens on single host TPU system. So, - # place all ops on tpu host if possible. - # - # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): - for name in self._names: - dequeue_ops = dequeue_ops_by_name[name] - for i, item in enumerate(dequeue_ops): - if dequeue_ops[i][0].shape.ndims == 0: - raise RuntimeError( - 'All tensors outfed from TPU should preserve batch size ' - 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Make the specification of the outfeed combinaton - # function more explicit and well-documented. We may want to give the - # user the option of concatenating along any axis. - if (self._ctx.config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.BROADCAST): - # If the infeed is in BROADCAST mode (each core recieving the same - # input), then we assume that the cores also produce identical - # copies of the same output, and we simply take the output from - # the first core. This mode is used by Mesh-TensorFlow. - with ops.control_dependencies(dequeue_ops[i]): - dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0]) - else: - # Assume that the input has been batch-split and that axis 0 of the - # output tensors represents the batch size. Concatenate along - # the axis 0 to re-combine the batch. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) - - if self._tensor_keys[name] is not None: - # The user-provided eval_metrics[1] is a dict. - dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) - try: - ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise - else: - ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops) - - # force all dequeue operations to be run if not consumed by the host calls - ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops) - return ret - - -class _OutfeedHostCallHook(session_run_hook.SessionRunHook): - """Hook to run host calls when use_tpu=False.""" - - def __init__(self, tensors): - self._tensors = tensors - - def begin(self): - # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than - # create a separate hook to guarantee execution order, because summaries - # need to be initialized before the outfeed thread starts. - # TODO(jhseu): Make a wrapper hook instead? - self._init_ops = contrib_summary.summary_writer_initializer_op() - # Get all the writer resources from the initializer, so we know what to - # flush. - self._finalize_ops = [] - for op in self._init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def after_create_session(self, session, coord): - session.run(self._init_ops) - - def before_run(self, run_context): - return basic_session_run_hooks.SessionRunArgs(self._tensors) - - def end(self, session): - session.run(self._finalize_ops) - - -class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): - """Calculate and report global_step/sec and examples/sec during runtime.""" - - def __init__(self, - batch_size, - every_n_steps=100, - every_n_secs=None, - output_dir=None, - summary_writer=None): - self._batch_size = batch_size - super(ExamplesPerSecondHook, self).__init__( - every_n_steps=every_n_steps, - every_n_secs=every_n_secs, - output_dir=output_dir, - summary_writer=summary_writer) - - def _log_and_record(self, elapsed_steps, elapsed_time, global_step): - global_step_per_sec = elapsed_steps / elapsed_time - examples_per_sec = self._batch_size * global_step_per_sec - if self._summary_writer is not None: - global_step_summary = Summary(value=[ - Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) - ]) - example_summary = Summary(value=[ - Summary.Value(tag='examples/sec', simple_value=examples_per_sec) - ]) - self._summary_writer.add_summary(global_step_summary, global_step) - self._summary_writer.add_summary(example_summary, global_step) - logging.info('global_step/sec: %g', global_step_per_sec) - logging.info('examples/sec: %g', examples_per_sec) - - -class InstallSignalHandlerHook(session_run_hook.SessionRunHook): - """Change SIGINT (CTRL^C) handler to force quit the process. - - The default behavior often results in hanging processes. - The original handler is restored after training/evaluation. - """ - - def __init__(self): - self._signal_fn = signal.getsignal(signal.SIGINT) - - def before_run(self, run_context): - signal.signal(signal.SIGINT, signal.SIG_DFL) - - def end(self, session): - signal.signal(signal.SIGINT, self._signal_fn) - - -class TPUEstimator(estimator_lib.Estimator): - """Estimator with TPU support. - - TPUEstimator also supports training on CPU and GPU. You don't need to define - a separate `tf.estimator.Estimator`. - - TPUEstimator handles many of the details of running on TPU devices, such as - replicating inputs and models for each core, and returning to host - periodically to run hooks. - - TPUEstimator transforms a global batch size in params to a per-shard batch - size when calling the `input_fn` and `model_fn`. Users should specify - global batch size in constructor, and then get the batch size for each shard - in `input_fn` and `model_fn` by `params['batch_size']`. - - - For training, `model_fn` gets per-core batch size; `input_fn` may get - per-core or per-host batch size depending on `per_host_input_for_training` - in `TPUConfig` (See docstring for TPUConfig for details). - - - For evaluation and prediction, `model_fn` gets per-core batch size and - `input_fn` get per-host batch size. - - Evaluation - ========== - - `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` - for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return - `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case - the following discussion on TPU evaluation does not apply. - - `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where - `tensors` could be a list of any nested structure of `Tensor`s (See - `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns - a dict from metric string name to the result of calling a metric function, - namely a `(metric_tensor, update_op)` tuple. - - One can set `use_tpu` to `False` for testing. All training, evaluation, and - predict will be executed on CPU. `input_fn` and `model_fn` will receive - `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. - - Current limitations: - -------------------- - - 1. TPU evaluation only works on a single host (one TPU worker) except - BROADCAST mode. - - 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception - (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all - batches should have the same size. - - Example (MNIST): - ---------------- - - ``` - # The metric Fn which runs on CPU. - def metric_fn(labels, logits): - predictions = tf.argmax(logits, 1) - return { - 'accuracy': tf.metrics.precision( - labels=labels, predictions=predictions), - } - - # Your model Fn which runs on TPU (eval_metrics is list in this example) - def model_fn(features, labels, mode, config, params): - ... - logits = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, [labels, logits])) - - # or specify the eval_metrics tensors as dict. - def model_fn(features, labels, mode, config, params): - ... - final_layer_output = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, { - 'labels': labels, - 'logits': final_layer_output, - })) - ``` - - Prediction - ========== - - Prediction on TPU is an experimental feature to support large batch inference. - It is not designed for latency-critical system. In addition, due to some - usability issues, for prediction with small dataset, CPU `.predict`, i.e., - creating a new `TPUEstimator` instance with `use_tpu=False`, might be more - convenient. - - Note: In contrast to TPU training/evaluation, the `input_fn` for prediction - *should* raise an end-of-input exception (`OutOfRangeError` or - `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be - precise, the ops created by `input_fn` produce one batch of the data. - The `predict()` API processes one batch at a time. When reaching the end of - the data source, an end-of-input exception should be raised by one of these - operations. The user usually does not need to do this manually. As long as the - dataset is not repeated forever, the `tf.data` API will raise an end-of-input - exception automatically after the last batch has been produced. - - Note: Estimator.predict returns a Python generator. Please consume all the - data from the generator so that TPUEstimator can shutdown the TPU system - properly for user. - - Current limitations: - -------------------- - 1. TPU prediction only works on a single host (one TPU worker). - - 2. `input_fn` must return a `Dataset` instance rather than `features`. In - fact, .train() and .evaluate() also support Dataset as return value. - - Example (MNIST): - ---------------- - ``` - height = 32 - width = 32 - total_examples = 100 - - def predict_input_fn(params): - batch_size = params['batch_size'] - - images = tf.random_uniform( - [total_examples, height, width, 3], minval=-1, maxval=1) - - dataset = tf.data.Dataset.from_tensor_slices(images) - dataset = dataset.map(lambda images: {'image': images}) - - dataset = dataset.batch(batch_size) - return dataset - - def model_fn(features, labels, params, mode): - # Generate predictions, called 'output', from features['image'] - - if mode == tf.estimator.ModeKeys.PREDICT: - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, - predictions={ - 'predictions': output, - 'is_padding': features['is_padding'] - }) - - tpu_est = TPUEstimator( - model_fn=model_fn, - ..., - predict_batch_size=16) - - # Fully consume the generator so that TPUEstimator can shutdown the TPU - # system. - for item in tpu_est.predict(input_fn=input_fn): - # Filter out item if the `is_padding` is 1. - # Process the 'predictions' - ``` - - Exporting - ========= - - `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, - and another with `tag_constants.SERVING` and `tag_constants.TPU`. - At serving time, these tags are used to select metagraph to load. - - Before running the graph on TPU, TPU system needs to be initialized. If - TensorFlow Serving model-server is used, this is done automatically. If - not, please call `session.run(tpu.initialize_system())`. - - `tpu.outside_compilation` can be used to wrap TPU incompatible ops in - `model_fn`. - - Example: - ---------------- - - ``` - def model_fn(features, labels, mode, config, params): - ... - logits = ... - export_outputs = { - 'logits': export_output_lib.PredictOutput( - {'logits': logits}) - } - - def host_call(logits): - class_ids = math_ops.argmax(logits) - classes = string_ops.as_string(class_ids) - export_outputs['classes'] = - export_output_lib.ClassificationOutput(classes=classes) - - tpu.outside_compilation(host_call, logits) - - ... - ``` - - """ - - def __init__(self, - model_fn=None, - model_dir=None, - config=None, - params=None, - use_tpu=True, - train_batch_size=None, - eval_batch_size=None, - predict_batch_size=None, - batch_axis=None, - eval_on_tpu=True, - export_to_tpu=True, - export_to_cpu=True, - warm_start_from=None, - experimental_exported_model_uses_all_cores=False): - """Constructs an `TPUEstimator` instance. - - Args: - model_fn: Model function as required by `Estimator` which returns - EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', - and `prediction_hooks` must not capure any TPU Tensor inside the - model_fn. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. If `None`, the model_dir in - `config` will be used if set. If both are set, they must be same. If - both are `None`, a temporary directory will be used. - config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. - params: An optional `dict` of hyper parameters that will be passed into - `input_fn` and `model_fn`. Keys are names of parameters, values are - basic python types. There are reserved keys for `TPUEstimator`, - including 'batch_size'. - use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training and evaluation respect this bit, but eval_on_tpu can - override execution of eval. See below. - Predict still happens on CPU. - train_batch_size: An int representing the global training batch size. - TPUEstimator transforms this global batch size to a per-shard batch - size, as params['batch_size'], when calling `input_fn` and `model_fn`. - Cannot be `None` if `use_tpu` is `True`. Must be divisible by total - number of replicas. - eval_batch_size: An int representing evaluation batch size. Must be - divisible by total number of replicas. - predict_batch_size: An int representing the prediction batch size. Must be - divisible by total number of replicas. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. For example, if your input_fn produced (images, labels) - where the images tensor is in `HWCN` format, your shard dimensions would - be [3, 0], where 3 corresponds to the `N` dimension of your images - Tensor, and 0 corresponds to the dimension along which to split the - labels to match up with the corresponding images. If None is supplied, - and per_host_input_for_training is True, batches will be sharded based - on the major dimension. If tpu_config.per_host_input_for_training is - False or `PER_HOST_V2`, batch_axis is ignored. - eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the - model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. - export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU. Note that unsupported export modes such as EVAL will be - ignored. For those modes, only a CPU model will be exported. - Currently, export_to_tpu only supports PREDICT. - export_to_cpu: If True, `export_savedmodel()` exports a metagraph for - serving on CPU. - warm_start_from: Optional string filepath to a checkpoint or SavedModel to - warm-start from, or a `tf.estimator.WarmStartSettings` object to fully - configure warm-starting. If the string filepath is provided instead of - a `WarmStartSettings`, then all variables are warm-started, and it is - assumed that vocabularies and Tensor names are unchanged. - experimental_exported_model_uses_all_cores: Whether to round-robin among - all cores visible to the host which is serving the saved model, or to - use a single core. This is a temporary flag to enable using all TPU - cores for inference with TPUPartitionedCall(). Once outside compilation - is supported in TPUPartitionedCall(), this flag will be enabled by - default. - - Raises: - ValueError: `params` has reserved keys already. - """ - if config is None or not isinstance(config, tpu_config.RunConfig): - raise ValueError( - '`config` must be provided with type `tpu_config.RunConfig`') - - if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): - raise ValueError('{} are reserved keys but existed in params {}.'.format( - _RESERVED_PARAMS_KEYS, params)) - - if use_tpu: - # Perform some very basic validations. More validations will be found in - # _InternalTPUContext. - if train_batch_size is None: - raise ValueError('`train_batch_size` cannot be `None`') - util_lib.check_positive_integer(train_batch_size, 'train_batch_size') - - if (config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.num_cores_per_replica): - raise ValueError( - 'Model parallelism only supports per host input for training. ' - 'Please adjust TPURunconfig.per_host_input_for_training.') - - if eval_batch_size is not None: - util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') - - if predict_batch_size is not None: - util_lib.check_positive_integer(predict_batch_size, - 'predict_batch_size') - - # Verifies the model_fn signature according to Estimator framework. - estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access - # We cannot store config and params in this constructor as parent - # constructor might change them, such as assigning a temp dir for - # config.model_dir. - model_function = self._augment_model_fn(model_fn, batch_axis) - - # Overwrite log_step_count_steps to disable TensorLoggingHook and - # StepCounterHook from being created in Estimator. TPUEstimator already - # added equivalent hooks in _augment_model_fn above. - self._log_every_n_steps = config.log_step_count_steps - config = config.replace(log_step_count_steps=None) - - # Passing non-None params as wrapped model_fn has it. - params = params or {} - super(TPUEstimator, self).__init__( - model_fn=model_function, - model_dir=model_dir, - config=config, - params=params, - warm_start_from=warm_start_from) - self._iterations_per_training_loop = ( - self._config.tpu_config.iterations_per_loop) - - # All properties passed to _InternalTPUContext are immutable. - # pylint: disable=protected-access - self._ctx = tpu_context._get_tpu_context( - self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu, eval_on_tpu) - - self._export_to_cpu = export_to_cpu - self._export_to_tpu = export_to_tpu - self._experimental_exported_model_uses_all_cores = ( - experimental_exported_model_uses_all_cores) - - self._is_input_fn_invoked = None - self._rendezvous = {} - - def _add_meta_graph_for_mode(self, - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables=True, - mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None, - check_variables=True): - if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - logging.warning('TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; Mode {} will be ignored ' - 'for TPU.'.format(mode)) - - if not self._export_to_cpu and not self._export_to_tpu: - raise ValueError('One of export_to_cpu and export_to_tpu must be true.') - - if self._export_to_cpu: - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) - - if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT: - input_receiver_fn_map = { - _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] - } - export_tags = [tag_constants.SERVING, tag_constants.TPU] - mode = _REWRITE_FOR_INFERENCE_MODE - - # See b/110052256 for why `check_variables` is `False`. - if not self._export_to_cpu: - check_variables = save_variables = True - else: - check_variables = save_variables = False - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables=save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) - - def _call_model_fn(self, features, labels, mode, config): - if mode == _REWRITE_FOR_INFERENCE_MODE: - return self._call_model_fn_for_inference(features, labels, mode, config) - else: - return super(TPUEstimator, self)._call_model_fn(features, labels, mode, - config) - - def _call_model_fn_for_inference(self, features, labels, mode, config): - """Wraps `_call_model_fn` for `export_savedmodel`.""" - if mode != _REWRITE_FOR_INFERENCE_MODE: - raise ValueError('mode must be {}; ' - 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) - - computation, capture = self._build_computation_for_inference( - features, labels, mode, config) - tensors = call_computation( - computation, - experimental_exported_model_uses_all_cores=self - ._experimental_exported_model_uses_all_cores) - estimator_spec, export_outputs_dict, predictions_dict, none_indices = ( - capture.get()) - predictions_list = tensors[:len(predictions_dict)] - export_outputs_list_without_none = tensors[len(predictions_dict):] - - # Reinsert `None`s which we've taken out in - # `_build_computation_for_inference()`. - export_outputs_list = [] - while none_indices or export_outputs_list_without_none: - if none_indices and none_indices[0] == len(export_outputs_list): - export_outputs_list.append(None) - none_indices.pop(0) - else: - export_outputs_list.append(export_outputs_list_without_none.pop(0)) - - # Reconstruct `export_outputs` with updated tensors. - new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict, - export_outputs_list) - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_export_outputs_dict)) - # Reconstruct `predictions` with updated tensors. - new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list) - if (len(new_predictions) == 1 and - _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions): - new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR] - - return estimator_spec._replace( - export_outputs=new_export_outputs, predictions=new_predictions) - - def _build_computation_for_inference(self, features, labels, mode, config): - capture = _CapturedObject() - - def computation(): - """Computation to be passed to `TPUPartitionedCall()`.""" - tpu_computation, tpu_capture = self._build_tpu_computation_for_inference( - features, labels, mode, config) - - tensors_on_cpu = tpu.rewrite_for_inference(tpu_computation) - (estimator_spec, export_outputs_dict, export_outputs_list, - predictions_dict) = ( - tpu_capture.get()) - predictions_list = tensors_on_cpu[:len(predictions_dict)] - export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):] - - # Reconstruct tensors used in export_outputs, with TPU tensors replaced - # with their CPU counterpart returned from `rewrite_for_inference()`. - # `function.Defun()` does not like `None`s in return values, so we leave - # `None`s out but record their positions for later reconstruction. - export_outputs_list_without_none = [] - none_indices = [] - for i, t in enumerate(export_outputs_list): - if t is None: - none_indices.append(i) - else: - export_outputs_list_without_none.append( - export_outputs_tpu_on_cpu_list.pop(0)) - - capture.capture((estimator_spec, export_outputs_dict, predictions_dict, - none_indices)) - return predictions_list + export_outputs_list_without_none - - return computation, capture - - def _build_tpu_computation_for_inference(self, features, labels, mode, - config): - capture = _CapturedObject() - - def computation(): - """Compute tpu tensors used in export_outputs. - - Passed to rewrite_for_inference so that model_fn will be called under - the rewriting contexts. Only tpu tensors are returned, but export_outputs - and scaffold are captured. - - Returns: - A list of Tensors used in export_outputs and not marked for - outside_compilation. - """ - # We should only call model fn once and it should be inside `computation` - # so that building the graph will happen under `rewrite_for_inference`. - mode = model_fn_lib.ModeKeys.PREDICT - estimator_spec = self._call_model_fn(features, labels, mode, config) - - # We pick the TPU tensors out from `export_output` and later return them - # from `computation` for rewriting. - export_outputs_dict = collections.OrderedDict( - (k, _export_output_to_tensors(v)) - for k, v in six.iteritems(estimator_spec.export_outputs)) - export_outputs_list = nest.flatten(export_outputs_dict) - export_outputs_tpu_list = [ - t for t in export_outputs_list if t is not None - ] - - if isinstance(estimator_spec.predictions, dict): - predictions_dict = collections.OrderedDict( - (k, v) for k, v in six.iteritems(estimator_spec.predictions)) - else: - predictions_dict = { - _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions - } - predictions_list = nest.flatten(predictions_dict) - - # We cannot return everything we want through the return values, so - # capture the rest here for later use. - capture.capture((estimator_spec, export_outputs_dict, export_outputs_list, - predictions_dict)) - return predictions_list + export_outputs_tpu_list - - return computation, capture - - def _create_global_step(self, graph): - """Creates a global step suitable for TPUs. - - Args: - graph: The graph in which to create the global step. - - Returns: - A global step `Tensor`. - - Raises: - ValueError: if the global step tensor is already defined. - """ - return _create_global_step(graph) - - def _convert_train_steps_to_hooks(self, steps, max_steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) - - # On TPU. - if steps is None and max_steps is None: - raise ValueError( - 'For TPU training, one of `steps` or `max_steps` must be set. ' - 'Cannot be both `None`.') - - # Estimator.train has explicit positiveness check. - if steps is not None: - util_lib.check_positive_integer(steps, 'Train steps') - if max_steps is not None: - util_lib.check_positive_integer(max_steps, 'Train max_steps') - - return [ - _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) - ] - - def _convert_eval_steps_to_hooks(self, steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) - - if steps is None: - raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - - util_lib.check_positive_integer(steps, 'Eval steps') - - return [ - evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access - num_evals=steps), - _SetEvalIterationsHook(steps) - ] - - def _call_input_fn(self, input_fn, mode): - """Calls the input function. - - Args: - input_fn: The input function. - mode: ModeKeys - - Returns: - In TPU mode, returns an input_fn to be called later in model_fn. - Otherwise, calls the input_fn and returns either fatures or - (features, labels). - - Raises: - ValueError: if input_fn takes invalid arguments or does not have `params`. - """ - input_fn_args = function_utils.fn_args(input_fn) - config = self.config # a deep copy. - kwargs = {} - if 'params' in input_fn_args: - kwargs['params'] = self.params # a deep copy. - else: - raise ValueError('input_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params["batch_size"]'.format(input_fn)) - if 'config' in input_fn_args: - kwargs['config'] = config - - if 'mode' in input_fn_args: - kwargs['mode'] = mode - - # Records the fact input_fn has been invoked. - self._is_input_fn_invoked = True - - with self._ctx.with_mode(mode) as ctx: - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - batch_size_for_input_fn = ctx.batch_size_for_input_fn - if batch_size_for_input_fn is not None: - _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY, - batch_size_for_input_fn) - - # For export_savedmodel, input_fn is never passed to Estimator. So, - # `is_export_mode` must be False. - if ctx.is_running_on_cpu(is_export_mode=False): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - # For TPU computation, input_fn should be invoked in a tf.while_loop for - # performance. While constructing the tf.while_loop, the structure of - # inputs returned by the `input_fn` needs to be recorded. The structure - # includes whether features or labels is dict or single Tensor, dict keys, - # tensor shapes, and dtypes. The recorded structure is used to create the - # infeed dequeue ops, which must be wrapped and passed as a Fn, called - # inside the TPU computation, as the TPU computation is wrapped inside a - # tf.while_loop also. So, we either pass input_fn to model_fn or pass - # dequeue_fn to model_fn. Here, `input_fn` is passed directly as - # `features` in `model_fn` signature. - def _input_fn(ctx): - _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) - return input_fn(**kwargs) - - return _input_fn - - def _validate_features_in_predict_input(self, result): - """Skip the validation. - - For TPUEstimator, we do not need to check the result type. `_InputPipeline` - has stronger check. Parent class's check generates confusing warning msg. - - Args: - result: `features` returned by input_fn. - """ - pass - - def train(self, - input_fn, - hooks=None, - steps=None, - max_steps=None, - saving_listeners=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous - try: - return super(TPUEstimator, self).train( - input_fn=input_fn, - hooks=hooks, - steps=steps, - max_steps=max_steps, - saving_listeners=saving_listeners) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('training_loop', sys.exc_info()) - finally: - rendezvous.record_done('training_loop') - rendezvous.raise_errors() - - def evaluate(self, - input_fn, - steps=None, - hooks=None, - checkpoint_path=None, - name=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous - try: - return super(TPUEstimator, self).evaluate( - input_fn, - steps=steps, - hooks=hooks, - checkpoint_path=checkpoint_path, - name=name) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('evaluation_loop', sys.exc_info()) - finally: - rendezvous.record_done('evaluation_loop') - rendezvous.raise_errors() - - def predict(self, - input_fn, - predict_keys=None, - hooks=None, - checkpoint_path=None, - yield_single_examples=True): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous - try: - for result in super(TPUEstimator, self).predict( - input_fn=input_fn, - predict_keys=predict_keys, - hooks=hooks, - checkpoint_path=checkpoint_path, - yield_single_examples=yield_single_examples): - yield result - except Exception: # pylint: disable=broad-except - rendezvous.record_error('prediction_loop', sys.exc_info()) - finally: - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - def _augment_model_fn(self, model_fn, batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - with self._ctx.with_mode(mode) as ctx: - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - - # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, - # but not in `export_savedmodel()`. - if self._is_input_fn_invoked: - is_export_mode = False - else: - is_export_mode = True - - # Clear the bit. - self._is_input_fn_invoked = None - - # examples_hook is added to training_hooks for both CPU and TPU - # execution. - if self._log_every_n_steps is not None: - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - # pylint:disable=g-long-ternary - output_dir=(self.model_dir - if not config or config.save_summary_steps - else None), - # pylint:enable=g-long-ternary - every_n_steps=self._log_every_n_steps) - - if ctx.is_running_on_cpu(is_export_mode=is_export_mode): - logging.info('Running %s on CPU', mode) - estimator_spec = model_fn_wrapper.call_without_tpu( - features, labels, is_export_mode=is_export_mode) - if self._log_every_n_steps is not None: - estimator_spec = estimator_spec._replace( - training_hooks=estimator_spec.training_hooks + (examples_hook,)) - return estimator_spec - - assert labels is None, '`labels` passed to `model_fn` must be `None`.' - # TPUEstimator._call_input_fn passes `input_fn` as features to here. - assert callable(features), '`input_fn` is not callable.' - input_fn = features - - input_holders = _InputPipeline(input_fn, batch_axis, ctx) - enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( - input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) - - graph = ops.get_default_graph() - for enqueue_op in enqueue_ops: - if isinstance(enqueue_op, list): - graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op) - else: - graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) - - if mode == model_fn_lib.ModeKeys.TRAIN: - compile_op, loss, host_call, scaffold, training_hooks = ( - _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) - host_ops = host_call.create_tpu_hostcall() - if host_ops is None: - host_ops = [] - - shutdown_hooks = [] - shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', - 'shutdown_worker') - if shutdown_mode: - if shutdown_mode == 'shutdown_worker': - finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=60 * 1000), - ] - elif shutdown_mode == 'shutdown_computation': - finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=60 * 1000), - ] - else: - raise ValueError( - 'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode) - - shutdown_hooks.append( - session_support.GracefulShutdownHook( - checkpoint_prefix=self.model_dir + '/model.ckpt', - on_shutdown_hooks=finalizer_hooks)) - - with ops.control_dependencies([loss]): - global_step = array_ops.identity(training.get_global_step()) - hooks = input_hooks + shutdown_hooks - hooks.extend([ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - host_ops, - tpu_compile_op=compile_op, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - master=self._config.master, - session_config=self._session_config, - ), - InstallSignalHandlerHook() - ]) - if self._log_every_n_steps is not None: - logging_hook_frequency = ( # Divide and round up - (self._log_every_n_steps + - self._config.tpu_config.iterations_per_loop - 1) // - self._config.tpu_config.iterations_per_loop) - hooks.append( - training.LoggingTensorHook({ - 'loss': array_ops.identity(loss), - 'step': global_step, - }, - every_n_iter=logging_hook_frequency)) - examples_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - hooks.append(examples_hook) - - if training_hooks: - hooks.extend(training_hooks) - - chief_hooks = [] - if (self._config.save_checkpoints_secs or - self._config.save_checkpoints_steps): - checkpoint_hook = training.CheckpointSaverHook( - self.model_dir, - save_secs=self._config.save_checkpoints_secs, - save_steps=self._config.save_checkpoints_steps, - scaffold=scaffold) - checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - chief_hooks.append(checkpoint_hook) - - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops(ctx) - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - train_op = control_flow_ops.group(*update_ops) - graph.add_to_collection(_TPU_TRAIN_OP, train_op) - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_chief_hooks=chief_hooks, - training_hooks=hooks, - train_op=train_op, - scaffold=scaffold) - - if mode == model_fn_lib.ModeKeys.EVAL: - compile_op, total_loss, host_calls, scaffold, eval_hooks = ( - _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) - iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), - # reads all variables back from TPU and updates the eval step - # counter properly - internal_ops_to_run = _sync_variables_ops(ctx) - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - - host_call_ret = host_calls.create_tpu_hostcall() - eval_metric_ops = {} - eval_update_ops = [] - - eval_metrics = host_call_ret.get('eval_metrics', {}) - if eval_metrics: - # Creates a dummy metric update_op for all metrics. Estimator - # expects all metrics in `eval_metric_ops` have update_op and calls - # them one by one. The real metric update_ops are invoked in a - # separated thread. So, here give Estimator the dummy op for all - # metrics. - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - for k, v in eval_metrics.items(): - eval_metric_ops[k] = (v[0], dummy_update_op) - eval_update_ops.append(v[1]) - else: - # If no eval metrics are passed, create an identity node for the - # loss and add `internal_ops_to_run` to its dependencies. So - # `internal_ops_to_run` can be executed. - with ops.control_dependencies(internal_ops_to_run): - mean_loss = array_ops.identity(mean_loss) - - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - hooks = [ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - eval_update_ops + host_ops, - tpu_compile_op=compile_op, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - master=self._config.evaluation_master, - session_config=self._session_config, - )] + input_hooks - - if eval_hooks: - hooks.extend(eval_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops, - scaffold=scaffold) - - # Predict - assert mode == model_fn_lib.ModeKeys.PREDICT - - (compile_op, dummy_predict_op, host_calls, - scaffold, prediction_hooks) = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) - with ops.control_dependencies([dummy_predict_op]): - internal_ops_to_run = _sync_variables_ops(ctx) - with ops.control_dependencies(internal_ops_to_run): - dummy_predict_op = control_flow_ops.no_op() - - # In train and evaluation, the main TPU program is passed to monitored - # training session to run. Infeed enqueue and outfeed dequeue are - # executed in side threads. This is not the configuration for - # prediction mode. - # - # For prediction, the Estimator executes the EstimatorSpec.predictions - # directly and yield the element (via generator) to call site. So, the - # outfeed based prediction must be passed to MonitoredSession directly. - # Other parts of the TPU execution are organized as follows. - # - # 1. All outfeed based Tensors must be grouped with predictions Tensors - # to form a single invocation. This avoid the issue we might trigger - # multiple outfeeds incorrectly. To achieve this, `host_call` is - # placed in control_dependencies of `stopping_signals`, and - # `stopping_signals` is passed into _StoppingPredictHook, which sets - # the `stopping_signals` as SessionRunArgs. MonitoredSession merges - # all SessionRunArgs with the fetch in session.run together. - # - # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) - # are grouped together. They will be launched once and only once in - # side threads and they quit naturally according to the SAME stopping - # condition. - enqueue_ops.append(dummy_predict_op) - - host_call_ret = host_calls.create_tpu_hostcall() - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - - predictions = host_call_ret['predictions'] - _verify_cross_hosts_transfer_size( - predictions, - message=( - 'The estimated size for TPUEstimatorSpec.predictions is too ' - 'large.')) - signals = host_call_ret['signals'] - - with ops.control_dependencies(host_ops): - host_ops = [] # Empty, we do do not need it anymore. - scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( - signals) - predictions = _PaddingSignals.slice_tensor_or_dict( - predictions, signals) - - hooks = [ - _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], - tpu_compile_op=compile_op, - master=self._config.master, - session_config=self._session_config), - ] + input_hooks - - if prediction_hooks: - hooks.extend(prediction_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - prediction_hooks=hooks, - predictions=predictions, - scaffold=scaffold) - - return _model_fn - - -def _export_output_to_tensors(export_output): - """Get a list of `Tensors` used in `export_output`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - - Returns: - a list of tensors used in export_output. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - return [export_output.scores, export_output.classes] - elif isinstance(export_output, export_output_lib.RegressionOutput): - return [export_output.value] - elif isinstance(export_output, export_output_lib.PredictOutput): - return list(export_output.outputs.values()) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _clone_export_output_with_tensors(export_output, tensors): - """Clones `export_output` but with new `tensors`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - tensors: a list of `Tensors` used to construct a new `export_output`. - - Returns: - A dict similar to `export_output` but with `tensors`. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - if len(tensors) != 2: - raise ValueError('tensors must be of length 2; ' - 'got {}.'.format(len(tensors))) - return export_output_lib.ClassificationOutput(*tensors) - elif isinstance(export_output, export_output_lib.RegressionOutput): - if len(tensors) != 1: - raise ValueError('tensors must be of length 1; ' - 'got {}'.format(len(tensors))) - return export_output_lib.RegressionOutput(*tensors) - elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output_lib.PredictOutput( - dict(zip(export_output.outputs.keys(), tensors))) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) - - def multi_tpu_eval_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, - [_ZERO_LOSS]) - - (compile_op, loss,) = tpu.split_compile_and_shard( - multi_tpu_eval_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - loss = loss[0] - scaffold = _get_scaffold(captured_scaffold_fn) - return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() - - -def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_train_step, host_call, captured_scaffold_fn, - captured_training_hooks) = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) - - def multi_tpu_train_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, - [_INITIAL_LOSS]) - - (compile_op, loss,) = tpu.split_compile_and_shard( - multi_tpu_train_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - loss = loss[0] - scaffold = _get_scaffold(captured_scaffold_fn) - return compile_op, loss, host_call, scaffold, captured_training_hooks.get() - - -def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - (single_tpu_predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks - ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) - - def multi_tpu_predict_steps_on_single_shard(): - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - inputs = [_StopSignals.NON_STOPPING_SIGNAL] - outputs = training_loop.while_loop( - cond, single_tpu_predict_step, inputs=inputs, name=b'loop') - return outputs - - (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( - multi_tpu_predict_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - dummy_predict_op = dummy_predict_op[0] - scaffold = _get_scaffold(captured_scaffold_fn) - return (compile_op, dummy_predict_op, host_calls, scaffold, - captured_predict_hooks.get()) - - -def _wrap_computation_in_while_loop(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def computation(i): - with ops.control_dependencies(op_fn()): - return i + 1 - - iterations_per_loop_var = _create_or_get_iterations_per_loop() - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.device(device): - iterations = array_ops.identity(iterations_per_loop_var) - return control_flow_ops.while_loop( - lambda i: i < iterations, - computation, [constant_op.constant(0)], - parallel_iterations=1) - - -def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - def computation(unused_scalar_stopping_signal): - return_value = op_fn() - execute_ops = return_value['ops'] - signals = return_value['signals'] - with ops.control_dependencies(execute_ops): - return _StopSignals.as_scalar_stopping_signal(signals) - - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.device(device): - return control_flow_ops.while_loop( - cond, - computation, [_StopSignals.NON_STOPPING_SIGNAL], - parallel_iterations=1) - - -def _validate_tpu_training_graph(): - """Validate graph before running distributed training. - - Raises: - ValueError: If the graph seems invalid for running on device - """ - operations = ops.get_default_graph().get_operations() - - # Check if there is atleast one CrossReplicaSum operation in the graph - # This should be introduced by using the CrossShardOptimizer wrapper - cross_replica_sum_ops = [ - o for o in operations if o.type == _CROSS_REPLICA_SUM_OP - ] - if not cross_replica_sum_ops: - raise ValueError( - 'CrossShardOptimizer must be used for model training on TPUs.') - - -class _CapturedObject(object): - """A placeholder to capture an object. - - This is useful when we need to capture a Python object in the Tensorflow - control flow body function and use it outside the control flow. - """ - - def __init__(self): - self._object = None - self._captured = False - - def capture(self, o): - if self._captured: - raise RuntimeError( - 'InternalError: Object can capture only once. Please file bug.') - - self._captured = True - self._object = o - - def get(self): - if not self._captured: - raise RuntimeError( - 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug.') - return self._object - - -def _get_scaffold(captured_scaffold_fn): - """Retrieves the Scaffold from `captured_scaffold_fn`.""" - with _CapturingContext(message='Inside scaffold_fn'): - scaffold_fn = captured_scaffold_fn.get() - if scaffold_fn: - scaffold = scaffold_fn() - if scaffold is None: - raise ValueError( - 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') - else: - scaffold = None - - if scaffold: - wrapped_finalize = scaffold.finalize - - def _finalize(): - with _CapturingContext('Inside Scaffold.finalize'): - wrapped_finalize() - - scaffold.finalize = _finalize - return scaffold - - -class _CapturingContext(control_flow_ops.ControlFlowContext): - """Tracks references to Tensors defined in TPU replication.""" - - def __init__(self, message): - control_flow_ops.ControlFlowContext.__init__(self) - self._message = message - - def to_control_flow_context_def(self, context_def, export_scope=None): - # pylint: disable=useless-super-delegation - # NOTE(slebedev): the method is required by `ControlFlowContext`. - super(_CapturingContext, self).to_control_flow_context_def( - context_def, export_scope) - - def AddOp(self, op): # pylint: disable=invalid-name - for c in op.inputs: - if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access - raise ValueError('{}: Op {} depends on TPU computation {}, ' - 'which is not allowed.'.format(self._message, op, c)) - - def __enter__(self): - # pylint: disable=protected-access - self._g = ops.get_default_graph() - self._old = self._g._get_control_flow_context() - self._g._set_control_flow_context(self) - # pylint: enable=protected-access - - def __exit__(self, _, __, ___): # pylint: disable=invalid-name - self._g._set_control_flow_context(self._old) # pylint: disable=protected-access - - -class _Inputs(object): - """A data structure representing the input_fn returned values. - - This also supports the returned value from input_fn as `Dataset`. - """ - - def __init__(self, features=None, labels=None, dataset=None, signals=None): - if dataset is not None and (features is not None or labels is not None or - signals is not None): - raise RuntimeError('Internal Error: Either (features and labels) or ' - 'dataset should be provided, not both. Please file ' - 'bug') - - self._features = features - self._labels = labels - self._signals = signals - - self._dataset = dataset - self._iterator = None - - @staticmethod - def from_input_fn(return_values): - """Returns an `_Inputs` instance according to `input_fn` return value.""" - if isinstance(return_values, dataset_ops.DatasetV2): - dataset = return_values - return _Inputs(dataset=dataset) - - features, labels = _Inputs._parse_inputs(return_values) - return _Inputs(features, labels) - - @staticmethod - def _parse_inputs(return_values): - if isinstance(return_values, tuple): - features, labels = return_values - else: - features, labels = return_values, None - return features, labels - - @property - def is_dataset(self): - """Returns True if the return value from input_fn is Dataset.""" - return self._dataset is not None - - def dataset_initializer(self): - """Returns the dataset's initializer. - - The initializer must be run before calling `features_and_labels`. - """ - self._iterator = dataset_ops.make_initializable_iterator(self._dataset) - return self._iterator.initializer - - def features_and_labels(self): - """Gets `features` and `labels`.""" - if self.is_dataset: - if self._iterator is None: - raise RuntimeError('Internal error: Must run dataset_initializer ' - 'before calling features_and_labels(). Please file ' - 'a bug!') - return _Inputs._parse_inputs(self._iterator.get_next()) - - return (self._features, self._labels) - - def signals(self): - return self._signals - - @property - def dataset(self): - return self._dataset - - -class _InputsWithStoppingSignals(_Inputs): - """Inputs with `_StopSignals` inserted into the dataset.""" - - def __init__(self, - dataset, - batch_size, - add_padding=False, - num_invocations_per_step=1): - - assert dataset is not None - user_provided_dataset = dataset.map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=False, batch_size=batch_size, add_padding=add_padding)) - if num_invocations_per_step == 1: - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - else: - # We append (2 * num_invocations_per_step - 1) batches for exhausting the - # user_provided_dataset and stop properly. - # For example, if num_invocations_per_step is 2, we append 3 additional - # padding batches: b1, b2, b3. - # If user_provided_dataset contains two batches: a1, a2 - # Step 1: [a1, a2] - # Step 2: [b1, b2] -> STOP - # If user_provided_dataset contains three batches: a1, a2, a3. - # The training loops: - # Step 1: [a1, a2] - # Step 2: [a3, b1] - # Step 3: [b2, b3] -> STOP. - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - final_batch_dataset = final_batch_dataset.repeat( - 2 * num_invocations_per_step - 1) - - def _set_mask(data_dict): - signals = data_dict['signals'] - signals['padding_mask'] = array_ops.ones_like(signals['padding_mask']) - data_dict['signals'] = signals - return data_dict - - # Mask out the extra batch. - final_batch_dataset = final_batch_dataset.map(_set_mask) - - dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) - - super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) - self._current_inputs = None - - def features_and_labels(self): - if self._current_inputs is not None: - raise RuntimeError( - 'Internal Error: The previous inputs have not been properly ' - 'consumed. First call features_and_labels, then call signals.') - - inputs_with_signals = self._iterator.get_next() - features = inputs_with_signals['features'] - labels = inputs_with_signals.get('labels') - - self._current_inputs = inputs_with_signals - return features, labels - - def signals(self): - """Returns the `Signals` from `_Inputs`.""" - if self._current_inputs is None: - raise RuntimeError( - 'Internal Error: The current inputs have not been properly ' - 'generated. First call features_and_labels, then call signals.') - signals = self._current_inputs['signals'] - self._current_inputs = None - return signals - - @staticmethod - def insert_stopping_signal(stop, batch_size, add_padding=False): - """Inserts stopping_signal into dataset via _map_fn. - - Here we change the data structure in the dataset, such that the return value - is a dictionary now and `features`, `labels`, and `signals` are three - distinguished keys in that dict. This provides a better structure, which - eases the process to decompose the inputs (see `features_and_labels`). - - Args: - stop: bool, state of current stopping signals. - batch_size: int, batch size. - add_padding: bool, whether to pad the tensor to full batch size. - - Returns: - A map_fn passed to dataset.map API. - """ - - def _map_fn(*args): - """The map fn to insert signals.""" - if len(args) == 1: - # Unpack the single Tensor/dict argument as features. This is required - # for the input_fn returns no labels. - args = args[0] - features, labels = _Inputs._parse_inputs(args) - new_input_dict = {} - - if add_padding: - padding_mask, features, labels = ( - _PaddingSignals.pad_features_and_labels(features, labels, - batch_size)) - - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - - else: - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - padding_mask = None - - new_input_dict['signals'] = _StopSignals( - stop=stop, batch_size=batch_size, - padding_mask=padding_mask).as_dict() - - return new_input_dict - - return _map_fn - - -class _StopSignals(object): - """Signals class holding all logic to handle TPU stopping condition.""" - - NON_STOPPING_SIGNAL = False - STOPPING_SIGNAL = True - - def __init__(self, stop, batch_size, padding_mask=None): - self._stop = stop - self._batch_size = batch_size - self._padding_mask = padding_mask - - def as_dict(self): - """Returns the signals as Python dict.""" - shape = [self._batch_size, 1] - dtype = dtypes.bool - - if self._stop: - stopping = array_ops.ones(shape=shape, dtype=dtype) - else: - stopping = array_ops.zeros(shape=shape, dtype=dtype) - - signals = {'stopping': stopping} - if self._padding_mask is not None: - signals['padding_mask'] = self._padding_mask - return signals - - @staticmethod - def as_scalar_stopping_signal(signals): - return array_ops.identity(signals['stopping'][0][0]) - - @staticmethod - def should_stop(scalar_stopping_signal): - """Detects whether scalar_stopping_signal indicates stopping.""" - if isinstance(scalar_stopping_signal, ops.Tensor): - # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF - # way to express the bool check whether scalar_stopping_signal is True. - return math_ops.logical_and(scalar_stopping_signal, - _StopSignals.STOPPING_SIGNAL) - else: - # For non Tensor case, it is used in SessionRunHook. So, we cannot modify - # the graph anymore. Here, we use pure Python. - return bool(scalar_stopping_signal) - - -class _PaddingSignals(object): - """Signals class holding all logic to handle padding.""" - - @staticmethod - def pad_features_and_labels(features, labels, batch_size): - """Pads out the batch dimension of features and labels.""" - real_batch_size = array_ops.shape( - _PaddingSignals._find_any_tensor(features))[0] - - batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) - - check_greater = check_ops.assert_greater_equal( - batch_size_tensor, - real_batch_size, - data=(batch_size_tensor, real_batch_size), - message='The real batch size should not be greater than batch_size.') - - with ops.control_dependencies([check_greater]): - missing_count = batch_size_tensor - real_batch_size - - def pad_single_tensor(tensor): - """Pads out the batch dimension of a tensor to the complete batch_size.""" - rank = len(tensor.shape) - assert rank > 0 - padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) - padded_shape = (batch_size,) + tuple(tensor.shape[1:]) - padded_tensor = array_ops.pad(tensor, padding) - padded_tensor.set_shape(padded_shape) - return padded_tensor - - def nest_pad(tensor_or_dict): - return nest.map_structure(pad_single_tensor, tensor_or_dict) - - features = nest_pad(features) - if labels is not None: - labels = nest_pad(labels) - - padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count, - batch_size) - - return padding_mask, features, labels - - @staticmethod - def slice_tensor_or_dict(tensor_or_dict, signals): - """Slice the real Tensors according to padding mask in signals.""" - - padding_mask = signals['padding_mask'] - batch_size = array_ops.shape(padding_mask)[0] - - def verify_batch_size(tensor): - check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) - with ops.control_dependencies([check_batch_size]): - return array_ops.identity(tensor) - - def slice_single_tensor(tensor): - rank = len(tensor.shape) - assert rank > 0 - real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) - return verify_batch_size(tensor)[0:real_batch_size] - - # As we split the Tensors to all TPU cores and concat them back, it is - # important to ensure the real data is placed before padded ones, i.e., - # order is preserved. By that, the sliced padding mask should have all 0's. - # If this assertion failed, # the slice logic here would not hold. - sliced_padding_mask = slice_single_tensor(padding_mask) - assert_padding_mask = math_ops.equal( - math_ops.reduce_sum(sliced_padding_mask), 0) - - with ops.control_dependencies([assert_padding_mask]): - should_stop = _StopSignals.should_stop( - _StopSignals.as_scalar_stopping_signal(signals)) - - is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) - - def slice_fn(tensor): - # If the current batch is full batch or part of stopping signals, we do - # not need to slice to save performance. - return control_flow_ops.cond( - math_ops.logical_or(should_stop, is_full_batch), - (lambda: verify_batch_size(tensor)), - (lambda: slice_single_tensor(tensor))) - - return nest.map_structure(slice_fn, tensor_or_dict) - - @staticmethod - def _find_any_tensor(batch_features): - tensors = [ - x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor) - ] - if not tensors: - raise ValueError('Cannot find any Tensor in features dict.') - return tensors[0] - - @staticmethod - def _padding_mask(real_batch_size, missing_count, batch_size): - padding_mask = array_ops.concat([ - array_ops.zeros((real_batch_size,), dtype=dtypes.int32), - array_ops.ones((missing_count,), dtype=dtypes.int32) - ], - axis=0) - padding_mask.set_shape((batch_size,)) - return padding_mask - - -def _verify_cross_hosts_transfer_size(tensor_dict, message): - total_size = 0 - tensor_structure = {} - for key, tensor in tensor_dict.items(): - shape = tensor.shape - size = np.product(shape) * tensor.dtype.size - tensor_structure[key] = shape - total_size += size - if total_size >= _ONE_GIGABYTE: - raise ValueError( - '{} The transfer size is larger than the protobuf limit. Please ' - 'consider to use Tensors with smaller shapes or reduce batch ' - 'size. Given:\n' - '{}'.format( - message, '\n'.join([ - ' -- Key: {}, Shape: {}'.format(k, v) - for k, v in tensor_structure.items() - ]))) - - -def _add_item_to_params(params, key, value): - """Adds a new item into `params`.""" - if isinstance(params, hparam.HParams): - # For HParams, we need to use special API. - if key in params: - params.set_hparam(key, value) - else: - params.add_hparam(key, value) - else: - # Now params is Python dict. - params[key] = value - - -def export_estimator_savedmodel(estimator, - export_dir_base, - serving_input_receiver_fn, - assets_extra=None, - as_text=False, - checkpoint_path=None, - strip_default_attrs=False): - """Export `Estimator` trained model for TPU inference. - - Args: - estimator: `Estimator` with which model has been trained. - export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - serving_input_receiver_fn: A function that takes no argument and returns a - `ServingInputReceiver` or `TensorServingInputReceiver`. - assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. - checkpoint_path: The checkpoint path to export. If `None` (the default), - the most recent checkpoint found within the model directory is chosen. - strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. - - Returns: - The string path to the exported directory. - """ - # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use - # `estimator.config`. - config = tpu_config.RunConfig(model_dir=estimator.model_dir) - est = TPUEstimator( - estimator._model_fn, # pylint: disable=protected-access - config=config, - params=estimator.params, - use_tpu=True, - train_batch_size=2048, # Does not matter. - eval_batch_size=2048, # Does not matter. - ) - return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, - assets_extra, as_text, checkpoint_path, - strip_default_attrs) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_estimator import * +# used by tests +from tensorflow.python.tpu.tpu_estimator import _clone_export_output_with_tensors +from tensorflow.python.tpu.tpu_estimator import _create_global_step +from tensorflow.python.tpu.tpu_estimator import _export_output_to_tensors +from tensorflow.python.tpu.tpu_estimator import _get_scaffold +from tensorflow.python.tpu.tpu_estimator import _Inputs +from tensorflow.python.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR +from tensorflow.python.tpu.tpu_estimator import _TPU_ENQUEUE_OPS +from tensorflow.python.tpu.tpu_estimator import _TPU_ESTIMATOR +from tensorflow.python.tpu.tpu_estimator import _TPU_TRAIN_OP +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index d5957b7e8ec40b40c7af8822378cee6134ef0d0f..af2542ea85290170ce6a38223188c4f9b871f032 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -1,898 +1,25 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""Helper library for handling infeed between hosts and TPUs. -""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_sharding - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.util import nest - - -class InfeedQueue(object): - """A helper object to build a device infeed queue. - - The InfeedQueue builds the host-side and device-side Ops to enqueue and - dequeue elements, respectively, and ensures that their types and - shapes match. - """ - - def __init__(self, - number_of_tuple_elements=None, - tuple_types=None, - tuple_shapes=None, - shard_dimensions=None, - name=None): - """Creates a new InfeedQueue with the given configuration. - - The configuration need not be fully specified at creation since it - can be modified subsequently by methods that set the values - explicitly or infer them from the shapes of inputs. - - Args: - number_of_tuple_elements: the number of Tensors fed atomically through the - queue, must be present unless it can be inferred from other arguments. - tuple_types: if not None, a list of types of the elements of the queue. - tuple_shapes: if not None, a list of shapes of the elements of the queue. - shard_dimensions: if not None, a list of dimensions on which the - elements of the queue should be sharded during automatic - parallelization. - name: the name of the queue. - - Raises: - ValueError: if number_of_tuple_elements <= 0; or - number_of_tuple_arguments, tuple_types, tuple_shapes, and - shard_dimensions are all None; or the length of tuple_types, - tuple_shapes, or shard_dimensions is not equal to - number_of_tuple_elements; or any element of shard_dimensions - can't be converted to a Dimension. - TypeError: if any element of tuple_types or tuple_shapes can't - be converted to a dtype or TensorShape, respectively. - """ - self._frozen = False - self._generated_enqueue_ops = False - self._generated_dequeue_op = False - self._name = "InfeedQueue" if name is None else name - if number_of_tuple_elements is None: - if tuple_types is not None: - number_of_tuple_elements = len(tuple_types) - elif tuple_shapes is not None: - number_of_tuple_elements = len(tuple_shapes) - elif shard_dimensions is not None: - number_of_tuple_elements = len(shard_dimensions) - else: - raise ValueError( - "number of tuple elements cannot be inferred from InfeedQueue " - "constructor") - if number_of_tuple_elements <= 0: - raise ValueError("number_of_tuple_elements %d must be > 0" % - number_of_tuple_elements) - # Make an empty sharding policy for each tuple element. - self._sharding_policies = [ - tpu_sharding.ShardingPolicy() - for _ in xrange(number_of_tuple_elements) - ] - if tuple_types is not None: - self.set_tuple_types(tuple_types) - else: - self._tuple_types = None - if tuple_shapes is not None: - self.set_tuple_shapes(tuple_shapes) - else: - self._tuple_shapes = None - if shard_dimensions is not None: - self.set_shard_dimensions(shard_dimensions) - self._validate() - - def _validate(self): - """Checks that the configuration is self-consistent. - - Raises: - ValueError: if the shapes and sharding policies don't match. - """ - if self.tuple_shapes is not None: - for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): - # Raise an error if the policy is incompatible with the shape. - _ = policy.get_sharded_shape(shape) - - @property - def number_of_tuple_elements(self): - """Returns the number of InfeedQueue tuple elements.""" - return len(self._sharding_policies) - - @property - def tuple_types(self): - """Returns the types of the InfeedQueue tuple elements.""" - return self._tuple_types - - def set_tuple_types(self, tuple_types): - """Sets the type of each element of the queue. - - tuple_types must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a dtype. - - Args: - tuple_types: the types of each queue element. - - Raises: - ValueError: if tuple_types is not of length - self.number_of_tuple_elements. - TypeError: if an element of tuple_types cannot be converted to a - dtype. - """ - if len(tuple_types) != self.number_of_tuple_elements: - raise ValueError("tuple_types is %s, but must be a list of length %d" % - (str(tuple_types), self.number_of_tuple_elements)) - if self._frozen: - for (frozen, updated) in zip(self._tuple_types, tuple_types): - if frozen != updated: - raise ValueError( - "Trying to update InfeedQueue with frozen configuration with an " - "incompatible type. Frozen types are %s, updated types are %s" % ( - str(self._tuple_types), str(tuple_types))) - else: - try: - self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] - except (TypeError) as e: - raise TypeError( - "tuple_types is %s, but must be a list of elements each " - "convertible to dtype: got error %s" % (str(tuple_types), str(e))) - - @property - def tuple_shapes(self): - """Returns the shapes of the InfeedQueue tuple elements.""" - return self._tuple_shapes - - def set_tuple_shapes(self, tuple_shapes): - """Sets the shape of each element of the queue. - - tuple_shapes must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a TensorShape. - - Args: - tuple_shapes: the shapes of each queue element. - - Raises: - ValueError: if tuple_shapes is not of length - self.number_of_tuple_elements. - TypeError: if an element of tuple_shapes cannot be converted to - a TensorShape. - """ - if len(tuple_shapes) != self.number_of_tuple_elements: - raise ValueError("tuple_shapes is %s, but must be a list of length %d" % - (str(tuple_shapes), self.number_of_tuple_elements)) - try: - tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] - except (ValueError, TypeError) as e: - raise TypeError( - "tuple_shapes is %s, but must be a list of elements each " - "convertible to TensorShape: got error %s" % (str(tuple_shapes), - str(e))) - if self._frozen: - for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): - if frozen != updated: - raise ValueError( - "Trying to update InfeedQueue with frozen configuration with an " - "incompatible shape. Frozen shapes are %s, updated shapes are %s" - % (str(self._tuple_shapes), str(tuple_shapes))) - else: - self._tuple_shapes = tuple_shapes - self._validate() - - @property - def sharding_policies(self): - """Returns the sharding policies of the InfeedQueue tuple elements.""" - return self._sharding_policies - - @property - def shard_dimensions(self): - """Gets the shard dimension of each tuple element. - - Returns: - A list of length number_of_tuple_elements, where each list entry - is the shard dimension of that tuple element or None if the - shard dimension has not been set. - """ - # The number of shards is always the same for all the policies. - return [policy.shard_dimension for policy in self._sharding_policies] - - def set_shard_dimensions(self, shard_dimensions): - """Sets the shard_dimension of each element of the queue. - - shard_dimensions must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a Dimension compatible with self.tuple_shapes. - - Args: - shard_dimensions: the dimensions of each queue element. - - Raises: - ValueError: if shard_dimensions is not of length - self.number_of_tuple_elements; or an element of - shard_dimensions cannot be converted to a Dimension; or an - element of shard_dimensions is a Dimension that is out of - range for the corresponding tuple element shape. - """ - if len(shard_dimensions) != self.number_of_tuple_elements: - raise ValueError("shard_dimensions is %s, but must be a list of length %d" - % (str(shard_dimensions), - self.number_of_tuple_elements)) - for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): - policy.set_shard_dimension(dimension) - self._validate() - - @property - def number_of_shards(self): - """Gets the number of shards to use for the InfeedQueue. - - Returns: - Number of shards or None if the number of shards has not been set. - """ - # The number of shards is always the same for all the policies. - return self._sharding_policies[0].number_of_shards - - def set_number_of_shards(self, number_of_shards): - """Sets the number of shards to use for the InfeedQueue. - - Args: - number_of_shards: number of ways to shard the InfeedQueue. - - Raises: - ValueError: if number_of_shards is not > 0; or the policies have - been frozen and number_of_shards was already set to something - else. - """ - for policy in self._sharding_policies: - policy.set_number_of_shards(number_of_shards) - self._validate() - - def set_configuration_from_input_tensors(self, input_tensors): - """Sets the shapes and types of the queue tuple elements. - - input_tensors is a list of Tensors whose types and shapes are used - to set the queue configuration. - - Args: - input_tensors: list of Tensors of the same types and shapes as - the desired queue Tuple. - - Raises: - ValueError: if input_tensors is not a list of length - self.number_of_tuple_elements - """ - if len(input_tensors) != self.number_of_tuple_elements: - raise ValueError("input_tensors is %s, but should be a list of %d Tensors" - % (str(input_tensors), self.number_of_tuple_elements)) - self.set_tuple_shapes([t.shape for t in input_tensors]) - self.set_tuple_types([t.dtype for t in input_tensors]) - - def set_configuration_from_sharded_input_tensors(self, input_tensors): - """Sets the shapes and types of the queue tuple elements. - - input_tensors is a list of lists of Tensors whose types and shapes are used - to set the queue configuration. The length of the outer list is the number - of shards required, and each inner list is the tuple of Tensors to use to - determine the types and shapes of the corresponding shard. This method - depends on the shard dimension, and calling it freezes the shard policy. - - Args: - input_tensors: list of lists of Tensors. The outer list length corresponds - to the desired number of shards, and each inner list is the size - and shape of the desired configuration of the corresponding shard. - - Raises: - ValueError: if any inner list is not a list of length - self.number_of_tuple_elements; or the inner lists do not combine to - form a consistent unsharded shape. - TypeError: if the types of the Tensors in the inner lists do not match. - """ - if not self._frozen: - # Unset the tuple shapes in case the configuration becomes - # transiently inconsistent. - self._tuple_shapes = None - number_of_shards = len(input_tensors) - self.set_number_of_shards(number_of_shards) - for t in input_tensors: - if len(t) != self.number_of_tuple_elements: - raise ValueError( - "input_tensors is %s but must be a list of lists, where each inner" - " list has length number_of_tuple_elements=%d" % ( - str(input_tensors), self.number_of_tuple_elements)) - # Transpose the inputs to make a list of shard shapes for each tuple - # element. - sharded_shapes = [[t[i].shape for t in input_tensors] - for i in xrange(self.number_of_tuple_elements)] - # For each tuple, get the unsharded shape using that tuple's policy. - unsharded_shapes = [ - policy.get_unsharded_shape(s) - for (policy, s) in zip(self._sharding_policies, sharded_shapes) - ] - self.set_tuple_shapes(unsharded_shapes) - for i in xrange(1, self.number_of_shards): - for (t1, t2) in zip(input_tensors[0], input_tensors[i]): - if t1.dtype != t2.dtype: - raise TypeError( - "types of the tuple elements of input_tensors %s are not " - "consistent" % str(input_tensors)) - self.set_tuple_types([t.dtype for t in input_tensors[0]]) - - def freeze(self): - """Freezes the InfeedQueue so it can no longer be modified. - - The configuration is implicitly frozen before any host-side or - device-side Ops are generated. The configuration cannot be frozen - until the types and shapes of the tuple elements have been set. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set. - """ - self._frozen = True - if self._tuple_types is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple types.") - if self._tuple_shapes is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple shapes.") - for shape in self._tuple_shapes: - if shape.dims is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple shapes.") - for policy in self._sharding_policies: - policy.freeze() - self._validate() - - def generate_dequeue_op(self, tpu_device=0): - """Generates the device-side Op to dequeue a tuple from the queue. - - Implicitly freezes the queue configuration if it is not already - frozen, which will raise errors if the shapes and types have not - been fully specified. - - Args: - tpu_device: The TPU device ordinal where the infeed instruction should be - placed. If None, no explicit placement will be performed, and it is up - to the user to call this API from within a proper TPU device scope. - The XLA code will fail if the TPU dequeue instruction is not bound to - any device. - - Returns: - A list of Outputs corresponding to a shard of infeed dequeued - into XLA, suitable for use within a replicated block. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set; or if a dequeue op has already been generated. - """ - self.freeze() - if self._generated_dequeue_op: - raise ValueError("Can't generate two dequeue Ops from the same queue") - self._generated_dequeue_op = True - full_name = "%s/dequeue" % self._name - sharded_shapes = [ - policy.get_sharded_shape(shape) - for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) - ] - if tpu_device is not None: - with ops.device(tpu.core(tpu_device)): - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - else: - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - - def _generate_enqueue_op(self, - inputs, - name_prefix, - index, - device=None, - tpu_ordinal=-1): - """Generate a host-side Op to enqueue a tuple to the queue. - - If device is None the inputs are all required to have the same - device specification, and the enqueue Op is colocated with - inputs[0]. Otherwise the enqueue Op is placed on 'device'. - - Args: - inputs: a list of Tensors with the types and shapes of the tuple elements. - name_prefix: the base name for the Op. - index: the shard index, used to uniquify the Op name. - device: device to place the Op on, or None if it should be - colocated with the inputs. - tpu_ordinal: ordinal of the TPU device on the host to use for - infeed if device is a CPU device. Should be set to -1 if device - is a TPU device. - - Returns: - An Op corresponding to a shard of infeed enqueued at the host, - suitable for use within a replicated block. - - Raises: - ValueError: if device is None and inputs do not all have the - same device specification. - """ - full_name = "%s/%d" % (name_prefix, index) - shapes = [t.shape for t in inputs] - if device is None: - devices = [t.device for t in inputs] - for i in xrange(1, self.number_of_tuple_elements): - if devices[0] != devices[i]: - raise ValueError( - "input devices for shard %d are %s, but should all be the same" % - (index, str(devices))) - with ops.colocate_with(inputs[0]): - return tpu_ops.infeed_enqueue_tuple( - inputs=inputs, - shapes=shapes, - name=full_name, - device_ordinal=tpu_ordinal) - else: - with ops.device(device): - return tpu_ops.infeed_enqueue_tuple( - inputs=inputs, - shapes=shapes, - name=full_name, - device_ordinal=tpu_ordinal) - - def generate_enqueue_ops(self, - sharded_inputs, - tpu_ordinal_function=None, - placement_function=None): - """Generates the host-side Ops to enqueue the shards of a tuple. - - sharded_inputs is a list, one for each shard, of lists of - Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed - shard 0 if the queue. Returns the host-side Ops that must be run to - enqueue the sharded tuple. The Op for shard i is colocated with the inputs - for shard i. - - Implicitly freezes the queue configuration if it is not already - frozen. If the configuration has already been frozen, and is not - compatible with the types and shapes of sharded_inputs, an error - will be raised. - - Args: - sharded_inputs: a list of lists of Tensors. The length of the outer list - determines the number of shards. Each inner list indicates the types - and shapes of the tuples in the corresponding shard. - tpu_ordinal_function: if not None, a function that takes the - shard index as input and returns the ordinal of the TPU device - the shard's infeed should be placed on. tpu_ordinal_function must be - set if the inputs are placed on CPU devices. - placement_function: if not None, a function that takes the shard index as - input and returns the host device where the enqueue op should be placed - on. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the shapes of the elements of sharded_inputs - don't form a consistent unsharded tuple; or if the elements of a tuple - have different device constraints. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the types of the elements of sharded_inputs - don't form a consistent unsharded tuple. - """ - self.set_configuration_from_sharded_input_tensors(sharded_inputs) - self.freeze() - if self._generated_enqueue_ops: - raise ValueError("Can't generate two enqueue Ops from the same queue") - self._generated_enqueue_ops = True - if tpu_ordinal_function is None: - tpu_ordinal_function = lambda index: -1 - name_prefix = "%s/enqueue" % self._name - return [ - self._generate_enqueue_op( - shard, - name_prefix, - index, - tpu_ordinal=tpu_ordinal_function(index), - device=placement_function(index) if placement_function else None) - for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) - ] - - # TODO(misard) Generalize this to the case of systems that don't - # have 8 devices per host, and figure out what to do with - # model-parallelism. - def _default_placement_function(self, index): - return "/task:%d/device:CPU:0" % (index / 8) - - def _default_ordinal_function(self, index): - return index % 8 - - # TODO(b/36470756) remove this from tutorials once we have a better story - # for automatic placement of input pipelines. - def split_inputs_and_generate_enqueue_ops(self, - inputs, - device_assignment=None, - placement_function=None, - tpu_ordinal_function=None): - """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. - - Generates the host-side Ops to enqueue a tuple. - - This method performs poorly because it takes an entire input on a single - host, splits it, and distributes it to all of the cores. It is present only - to simplify tutorial examples. - - inputs is a list of Tensors to use to feed the queue. Each input is split - into self.number_of_shards shards. Returns an Op for each shard to enqueue - the shard. The Op for shard i is placed on device placement_function(i). - - Implicitly freezes the queue configuration if it is not already - frozen. If the configuration has already been frozen, and is not - compatible with the types and shapes of inputs, an error - will be raised. - - Args: - inputs: a list of Tensors which indicates the types and shapes of the - queue tuple. - device_assignment: if not `None`, a TPU `DeviceAssignment`. If - device_assignment is not `None`, but `placement_function` and - `ordinal_function` are None, then `device_assignment` will be used to - place infeeds on the first k TPU shards, where k is the number of shards - in the queue. If all three are `None`, then default placement and - ordinal functions are used. - placement_function: if not None, a function that takes the shard - index as input and returns a device string indicating which - device the shard's infeed should be placed on. If placement_function - and tpu_ordinal_function are None, inputs are sharded round-robin - across the devices in the system. - tpu_ordinal_function: if not None, a function that takes the - shard index as input and returns the ordinal of the TPU device - the shard's infeed should be placed on. If placement_function - and tpu_ordinal_function are None, inputs are sharded round-robin - across the devices in the system. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of inputs are not compatible with the frozen - configuration. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of inputs are not compatible with the frozen - configuration. - """ - if device_assignment is None: - if placement_function is None: - placement_function = self._default_placement_function - if tpu_ordinal_function is None: - tpu_ordinal_function = self._default_ordinal_function - else: - - def _placement_function_from_map(index): - return device_assignment.host_device(replica=index) - - def _ordinal_function_from_map(index): - return device_assignment.tpu_ordinal(replica=index) - - if placement_function is None: - placement_function = _placement_function_from_map - if tpu_ordinal_function is None: - tpu_ordinal_function = _ordinal_function_from_map - self.set_configuration_from_input_tensors(inputs) - self.freeze() - if self._generated_enqueue_ops: - raise ValueError("Can't generate two enqueue Ops from the same queue") - self._generated_enqueue_ops = True - split_name_prefix = "%s/split" % self._name - if self.number_of_shards == 1: - transposed_sharded_inputs = [[inp] for inp in inputs] - else: - - def split_fn(inp, num_shards, axis, name): - with ops.colocate_with(inp): - return array_ops.split(inp, num_shards, axis=axis, name=name) - - transposed_sharded_inputs = [ - split_fn( - inp, - self.number_of_shards, - axis=policy.shard_dimension, - name="%s/%d" % (split_name_prefix, index)) - for (inp, policy, index) in zip(inputs, self._sharding_policies, - xrange(self.number_of_tuple_elements)) - ] - sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] - for i in xrange(self.number_of_shards)] - name_prefix = "%s/enqueue" % self._name - return [ - self._generate_enqueue_op( - shard, - name_prefix, - index, - device=placement_function(index), - tpu_ordinal=tpu_ordinal_function(index)) - for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) - ] - - -class _PartitionedInfeedQueue(InfeedQueue): - """A helper object to build a device infeed queue with input partition. - - Args: - number_of_tuple_elements: the number of Tensors fed atomically through the - queue, must be present unless it can be inferred from other arguments. - device_assignment: A TPU `DeviceAssignment` which is used to place all the - partitions to different TPU infeed queues. - host_id: The id of the host machine. - input_partition_dims: A nested list/tuple of integers. Each inner - list/tuple describes how to partition the corresponding input tensor. - tuple_types: If not None, a list of types of the elements of the queue. - tuple_shapes: If not None, a list of shapes of the elements of the queue. - name: The name of the queue. - """ - - def __init__(self, - number_of_tuple_elements, - device_assignment, - host_id, - input_partition_dims=None, - tuple_types=None, - tuple_shapes=None, - name=None): - super(_PartitionedInfeedQueue, self).__init__( - number_of_tuple_elements=number_of_tuple_elements, - tuple_types=tuple_types, - tuple_shapes=None, - shard_dimensions=None, - name="PartitionedInfeedQueue" if name is None else name) - self._input_partition_dims = input_partition_dims - self._host_id = host_id - self._device_assignment = device_assignment - - def generate_dequeue_op(self, tpu_device=0): - """Generate TPU dequeue ops. - - Args: - tpu_device: The TPU device ordinal where the infeed instruction should be - placed. - - Returns: - A list of Outputs corresponding to a partition of infeed dequeued - into XLA, suitable for use within a replicated block. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set; or if a dequeue op has already been generated. - """ - self.freeze() - if self._generated_dequeue_op: - raise ValueError("Can't generate two dequeue Ops from the same queue") - self._generated_dequeue_op = True - full_name = "%s/dequeue" % self._name - sharded_shapes = [ - policy.get_sharded_shape(shape) - for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) - ] - with ops.device(tpu.core(tpu_device)): - values = tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - return self._tag_sharding_attribute_for_dequeued_tensors( - values, self._input_partition_dims) - - def generate_enqueue_ops(self, per_host_sharded_inputs): - """Generates the host-side Ops to enqueue the partitioned inputs. - - per_host_sharded_inputs is a list, one for each replica, of lists of - Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed - replica i. - sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. - - For example, if sharded_inputs[i][j] is a 2-D Tensor: - [[A, B, C, D], - [E ,F, G, H]] - self._input_partition_dims[j] is [2, 4]. - - sharded_inputs[i][j] will be partitioned and flattened into: - [A, B, C, D, E, F, G, H] and fed into the logical core ids: - [0, 1, 2, 3, 4, 5, 6, 7] respectively. - - Args: - per_host_sharded_inputs: a list of lists of Tensors. The length of the - outer list determines the number of shards. Each inner list indicates - the types and shapes of the tuples in the corresponding shard. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the shapes of the elements of sharded_inputs - don't form a consistent unsharded tuple; or if the elements of a tuple - have different device constraints; or if the partition dims are invalid. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the types of the elements of sharded_inputs - don't form a consistent unsharded tuple. - """ - self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) - number_of_replicas_per_host = len(per_host_sharded_inputs) - number_of_tuple_elements = len(per_host_sharded_inputs[0]) - - assert len(self._input_partition_dims) == number_of_tuple_elements - per_host_enqueue_ops = [] - - for replica_index in range(number_of_replicas_per_host): - flattened_inputs = per_host_sharded_inputs[replica_index] - inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, - self._input_partition_dims) - inputs_parted_iters = [ - iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in - zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) - ] - - for logical_core in xrange(self._device_assignment.num_cores_per_replica): - # Places different partitions to different logic cores. - replica_id = self._device_assignment.lookup_replicas( - self._host_id, logical_core)[replica_index] - ordinal = self._device_assignment.tpu_ordinal( - replica=replica_id, logical_core=logical_core) - infeed_inputs = [] - for it in inputs_parted_iters: - input_for_device = next(it, None) - if input_for_device is not None: - infeed_inputs.append(input_for_device) - - if infeed_inputs: - per_host_enqueue_ops.append( - tpu_ops.infeed_enqueue_tuple( - inputs=infeed_inputs, - shapes=[x.shape for x in infeed_inputs], - name="enqueue/replica_{0}/input_{1}".format( - replica_index, logical_core), - device_ordinal=ordinal)) - return per_host_enqueue_ops - - def _check_input_partition_dims(self, tensor, dims): - """Checks that input partition dims are valid for the `Tensor`. - - Args: - tensor: Input tensor for partitioning. - dims: 1-D np.array of the list of integer describes how to partition the - input tensor. - - Raises: - ValueError: If the tensor can't be partitioned by dims or the - num_cores_per_replica doesn't match the number of - partitions(dims.prod()). - """ - if (dims < 1).any(): - raise ValueError("All input partition dims must be >= 1.") - - # No partitioning, so don't perform further checks. - if dims.prod() == 1: - return - - if dims.prod() != self._device_assignment.num_cores_per_replica: - raise ValueError( - "The product of each input parition dim should equal to " - "num_cores_per_replica. (dim = {}, num_cores_per_replica " - "= {})".format(dims, self._device_assignment.num_cores_per_replica)) - if dims.shape[0] != tensor.shape.ndims: - raise ValueError( - "Input partition dims must have the same number of dimensions " - "as the `Tensor` to be partitioned. (tensor shape = {}, input " - "partition dims = {}).".format(tensor.shape.as_list(), dims)) - - tensor.shape.assert_is_fully_defined() - - def _partition_or_replicate_on_host(self, tensor, dims): - """Partitions or replicates the input tensor. - - The ops inside this function are placed on the host side. - - Args: - tensor: The input tensor which will be partioned or replicated. - dims: A list of integer describes how to partition the input tensor. - Returns: - An iterator of `Tensor`s or a list of partioned tensors. - """ - if dims is None: - return itertools.repeat(tensor) - dims = np.array(dims) - self._check_input_partition_dims(tensor, dims) - output = [tensor] - shape_list = np.array(tensor.shape.as_list()) - quotients, remainders = np.divmod(shape_list, dims) - for axis, (quotient, remainder, dim, original_size) in enumerate( - zip(quotients, remainders, dims, shape_list)): - if dim <= 1: - continue - if remainder > 0: - # For each dimension, when it cannot be evenly partitioned, XLA assumes - # tensors are partitioned in a greedy manner by using - # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims - # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => - # [[(3, 4), (3, 4), (2, 4), (2, 2)], - # [(2, 4), (2, 4), (2, 4), (2, 2)]] - ceil_ratio = quotient + 1 - num_full_slots, left_over = np.divmod(original_size, ceil_ratio) - num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] - if len(num_or_size_splits) < dim: - num_or_size_splits += [0] * (dim - len(num_or_size_splits)) - new_output = [] - for x in output: - new_output.append( - array_ops.split( - x, num_or_size_splits=num_or_size_splits, axis=axis)) - output = new_output - else: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output - - def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensor. - - Args: - tensor: The dequeued tensor on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same tensor with the xla_sharding attribute. - """ - if dims is None: - return xla_sharding.replicate(tensor) - elif np.prod(dims) == 1: - return xla_sharding.assign_device(tensor, 0) - else: - tile_assignment = np.arange(np.prod(dims)).reshape(dims) - return xla_sharding.tile( - tensor=tensor, - tile_assignment=tile_assignment) - - def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensors. - - Args: - dequeues: A list of dequeued tensors on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same dequeues with appropriate xla_sharding attribute. - """ - nest.assert_shallow_structure(dequeues, dims) - return nest.map_structure_up_to( - dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, - dims) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_feed import * +# used by tests +from tensorflow.python.tpu.tpu_feed import _PartitionedInfeedQueue +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py index 84d5967ea547f0c036f7c9aa936ac0c99c141304..f2755c6979c2e49dbc19b6800462949601811496 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py @@ -1,57 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper library for functions used during TPU compilation.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - - -class TpuContext(object): - """A context object holding state about the TPU computation being built.""" - - def __init__(self): - """Creates a new TpuContext.""" - self._number_of_shards = None - - @property - def number_of_shards(self): - return self._number_of_shards - - def set_number_of_shards(self, number_of_shards): - self._number_of_shards = number_of_shards - - -# The Tpu context holds the number of shards when a sharded computation is -# being built, or None if no computation is being built. -_current_tpu_context = TpuContext() - - -@contextlib.contextmanager -def tpu_shard_context(number_of_shards): - if _current_tpu_context.number_of_shards is not None: - raise NotImplementedError("tpu_shard_context cannot be nested.") - try: - _current_tpu_context.set_number_of_shards(number_of_shards) - yield - finally: - _current_tpu_context.set_number_of_shards(None) - - -def get_tpu_context(): - return _current_tpu_context +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_function import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index 1e11de6421e360faf0b9ad573a84f9aecdf9c98f..ca58e78d7b342c7ca70400652d99092ccbecbbde 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -1,203 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Optimizer that implements cross-shard gradient reduction for TPU.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import ops -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import optimizer - - -class CrossShardOptimizer(optimizer.Optimizer): - """An optimizer that averages gradients across TPU shards.""" - - def __init__(self, - opt, - reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer", - group_assignment=None): - """Construct a new cross-shard optimizer. - - Args: - opt: An existing `Optimizer` to encapsulate. - reduction: The reduction to apply to the shard losses. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "CrossShardOptimizer". - group_assignment: Optional 2d int32 lists with shape - [num_groups, num_replicas_per_group] which describles how to apply - optimizer to subgroups. - - Raises: - ValueError: If reduction is not a valid cross-shard reduction. - """ - if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN): - raise ValueError("Unsupported reduction: %s." % reduction) - - super(CrossShardOptimizer, self).__init__(False, name) - self._opt = opt - self._reduction = reduction - self._group_assignment = group_assignment - - def _verify_and_get_subgroup_size(self, group_assignment, num_shards): - """Verify group_assignment and get the subgroup size". - - Args: - group_assignment: list of group ids for applying the optimizer - to subgroups. - num_shards: The number of TPU shards. - - Returns: - The size of one subgroup in group_assignment. - - Raises: - ValueError: If group_assignment is invalid. - """ - if not group_assignment: - return None - if not (isinstance(group_assignment, list) and - all(isinstance(i, list) for i in group_assignment)): - raise ValueError("group_assignment must be a list of list. Got {}".format( - group_assignment)) - - replica_ids = set() - for g in group_assignment: - for i in g: - replica_ids.add(i) - - if set(range(num_shards)) != replica_ids: - raise ValueError("group_assignment must be a permutation of range({0})." - " Got group_assignment={1}".format( - num_shards, group_assignment)) - - subgroup_size_list = [len(group) for group in group_assignment] - if all(subgroup_size_list[0] == size for size in subgroup_size_list): - return subgroup_size_list[0] - else: - raise ValueError("The size of each subgroup in group_assignment must " - "be equal. Got group_assignment={}".format( - self._group_assignment)) - - def compute_gradients(self, loss, var_list=None, **kwargs): - """Compute gradients of "loss" for the variables in "var_list". - - This simply wraps the compute_gradients() from the real optimizer. The - gradients will be aggregated in the apply_gradients() so that user can - modify the gradients like clipping with per replica global norm if needed. - The global norm with aggregated gradients can be bad as one replica's huge - gradients can hurt the gradients from other replicas. - - Args: - loss: A Tensor containing the value to minimize. - var_list: Optional list or tuple of `tf.Variable` to update to minimize - `loss`. Defaults to the list of variables collected in the graph - under the key `GraphKey.TRAINABLE_VARIABLES`. - **kwargs: Keyword arguments for compute_gradients(). - - Returns: - A list of (gradient, variable) pairs. - - Raises: - ValueError: If not within a tpu_shard_context or group_assignment is - invalid. - """ - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "CrossShardOptimizer should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - - subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, - num_shards) - - if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - if self._group_assignment: - scale = 1.0 / subgroup_size - else: - scale = 1.0 / num_shards - loss *= scale - - return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) - - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """Apply gradients to variables. - - Calls tpu_ops.cross_replica_sum() to sum gradient contributions across - replicas, and then applies the real optimizer. - - Args: - grads_and_vars: List of (gradient, variable) pairs as returned by - compute_gradients(). - global_step: Optional Variable to increment by one after the - variables have been updated. - name: Optional name for the returned operation. Default to the - name passed to the Optimizer constructor. - - Returns: - An `Operation` that applies the gradients. If `global_step` was not None, - that operation also increments `global_step`. - - Raises: - ValueError: If the grads_and_vars is malformed. - """ - summed_grads_and_vars = [] - for (grad, var) in grads_and_vars: - if grad is None: - summed_grads_and_vars.append((grad, var)) - else: - with ops.colocate_with(grad): - summed_grads_and_vars.append((tpu_ops.cross_replica_sum( - grad, self._group_assignment), var)) - return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) - - def get_slot(self, *args, **kwargs): - """Return a slot named "name" created for "var" by the Optimizer. - - This simply wraps the get_slot() from the actual optimizer. - - Args: - *args: Arguments for get_slot(). - **kwargs: Keyword arguments for get_slot(). - - Returns: - The `Variable` for the slot if it was created, `None` otherwise. - """ - return self._opt.get_slot(*args, **kwargs) - - def get_slot_names(self, *args, **kwargs): - """Return a list of the names of slots created by the `Optimizer`. - - This simply wraps the get_slot_names() from the actual optimizer. - - Args: - *args: Arguments for get_slot(). - **kwargs: Keyword arguments for get_slot(). - - Returns: - A list of strings. - """ - return self._opt.get_slot_names(*args, **kwargs) - - def variables(self): - """Forwarding the variables from the underlying optimizer.""" - return self._opt.variables() +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_optimizer import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index f5af03f33ca8f13af517007672e9ce0e12be6205..93c52335a582e5fa83092f78212ca268079b7c12 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -1,253 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper library for sharding during TPU compilation.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.python.framework import tensor_shape - -_DEFAULT_NUMBER_OF_SHARDS = 1 -_DEFAULT_SHARD_DIMENSION = 0 - - -# TODO(b/36777903) change other parts of tpu.py to use this class. -class ShardingPolicy(object): - """An object use to hold the sharding policy for a Tensor. - """ - - def __init__(self): - self._number_of_shards = None - self._shard_dimension = None - self._frozen = False - - def __str__(self): - if self.number_of_shards is None or self.shard_dimension is None: - return "ShardingPolicy(unset)" - else: - return ("ShardingPolicy(%d shards dimension %d)" % - (self.number_of_shards, self.shard_dimension)) - - def _fill_default_values(self): - if self._number_of_shards is None: - self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS - if self._shard_dimension is None: - self._shard_dimension = tensor_shape.as_dimension( - _DEFAULT_SHARD_DIMENSION) - - def freeze(self): - """Prevents further modification to the sharding policy. - - Any values that have not been set when freeze is called are set to - defaults. If the ShardingPolicy is already frozen, this is a NoOp. - """ - if not self._frozen: - self._fill_default_values() - self._frozen = True - - @property - def number_of_shards(self): - """Returns the number of shards in the policy or None if unspecified.""" - return self._number_of_shards - - def set_number_of_shards(self, number_of_shards): - """Sets the number of shards for the current policy. - - If the policy has been frozen then number_of_shards must match the - existing setting. - - Args: - number_of_shards: The number of shards to use in the policy. - - Raises: - ValueError: If the policy has been frozen and number_of_shards - differs from the frozen value; or number_of_shards <= 0. - """ - if self._frozen: - if self._number_of_shards != number_of_shards: - raise ValueError( - "Can't set sharding policy to use %d shards since it has been " - "frozen to use %d." % (number_of_shards, self._number_of_shards)) - else: - if number_of_shards > 0: - self._number_of_shards = number_of_shards - else: - raise ValueError( - "Can't set sharding policy to use %s shards; value must be >0", - str(number_of_shards)) - - @property - def shard_dimension(self): - """Returns the shard dimension of the policy or None if unspecified.""" - return self._shard_dimension - - def set_shard_dimension(self, shard_dimension): - """Sets the shard dimension for the current policy. - - If the policy has been frozen then shard_dimension must match the - existing setting. - - Args: - shard_dimension: The shard dimension to use in the policy. - - Raises: - ValueError: If the policy has been frozen and shard_dimension - differs from the frozen value, or shard_dimension can't be - interpreted as a Dimension. - """ - if self._frozen: - if self._shard_dimension != shard_dimension: - raise ValueError( - "Can't set shard dimension to %d since it has been frozen to " - "use %d." % (shard_dimension, self._shard_dimension)) - else: - self._shard_dimension = tensor_shape.as_dimension(shard_dimension) - - def merge(self, other): - """Merges the policy of another policy into the current policy. - - Args: - other: The policy to merge into this one. - - Raises: - ValueError: If this policy has been frozen and the merge conflicts with - the frozen policy. - """ - if other.number_of_shards is not None: - self.set_number_of_shards(other.number_of_shards) - if other.shard_dimension is not None: - self.set_shard_dimension(other.shard_dimension) - - def get_sharded_shape(self, shape, shard_index=None): - """Returns the shape of a shard of a full Tensor. - - When given the shape of a 'full-size' Tensor, returns the shape of - the sub-Tensor after it has been sharded. Freezes the policy if it - has not yet been frozen. - - Args: - shape: The shape of the full-size Tensor to be sharded. - shard_index: The index of the shard whose shape should be returned. - shard_index can be None for sharding policies that use the same - shape for every shard. - freeze_config: - - Returns: - The shape of the sharded version of the Tensor. - - Raises: - ValueError: If shard_index is None when shards are of different - shapes; or shard_index is not None and - !(0<=shard_index= self.number_of_shards: - raise ValueError("shard_index %d, but must be in [0,%d)." % - (shard_index, self._number_of_shards)) - shape = tensor_shape.as_shape(shape) - if self._number_of_shards == 1: - # Don't do anything when there's only one shard. - return shape - ndims = shape.ndims - if ndims is None: - raise ValueError("shape must be a specified shape not Unknown") - if ndims <= self._shard_dimension: - raise ValueError("shape %s does not contain shard_dimension %d" % - (shape.as_list(), self._shard_dimension)) - dims = shape.as_list() - if dims[self._shard_dimension] is None: - raise ValueError("shape %s must have a fixed size for dimension %d " - "that is known at graph construction time." % - (shape.as_list(), self._shard_dimension)) - if (dims[self._shard_dimension] % self._number_of_shards) != 0: - raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % - (shape.as_list(), self._number_of_shards, - self._shard_dimension)) - dims[self._shard_dimension] /= self._number_of_shards - return tensor_shape.as_shape(dims) - - def _unshard_shape(self, shape): - """Return the unsharded shape that would generate a given sharded shape. - - Args: - shape: the sharded shape to unshard - - Returns: - The unsharded shape. - - Raises: - ValueError: if shape is unknown or does not contain - self.shard_dimension - TypeError: if shape is not convertible to a TensorShape - """ - shape = tensor_shape.as_shape(shape) - if self._number_of_shards == 1: - # Don't do anything when there's only one shard. - return shape - ndims = shape.ndims - if ndims is None: - raise ValueError("shape must be a specified shape not Unknown") - if ndims <= self._shard_dimension: - raise ValueError("shape %s does not contain shard_dimension %d" % - (shape.as_list(), self._shard_dimension)) - dims = shape.as_list() - dims[self._shard_dimension] *= self._number_of_shards - return tensor_shape.as_shape(dims) - - def get_unsharded_shape(self, shapes): - """Returns the shape of an unsharded Tensor given a list of shards. - - When given a list of shapes of shards, returns the shape of the - unsharded Tensor that would generate the shards. Sets defaults for the - policy if number_of_shards or shard_dimension is None. - - Args: - shapes: The shapes of the Tensor shards to be combined. - - Returns: - The shape of the unsharded version of the Tensor. - - Raises: - ValueError: if shapes is not a list of length - self.number_of_shards; or any element of shapes is not a valid - shape consistent with the sharding policy; or the list of - shapes is not a valid sharding of a full shape. - TypeError: if an element of shapes is not convertible to a - TensorShape - """ - self._fill_default_values() - if len(shapes) != self.number_of_shards: - raise ValueError( - "shapes is %s but must be a list of length number_of_shards=%d" % ( - str(shapes), self.number_of_shards)) - unsharded_shapes = [self._unshard_shape(s) for s in shapes] - for i in xrange(self.number_of_shards - 1): - if not unsharded_shapes[i].is_compatible_with( - unsharded_shapes[self.number_of_shards - 1]): - raise ValueError( - "sharded shapes %s are not consistent shards of a full shape " - "sharded %d ways along dimension %d" % ( - str(shapes), self.number_of_shards, self.shard_dimension)) - return unsharded_shapes[0] +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_sharding import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e..258d34ddaf5250e49c5a354caf018e4b64abae62 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -1,156 +1,25 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU system metadata and associated tooling.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import re - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging - -_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min -_RETRY_TIMES = 120 -_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins - -_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') - -# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, -# including num_cores and num_hosts. -_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ - 'num_cores', - 'num_hosts', - 'num_of_cores_per_host', - 'topology', - 'devices', -]) - - -def _query_tpu_system_metadata(master_address, cluster_def=None, - query_topology=False): - """Automatically detects the TPU system metadata in the system.""" - tpu_core_count = 0 - devices = [] - device_dict = collections.defaultdict(list) - - # TODO(b/120564445): Replace with standard library for retries. - retry_count = 1 - while True: - logging.info('Querying Tensorflow master (%s) for TPU system metadata.', - master_address) - try: - with ops.Graph().as_default(): - with session_lib.Session( - master_address, - config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, - cluster_def)) as sess: - devices = sess.list_devices() - for device in devices: - match = _TPU_DEVICE_REG.match(device.name) - if match: - host_id = match.group(1) - core_id = match.group(2) - device_dict[host_id].append(core_id) - tpu_core_count += 1 - break - except errors.DeadlineExceededError: - msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' - 'not be ready (still scheduling) or the Tensorflow master address ' - 'is incorrect: got (%s).' % - (master_address)) - - # TODO(xiejw): For local or grpc master we might not need retry logic - # here. - if retry_count <= _RETRY_TIMES: - logging.warning('%s', msg) - logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) - retry_count += 1 - else: - raise ValueError(msg) - - num_of_cores_per_host = 0 - if tpu_core_count: - num_cores_per_host_set = set( - [len(core_ids) for core_ids in device_dict.values()]) - if len(num_cores_per_host_set) != 1: - raise RuntimeError( - 'TPU cores on each host is not same. This should not happen!. ' - 'devices: {}'.format(devices)) - num_of_cores_per_host = num_cores_per_host_set.pop() - - topology = None - if query_topology: - if not tpu_core_count: - raise RuntimeError( - 'Cannot find any TPU cores in the system (master address {}). ' - 'This usually means the master address is incorrect or the ' - 'TPU worker has some problems. Available devices: {}'.format( - master_address, devices)) - - topology = _obtain_topology(master_address, cluster_def) - - metadata = _TPUSystemMetadata( - num_cores=tpu_core_count, - num_hosts=len(device_dict), - num_of_cores_per_host=num_of_cores_per_host, - topology=topology, - devices=devices) - - if tpu_core_count: - logging.info('Found TPU system:') - logging.info('*** Num TPU Cores: %d', metadata.num_cores) - logging.info('*** Num TPU Workers: %d', metadata.num_hosts) - logging.info('*** Num TPU Cores Per Worker: %d', - metadata.num_of_cores_per_host) - for device in metadata.devices: - logging.info('*** Available Device: %s', device) - else: - logging.info('Failed to find TPU: %s', metadata) - return metadata - - -def _obtain_topology(master_address, cluster_def): - """Obtains TPU fabric topology.""" - try: - logging.info('Initializing TPU system (master: %s) to fetch topology ' - 'for model parallelism. This might take a while.', - master_address) - with ops.Graph().as_default(): - session_config = get_session_config_with_timeout( - _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) - with session_lib.Session( - master_address, config=session_config) as sess: - topology = sess.run(tpu.initialize_system()) - return topology - except errors.DeadlineExceededError: - raise ValueError( - 'Fail to initialize TPU system with master (%s). ' - 'Please double check the TPU system is functional.' % ( - master_address)) - - -def get_session_config_with_timeout(timeout_in_secs, cluster_def): - """Returns a session given a timeout and a cluster configuration.""" - config = config_pb2.ConfigProto( - operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) - return config +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_system_metadata import * +# used by tests +from tensorflow.python.tpu.tpu_system_metadata import _query_tpu_system_metadata +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py deleted file mode 100644 index 6bdaa528f9f946ae4b9813d554409da2406b1f8d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -"""Tests for tpu_function helpers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_feed -from tensorflow.contrib.tpu.python.tpu import training_loop - -from tensorflow.python.framework import dtypes -from tensorflow.python.layers import convolutional -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util -from tensorflow.python.ops import math_ops - -from tensorflow.python.platform import test - - -class TPUContextTest(test.TestCase): - - def testIsInContext(self): - """Test that control_flow_util can check that we're in a TPU context.""" - z1 = array_ops.identity(1) - pivot = control_flow_ops.no_op() - context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) - context.Enter() - z2 = array_ops.identity(1) - context.Exit() - self.assertFalse(control_flow_util.IsInXLAContext(z1.op)) - self.assertTrue(control_flow_util.IsInXLAContext(z2.op)) - - -class TPULayerRewriteTest(test.TestCase): - - def testUsingInfeedQueueWithRegularizer(self): - """Test that Layer regularizers can reference data created in loops.""" - - def make_regularizer(scale): - return lambda inputs: scale * math_ops.reduce_sum(math_ops.square(inputs)) - - def training_step(inputs, scale): - outputs = convolutional.conv2d( - inputs, - filters=16, - kernel_size=(3, 3), - data_format="channels_first", - kernel_regularizer=make_regularizer(scale)) - loss = math_ops.reduce_mean(math_ops.square(outputs)) - return loss.op - - inputs = array_ops.zeros(shape=(128, 32, 32, 16)) - scale = array_ops.ones(shape=()) - infeed = tpu_feed.InfeedQueue( - tuple_types=[dtypes.float32, dtypes.float32], - tuple_shapes=[inputs.shape, scale.shape]) - - def loop(): - return training_loop.repeat(5, training_step, infeed_queue=infeed) - - # This should not throw an error. - tpu.rewrite(loop) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index 0187b4bec6ecc55943bf48b9268a74e18ea5b488..673359b232d6857d468723873c449cb3e48168c7 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -1,214 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Library for constructing a training loop, suitable for TPUs.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.compiler import xla -from tensorflow.contrib.tpu.python.tpu import tpu_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops - - -def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop for TPUs. - - The set of loop-carried tensors corresponds to `inputs`. Both - `condition` and `body` take the current value of the loop-carried - tensors. 'body' additionally takes a tuple of infeed from - infeed_queue if infeed_queue is not None. `condition` must return a - single boolean value that determines whether iteration - continues. `body` must return an updated list of values for the - loop-carried tensors. - - Args: - condition: a Python function that builds the loop condition. - body: a Python function that builds the loop body. - inputs: a list of initial values passed into the training loop, or - None (equivalent to an empty list). - infeed_queue: if not None, the infeed queue from which to append a tuple - of arguments as inputs to condition. - name: (Deprecated) Does nothing. - - Returns: - The final values of the loop-carried tensors. - - Raises: - TypeError: if body or condition has the wrong signature. - """ - del name - # Converts inputs to Tensors. - inputs = [] if inputs is None else [ops.convert_to_tensor(x) for - x in inputs] - input_types = [x.dtype for x in inputs] - input_arity = len(inputs) - - body_arg_error = xla.check_function_argument_count( - body, input_arity, infeed_queue) - if body_arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied loop body function cannot be called with the specified " - "inputs. You specified %d inputs: %s, but the loop body needs %s" % ( - input_arity, str([i.name for i in inputs]), body_arg_error)) - else: - raise TypeError( - "Supplied loop body function cannot be called with the specified " - "inputs. You specified %d inputs: %s and %d additional inputs from " - "infeed, but the computation needs %s" % (input_arity, str( - [i.name for i in inputs]), infeed_queue.number_of_tuple_elements, - body_arg_error)) - condition_arg_error = xla.check_function_argument_count( - condition, input_arity, None) - if condition_arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied loop condition function cannot be called with the " - "specified inputs. You specified %d inputs: %s, but the loop " - "condition needs %s" % (input_arity, str([i.name for i in inputs]), - condition_arg_error)) - else: - raise TypeError( - "Supplied loop condition function cannot be called with the " - "specified inputs. You specified %d inputs: %s, but the loop " - "condition needs %s. Note that infeed is not passed to the loop " - "condition." % (input_arity, str([i.name for i in inputs]), - condition_arg_error)) - - def condition_wrapper(*inputs): - # Discards the dummy output added for arity-0 loops. - if input_arity == 0: - inputs = [] - return condition(*inputs) - - def body_wrapper(*inputs): - """Wrapper around `body` that handles infeed queues and control deps.""" - inputs = list(inputs) - - # Discards the dummy output added for arity-0 loops. - if input_arity == 0: - inputs = [] - - # Runs `body` with the dequeue_ops appended. - if infeed_queue: - number_of_shards = tpu_function.get_tpu_context().number_of_shards - if number_of_shards is None: - raise ValueError("Can't build training loop with infeed when there is " - "no tpu_shard_context. Are you building a loop or " - "graph directly rather than from inside tpu.rewrite, " - "tpu.batch_parallel, tpu.shard, or tpu.replicate?") - infeed_queue.set_number_of_shards(number_of_shards) - dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] - else: - dequeue_ops = [] - outputs = body(*(inputs + dequeue_ops)) - - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs - if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU training loop body must return zero or more Tensor values " - "followed by zero or more Operations.") - - output_types = [op.dtype for op in output_tensors] - if input_types != output_types: - raise TypeError( - "Mismatch between input types and output types for training loop " - "body: {} vs {}".format(input_types, output_types)) - - # Add the dequeue operations to output_operations to ensure they are run - # by the loop, even if the programmer's loop body does not use them. - output_operations += dequeue_ops - - # Add a dummy output, if needed. - if not output_tensors: - output_tensors = array_ops.constant(0) - - if output_operations: - # TODO(phawkins): in principle this is too restrictive since it serializes - # the training loop steps. In practice it does not matter since this loop - # will be compiled by XLA. - return control_flow_ops.tuple(output_tensors, - control_inputs=output_operations) - else: - return output_tensors - - # If the body has arity 0, add a dummy loop-carried value to which we can add - # control dependencies from any side-effecting operations. - if input_arity == 0: - inputs = [array_ops.constant(0)] - return control_flow_ops.while_loop( - condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) - - -def repeat(n, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop that executes a fixed number of iterations. - - The set of loop-carried tensors correspond to `inputs`. - `body` must be a function that takes and returns the values of the - loop-carried tensors. - - Args: - n: the number of loop iterations - body: a Python function that builds the loop body. - inputs: a list of initial values passed into the training loop or - None (equivalent to an empty list). - infeed_queue: if not None, the infeed queue from which to append a tuple - of arguments as inputs to condition. - name: (Deprecated) Does nothing. - Returns: - The final values of the loop-carried tensors. - Raises: - ValueError: if there is a type error. - """ - def _convert_to_list(xs): - if not isinstance(xs, (list, tuple)): - return [xs] - else: - return list(xs) - - def cond(i, *args): - del args - return i < n - - def body_wrapper(i, *args): - return [i + 1] + _convert_to_list(body(*args)) - - inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) - outputs = while_loop( - cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) - outputs = _convert_to_list(outputs) - if len(outputs) == 1: - # Returns the Op rather than an empty list. - return outputs[0].op - else: - return outputs[1:] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.training_loop import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py index dfb8ce1d1821da05c853bb0d10b1db3a857ccb1b..8d9b70d46eb42c9a525eeafc51d07f0ad4241d52 100644 --- a/tensorflow/contrib/tpu/python/tpu/util.py +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -1,51 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""Utilities for the functionalities.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time -import six - -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import training - -def check_positive_integer(value, name): - """Checks whether `value` is a positive integer.""" - if not isinstance(value, six.integer_types): - raise TypeError('{} must be int, got {}'.format(name, type(value))) - - if value <= 0: - raise ValueError('{} must be positive, got {}'.format(name, value)) - - -# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we -# release a tensorflow_estimator with MultiHostDatasetInitializerHook in -# python/estimator/util.py. -class MultiHostDatasetInitializerHook(training.SessionRunHook): - """Creates a SessionRunHook that initializes all passed iterators.""" - - def __init__(self, dataset_initializers): - self._initializers = dataset_initializers - - def after_create_session(self, session, coord): - del coord - start = time.time() - session.run(self._initializers) - logging.info('Initialized dataset iterators in %d seconds', - time.time() - start) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.util import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index f6427ae05a20f253edf030eff0f860361616042b..5bc4c3b88efd641b6f17a54753a29b0603c2b98c 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -264,9 +264,9 @@ py_test( py_test( name = "training_test", - size = "large", + size = "medium", srcs = ["python/training/training_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", tags = ["notsan"], deps = [ diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index e7f23edc901eacfa3a753792c2dbf738bb5a9421..10f3f88f3eb877998f3498018863b4972ee45b07 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -400,7 +400,7 @@ def bucket_by_sequence_length(input_length, math_ops.less_equal(buckets_min, input_length), math_ops.less(input_length, buckets_max)) which_bucket = math_ops.reduce_min(array_ops.where(conditions_c)) - which_bucket = math_ops.to_int32(which_bucket) + which_bucket = math_ops.cast(which_bucket, dtypes.int32) if shapes is not None: shapes = [tensor_shape.scalar()] + shapes diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 27f0d9b2e38c433d4fb4573285ecb8c9946112e8..cb0a25f333b2bba9c4eee991180eab2a083eeb31 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -353,8 +353,10 @@ class HParams(object): def my_program(): # Create a HParams object specifying the names and values of the # model hyperparameters: - hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, - activations=['relu', 'tanh']) + hparams = tf.contrib.training.HParams( + learning_rate=0.1, + num_hidden_units=100, + activations=['relu', 'tanh']) # Override hyperparameters values by parsing the command line hparams.parse(args.hparams) @@ -387,7 +389,7 @@ class HParams(object): # Define 3 hyperparameters: 'learning_rate' is a float parameter, # 'num_hidden_units' an integer parameter, and 'activation' a string # parameter. - hparams = tf.HParams( + hparams = tf.contrib.training.HParams( learning_rate=0.1, num_hidden_units=100, activation='relu') hparams.activation ==> 'relu' diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 53e4f23a7cd940c026e462dc7fb55cf9f175bf02..ce3d5ec1de569a61f4b58a4522acae1be1fc59ee 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -1597,7 +1597,7 @@ def _padding(sequences, num_unroll): else: # Only have SparseTensors sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values() if isinstance(value, sparse_tensor.SparseTensor)] - length = math_ops.reduce_max(math_ops.to_int32(sparse_lengths)) + length = math_ops.reduce_max(math_ops.cast(sparse_lengths, dtypes.int32)) unroll = array_ops.constant(num_unroll) padded_length = length + ((unroll - (length % unroll)) % unroll) @@ -1620,8 +1620,9 @@ def _padding(sequences, num_unroll): # 3. concat values with paddings padded_sequences[key] = array_ops.concat([value, paddings], 0) else: - padded_shape = array_ops.concat([[math_ops.to_int64(padded_length)], - value.dense_shape[1:]], 0) + padded_shape = array_ops.concat( + [[math_ops.cast(padded_length, dtypes.int64)], value.dense_shape[1:]], + 0) padded_sequences[key] = sparse_tensor.SparseTensor( indices=value.indices, values=value.values, @@ -1834,8 +1835,8 @@ def _reconstruct_sparse_tensor_seq(sequence, Returns: A SparseTensor with a +1 higher rank than the input. """ - idx_batch = math_ops.to_int64( - math_ops.floor(sp_tensor.indices[:, 0] / num_unroll)) + idx_batch = math_ops.cast( + math_ops.floor(sp_tensor.indices[:, 0] / num_unroll), dtypes.int64) idx_time = math_ops.mod(sp_tensor.indices[:, 0], num_unroll) indices = array_ops.concat( [ diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index fc6e38ab4a5243cb7502f4ca42db03cbfd342a40..4ceb6e9350f5167efc8f7266d4e748cc6fa4ffd6 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -244,7 +244,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -354,11 +353,11 @@ def multiply_gradients(grads_and_vars, gradient_multipliers): raise ValueError('Requested multiple of `None` gradient.') if isinstance(grad, ops.IndexedSlices): - tmp = grad.values * constant_op.constant( + tmp = grad.values * ops.convert_to_tensor( gradient_multipliers[key], dtype=grad.dtype) grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) else: - grad *= constant_op.constant( + grad *= ops.convert_to_tensor( gradient_multipliers[key], dtype=grad.dtype) multiplied_grads_and_vars.append((grad, var)) return multiplied_grads_and_vars @@ -433,7 +432,7 @@ def create_train_op(total_loss, else: # Make sure that variables_to_train are in tf.trainable_variables() for v in variables_to_train: - assert v in tf_variables.trainable_variables() + assert v.trainable or v in tf_variables.trainable_variables() assert variables_to_train diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index d9ccda8e89a4c9a1b3f3d24915b9ad3fb4d9be5f..7b2bc30e3a85ed890e3c66ceeb448cbeb61e86d3 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -16,10 +16,15 @@ cc_library( srcs = ["convert_graphdef_memmapped_format_lib.cc"], hdrs = ["convert_graphdef_memmapped_format_lib.h"], deps = [ + "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", ], @@ -39,7 +44,7 @@ tf_cc_test( name = "convert_graphdef_memmapped_format_test", srcs = ["convert_graphdef_memmapped_format_test.cc"], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), deps = [ diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h index 444c863b942ef8bce8d54d59765563b12eb6087e..e616778665a9c95b30099b128ec5d1e181ba0618 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service.h @@ -25,12 +25,6 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/lib/core/refcount.h" -namespace grpc { -class ServerBuilder; -class ServerCompletionQueue; -class Alarm; -} // namespace grpc - namespace tensorflow { class GrpcVerbsService : public AsyncServiceInterface { diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index cfb9b7ddd7d88c150e47caff66f0865fcaec662c..2432c34ae2353d5d7bca03d80a043b5875ef8cce 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -27,14 +27,6 @@ limitations under the License. #include "tensorflow/contrib/verbs/verbs_service.pb.h" -namespace grpc { -class CompletionQueue; -class Channel; -class RpcService; -class ServerCompletionQueue; -class ServerContext; -} // namespace grpc - namespace tensorflow { namespace grpc { diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 9db80f6b5736d849d88e1e41ea467a5ff11844f5..b4b6b705f4bab74ac9579ea0354bb7306f06b312 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -1086,7 +1086,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, // The tensor must be copied from GPU to CPU, because either: // 1. The tensor is located on a non GDR compatible GPU. // 2. The tensor's meta-data has changed. - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetGpuHostAllocator(0); copy = Tensor(alloc, in.dtype(), in.shape()); CountCopies(rm_.name_, (void*)DMAHelper::base(&in), (void*)DMAHelper::base(©), in.TotalBytes(), true); @@ -1543,7 +1543,7 @@ bool RdmaTensorRequest::AllocateTensors() { if (mr_ == nullptr) { // Can't RDMA directly to result. Use a proxy. proxy_tensor_ = - new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0), + new Tensor(GPUProcessState::singleton()->GetGpuHostAllocator(0), result_tensor_->dtype(), result_tensor_->shape()); rdma_addr_ = DMAHelper::base(proxy_tensor_); mr_ = diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 2f2375427862ad1e99a0e6bfc506382d200e9b1d..5ac9f46447c518c342b565b2b32bac56ead7be5e 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -277,8 +277,8 @@ void RdmaMgr::InitAllocators() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); #if GOOGLE_CUDA - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(0, alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(0, free_visitor); + GPUProcessState::singleton()->AddGpuHostAllocVisitor(0, alloc_visitor); + GPUProcessState::singleton()->AddGpuHostFreeVisitor(0, free_visitor); if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 19ef109f671ee57ce2aceb55110c50aa44352223..d07fd5ae6e9cc0dbf67c6b6a4e8db086b4c74aa1 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -81,7 +81,10 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, Status VerbsServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); }); - Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); + GrpcServerOptions opts; + opts.service_func = service_func; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + Status s = GrpcServer::Init(opts); { mutex_lock l(mu_); CHECK_EQ(verbs_state_, DISCONNECTED); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 084db7a0fd470574311501adb5bab8b5f7e19dab..9ff4b631634fe90e968cc6ff2bc4c2156deda1bb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -63,7 +63,15 @@ # //tensorflow/tensorflow.bzl) will include the necessary symbols in binary # build targets. +package_group( + name = "dependency_whitelist", + packages = [ + "//learning/freud/topic_models/tensorflow/...", + ], +) + package(default_visibility = [ + ":dependency_whitelist", "//tensorflow:internal", "//tensorflow_models:__subpackages__", ]) @@ -77,6 +85,7 @@ load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_android", + "if_emscripten", "if_ios", "if_linux_x86_64", "if_mobile", @@ -87,10 +96,12 @@ load( "tf_copts", "tf_cuda_library", "tf_features_nomodules_if_android", + "tf_features_nomodules_if_emscripten", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", + "tf_opts_nortti_if_emscripten", "transitive_hdrs", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") @@ -125,7 +136,9 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", - "tf_additional_proto_compiler_hdrs", + "tf_additional_numa_deps", + "tf_additional_numa_lib_defines", + "tf_additional_numa_copts", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", @@ -144,6 +157,7 @@ load( "tf_protos_grappler", "tf_protos_grappler_impl", "tf_pyclif_proto_library", + "tf_grpc_service_all", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -201,10 +215,12 @@ COMMON_PROTO_SRCS = [ "protobuf/cluster.proto", "protobuf/debug.proto", "protobuf/device_properties.proto", + "protobuf/graph_debug_info.proto", "protobuf/queue_runner.proto", "protobuf/rewriter_config.proto", "protobuf/tensor_bundle.proto", "protobuf/saver.proto", + "protobuf/verifier_config.proto", "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", @@ -224,13 +240,15 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS # ones with individual proto_library targets. ADDITIONAL_CORE_PROTO_SRCS = [ "example/example_parser_configuration.proto", - "protobuf/checkpointable_object_graph.proto", + "protobuf/trackable_object_graph.proto", "protobuf/control_flow.proto", # TODO(ebrevdo): Re-enable once CriticalSection is in core. # "protobuf/critical_section.proto", "protobuf/meta_graph.proto", "protobuf/named_tensor.proto", "protobuf/saved_model.proto", + "protobuf/saved_object_graph.proto", + "protobuf/struct.proto", "protobuf/tensorflow_server.proto", "protobuf/transport_options.proto", "util/test_log.proto", @@ -381,15 +399,15 @@ cc_library( ":platform_port_hdrs", ":platform_port_internal_hdrs", ], - copts = tf_copts(), + copts = tf_copts() + tf_additional_numa_copts(), visibility = ["//tensorflow/core:__subpackages__"], deps = [ ":lib_platform", ":platform_base", - "//tensorflow/core/platform/default/build_config:port", "@com_google_absl//absl/base", + "//tensorflow/core/platform/default/build_config:port", "@snappy", - ], + ] + tf_additional_numa_deps(), ) filegroup( @@ -413,9 +431,8 @@ cc_library( name = "platform_protobuf", srcs = tf_platform_hdrs([ "protobuf.h", - ]) + tf_platform_srcs([ - "protobuf.cc", ]) + [ + "platform/protobuf.cc", "platform/protobuf_util.cc", "lib/core/status.h", ], @@ -434,6 +451,17 @@ cc_library( ], ) +cc_library( + name = "grpc_services", + srcs = [], + hdrs = [ + "platform/grpc_services.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = tf_grpc_service_all(), +) + cc_library( name = "human_readable_json", srcs = tf_platform_srcs(["human_readable_json.cc"]), @@ -452,10 +480,7 @@ cc_library( hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], - deps = [ - ":lib_proto_parsing", - "@protobuf_archive//:protobuf", - ], + deps = [":lib_proto_parsing"], ) filegroup( @@ -506,6 +531,7 @@ cc_library( ":platform_port", ":platform_protobuf", "//tensorflow/core/platform/default/build_config:env", + "//tensorflow/core/platform/default/build_config:port", ], ) @@ -661,7 +687,7 @@ cc_library( name = "lib_proto_compiler", hdrs = [ "platform/protobuf_compiler.h", - ] + tf_additional_proto_compiler_hdrs(), + ], copts = tf_copts(), deps = tf_lib_proto_compiler_deps() + [ ":lib_proto_parsing", @@ -907,6 +933,7 @@ tf_cuda_library( "framework/tensor_slice.h", "framework/tensor_types.h", "framework/tensor_util.h", + "framework/thread_factory.h", "framework/tracking_allocator.h", "framework/type_index.h", "framework/type_traits.h", @@ -953,7 +980,10 @@ tf_cuda_library( "util/mkl_util.h", ]), visibility = ["//visibility:public"], - deps = [":framework_internal"], + deps = [ + ":framework_internal", + "@com_google_absl//absl/base", + ], ) cc_library( @@ -1046,13 +1076,13 @@ cc_library( "platform/default/integral_types.h", "platform/default/logging.h", "platform/default/mutex.h", - "platform/default/protobuf.h", "platform/default/thread_annotations.h", "platform/dynamic_annotations.h", "platform/macros.h", "platform/mutex.h", "platform/platform.h", "platform/prefetch.h", + "platform/protobuf.h", "platform/thread_annotations.h", "platform/types.h", "platform/cpu_info.h", @@ -1138,6 +1168,13 @@ tf_gen_op_libs( deps = [":protos_all_cc"], ) +tf_gen_op_libs( + op_lib_names = [ + "mkl_array_ops", + ], + deps = [":protos_all_cc"], +) + tf_gen_op_libs( op_lib_names = [ "audio_ops", @@ -1158,6 +1195,29 @@ tf_gen_op_libs( deps = [":lib"], ) +tf_gen_op_libs( + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + deps = [ + ":lib", + ":lib_proto_parsing", + ":protos_all_cc", + "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + ], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1274,10 +1334,23 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), + ] + if_mkl([ + ":mkl_array_ops_op_lib", + ":mkl_nn_ops_op_lib", + ]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -1379,8 +1452,9 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( name = "all_kernels_impl", - visibility = ["//visibility:private"], + visibility = ["//tensorflow/core:__subpackages__"], deps = [ + "//tensorflow/c/kernels:bitcast_op", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", "//tensorflow/core/kernels:batch_kernels", @@ -1449,6 +1523,7 @@ cc_library( "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", + "//tensorflow/core/kernels:mkl_requantize_ops", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", @@ -1522,6 +1597,7 @@ cc_library( "framework/function_testlib.h", "framework/shape_inference_testutil.h", "framework/tensor_testutil.h", + "graph/benchmark_testlib.h", "graph/testlib.h", # TODO(josh11b): Drop this once users are depending on # kernels:ops_testutil instead. @@ -1537,6 +1613,7 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", + ":ops", ":protos_all_cc", ":shape_inference_testutil", ":tensor_testutil", @@ -1561,8 +1638,8 @@ cc_library( srcs = ["common_runtime/testlib_ops.cc"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", + ":framework", + ":lib", ], alwayslink = 1, ) @@ -1637,6 +1714,7 @@ filegroup( "platform/**/logger.cc", "platform/default/test_benchmark.*", "platform/cuda.h", + "platform/rocm.h", "platform/google/**/*", "platform/hadoop/**/*", "platform/gif.h", @@ -1663,6 +1741,7 @@ filegroup( "//tensorflow/core/kernels:android_srcs", "//tensorflow/core/util/ctc:android_srcs", "//tensorflow/core/util/tensor_bundle:android_srcs", + "//tensorflow/c:srcs", ] + glob( [ "common_runtime/**/*.h", @@ -1767,6 +1846,29 @@ cc_library( ], ) +cc_library( + name = "emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + srcs = if_emscripten(["//tensorflow/core:mobile_srcs_no_runtime"]), + copts = ["-DSUPPORT_SELECTIVE_REGISTRATION"] + tf_opts_nortti_if_emscripten(), + defines = ["TENSORFLOW_LITE_PROTOS"], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":emscripten_proto_lib_no_rtti_lite_runtime", + ":mobile_additional_lib_deps", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@zlib_archive//:zlib", + ], + alwayslink = 1, +) + # Native library support for iOS applications. # # bazel build --config=ios_x86_64 \ @@ -1860,6 +1962,7 @@ filegroup( "**/*testutil*", "**/*testlib*", "**/*main.cc", + "**/tpu_*", ], ), visibility = ["//visibility:public"], @@ -1962,6 +2065,14 @@ cc_library( ], ) +cc_library( + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform/default/build_config:rocm", + ], +) + # ----------------------------------------------------------------------------- # Clif-related proto libraries. @@ -2066,6 +2177,29 @@ tf_pyclif_proto_library( # ----------------------------------------------------------------------------- # Internal targets +tf_proto_library( + name = "autotuning_proto", + srcs = ["protobuf/autotuning.proto"], + cc_api_version = 2, + default_header = True, + provide_cc_alias = True, + visibility = [ + "//tensorflow:internal", + ], +) + +tf_proto_library( + name = "conv_autotuning_proto", + srcs = ["protobuf/conv_autotuning.proto"], + cc_api_version = 2, + default_header = True, + protodeps = tf_additional_all_protos(), + provide_cc_alias = True, + visibility = [ + "//tensorflow:internal", + ], +) + tf_proto_library_cc( name = "worker_proto", srcs = ["protobuf/worker.proto"], @@ -2135,6 +2269,7 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob( "platform/jpeg.h", "platform/png.h", "platform/**/cuda.h", + "platform/**/rocm.h", "platform/**/stream_executor.h", ], ) @@ -2183,11 +2318,14 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ ] # Replicated for lib_internal and lib_internal_impl. -LIB_INTERNAL_DEFINES = (tf_additional_lib_defines() + [ - "TF_USE_SNAPPY", - ] + tf_additional_verbs_lib_defines() + - tf_additional_mpi_lib_defines() + - tf_additional_gdr_lib_defines()) +LIB_INTERNAL_DEFINES = ( + tf_additional_lib_defines() + [ + "TF_USE_SNAPPY", + ] + tf_additional_verbs_lib_defines() + + tf_additional_mpi_lib_defines() + + tf_additional_gdr_lib_defines() + + tf_additional_numa_lib_defines() +) cc_library( name = "lib_internal", @@ -2220,7 +2358,6 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", - ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2238,12 +2375,14 @@ cc_library( "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", + "platform/protobuf.cc", ], ) + tf_additional_lib_srcs( exclude = [ "**/*test*", "platform/**/cuda.h", "platform/**/cuda_libdevice_path.cc", + "platform/**/rocm.h", "platform/**/stream_executor.h", "platform/**/env_time.cc", "platform/**/device_tracer.cc", @@ -2260,17 +2399,20 @@ cc_library( copts = tf_copts(), defines = LIB_INTERNAL_DEFINES, deps = tf_additional_lib_deps() + [ - ":lib_hash_crc32c_accelerate_internal", - ":lib_proto_parsing", - ":abi", - ":core_stringpiece", - "//third_party/eigen3", - "//tensorflow/core/platform/default/build_config:platformlib", - "@snappy", - "@zlib_archive//:zlib", - "@double_conversion//:double-conversion", - "@protobuf_archive//:protobuf", - ] + tf_protos_all_impl() + tf_protos_grappler_impl(), + ":lib_hash_crc32c_accelerate_internal", + ":lib_proto_parsing", + ":abi", + ":core_stringpiece", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//third_party/eigen3", + "//tensorflow/core/platform/default/build_config:platformlib", + "@snappy", + "@zlib_archive//:zlib", + "@double_conversion//:double-conversion", + "@protobuf_archive//:protobuf", + ] + tf_protos_all_impl() + tf_protos_grappler_impl() + + tf_additional_numa_deps(), ) # File compiled with extra flags to get cpu-specific acceleration. @@ -2360,12 +2502,6 @@ cc_library( cc_library( name = "tflite_portable_logging", - srcs = [ - ] + if_ios([ - "platform/default/logging.cc", - "platform/env_time.cc", - "platform/posix/env_time.cc", - ]), hdrs = [ "lib/bfloat16/bfloat16.h", "platform/default/integral_types.h", @@ -2374,10 +2510,11 @@ cc_library( "platform/macros.h", "platform/platform.h", "platform/types.h", - ] + if_windows(["platform/windows/integral_types.h"]) + if_ios(["platform/env_time.h"]), + ], copts = tf_copts(), linkopts = ["-ldl"], deps = [ + ":platform_base", "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2643,7 +2780,6 @@ tf_cuda_library( "example/**/*.cc", "framework/**/*.cc", "util/**/*.cc", - ] + [ "graph/edgeset.cc", "graph/graph.cc", "graph/graph_def_builder.cc", @@ -2695,6 +2831,7 @@ tf_cuda_library( ":stats_calculator_portable", ":version_lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "//tensorflow/core/platform/default/build_config:platformlib", @@ -2741,6 +2878,7 @@ tf_cuda_library( srcs = ["platform/stream_executor.h"], hdrs = [ "platform/cuda.h", + "platform/rocm.h", "platform/stream_executor.h", ], deps = [ @@ -2831,6 +2969,7 @@ tf_cuda_library( "//third_party/eigen3", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -2887,6 +3026,7 @@ tf_cuda_library( CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/allocator_retry.h", + "common_runtime/shared_counter.h", "common_runtime/base_collective_executor.h", "common_runtime/bfc_allocator.h", "common_runtime/hierarchical_tree_broadcaster.h", @@ -2911,6 +3051,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/lower_if_while.h", "common_runtime/lower_while_op.h", "common_runtime/memory_types.h", + "common_runtime/metrics.h", "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", @@ -2922,6 +3063,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", + "common_runtime/ring_alg.h", + "common_runtime/ring_gatherer.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2946,6 +3089,8 @@ tf_cuda_library( "common_runtime/collective_param_resolver_local.cc", "common_runtime/collective_rma_local.cc", "common_runtime/collective_util.cc", + "common_runtime/colocation_graph.cc", + "common_runtime/colocation_graph.h", "common_runtime/constant_folding.cc", "common_runtime/copy_tensor.cc", "common_runtime/costmodel_manager.cc", @@ -2966,6 +3111,7 @@ tf_cuda_library( "common_runtime/lower_if_while.cc", "common_runtime/lower_while_op.cc", "common_runtime/memory_types.cc", + "common_runtime/metrics.cc", "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", @@ -2978,11 +3124,14 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/ring_alg.cc", + "common_runtime/ring_gatherer.cc", "common_runtime/ring_reducer.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", "common_runtime/session_options.cc", "common_runtime/session_state.cc", + "common_runtime/single_threaded_cpu_device.cc", "common_runtime/stats_publisher_interface.cc", "common_runtime/step_stats_collector.cc", "common_runtime/threadpool_device.cc", @@ -3005,6 +3154,7 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "//third_party/eigen3", @@ -3066,15 +3216,6 @@ cc_library( deps = [":lib_internal"], ) -tf_cuda_library( - name = "metrics", - srcs = ["common_runtime/metrics.cc"], - hdrs = ["common_runtime/metrics.h"], - deps = [ - ":lib", - ], -) - tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], @@ -3091,7 +3232,6 @@ tf_cuda_library( ":graph", ":lib", ":lib_internal", - ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", @@ -3175,7 +3315,7 @@ cc_library( ) GPU_RUNTIME_HEADERS = [ - "common_runtime/gpu/cuda_host_allocator.h", + "common_runtime/gpu/gpu_host_allocator.h", "common_runtime/gpu/gpu_bfc_allocator.h", "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", @@ -3458,6 +3598,7 @@ tf_cc_tests( "platform/vmodule_benchmark_test.cc", ], deps = [ + ":core_cpu_internal", ":lib", ":lib_internal", ":lib_test_internal", @@ -3502,6 +3643,7 @@ tf_cc_test( name = "platform_strings_test", size = "small", srcs = ["platform/platform_strings_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs deps = [ ":lib", ":platform_strings", @@ -3667,6 +3809,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "lib_strings_proto_serialization_test", + srcs = ["lib/strings/proto_serialization_test.cc"], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + ":test_main", + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "lib_random_weighted_picker_test", size = "medium", @@ -3719,7 +3875,6 @@ tf_cc_tests( srcs = [ "common_runtime/buf_rendezvous_test.cc", "common_runtime/collective_executor_mgr_test.cc", - "common_runtime/collective_param_resolver_local_test.cc", "common_runtime/collective_rma_local_test.cc", "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", @@ -3799,7 +3954,7 @@ tf_cc_tests( "util/work_sharder_test.cc", ], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), linkstatic = tf_kernel_tests_linkstatic(), @@ -3835,10 +3990,11 @@ tf_cc_tests( name = "higher_level_tests_needing_kernels", size = "small", srcs = [ + "common_runtime/collective_param_resolver_local_test.cc", "graph/graph_constructor_test.cc", ], linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), linkstatic = tf_kernel_tests_linkstatic(), @@ -3874,7 +4030,6 @@ tf_cc_test( "ops/cudnn_rnn_ops_test.cc", ], deps = [ - ":cudnn_rnn_ops", "//tensorflow/core", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -3934,6 +4089,35 @@ tf_cc_tests_gpu( ], ) +tf_cc_tests_gpu( + name = "ring_gatherer_test", + size = "medium", + srcs = [ + "common_runtime/ring_gatherer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":all_kernels", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":protos_test_cc", + ":test", + ":test_main", + ":testlib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_tests_gpu( name = "hierarchical_tree_broadcaster_test", size = "medium", @@ -4472,7 +4656,7 @@ tf_cc_test( "//tensorflow/cc:scope", "//tensorflow/core/kernels:cwise_op", "//third_party/eigen3", - ], + ] + if_mkl([":mkl_array_ops_op_lib"]), ) tf_cc_test( @@ -4885,6 +5069,7 @@ tf_cc_tests( "//tensorflow/cc:client_session", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", ], ) @@ -4908,6 +5093,7 @@ tf_cc_tests( "//tensorflow/cc:client_session", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "@com_google_absl//absl/algorithm:container", ], ) @@ -5025,6 +5211,39 @@ transitive_hdrs( # ----------------------------------------------------------------------------- # Google-internal targets go here (must be at the end). +load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") + +genrule( + name = "emscripten_proto_config_lite_runtime", + outs = ["emscripten_proto_config_lite_runtime.asciipb"], + cmd = tf_genrule_cmd_append_to_srcs("optimize_mode:LITE_RUNTIME"), + visibility = ["//visibility:private"], +) + +# We are keeping the "android" version of tf_android_core_proto_headers. All it does is +# normalize CORE_PROTO_SRCS to generate valid output file names. +tf_portable_proto_library( + name = "emscripten_proto_lib_no_rtti_lite_runtime", + config = ":emscripten_proto_config_lite_runtime", + copts = tf_opts_nortti_if_emscripten(), + features = tf_features_nomodules_if_emscripten(), + header_outs = tf_android_core_proto_headers(CORE_PROTO_SRCS) + ["//google/protobuf/any.proto.h"], + link_full_protobuf = False, + prefix_dir = "emscripten_proto_no_rtti", + proto_deps = [ + ":protos_all_cc", + "@protobuf_archive//:protobuf", + ], + visibility = ["//visibility:public"], +) + +# There is currently no need for a full proto version of emscripten tf lib lite. +alias( + name = "emscripten_lib_lite_no_runtime", + actual = "//tensorflow/core:emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + visibility = ["//visibility:public"], +) + alias( name = "android_srcs_no_runtime", actual = ":mobile_srcs_no_runtime", diff --git a/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e585bae4a373c6d5afe217b74acd37caa0262023 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt @@ -0,0 +1,68 @@ +op { + graph_op_name: "AllToAll" + visibility: HIDDEN + in_arg { + name: "input" + description: <

+// +//
+func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Sinh", + Type: "ParallelDynamicStitch", Input: []tf.Input{ - x, + tf.OutputList(indices), tf.OutputList(data), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the sum along sparse segments of a tensor. +// Partitions `data` into `num_partitions` tensors using indices from `partitions`. // -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. +// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +// are placed in `outputs[i]` in lexicographic order of `js`, and the first +// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +// In detail, // -// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. +// ```python +// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] +// +// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) +// ``` +// +// `data.shape` must start with `partitions.shape`. // // For example: // // ```python -// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) -// -// # Select two rows, one segment. -// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) -// # => [[0 0 0 0]] +// # Scalar partitions. +// partitions = 1 +// num_partitions = 2 +// data = [10, 20] +// outputs[0] = [] # Empty with shape [0, 2] +// outputs[1] = [[10, 20]] // -// # Select two rows, two segment. -// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) -// # => [[ 1 2 3 4] -// # [-1 -2 -3 -4]] +// # Vector partitions. +// partitions = [0, 0, 1, 1, 0] +// num_partitions = 2 +// data = [10, 20, 30, 40, 50] +// outputs[0] = [10, 20, 50] +// outputs[1] = [30, 40] +// ``` // -// # Select all rows, two segments. -// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) -// # => [[0 0 0 0] -// # [5 6 7 8]] +// See `dynamic_stitch` for an example on how to merge partitions back. // -// # Which is equivalent to: -// tf.segment_sum(c, tf.constant([0, 0, 1])) -// ``` +//
+// +//
// // Arguments: // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// partitions: Any shape. Indices in the range `[0, num_partitions)`. +// num_partitions: The number of partitions to output. +func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_partitions": num_partitions} opspec := tf.OpSpec{ - Type: "SparseSegmentSum", + Type: "DynamicPartition", Input: []tf.Input{ - data, indices, segment_ids, + data, partitions, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes natural logarithm of x element-wise. -// -// I.e., \\(y = \log_e x\\). -func Log(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Log", - Input: []tf.Input{ - x, - }, + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("DynamicPartition", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return outputs } -// Rounds the values of a tensor to the nearest integer, element-wise. +// Produces a string handle for the given MultiDeviceIterator. // -// Rounds half to even. Also known as bankers rounding. If you want to round -// according to the current system rounding mode use std::cint. -func Round(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// multi_device_iterator: A MultiDeviceIterator resource. +// +// Returns A string representing the resource. +func MultiDeviceIteratorToStringHandle(scope *Scope, multi_device_iterator tf.Output) (string_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Round", + Type: "MultiDeviceIteratorToStringHandle", Input: []tf.Input{ - x, + multi_device_iterator, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes reciprocal of square root of x element-wise. +// Checks whether a tree has been initialized. // -// I.e., \\(y = 1 / \sqrt{x}\\). -func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// tree_handle: Handle to the tree. +// +// Returns Whether the tree is initialized. +func TensorForestTreeIsInitializedOp(scope *Scope, tree_handle tf.Output) (is_initialized tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Rsqrt", + Type: "TensorForestTreeIsInitializedOp", Input: []tf.Input{ - x, + tree_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Get the value of the tensor specified by its handle. +// Gets next element for the provided shard number. // // Arguments: -// handle: The handle for a tensor stored in the session state. -// dtype: The type of the output value. +// multi_device_iterator: A MultiDeviceIterator resource. +// shard_num: Integer representing which shard to fetch data for. +// incarnation_id: Which incarnation of the MultiDeviceIterator is running. +// output_types: The type list for the return values. +// output_shapes: The list of shapes being produced. // -// Returns The tensor for the given handle. -func GetSessionTensor(scope *Scope, handle tf.Output, dtype tf.DataType) (value tf.Output) { +// Returns Result of the get_next on the dataset. +func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.Output, shard_num tf.Output, incarnation_id tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "GetSessionTensor", + Type: "MultiDeviceIteratorGetNextFromShard", Input: []tf.Input{ - handle, + multi_device_iterator, shard_num, incarnation_id, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("MultiDeviceIteratorGetNextFromShard", err) + return + } + return components } -// Computes the gradient for the sqrt of `x` wrt its input. +// Initializes the multi device iterator with the given dataset. // -// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` -// is the corresponding input gradient. -func SqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// Arguments: +// dataset: Dataset to be iterated upon. +// multi_device_iterator: A MultiDeviceIteratorResource. +// max_buffer_size: The maximum size of the host side per device buffer to keep. +// +// Returns An int64 indicating which incarnation of the MultiDeviceIterator +// is running. +func MultiDeviceIteratorInit(scope *Scope, dataset tf.Output, multi_device_iterator tf.Output, max_buffer_size tf.Output) (incarnation_id tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SqrtGrad", + Type: "MultiDeviceIteratorInit", Input: []tf.Input{ - y, dy, + dataset, multi_device_iterator, max_buffer_size, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatrixInverseAttr is an optional argument to MatrixInverse. -type MatrixInverseAttr func(optionalAttr) - -// MatrixInverseAdjoint sets the optional adjoint attribute to value. -// If not specified, defaults to false -func MatrixInverseAdjoint(value bool) MatrixInverseAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Computes the inverse of one or more square invertible matrices or their +// Copy a tensor setting everything outside a central band in each innermost matrix // -// adjoints (conjugate transposes). +// to zero. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the inverse for all input submatrices `[..., :, :]`. +// The `band` part is computed as follows: +// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +// tensor with the same shape where // -// The op uses LU decomposition with partial pivoting to compute the inverses. +// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. // -// If a matrix is not invertible there is no guarantee what the op does. It -// may detect the condition and raise an exception or it may simply return a -// garbage result. +// The indicator function // -// Arguments: -// input: Shape is `[..., M, M]`. +// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && +// (num_upper < 0 || (n-m) <= num_upper)`. // -// Returns Shape is `[..., M, M]`. +// For example: // -// @compatibility(numpy) -// Equivalent to np.linalg.inv -// @end_compatibility -func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) (output tf.Output) { +// ``` +// # if 'input' is [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [-2, -1, 0, 1] +// [-3, -2, -1, 0]], +// +// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [ 0, -1, 0, 1] +// [ 0, 0, -1, 0]], +// +// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] +// [-1, 0, 1, 0] +// [-2, -1, 0, 1] +// [ 0, -2, -1, 0]] +// ``` +// +// Useful special cases: +// +// ``` +// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. +// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. +// tf.matrix_band_part(input, 0, 0) ==> Diagonal. +// ``` +// +// Arguments: +// input: Rank `k` tensor. +// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire +// lower triangle. +// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep +// entire upper triangle. +// +// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. +func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MatrixInverse", + Type: "MatrixBandPart", Input: []tf.Input{ - input, + input, num_lower, num_upper, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x + y element-wise. -// -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Gets the next output from the given iterator as an Optional variant. +func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (optional tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Add", + Type: "IteratorGetNextAsOptional", Input: []tf.Input{ - x, y, + iterator, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the derivative of a Gamma random sample w.r.t. `alpha`. -func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { +// Returns the value stored in an Optional variant or raises an error if none exists. +func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "RandomGammaGrad", + Type: "OptionalGetValue", Input: []tf.Input{ - alpha, sample, + optional, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes square of x element-wise. -// -// I.e., \\(y = x * x = x^2\\). -func Square(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Square", - Input: []tf.Input{ - x, - }, + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("OptionalGetValue", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return components } -// Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. -// -// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) -// ](http://arxiv.org/abs/1511.07289) -func Elu(scope *Scope, features tf.Output) (activations tf.Output) { +// Returns true if and only if the given Optional variant has a value. +func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Elu", + Type: "OptionalHasValue", Input: []tf.Input{ - features, + optional, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the reciprocal of x element-wise. +// Deserializes a proto into the tree handle // -// I.e., \\(y = 1 / x\\). -func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// tree_handle: Handle to the tree resource to be restored. +// tree_config: Serialied proto string of the boosted_trees.Tree proto. +// +// Returns the created operation. +func TensorForestTreeDeserialize(scope *Scope, tree_handle tf.Output, tree_config tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Reciprocal", + Type: "TensorForestTreeDeserialize", Input: []tf.Input{ - x, + tree_handle, tree_config, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Returns a batched matrix tensor with new batched diagonal values. -// -// Given `input` and `diagonal`, this operation returns a tensor with the -// same shape and values as `input`, except for the main diagonal of the -// innermost matrices. These will be overwritten by the values in `diagonal`. -// -// The output is computed as follows: -// -// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has -// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a -// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: -// -// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. -// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. -// -// Arguments: -// input: Rank `k+1`, where `k >= 1`. -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = input.shape`. -func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { +// Constructs an Optional variant from a tuple of tensors. +func OptionalFromValue(scope *Scope, components []tf.Output) (optional tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixSetDiag", + Type: "OptionalFromValue", Input: []tf.Input{ - input, diagonal, + tf.OutputList(components), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the element-wise max of two SparseTensors. +// Creates a dataset by applying optimizations to `input_dataset`. // -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// Creates a dataset by applying optimizations to `input_dataset`. // // Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// input_dataset: A variant tensor representing the input dataset. +// optimizations: A `tf.string` vector `tf.Tensor` identifying optimizations to use. // -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { +// +func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SparseSparseMaximum", + Type: "OptimizeDataset", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, + input_dataset, optimizations, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Computes the reciprocal of x element-wise. +// Converts the given variant tensor to an iterator and stores it in the given resource. // -// I.e., \\(y = 1 / x\\). -func Inv(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// resource_handle: A handle to an iterator resource. +// serialized: A variant tensor storing the state of the iterator contained in the +// resource. +// +// Returns the created operation. +func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Inv", + Type: "DeserializeIterator", Input: []tf.Input{ - x, + resource_handle, serialized, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset that batches input elements into a SparseTensor. +// Gather slices from `params` axis `axis` according to `indices`. // -// Arguments: -// input_dataset: A handle to an input dataset. Must have a single component. -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// row_shape: A vector representing the dense shape of each row in the produced -// SparseTensor. The shape may be partially specified, using `-1` to indicate -// that a particular dimension should use the maximum size of all batch elements. +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `params.shape[:axis] + indices.shape + +// params.shape[axis + 1:]` where: // +// ```python +// # Scalar indices (output is rank(params) - 1). +// output[a_0, ..., a_n, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices, b_0, ..., b_n] // -func ExperimentalDenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalDenseToSparseBatchDataset", - Input: []tf.Input{ - input_dataset, batch_size, row_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ComplexAbsAttr is an optional argument to ComplexAbs. -type ComplexAbsAttr func(optionalAttr) - -// ComplexAbsTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ComplexAbsTout(value tf.DataType) ComplexAbsAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Computes the complex absolute value of a tensor. +// # Vector indices (output is rank(params)). +// output[a_0, ..., a_n, i, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] // -// Given a tensor `x` of complex numbers, this operation returns a tensor of type -// `float` or `double` that is the absolute value of each element in `x`. All -// elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute -// value is computed as \\( \sqrt{a^2 + b^2}\\). -func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ComplexAbs", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the truth value of x AND y element-wise. +// # Higher rank indices (output is rank(params) + rank(indices) - 1). +// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = +// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] +// ``` // -// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +//
+// +//
+// +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, a 0 is stored in the +// corresponding output value. +// +// See also `tf.batch_gather` and `tf.gather_nd`. +// +// Arguments: +// params: The tensor from which to gather values. Must be at least rank +// `axis + 1`. +// indices: Index tensor. Must be in range `[0, params.shape[axis])`. +// axis: The axis in `params` to gather `indices` from. Defaults to the first +// dimension. Supports negative indexes. +// +// Returns Values from `params` gathered from indices given by `indices`, with +// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. +func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LogicalAnd", + Type: "GatherV2", Input: []tf.Input{ - x, y, + params, indices, axis, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// CastAttr is an optional argument to Cast. -type CastAttr func(optionalAttr) - -// CastTruncate sets the optional Truncate attribute to value. -// If not specified, defaults to false -func CastTruncate(value bool) CastAttr { - return func(m optionalAttr) { - m["Truncate"] = value - } -} - -// Cast x of type SrcT to y of DstT. -func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) { +// Converts the given `resource_handle` representing an iterator to a variant tensor. +// +// Arguments: +// resource_handle: A handle to an iterator resource. +// +// Returns A variant tensor storing the state of the iterator contained in the +// resource. +func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"DstT": DstT} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Cast", + Type: "SerializeIterator", Input: []tf.Input{ - x, + resource_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -8264,8 +7789,9 @@ type RegexReplaceAttr func(optionalAttr) // RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. // -// value: If True, the replacement is global, otherwise the replacement -// is done only on the first match. +// value: If True, the replacement is global (that is, all matches of the `pattern` regular +// expression in each input string are rewritten), otherwise the `rewrite` +// substitution is only made for the first `pattern` match. // If not specified, defaults to true func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { return func(m optionalAttr) { @@ -8273,16 +7799,18 @@ func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { } } -// Replaces the match of pattern in input with rewrite. +// Replaces matches of the `pattern` regular expression in `input` with the +// replacement string provided in `rewrite`. // // It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) // // Arguments: // input: The text to be processed. -// pattern: The regular expression to match the input. -// rewrite: The rewrite to be applied to the matched expression. +// pattern: The regular expression to be matched in the `input` strings. +// rewrite: The rewrite string to be substituted for the `pattern` expression where it is +// matched in the `input` strings. // -// Returns The text after applying pattern and rewrite. +// Returns The text after applying pattern match and rewrite substitution. func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { if scope.Err() != nil { return @@ -8740,24 +8268,6 @@ func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { return op.Output(0) } -// Returns the truth value of (x <= y) element-wise. -// -// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LessEqual", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes softmax activations. // // For each batch `i` and class `j` we have @@ -8827,85 +8337,33 @@ func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (ima return op.Output(0) } -// BatchMatMulAttr is an optional argument to BatchMatMul. -type BatchMatMulAttr func(optionalAttr) - -// BatchMatMulAdjX sets the optional adj_x attribute to value. -// -// value: If `True`, adjoint the slices of `x`. Defaults to `False`. -// If not specified, defaults to false -func BatchMatMulAdjX(value bool) BatchMatMulAttr { - return func(m optionalAttr) { - m["adj_x"] = value - } -} - -// BatchMatMulAdjY sets the optional adj_y attribute to value. -// -// value: If `True`, adjoint the slices of `y`. Defaults to `False`. -// If not specified, defaults to false -func BatchMatMulAdjY(value bool) BatchMatMulAttr { - return func(m optionalAttr) { - m["adj_y"] = value - } -} - -// Multiplies slices of two tensors in batches. -// -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be adjointed (to adjoint a matrix -// means to transpose and conjugate it) before multiplication by setting -// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if adj_x else r_x -// c_o = r_y if adj_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -// -// Arguments: -// x: 2-D or higher with shape `[..., r_x, c_x]`. -// y: 2-D or higher with shape `[..., r_y, c_y]`. +// Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. // -// Returns 3-D or higher with shape `[..., r_o, c_o]` -func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { +// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) +// ](http://arxiv.org/abs/1511.07289) +func Elu(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "BatchMatMul", + Type: "Elu", Input: []tf.Input{ - x, y, + features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns which elements of x are NaN. +// Computes square of x element-wise. // -// @compatibility(numpy) -// Equivalent to np.isnan -// @end_compatibility -func IsNan(scope *Scope, x tf.Output) (y tf.Output) { +// I.e., \\(y = x * x = x^2\\). +func Square(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsNan", + Type: "Square", Input: []tf.Input{ x, }, @@ -8914,125 +8372,6 @@ func IsNan(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Identity op for gradient debugging. -// -// This op is hidden from public in Python. It is used by TensorFlow Debugger to -// register gradient tensors for gradient debugging. -// This op operates on non-reference-type tensors. -func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DebugGradientIdentity", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. -type ResourceSparseApplyAdadeltaAttr func(optionalAttr) - -// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// var: Should be from a Variable(). -// -// Arguments: -// -// accum: Should be from a Variable(). -// accum_update: : Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdadelta", - Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Checks whether a tree has been initialized. -// -// Arguments: -// tree_handle: Handle to the tree. -// -// Returns Whether the tree is initialized. -func TensorForestTreeIsInitializedOp(scope *Scope, tree_handle tf.Output) (is_initialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorForestTreeIsInitializedOp", - Input: []tf.Input{ - tree_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gets next element for the provided shard number. -// -// Arguments: -// multi_device_iterator: A MultiDeviceIterator resource. -// shard_num: Integer representing which shard to fetch data for. -// incarnation_id: Which incarnation of the MultiDeviceIterator is running. -// output_types: The type list for the return values. -// output_shapes: The list of shapes being produced. -// -// Returns Result of the get_next on the dataset. -func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.Output, shard_num tf.Output, incarnation_id tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "MultiDeviceIteratorGetNextFromShard", - Input: []tf.Input{ - multi_device_iterator, shard_num, incarnation_id, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("MultiDeviceIteratorGetNextFromShard", err) - return - } - return components -} - // LeakyReluGradAttr is an optional argument to LeakyReluGrad. type LeakyReluGradAttr func(optionalAttr) @@ -9071,54 +8410,6 @@ func LeakyReluGrad(scope *Scope, gradients tf.Output, features tf.Output, option return op.Output(0) } -// Deprecated. Use TensorArrayGradV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 -func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayWriteV2", - Input: []tf.Input{ - handle, index, value, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LeakyReluAttr is an optional argument to LeakyRelu. -type LeakyReluAttr func(optionalAttr) - -// LeakyReluAlpha sets the optional alpha attribute to value. -// If not specified, defaults to 0.2 -func LeakyReluAlpha(value float32) LeakyReluAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// Computes rectified linear: `max(features, features * alpha)`. -func LeakyRelu(scope *Scope, features tf.Output, optional ...LeakyReluAttr) (activations tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LeakyRelu", - Input: []tf.Input{ - features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes rectified linear 6: `min(max(features, 0), 6)`. func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { @@ -9232,7 +8523,7 @@ func SdcaOptimizerV2(scope *Scope, sparse_example_indices []tf.Output, sparse_fe // Computes the minimum along segments of a tensor. // // Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation) +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) // for an explanation of segments. // // This operator is similar to the unsorted segment sum operator found @@ -9246,6 +8537,15 @@ func SdcaOptimizerV2(scope *Scope, sparse_example_indices []tf.Output, sparse_fe // possible value for the specific numeric type, // `output[i] = numeric_limits::max()`. // +// For example: +// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 1, 2, 2, 1], +// # [5, 6, 7, 8]] +// ``` +// // If the given segment ID `i` is negative, then the corresponding value is // dropped, and will not be included in the result. // @@ -9293,28 +8593,32 @@ func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops return op.Output(0) } -// Computes the gradient of morphological 2-D dilation with respect to the input. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. +type TensorArrayGatherV2Attr func(optionalAttr) + +// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Deprecated. Use TensorArrayGatherV3 // -// Returns 4-D with shape `[batch, in_height, in_width, depth]`. -func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (in_backprop tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 +func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Dilation2DBackpropInput", + Type: "TensorArrayGatherV2", Input: []tf.Input{ - input, filter, out_backprop, + handle, indices, flow_in, }, Attrs: attrs, } @@ -9322,52 +8626,42 @@ func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, ou return op.Output(0) } -// Compute the polygamma function \\(\psi^{(n)}(x)\\). -// -// The polygamma function is defined as: -// -// -// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) +// Returns the truth value of (x == y) element-wise. // -// where \\(\psi(x)\\) is the digamma function. -func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { +// *NOTE*: `Equal` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Polygamma", + Type: "Equal", Input: []tf.Input{ - a, x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes second-order gradients of the maxpooling function. +// Compute the polygamma function \\(\psi^{(n)}(x)\\). // -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// input of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// The polygamma function is defined as: // -// Returns Gradients of gradients w.r.t. the input of `max_pool`. -func MaxPoolGradGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { +// +// \\(\psi^{(a)}(x) = \frac{d^a}{dx^a} \psi(x)\\) +// +// where \\(\psi(x)\\) is the digamma function. +// The polygamma function is defined only for non-negative integer orders \\a\\. +func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "MaxPoolGradGradWithArgmax", + Type: "Polygamma", Input: []tf.Input{ - input, grad, argmax, + a, x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -9421,6 +8715,19 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } +// MaxPoolGradWithArgmaxAttr is an optional argument to MaxPoolGradWithArgmax. +type MaxPoolGradWithArgmaxAttr func(optionalAttr) + +// MaxPoolGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradWithArgmaxAttr { + return func(m optionalAttr) { + m["include_batch_in_index"] = value + } +} + // Computes gradients of the maxpooling function. // // Arguments: @@ -9434,11 +8741,14 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output // padding: The type of padding algorithm to use. // // Returns Gradients w.r.t. the input of `max_pool`. -func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { +func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradWithArgmaxAttr) (output tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ Type: "MaxPoolGradWithArgmax", Input: []tf.Input{ @@ -9495,6 +8805,21 @@ func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { return op.Output(0) } +// Connects N inputs to an N-way replicated TPU computation. +func TPUReplicatedInput(scope *Scope, inputs []tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TPUReplicatedInput", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AvgPool3DAttr is an optional argument to AvgPool3D. type AvgPool3DAttr func(optionalAttr) @@ -9542,98 +8867,6 @@ func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, pa return op.Output(0) } -// Returns element-wise remainder of division. This emulates C semantics in that -// -// the result here is consistent with a truncating divide. E.g. -// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. -// -// *NOTE*: `Mod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Mod", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes offsets of concat inputs within its output. -// -// For example: -// -// ``` -// # 'x' is [2, 2, 7] -// # 'y' is [2, 3, 7] -// # 'z' is [2, 5, 7] -// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] -// ``` -// -// This is typically used by gradient computations for a concat operation. -// -// Arguments: -// concat_dim: The dimension along which to concatenate. -// shape: The `N` int32 vectors representing shape of tensors being concatenated. -// -// Returns The `N` int32 vectors representing the starting offset -// of input tensors within the concatenated output. -func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatOffset", - Input: []tf.Input{ - concat_dim, tf.OutputList(shape), - }, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { - scope.UpdateErr("ConcatOffset", err) - return - } - return offset -} - -// Compute the lower regularized incomplete Gamma function `P(a, x)`. -// -// The lower regularized incomplete Gamma function is defined as: -// -// -// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) -// -// where -// -// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) -// -// is the lower incomplete Gamma function. -// -// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete -// Gamma function. -func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igamma", - Input: []tf.Input{ - a, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DepthToSpaceAttr is an optional argument to DepthToSpace. type DepthToSpaceAttr func(optionalAttr) @@ -9822,37 +9055,20 @@ func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output return op.Output(0) } -// Computes square root of x element-wise. -// -// I.e., \\(y = \sqrt{x} = x^{1/2}\\). -func Sqrt(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sqrt", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv3DBackpropFilterAttr is an optional argument to Conv3DBackpropFilter. -type Conv3DBackpropFilterAttr func(optionalAttr) +// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. +type Conv3DBackpropInputAttr func(optionalAttr) -// Conv3DBackpropFilterDilations sets the optional dilations attribute to value. +// Conv3DBackpropInputDilations sets the optional dilations attribute to value. // If not specified, defaults to -func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { +func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value } } -// Computes the gradients of 3-D convolution with respect to the filter. +// Computes the gradients of 3-D convolution with respect to the input. // -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropFilterV2 +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 // // Arguments: // input: Shape `[batch, depth, rows, cols, in_channels]`. @@ -9863,7 +9079,7 @@ func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { // strides: 1-D tensor of length 5. The stride of the sliding window for each // dimension of `input`. Must have `strides[0] = strides[4] = 1`. // padding: The type of padding algorithm to use. -func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterAttr) (output tf.Output) { +func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -9872,7 +9088,7 @@ func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_b a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilter", + Type: "Conv3DBackpropInput", Input: []tf.Input{ input, filter, out_backprop, }, @@ -9882,24 +9098,6 @@ func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_b return op.Output(0) } -// Computes the gradient for the rsqrt of `x` wrt its input. -// -// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` -// is the corresponding input gradient. -func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RsqrtGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. type DepthwiseConv2dNativeAttr func(optionalAttr) @@ -9977,78 +9175,6 @@ func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, stri return op.Output(0) } -// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. -type MaxPoolGradV2Attr func(optionalAttr) - -// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients w.r.t. the output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients w.r.t. the input to `max_pool`. -func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradV2", - Input: []tf.Input{ - orig_input, orig_output, grad, ksize, strides, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Restore a reader to a previously saved state. -// -// Not all Readers support being restored, so this can produce an -// Unimplemented error. -// -// Arguments: -// reader_handle: Handle to a Reader. -// state: Result of a ReaderSerializeState of a Reader with type -// matching reader_handle. -// -// Returns the created operation. -func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderRestoreStateV2", - Input: []tf.Input{ - reader_handle, state, - }, - } - return scope.AddOperation(opspec) -} - // MaxPoolGradAttr is an optional argument to MaxPoolGrad. type MaxPoolGradAttr func(optionalAttr) @@ -10180,6 +9306,236 @@ func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Ou return op.Output(0) } +// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. +type Conv2DBackpropFilterAttr func(optionalAttr) + +// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DBackpropFilterExplicitPaddings sets the optional explicit_paddings attribute to value. +// +// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith +// dimension, the amount of padding inserted before and after the dimension is +// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If +// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// If not specified, defaults to <> +func Conv2DBackpropFilterExplicitPaddings(value []int64) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["explicit_paddings"] = value + } +} + +// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the filter. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, out_channels]` tensor. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv2DBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes Psi, the derivative of Lgamma (the log of the absolute value of +// +// `Gamma(x)`), element-wise. +func Digamma(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Digamma", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the number of work units this Reader has finished processing. +// +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderNumWorkUnitsCompletedV2(scope *Scope, reader_handle tf.Output) (units_completed tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderNumWorkUnitsCompletedV2", + Input: []tf.Input{ + reader_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv2DAttr is an optional argument to Conv2D. +type Conv2DAttr func(optionalAttr) + +// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DExplicitPaddings sets the optional explicit_paddings attribute to value. +// +// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith +// dimension, the amount of padding inserted before and after the dimension is +// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If +// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// If not specified, defaults to <> +func Conv2DExplicitPaddings(value []int64) Conv2DAttr { + return func(m optionalAttr) { + m["explicit_paddings"] = value + } +} + +// Conv2DDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func Conv2DDataFormat(value string) Conv2DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv2DDilations(value []int64) Conv2DAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 2-D convolution given 4-D `input` and `filter` tensors. +// +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]`, this op +// performs the following: +// +// 1. Flattens the filter to a 2-D matrix with shape +// `[filter_height * filter_width * in_channels, output_channels]`. +// 2. Extracts image patches from the input tensor to form a *virtual* +// tensor of shape `[batch, out_height, out_width, +// filter_height * filter_width * in_channels]`. +// 3. For each patch, right-multiplies the filter matrix and the image patch +// vector. +// +// In detail, with the default NHWC format, +// +// output[b, i, j, k] = +// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * +// filter[di, dj, q, k] +// +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// +// Arguments: +// input: A 4-D tensor. The dimension order is interpreted according to the value +// of `data_format`, see below for details. +// filter: A 4-D tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]` +// strides: 1-D tensor of length 4. The stride of the sliding window for each +// dimension of `input`. The dimension order is determined by the value of +// `data_format`, see below for details. +// padding: The type of padding algorithm to use. +// +// Returns A 4-D tensor. The dimension order is determined by the value of +// `data_format`, see below for details. +func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv2D", + Input: []tf.Input{ + input, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Fills empty rows in the input 2-D `SparseTensor` with a default value. // // The input `SparseTensor` is represented via the tuple of inputs @@ -10243,36 +9599,66 @@ func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dens return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Reduces `input` from `num_devices` using `reduction` to a single device. +// LoadTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingADAMParametersGradAccumDebug. +type LoadTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// Reduces `input` from `num_devices` using `reduction` to a single device. +// REQUIRES: value >= -1 +func LoadTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load ADAM embedding parameters with debug support. // -// The graph should be constructed so that all inputs have a valid device -// assignment, and the op itself is assigned one of these devices. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // -// input: The input to the reduction. -// data: the value of the reduction across all `num_devices` devices. -// reduction: the reduction operation to perform. -func NcclReduce(scope *Scope, input []tf.Output, reduction string) (data tf.Output) { +// Arguments: +// parameters: Value of parameters used in the ADAM optimization algorithm. +// momenta: Value of momenta used in the ADAM optimization algorithm. +// velocities: Value of velocities used in the ADAM optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the ADAM optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingADAMParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"reduction": reduction} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "NcclReduce", + Type: "LoadTPUEmbeddingADAMParametersGradAccumDebug", Input: []tf.Input{ - tf.OutputList(input), + parameters, momenta, velocities, gradient_accumulators, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// BiasAddGradAttr is an optional argument to BiasAddGrad. -type BiasAddGradAttr func(optionalAttr) +// BiasAddAttr is an optional argument to BiasAdd. +type BiasAddAttr func(optionalAttr) -// BiasAddGradDataFormat sets the optional data_format attribute to value. +// BiasAddDataFormat sets the optional data_format attribute to value. // // value: Specify the data format of the input and output data. With the // default format "NHWC", the bias tensor will be added to the last dimension @@ -10282,23 +9668,23 @@ type BiasAddGradAttr func(optionalAttr) // The tensor will be added to "in_channels", the third-to-the-last // dimension. // If not specified, defaults to "NHWC" -func BiasAddGradDataFormat(value string) BiasAddGradAttr { +func BiasAddDataFormat(value string) BiasAddAttr { return func(m optionalAttr) { m["data_format"] = value } } -// The backward operation for "BiasAdd" on the "bias" tensor. +// Adds `bias` to `value`. // -// It accumulates all the values from out_backprop into the feature dimension. -// For NHWC data format, the feature dimension is the last. For NCHW data format, -// the feature dimension is the third-to-last. +// This is a special case of `tf.add` where `bias` is restricted to be 1-D. +// Broadcasting is supported, so `value` may have any number of dimensions. // // Arguments: -// out_backprop: Any number of dimensions. +// value: Any number of dimensions. +// bias: 1-D with size the last dimension of `value`. // -// Returns 1-D with size the feature dimension of `out_backprop`. -func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { +// Returns Broadcasted sum of `value` and `bias`. +func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -10307,9 +9693,9 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt a(attrs) } opspec := tf.OpSpec{ - Type: "BiasAddGrad", + Type: "BiasAdd", Input: []tf.Input{ - out_backprop, + value, bias, }, Attrs: attrs, } @@ -10317,13 +9703,151 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt return op.Output(0) } -// Returns 0 if x == 0, and x / y otherwise, elementwise. -func Xdivy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. +type SparseReduceSumSparseAttr func(optionalAttr) + +// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a +// SparseTensor. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Xdivy", + Type: "SparseReduceSumSparse", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to LoadTPUEmbeddingStochasticGradientDescentParameters. +type LoadTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingStochasticGradientDescentParametersTableName(value string) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load SGD embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the stochastic gradient descent optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, parameters tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingStochasticGradientDescentParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingStochasticGradientDescentParameters", + Input: []tf.Input{ + parameters, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Selects the k nearest centers for each point. +// +// Rows of points are assumed to be input points. Rows of centers are assumed to be +// the list of candidate centers. For each point, the k centers that have least L2 +// distance to it are computed. +// +// Arguments: +// points: Matrix of shape (n, d). Rows are assumed to be input points. +// centers: Matrix of shape (m, d). Rows are assumed to be centers. +// k: Number of nearest centers to return for each point. If k is larger than m, then +// only m centers are returned. +// +// Returns Matrix of shape (n, min(m, k)). Each row contains the indices of the centers +// closest to the corresponding point, ordered by increasing distance.Matrix of shape (n, min(m, k)). Each row contains the squared L2 distance to the +// corresponding center in nearest_center_indices. +func NearestNeighbors(scope *Scope, points tf.Output, centers tf.Output, k tf.Output) (nearest_center_indices tf.Output, nearest_center_distances tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NearestNeighbors", + Input: []tf.Input{ + points, centers, k, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Returns x * y element-wise. +// +// *NOTE*: `Multiply` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Mul", Input: []tf.Input{ x, y, }, @@ -10530,298 +10054,28 @@ func KmeansPlusPlusInitialization(scope *Scope, points tf.Output, num_to_sample return op.Output(0) } -// Shuffle dimensions of x according to a permutation. +// Transforms a Tensor into a serialized TensorProto proto. // -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { +// Arguments: +// tensor: A Tensor of type `T`. +// +// Returns A serialized TensorProto proto of the input tensor. +func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Transpose", + Type: "SerializeTensor", Input: []tf.Input{ - x, perm, + tensor, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MinAttr is an optional argument to Min. -type MinAttr func(optionalAttr) - -// MinKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MinKeepDims(value bool) MinAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the minimum of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Min", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the Bessel i1e function of `x` element-wise. -// -// Exponentially scaled modified Bessel function of order 0 defined as -// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. -// -// This function is faster and numerically stabler than `bessel_i1(x)`. -func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BesselI1e", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MapClearAttr is an optional argument to MapClear. -type MapClearAttr func(optionalAttr) - -// MapClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapClearCapacity(value int64) MapClearAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapClearMemoryLimit(value int64) MapClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapClearContainer(value string) MapClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapClearSharedName(value string) MapClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. -// -// Returns the created operation. -func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapClear", - - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// DecodeCSVAttr is an optional argument to DecodeCSV. -type DecodeCSVAttr func(optionalAttr) - -// DecodeCSVFieldDelim sets the optional field_delim attribute to value. -// -// value: char delimiter to separate fields in a record. -// If not specified, defaults to "," -func DecodeCSVFieldDelim(value string) DecodeCSVAttr { - return func(m optionalAttr) { - m["field_delim"] = value - } -} - -// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. -// -// value: If false, treats double quotation marks as regular -// characters inside of the string fields (ignoring RFC 4180, Section 2, -// Bullet 5). -// If not specified, defaults to true -func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { - return func(m optionalAttr) { - m["use_quote_delim"] = value - } -} - -// DecodeCSVNaValue sets the optional na_value attribute to value. -// -// value: Additional string to recognize as NA/NaN. -// If not specified, defaults to "" -func DecodeCSVNaValue(value string) DecodeCSVAttr { - return func(m optionalAttr) { - m["na_value"] = value - } -} - -// DecodeCSVSelectCols sets the optional select_cols attribute to value. -// If not specified, defaults to <> -func DecodeCSVSelectCols(value []int64) DecodeCSVAttr { - return func(m optionalAttr) { - m["select_cols"] = value - } -} - -// Convert CSV records to tensors. Each column maps to one tensor. -// -// RFC 4180 format is expected for the CSV records. -// (https://tools.ietf.org/html/rfc4180) -// Note that we allow leading and trailing spaces with int or float field. -// -// Arguments: -// records: Each string is a record/row in the csv and all records should have -// the same format. -// record_defaults: One tensor per column of the input record, with either a -// scalar default value for that column or an empty vector if the column is -// required. -// -// Returns Each tensor will have the same shape as records. -func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeCSV", - Input: []tf.Input{ - records, tf.OutputList(record_defaults), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("DecodeCSV", err) - return - } - return output -} - -// Convert JSON-encoded Example records to binary protocol buffer strings. -// -// This op translates a tensor containing Example records, encoded using -// the [standard JSON -// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), -// into a tensor containing the same records encoded as binary protocol -// buffers. The resulting tensor can then be fed to any of the other -// Example-parsing ops. -// -// Arguments: -// json_examples: Each string is a JSON object serialized according to the JSON -// mapping of the Example proto. -// -// Returns Each string is a binary Example protocol buffer corresponding -// to the respective element of `json_examples`. -func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeJSONExample", - Input: []tf.Input{ - json_examples, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Transforms a Tensor into a serialized TensorProto proto. -// -// Arguments: -// tensor: A Tensor of type `T`. -// -// Returns A serialized TensorProto proto of the input tensor. -func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SerializeTensor", - Input: []tf.Input{ - tensor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes acos of x element-wise. -func Acos(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Acos", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnbatchGradAttr is an optional argument to UnbatchGrad. -type UnbatchGradAttr func(optionalAttr) +// UnbatchGradAttr is an optional argument to UnbatchGrad. +type UnbatchGradAttr func(optionalAttr) // UnbatchGradContainer sets the optional container attribute to value. // If not specified, defaults to "" @@ -11086,53 +10340,100 @@ func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values } -// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. -type QuantizeAndDequantizeAttr func(optionalAttr) +// SparseToDenseAttr is an optional argument to SparseToDense. +type SparseToDenseAttr func(optionalAttr) -// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. +// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. +// +// value: If true, indices are checked to make sure they are sorted in +// lexicographic order and that there are no repeats. // If not specified, defaults to true -func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { +func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { return func(m optionalAttr) { - m["signed_input"] = value + m["validate_indices"] = value } } -// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["num_bits"] = value +// Converts a sparse representation into a dense tensor. +// +// Builds an array `dense` with shape `output_shape` such that +// +// ``` +// # If sparse_indices is scalar +// dense[i] = (i == sparse_indices ? sparse_values : default_value) +// +// # If sparse_indices is a vector, then for each i +// dense[sparse_indices[i]] = sparse_values[i] +// +// # If sparse_indices is an n by d matrix, then for each i in [0, n) +// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +// ``` +// +// All other values in `dense` are set to `default_value`. If `sparse_values` is a +// scalar, all sparse indices are set to this single value. +// +// Indices should be sorted in lexicographic order, and indices must not +// contain any repeats. If `validate_indices` is true, these properties +// are checked during execution. +// +// Arguments: +// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete +// index where `sparse_values[i]` will be placed. +// output_shape: 1-D. Shape of the dense output tensor. +// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, +// or a scalar value to be used for all sparse indices. +// default_value: Scalar value to set for indices not specified in +// `sparse_indices`. +// +// Returns Dense output tensor of shape `output_shape`. +func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { + if scope.Err() != nil { + return } -} - -// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to false -func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["range_given"] = value + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } -} - -// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["input_min"] = value + opspec := tf.OpSpec{ + Type: "SparseToDense", + Input: []tf.Input{ + sparse_indices, output_shape, sparse_values, default_value, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { +// PreventGradientAttr is an optional argument to PreventGradient. +type PreventGradientAttr func(optionalAttr) + +// PreventGradientMessage sets the optional message attribute to value. +// +// value: Will be printed in the error when anyone tries to differentiate +// this operation. +// If not specified, defaults to "" +func PreventGradientMessage(value string) PreventGradientAttr { return func(m optionalAttr) { - m["input_max"] = value + m["message"] = value } } -// Use QuantizeAndDequantizeV2 instead. +// An identity op that triggers an error if a gradient is requested. // -// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 -func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { +// When executed in a graph, this op outputs its input tensor as-is. +// +// When building ops to compute gradients, the TensorFlow gradient system +// will return an error when trying to lookup the gradient of this op, +// because no gradient must ever be registered for this function. This +// op exists to prevent subtle bugs from silently returning unimplemented +// gradients in some corner cases. +// +// Arguments: +// input: any tensor. +// +// Returns the same input tensor. +func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -11141,7 +10442,7 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantize", + Type: "PreventGradient", Input: []tf.Input{ input, }, @@ -11151,498 +10452,425 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn return op.Output(0) } -// Returns locations of nonzero / true values in a tensor. +// Computes asin of x element-wise. +func Asin(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Asin", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sum along sparse segments of a tensor. // -// This operation returns the coordinates of true elements in `condition`. The -// coordinates are returned in a 2-D tensor where the first dimension (rows) -// represents the number of true elements, and the second dimension (columns) -// represents the coordinates of the true elements. Keep in mind, the shape of -// the output tensor can vary depending on how many true values there are in -// `condition`. Indices are output in row-major order. +// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/sparse#Segmentation) +// for an explanation of segments. // // For example: // +// ```python +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// +// tf.sparse_segment_sum_with_num_segments( +// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +// # => [[0 0 0 0] +// # [0 0 0 0] +// # [0 0 0 0]] +// +// tf.sparse_segment_sum_with_num_segments(c, +// tf.constant([0, 1]), +// tf.constant([0, 2], +// num_segments=4)) +// # => [[ 1 2 3 4] +// # [ 0 0 0 0] +// # [-1 -2 -3 -4] +// # [ 0 0 0 0]] // ``` -// # 'input' tensor is [[True, False] -// # [True, False]] -// # 'input' has two true values, so output has two coordinates. -// # 'input' has rank of 2, so coordinates have two indices. -// where(input) ==> [[0, 0], -// [1, 0]] // -// # `condition` tensor is [[[True, False] -// # [True, False]] -// # [[False, True] -// # [False, True]] -// # [[False, False] -// # [False, True]]] -// # 'input' has 5 true values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] +// Arguments: // -// # `condition` tensor is [[[1.5, 0.0] -// # [-0.5, 0.0]] -// # [[0.0, 0.25] -// # [0.0, 0.75]] -// # [[0.0, 0.0] -// # [0.0, 0.01]]] -// # 'input' has 5 nonzero values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. // -// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] -// # [0.0 + 0.5j, 0.0 + 0.0j]] -// # [[0.0 + 0.0j, 0.25 + 1.5j] -// # [0.0 + 0.0j, 0.75 + 0.0j]] -// # [[0.0 + 0.0j, 0.0 + 0.0j] -// # [0.0 + 0.0j, 0.01 + 0.0j]]] -// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// ``` -func Where(scope *Scope, condition tf.Output) (index tf.Output) { +// Returns Has same shape as data, except for dimension 0 which +// has size `num_segments`. +func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Where", + Type: "SparseSegmentSumWithNumSegments", Input: []tf.Input{ - condition, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QueueDequeueV2Attr is an optional argument to QueueDequeueV2. -type QueueDequeueV2Attr func(optionalAttr) +// SparseReduceMaxAttr is an optional argument to SparseReduceMax. +type SparseReduceMaxAttr func(optionalAttr) -// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value. +// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. // -// value: If the queue is empty, this operation will block for up to -// timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["keep_dims"] = value } } -// Dequeues a tuple of one or more tensors from the given queue. +// Computes the max of elements across dimensions of a SparseTensor. // -// This operation has k outputs, where k is the number of components -// in the tuples stored in the given queue, and output i is the ith -// component of the dequeued tuple. +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. // -// N.B. If the queue is empty, this operation will block until an element -// has been dequeued (or 'timeout_ms' elapses, if specified). +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. // -// Arguments: -// handle: The handle to a queue. -// component_types: The type of each component in a tuple. +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) { +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueDequeueV2", + Type: "SparseReduceMax", Input: []tf.Input{ - handle, + input_indices, input_values, input_shape, reduction_axes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueV2", err) - return - } - return components + return op.Output(0) } -// ParseSequenceExampleAttr is an optional argument to ParseSequenceExample. -type ParseSequenceExampleAttr func(optionalAttr) +// DecodeRawAttr is an optional argument to DecodeRaw. +type DecodeRawAttr func(optionalAttr) -// ParseSequenceExampleNcontextSparse sets the optional Ncontext_sparse attribute to value. -// If not specified, defaults to 0 +// DecodeRawLittleEndian sets the optional little_endian attribute to value. // -// REQUIRES: value >= 0 -func ParseSequenceExampleNcontextSparse(value int64) ParseSequenceExampleAttr { +// value: Whether the input `bytes` are in little-endian order. +// Ignored for `out_type` values that are stored in a single byte like +// `uint8`. +// If not specified, defaults to true +func DecodeRawLittleEndian(value bool) DecodeRawAttr { return func(m optionalAttr) { - m["Ncontext_sparse"] = value + m["little_endian"] = value } } -// ParseSequenceExampleNcontextDense sets the optional Ncontext_dense attribute to value. -// If not specified, defaults to 0 +// Reinterpret the bytes of a string as a vector of numbers. // -// REQUIRES: value >= 0 -func ParseSequenceExampleNcontextDense(value int64) ParseSequenceExampleAttr { - return func(m optionalAttr) { - m["Ncontext_dense"] = value +// Arguments: +// bytes: All the elements must have the same length. +// +// +// Returns A Tensor with one more dimension than the input `bytes`. The +// added dimension will have size equal to the length of the elements +// of `bytes` divided by the number of bytes to represent `out_type`. +func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"out_type": out_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeRaw", + Input: []tf.Input{ + bytes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// ParseSequenceExampleNfeatureListSparse sets the optional Nfeature_list_sparse attribute to value. -// If not specified, defaults to 0 +// RetrieveTPUEmbeddingADAMParametersAttr is an optional argument to RetrieveTPUEmbeddingADAMParameters. +type RetrieveTPUEmbeddingADAMParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// REQUIRES: value >= 0 -func ParseSequenceExampleNfeatureListSparse(value int64) ParseSequenceExampleAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingADAMParametersTableId(value int64) RetrieveTPUEmbeddingADAMParametersAttr { return func(m optionalAttr) { - m["Nfeature_list_sparse"] = value + m["table_id"] = value } } -// ParseSequenceExampleNfeatureListDense sets the optional Nfeature_list_dense attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func ParseSequenceExampleNfeatureListDense(value int64) ParseSequenceExampleAttr { +// RetrieveTPUEmbeddingADAMParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersTableName(value string) RetrieveTPUEmbeddingADAMParametersAttr { return func(m optionalAttr) { - m["Nfeature_list_dense"] = value + m["table_name"] = value } } -// ParseSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. +// Retrieve ADAM embedding parameters. // -// value: A list of Ncontext_sparse types; the data types of data in -// each context Feature given in context_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// REQUIRES: len(value) >= 0 -func ParseSequenceExampleContextSparseTypes(value []tf.DataType) ParseSequenceExampleAttr { - return func(m optionalAttr) { - m["context_sparse_types"] = value +// Returns Parameter parameters updated by the ADAM optimization algorithm.Parameter momenta updated by the ADAM optimization algorithm.Parameter velocities updated by the ADAM optimization algorithm. +func RetrieveTPUEmbeddingADAMParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output) { + if scope.Err() != nil { + return } -} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingADAMParameters", -// ParseSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_dense_types"] = value + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// ParseSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. -// -// value: A list of Ncontext_dense shapes; the shapes of data in -// each context Feature given in context_dense_keys. -// The number of elements in the Feature corresponding to context_dense_key[j] -// must always equal context_dense_shapes[j].NumEntries(). -// The shape of context_dense_values[j] will match context_dense_shapes[j]. -// If not specified, defaults to <> +// FusedBatchNormAttr is an optional argument to FusedBatchNorm. +type FusedBatchNormAttr func(optionalAttr) + +// FusedBatchNormEpsilon sets the optional epsilon attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSequenceExampleContextDenseShapes(value []tf.Shape) ParseSequenceExampleAttr { +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { return func(m optionalAttr) { - m["context_dense_shapes"] = value + m["epsilon"] = value } } -// ParseSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. -// -// value: A list of Nfeature_list_sparse types; the data types -// of data in each FeatureList given in feature_list_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> +// FusedBatchNormDataFormat sets the optional data_format attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSequenceExampleAttr { +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { return func(m optionalAttr) { - m["feature_list_sparse_types"] = value + m["data_format"] = value } } -// ParseSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. -// -// value: A list of Nfeature_list_dense shapes; the shapes of -// data in each FeatureList given in feature_list_dense_keys. -// The shape of each Feature in the FeatureList corresponding to -// feature_list_dense_key[j] must always equal -// feature_list_dense_shapes[j].NumEntries(). -// If not specified, defaults to <> +// FusedBatchNormIsTraining sets the optional is_training attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSequenceExampleAttr { +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { return func(m optionalAttr) { - m["feature_list_dense_shapes"] = value + m["is_training"] = value } } -// Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors. +// Batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. // // Arguments: -// serialized: A vector containing binary serialized SequenceExample protos. -// debug_name: A vector containing the names of the serialized protos. -// May contain, for example, table key (descriptive) name for the -// corresponding serialized proto. This is purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty vector if no name is available. -// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). -// context_dense_defaults[j] provides default values -// when the SequenceExample's context map lacks context_dense_key[j]. -// If an empty Tensor is provided for context_dense_defaults[j], -// then the Feature context_dense_keys[j] is required. -// The input type is inferred from context_dense_defaults[j], even when it's -// empty. If context_dense_defaults[j] is not empty, its shape must match -// context_dense_shapes[j]. -// feature_list_dense_missing_assumed_empty: A vector listing the -// FeatureList keys which may be missing from the SequenceExamples. If the -// associated FeatureList is missing, it is treated as empty. By default, -// any FeatureList not listed in this vector must exist in the SequenceExamples. -// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). -// The keys expected in the Examples' features associated with context_sparse -// values. -// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' context features associated with -// dense values. -// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors -// (scalars). The keys expected in the FeatureLists associated with sparse -// values. -// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' feature_lists associated -// with lists of dense values. -func ParseSequenceExample(scope *Scope, serialized tf.Output, debug_name tf.Output, context_dense_defaults []tf.Output, feature_list_dense_missing_assumed_empty []string, context_sparse_keys []string, context_dense_keys []string, feature_list_sparse_keys []string, feature_list_dense_keys []string, optional ...ParseSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output, feature_list_dense_lengths []tf.Output) { +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. +// +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"feature_list_dense_missing_assumed_empty": feature_list_dense_missing_assumed_empty, "context_sparse_keys": context_sparse_keys, "context_dense_keys": context_dense_keys, "feature_list_sparse_keys": feature_list_sparse_keys, "feature_list_dense_keys": feature_list_dense_keys} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ParseSequenceExample", + Type: "FusedBatchNorm", Input: []tf.Input{ - serialized, debug_name, tf.OutputList(context_dense_defaults), + x, scale, offset, mean, variance, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - if feature_list_dense_lengths, idx, err = makeOutputList(op, idx, "feature_list_dense_lengths"); err != nil { - scope.UpdateErr("ParseSequenceExample", err) - return - } - return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Computes the Gauss error function of `x` element-wise. -func Erf(scope *Scope, x tf.Output) (y tf.Output) { +// Creates a dataset that shards the input dataset. +// +// Creates a dataset that shards the input dataset by num_workers, returning a +// sharded dataset for the index-th worker. This attempts to automatically shard +// a dataset by examining the Dataset graph and inserting a shard op before the +// inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset). +// +// This dataset will throw a NotFound error if we cannot shard the dataset +// automatically. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// num_workers: A scalar representing the number of workers to distribute this dataset across. +// index: A scalar representing the index of the current worker out of num_workers. +// +// +func ExperimentalAutoShardDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Erf", + Type: "ExperimentalAutoShardDataset", Input: []tf.Input{ - x, + input_dataset, num_workers, index, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns element-wise largest integer not greater than x. -func Floor(scope *Scope, x tf.Output) (y tf.Output) { +// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. +type RandomStandardNormalAttr func(optionalAttr) + +// RandomStandardNormalSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with random normal values. +func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Floor", + Type: "RandomStandardNormal", Input: []tf.Input{ - x, + shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// OneHotAttr is an optional argument to OneHot. -type OneHotAttr func(optionalAttr) +// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. +type FusedResizeAndPadConv2DAttr func(optionalAttr) -// OneHotAxis sets the optional axis attribute to value. +// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. // -// value: The axis to fill (default: -1, a new inner-most axis). -// If not specified, defaults to -1 -func OneHotAxis(value int64) OneHotAttr { +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { return func(m optionalAttr) { - m["axis"] = value + m["resize_align_corners"] = value } } -// Returns a one-hot tensor. -// -// The locations represented by indices in `indices` take value `on_value`, -// while all other locations take value `off_value`. -// -// If the input `indices` is rank `N`, the output will have rank `N+1`, -// The new axis is created at dimension `axis` (default: the new axis is -// appended at the end). -// -// If `indices` is a scalar the output shape will be a vector of length `depth`. -// -// If `indices` is a vector of length `features`, the output shape will be: -// ``` -// features x depth if axis == -1 -// depth x features if axis == 0 -// ``` -// -// If `indices` is a matrix (batch) with shape `[batch, features]`, -// the output shape will be: -// ``` -// batch x features x depth if axis == -1 -// batch x depth x features if axis == 1 -// depth x batch x features if axis == 0 -// ``` -// -// -// Examples -// ========= -// -// Suppose that -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 5.0 -// off_value = 0.0 -// axis = -1 -// ``` -// -// Then output is `[4 x 3]`: -// ``` -// output = -// [5.0 0.0 0.0] // one_hot(0) -// [0.0 0.0 5.0] // one_hot(2) -// [0.0 0.0 0.0] // one_hot(-1) -// [0.0 5.0 0.0] // one_hot(1) -// ``` -// -// Suppose that -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 0.0 -// off_value = 3.0 -// axis = 0 -// ``` -// -// Then output is `[3 x 4]`: -// ``` -// output = -// [0.0 3.0 3.0 3.0] -// [3.0 3.0 3.0 0.0] -// [3.0 3.0 3.0 3.0] -// [3.0 0.0 3.0 3.0] -// // ^ one_hot(0) -// // ^ one_hot(2) -// // ^ one_hot(-1) -// // ^ one_hot(1) -// ``` -// -// Suppose that -// ``` -// indices = [[0, 2], [1, -1]] -// depth = 3 -// on_value = 1.0 -// off_value = 0.0 -// axis = -1 -// ``` +// Performs a resize and padding as a preprocess during a convolution. // -// Then output is `[2 x 2 x 3]`: -// ``` -// output = -// [ -// [1.0, 0.0, 0.0] // one_hot(0) -// [0.0, 0.0, 1.0] // one_hot(2) -// ][ -// [0.0, 1.0, 0.0] // one_hot(1) -// [0.0, 0.0, 0.0] // one_hot(-1) -// ] -// ``` +// It's often possible to do spatial transformations more efficiently as part of +// the packing stage of a convolution, so this op allows for an optimized +// implementation where these stages are fused together. This prevents the need to +// write out the intermediate results as whole tensors, reducing memory pressure, +// and we can get some latency gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and defaults to +// 'NHWC' order. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. // // Arguments: -// indices: A tensor of indices. -// depth: A scalar defining the depth of the one hot dimension. -// on_value: A scalar defining the value to fill in output when `indices[j] = i`. -// off_value: A scalar defining the value to fill in output when `indices[j] != i`. +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. // -// Returns The one-hot tensor. -func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OneHot", + Type: "FusedResizeAndPadConv2D", Input: []tf.Input{ - indices, depth, on_value, off_value, + input, size, paddings, filter, }, Attrs: attrs, } @@ -11650,98 +10878,5594 @@ func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output return op.Output(0) } -// CudnnRNNAttr is an optional argument to CudnnRNN. -type CudnnRNNAttr func(optionalAttr) - -// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNRnnMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNInputMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNDirection(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} +// RandomUniformAttr is an optional argument to RandomUniform. +type RandomUniformAttr func(optionalAttr) -// CudnnRNNDropout sets the optional dropout attribute to value. +// RandomUniformSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. // If not specified, defaults to 0 -func CudnnRNNDropout(value float32) CudnnRNNAttr { +func RandomUniformSeed(value int64) RandomUniformAttr { return func(m optionalAttr) { - m["dropout"] = value + m["seed"] = value } } -// CudnnRNNSeed sets the optional seed attribute to value. +// RandomUniformSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. // If not specified, defaults to 0 -func CudnnRNNSeed(value int64) CudnnRNNAttr { +func RandomUniformSeed2(value int64) RandomUniformAttr { return func(m optionalAttr) { - m["seed"] = value + m["seed2"] = value } } -// CudnnRNNSeed2 sets the optional seed2 attribute to value. +// Outputs random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with uniform random values. +func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniform", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. +type ResourceApplyFtrlAttr func(optionalAttr) + +// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Ftrl-proximal scheme. +// +// accum_new = accum + grad * grad +// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 regulariation. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyFtrl", + Input: []tf.Input{ + var_, accum, linear, grad, lr, l1, l2, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Transforms a vector of brain.Example protos (as strings) into typed tensors. +// +// Arguments: +// serialized: A vector containing a batch of binary serialized Example protos. +// names: A vector containing the names of the serialized protos. +// May contain, for example, table key (descriptive) names for the +// corresponding serialized protos. These are purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty vector if no names are available. +// If non-empty, this vector must be the same length as "serialized". +// sparse_keys: A list of Nsparse string Tensors (scalars). +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: A list of Ndense string Tensors (scalars). +// The keys expected in the Examples' features associated with dense values. +// dense_defaults: A list of Ndense Tensors (some may be empty). +// dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// sparse_types: A list of Nsparse types; the data types of data in each Feature +// given in sparse_keys. +// Currently the ParseExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature +// given in dense_keys. +// The number of elements in the Feature corresponding to dense_key[j] +// must always equal dense_shapes[j].NumEntries(). +// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output +// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): +// The dense outputs are just the inputs row-stacked by batch. +// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case +// the shape of the output Tensor dense_values[j] will be +// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks +// of elements of length D1 * .... * DN, across all minibatch entries +// in the input. Any minibatch entry with less than M blocks of elements of +// length D1 * ... * DN will be padded with the corresponding default_value +// scalar element along the second dimension. +func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} + opspec := tf.OpSpec{ + Type: "ParseExample", + Input: []tf.Input{ + serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values +} + +// Compute the pairwise cross product. +// +// `a` and `b` must be the same shape; they can either be simple 3-element vectors, +// or any shape where the innermost dimension is 3. In the latter case, each pair +// of corresponding 3-element vectors is cross-multiplied independently. +// +// Arguments: +// a: A tensor containing 3-element vectors. +// b: Another tensor, of same type and shape as `a`. +// +// Returns Pairwise cross product of the vectors in `a` and `b`. +func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cross", + Input: []tf.Input{ + a, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LeakyReluAttr is an optional argument to LeakyRelu. +type LeakyReluAttr func(optionalAttr) + +// LeakyReluAlpha sets the optional alpha attribute to value. +// If not specified, defaults to 0.2 +func LeakyReluAlpha(value float32) LeakyReluAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// Computes rectified linear: `max(features, features * alpha)`. +func LeakyRelu(scope *Scope, features tf.Output, optional ...LeakyReluAttr) (activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LeakyRelu", + Input: []tf.Input{ + features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs random integers from a uniform distribution. +// +// The generated values are uniform integers in the range `[minval, maxval)`. +// The lower bound `minval` is included in the range, while the upper bound +// `maxval` is excluded. +// +// The random integers are slightly biased unless `maxval - minval` is an exact +// power of two. The bias is small for values of `maxval - minval` significantly +// smaller than the range of the output (either `2^32` or `2^64`). +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// minval: Minimum value (inclusive, scalar). +// maxval: Maximum value (exclusive, scalar). +// +// Returns Random values with specified shape. +func StatefulUniformInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatefulUniformInt", + Input: []tf.Input{ + resource, algorithm, shape, minval, maxval, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. +type DecodeAndCropJpegAttr func(optionalAttr) + +// DecodeAndCropJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. // If not specified, defaults to 0 -func CudnnRNNSeed2(value int64) CudnnRNNAttr { +func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["seed2"] = value + m["channels"] = value } } -// CudnnRNNIsTraining sets the optional is_training attribute to value. +// DecodeAndCropJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). // If not specified, defaults to true -func CudnnRNNIsTraining(value bool) CudnnRNNAttr { +func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["is_training"] = value + m["fancy_upscaling"] = value } } -// A RNN backed by cuDNN. +// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. // -// Computes the RNN from the input and initial states, with respect to the params -// buffer. +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. // -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// the actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. Should be -// "unidirectional" or "bidirectional". -// dropout: Dropout probability. When set to 0., dropout is disabled. -// seed: The 1st part of a seed to initialize dropout. -// seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: A 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: The same shape has input_h. -// output_c: The same shape as input_c for LSTM. An empty tensor for other models. -// is_training: Indicates whether this operation is used for inferenece or -// training. -// reserve_space: An opaque tensor that can be used in backprop calculation. It -// is only produced if is_training is false. -func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode and Crop a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// It is equivalent to a combination of decode and crop, but much faster by only +// decoding partial jpeg image. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeAndCropJpeg", + Input: []tf.Input{ + contents, crop_window, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatefulStandardNormalV2Attr is an optional argument to StatefulStandardNormalV2. +type StatefulStandardNormalV2Attr func(optionalAttr) + +// StatefulStandardNormalV2Dtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatefulStandardNormalV2Dtype(value tf.DataType) StatefulStandardNormalV2Attr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// +// Returns A tensor of the specified shape filled with random normal values. +func StatefulStandardNormalV2(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulStandardNormalV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulStandardNormalV2", + Input: []tf.Input{ + resource, algorithm, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatefulUniformFullIntAttr is an optional argument to StatefulUniformFullInt. +type StatefulUniformFullIntAttr func(optionalAttr) + +// StatefulUniformFullIntDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_UINT64 +func StatefulUniformFullIntDtype(value tf.DataType) StatefulUniformFullIntAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random integers from a uniform distribution. +// +// The generated values are uniform integers covering the whole range of `dtype`. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// +// Returns Random values with specified shape. +func StatefulUniformFullInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulUniformFullIntAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulUniformFullInt", + Input: []tf.Input{ + resource, algorithm, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Locks a mutex resource. The output is the lock. So long as the lock tensor +// +// is alive, any other request to use `MutexLock` with this mutex will wait. +// +// This is particularly useful for creating a critical section when used in +// conjunction with `MutexLockIdentity`: +// +// ```python +// +// mutex = mutex_v2( +// shared_name=handle_name, container=container, name=name) +// +// def execute_in_critical_section(fn, *args, **kwargs): +// lock = gen_resource_variable_ops.mutex_lock(mutex) +// +// with ops.control_dependencies([lock]): +// r = fn(*args, **kwargs) +// +// with ops.control_dependencies(nest.flatten(r)): +// with ops.colocate_with(mutex): +// ensure_lock_exists = mutex_lock_identity(lock) +// +// # Make sure that if any element of r is accessed, all of +// # them are executed together. +// r = nest.map_structure(tf.identity, r) +// +// with ops.control_dependencies([ensure_lock_exists]): +// return nest.map_structure(tf.identity, r) +// ``` +// +// While `fn` is running in the critical section, no other functions which wish to +// use this critical section may run. +// +// Often the use case is that two executions of the same graph, in parallel, +// wish to run `fn`; and we wish to ensure that only one of them executes +// at a time. This is especially important if `fn` modifies one or more +// variables at a time. +// +// It is also useful if two separate functions must share a resource, but we +// wish to ensure the usage is exclusive. +// +// Arguments: +// mutex: The mutex resource to lock. +// +// Returns A tensor that keeps a shared pointer to a lock on the mutex; +// when the Tensor is destroyed, the use count on the shared pointer is decreased +// by 1. When it reaches 0, the lock is released. +func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MutexLock", + Input: []tf.Input{ + mutex, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// +// Arguments: +// serialized: A scalar string containing a serialized TensorProto proto. +// out_type: The type of the serialized tensor. The provided type must match the +// type of the serialized tensor and no implicit conversion will take place. +// +// Returns A Tensor of type `out_type`. +func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "ParseTensor", + Input: []tf.Input{ + serialized, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. +type MaxPoolWithArgmaxAttr func(optionalAttr) + +// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. +// If not specified, defaults to DT_INT64 +func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { + return func(m optionalAttr) { + m["Targmax"] = value + } +} + +// MaxPoolWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolWithArgmaxIncludeBatchInIndex(value bool) MaxPoolWithArgmaxAttr { + return func(m optionalAttr) { + m["include_batch_in_index"] = value + } +} + +// Performs max pooling on the input and outputs both max values and indices. +// +// The indices in `argmax` are flattened, so that a maximum value at position +// `[b, y, x, c]` becomes flattened index: +// `(y * width + x) * channels + c` if `include_batch_in_index` is False; +// `((b * height + y) * width + x) * channels + c` if `include_batch_in_index` is True. +// +// The indices returned are always in `[0, height) x [0, width)` before flattening, +// even if padding is involved and the mathematically correct answer is outside +// (either negative or too large). This is a bug, but fixing it is difficult to do +// in a safe backwards compatible way, especially due to flattening. +// +// Arguments: +// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. +func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolWithArgmax", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// ModelDatasetAttr is an optional argument to ModelDataset. +type ModelDatasetAttr func(optionalAttr) + +// ModelDatasetCpuBudget sets the optional cpu_budget attribute to value. +// If not specified, defaults to 0 +func ModelDatasetCpuBudget(value int64) ModelDatasetAttr { + return func(m optionalAttr) { + m["cpu_budget"] = value + } +} + +// Identity transformation that models performance. +// +// Identity transformation that models performance. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// +// +func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ModelDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ModelDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Fast Fourier transform. +// +// Computes the 1-dimensional discrete Fourier transform over the inner-most +// dimension of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft +// @end_compatibility +func FFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) + +// MaxPoolDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs max pooling on the input. +// +// Arguments: +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Multiplies sparse updates into the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] *= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] *= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions multiply. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterMul", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// Subtracts sparse updates from the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] -= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] -= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions add. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterSub", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// Adds sparse updates to the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] += updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] += updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions add. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterAdd", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// Reads the value of a variable. +// +// The tensor returned by this operation is immutable. +// +// The value returned by this operation is guaranteed to be influenced by all the +// writes on which this operation depends directly or indirectly, and to not be +// influenced by any of the writes which depend directly or indirectly on this +// operation. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// dtype: the dtype of the value. +func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "ReadVariableOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. +type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) + +// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// prox_v = var +// prox_v -= lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyProximalAdagrad", + Input: []tf.Input{ + var_, accum, lr, l1, l2, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// DecodeJpegAttr is an optional argument to DecodeJpeg. +type DecodeJpegAttr func(optionalAttr) + +// DecodeJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeJpegChannels(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodeJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeJpegRatio(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeJpegDctMethod(value string) DecodeJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// This op also supports decoding PNGs and non-animated GIFs since the interface is +// the same, though it is cleaner to use `tf.image.decode_image`. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeJpeg", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput. +type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr) + +// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of depthwise convolution with respect to the input. +// +// Arguments: +// input_sizes: An integer vector representing the shape of `input`, based +// on `data_format`. For example, if `data_format` is 'NHWC' then +// `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, depthwise_multiplier]`. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape according to `data_format`. For example, if +// `data_format` is 'NHWC', output shape is `[batch, in_height, +// in_width, in_channels]`. Gradient w.r.t. the input of the +// convolution. +func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNativeBackpropInput", + Input: []tf.Input{ + input_sizes, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EditDistanceAttr is an optional argument to EditDistance. +type EditDistanceAttr func(optionalAttr) + +// EditDistanceNormalize sets the optional normalize attribute to value. +// +// value: boolean (if true, edit distances are normalized by length of truth). +// +// The output is: +// If not specified, defaults to true +func EditDistanceNormalize(value bool) EditDistanceAttr { + return func(m optionalAttr) { + m["normalize"] = value + } +} + +// Computes the (possibly normalized) Levenshtein Edit Distance. +// +// The inputs are variable-length sequences provided by SparseTensors +// (hypothesis_indices, hypothesis_values, hypothesis_shape) +// and +// (truth_indices, truth_values, truth_shape). +// +// The inputs are: +// +// Arguments: +// hypothesis_indices: The indices of the hypothesis list SparseTensor. +// This is an N x R int64 matrix. +// hypothesis_values: The values of the hypothesis list SparseTensor. +// This is an N-length vector. +// hypothesis_shape: The shape of the hypothesis list SparseTensor. +// This is an R-length vector. +// truth_indices: The indices of the truth list SparseTensor. +// This is an M x R int64 matrix. +// truth_values: The values of the truth list SparseTensor. +// This is an M-length vector. +// truth_shape: truth indices, vector. +// +// Returns A dense float tensor with rank R - 1. +// +// For the example input: +// +// // hypothesis represents a 2x1 matrix with variable-length values: +// // (0,0) = ["a"] +// // (1,0) = ["b"] +// hypothesis_indices = [[0, 0, 0], +// [1, 0, 0]] +// hypothesis_values = ["a", "b"] +// hypothesis_shape = [2, 1, 1] +// +// // truth represents a 2x2 matrix with variable-length values: +// // (0,0) = [] +// // (0,1) = ["a"] +// // (1,0) = ["b", "c"] +// // (1,1) = ["a"] +// truth_indices = [[0, 1, 0], +// [1, 0, 0], +// [1, 0, 1], +// [1, 1, 0]] +// truth_values = ["a", "b", "c", "a"] +// truth_shape = [2, 2, 2] +// normalize = true +// +// The output will be: +// +// // output is a 2x2 matrix with edit distances normalized by truth lengths. +// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis +// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis +func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EditDistance", + Input: []tf.Input{ + hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns 0 if x == 0, and x * log(y) otherwise, elementwise. +func Xlogy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Xlogy", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Stops gradient computation. +// +// When executed in a graph, this op outputs its input tensor as-is. +// +// When building ops to compute gradients, this op prevents the contribution of +// its inputs to be taken into account. Normally, the gradient generator adds ops +// to a graph to compute the derivatives of a specified 'loss' by recursively +// finding out inputs that contributed to its computation. If you insert this op +// in the graph it inputs are masked from the gradient generator. They are not +// taken into account for computing gradients. +// +// This is useful any time you want to compute a value with TensorFlow but need +// to pretend that the value was a constant. Some examples include: +// +// * The *EM* algorithm where the *M-step* should not involve backpropagation +// through the output of the *E-step*. +// * Contrastive divergence training of Boltzmann machines where, when +// differentiating the energy function, the training must not backpropagate +// through the graph that generated the samples from the model. +// * Adversarial training, where no backprop should happen through the adversarial +// example generation process. +func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StopGradient", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Eagerly executes a python function to compute func(input)->output. The +// +// semantics of the input, output, and attributes are the same as those for +// PyFunc. +func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"token": token, "Tout": Tout} + opspec := tf.OpSpec{ + Type: "EagerPyFunc", + Input: []tf.Input{ + tf.OutputList(input), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("EagerPyFunc", err) + return + } + return output +} + +// Concats all tensors in the list along the 0th dimension. +// +// Requires that all tensors have the same shape except the first dimension. +// +// input_handle: The input list. +// element_shape: The shape of the uninitialized elements in the list. If the first +// dimension is not -1, it is assumed that all list elements have the same +// leading dim. +// leading_dims: The list of leading dims of uninitialized list elements. Used if +// the leading dim of input_handle.element_shape or the element_shape input arg +// is not already set. +// tensor: The concated result. +// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. +// +func TensorListConcatV2(scope *Scope, input_handle tf.Output, element_shape tf.Output, leading_dims tf.Output, element_dtype tf.DataType) (tensor tf.Output, lengths tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListConcatV2", + Input: []tf.Input{ + input_handle, element_shape, leading_dims, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve. +type MatrixTriangularSolveAttr func(optionalAttr) + +// MatrixTriangularSolveLower sets the optional lower attribute to value. +// +// value: Boolean indicating whether the innermost matrices in `matrix` are +// lower or upper triangular. +// If not specified, defaults to true +func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { + return func(m optionalAttr) { + m["lower"] = value + } +} + +// MatrixTriangularSolveAdjoint sets the optional adjoint attribute to value. +// +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. +// +// @compatibility(numpy) +// Equivalent to scipy.linalg.solve_triangular +// @end_compatibility +// If not specified, defaults to false +func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { + return func(m optionalAttr) { + m["adjoint"] = value + } +} + +// Solves systems of linear equations with upper or lower triangular matrices by +// +// backsubstitution. +// +// `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +// square matrices. If `lower` is `True` then the strictly upper triangular part +// of each inner-most matrix is assumed to be zero and not accessed. +// If `lower` is False then the strictly lower triangular part of each inner-most +// matrix is assumed to be zero and not accessed. +// `rhs` is a tensor of shape `[..., M, K]`. +// +// The output is a tensor of shape `[..., M, K]`. If `adjoint` is +// `True` then the innermost matrices in `output` satisfy matrix equations +// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `False` then the strictly then the innermost matrices in +// `output` satisfy matrix equations +// `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. +// +// Arguments: +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. +// +// Returns Shape is `[..., M, K]`. +func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixTriangularSolveAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixTriangularSolve", + Input: []tf.Input{ + matrix, rhs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Saves tensors in V2 checkpoint format. +// +// By default, saves the named tensors in full. If the caller wishes to save +// specific slices of full tensors, "shape_and_slices" should be non-empty strings +// and correspondingly well-formed. +// +// Arguments: +// prefix: Must have a single element. The prefix of the V2 checkpoint to which we +// write the tensors. +// tensor_names: shape {N}. The names of the tensors to be saved. +// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. +// Empty strings indicate that they are non-partitioned tensors. +// tensors: `N` tensors to save. +// +// Returns the created operation. +func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SaveV2", + Input: []tf.Input{ + prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), + }, + } + return scope.AddOperation(opspec) +} + +// Concatenates quantized tensors along one dimension. +// +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// input_mins: The minimum scalar values for each of the input tensors. +// input_maxes: The maximum scalar values for each of the input tensors. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QuantizedConcat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Slice a `SparseTensor` based on the `start` and `size`. +// +// For example, if the input is +// +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] +// +// Graphically the output tensors are: +// +// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] +// [ a ] +// [b c ] +// +// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] +// [ d e ] +// [ ] +// +// Arguments: +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// start: 1-D. tensor represents the start of the slice. +// size: 1-D. tensor represents the size of the slice. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSlice", + Input: []tf.Input{ + indices, values, shape, start, size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Runs multiple additive regression ensemble predictors on input instances and +// +// computes the logits. It is designed to be used during prediction. +// It traverses all the trees and calculates the final score for each instance. +// +// Arguments: +// +// bucketized_features: A list of rank 1 Tensors containing bucket id for each +// feature. +// logits_dimension: scalar, dimension of the logits, to be used for partial logits +// shape. +// +// Returns Output rank 2 Tensor containing logits for each example. +func BoostedTreesPredict(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (logits tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} + opspec := tf.OpSpec{ + Type: "BoostedTreesPredict", + Input: []tf.Input{ + tree_ensemble_handle, tf.OutputList(bucketized_features), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Pads a tensor with zeros. +// +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` +// in that dimension. +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +// +func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Pad", + Input: []tf.Input{ + input, paddings, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Checks whether a resource handle-based variable has been initialized. +// +// Arguments: +// resource: the input resource handle. +// +// Returns a scalar boolean which is true if the variable has been +// initialized. +func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "VarIsInitializedOp", + Input: []tf.Input{ + resource, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// +// *NOTE*: `Minimum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Minimum", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` +// +// if < 0, `scale * features` otherwise. +// +// To be used together with +// `initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. +// For correct dropout, use `tf.contrib.nn.alpha_dropout`. +// +// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) +func Selu(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Selu", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SetSizeAttr is an optional argument to SetSize. +type SetSizeAttr func(optionalAttr) + +// SetSizeValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SetSizeValidateIndices(value bool) SetSizeAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Number of unique elements along last dimension of input `set`. +// +// Input `set` is a `SparseTensor` represented by `set_indices`, `set_values`, +// and `set_shape`. The last dimension contains values in a set, duplicates are +// allowed but ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set` +// indices. +// +// Arguments: +// set_indices: 2D `Tensor`, indices of a `SparseTensor`. +// set_values: 1D `Tensor`, values of a `SparseTensor`. +// set_shape: 1D `Tensor`, shape of a `SparseTensor`. +// +// Returns For `set` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st +// `n-1` dimensions as `set`. Each value is the number of unique elements in +// the corresponding `[0...n-1]` dimension of `set`. +func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shape tf.Output, optional ...SetSizeAttr) (size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SetSize", + Input: []tf.Input{ + set_indices, set_values, set_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds sparse `updates` to an existing tensor according to `indices`. +// +// This operation creates a new tensor by adding sparse `updates` to the passed +// in `tensor`. +// This operation is very similar to `tf.scatter_nd_add`, except that the updates +// are added onto an existing tensor (as opposed to a variable). If the memory +// for the existing tensor cannot be re-used, a copy is made and updated. +// +// `indices` is an integer tensor containing indices into a new tensor of shape +// `shape`. The last dimension of `indices` can be at most the rank of `shape`: +// +// indices.shape[-1] <= shape.rank +// +// The last dimension of `indices` corresponds to indices into elements +// (if `indices.shape[-1] = shape.rank`) or slices +// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +// `shape`. `updates` is a tensor with shape +// +// indices.shape[:-1] + shape[indices.shape[-1]:] +// +// The simplest form of tensor_scatter_add is to add individual elements to a +// tensor by index. For example, say we want to add 4 elements in a rank-1 +// tensor with 8 elements. +// +// In Python, this scatter add operation would look like this: +// +// ```python +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// tensor = tf.ones([8], dtype=tf.int32) +// updated = tf.tensor_scatter_add(tensor, indices, updates) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [1, 12, 1, 11, 10, 1, 1, 13] +// +// We can also, insert entire slices of a higher rank tensor all at once. For +// example, if we wanted to insert two slices in the first dimension of a +// rank-3 tensor with two matrices of new values. +// +// In Python, this scatter add operation would look like this: +// +// ```python +// indices = tf.constant([[0], [2]]) +// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]], +// [[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]]]) +// tensor = tf.ones([4, 4, 4]) +// updated = tf.tensor_scatter_add(tensor, indices, updates) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], +// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], +// [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], +// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] +// +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, the index is ignored. +// +// Arguments: +// tensor: Tensor to copy/update. +// indices: Index tensor. +// updates: Updates to scatter into output. +// +// Returns A new tensor copied from tensor and updates added according to the indices. +func TensorScatterAdd(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorScatterAdd", + Input: []tf.Input{ + tensor, indices, updates, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sign and the log of the absolute value of the determinant of +// +// one or more square matrices. +// +// The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions +// form square matrices. The outputs are two tensors containing the signs and +// absolute values of the log determinants for all N input submatrices +// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant). +// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU +// is the LU decomposition of the input and P is the corresponding +// permutation matrix. +// +// Arguments: +// input: Shape is `[N, M, M]`. +// +// Returns The signs of the log determinants of the inputs. Shape is `[N]`.The logs of the absolute values of the determinants +// of the N input matrices. Shape is `[N]`. +func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_abs_determinant tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LogMatrixDeterminant", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. +// +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed precision at `k` as a `bool Tensor`. +func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InTopKV2", + Input: []tf.Input{ + predictions, targets, k, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Check if the input matches the regex pattern. +// +// The input is a string tensor of any shape. The pattern is a scalar +// string tensor which is applied to every element of the input tensor. +// The boolean values (True or False) of the output tensor indicate +// if the input matches the regex pattern provided. +// +// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// +// Arguments: +// input: A string tensor of the text to be processed. +// pattern: A scalar string tensor containing the regular expression to match the input. +// +// Returns A bool tensor with the same shape as `input`. +func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RegexFullMatch", + Input: []tf.Input{ + input, pattern, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts a `RaggedTensor` into a `SparseTensor` with the same values. +// +// input=ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) +// output=SparseTensor(indices=sparse_indices, values=sparse_values, +// dense_shape=sparse_dense_shape) +// +// Arguments: +// rt_nested_splits: The `row_splits` for the `RaggedTensor`. +// rt_dense_values: The `flat_values` for the `RaggedTensor`. +// +// Returns The indices for the `SparseTensor`.The values of the `SparseTensor`.`sparse_dense_shape` is a tight bounding box of the input `RaggedTensor`. +func RaggedTensorToSparse(scope *Scope, rt_nested_splits []tf.Output, rt_dense_values tf.Output) (sparse_indices tf.Output, sparse_values tf.Output, sparse_dense_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RaggedTensorToSparse", + Input: []tf.Input{ + tf.OutputList(rt_nested_splits), rt_dense_values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. +type FusedBatchNormGradV2Attr func(optionalAttr) + +// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Gradient for batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. +// +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNormGradV2", + Input: []tf.Input{ + y_backprop, x, scale, reserve_space_1, reserve_space_2, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// Component-wise multiplies a SparseTensor by a dense Tensor. +// +// The output locations corresponding to the implicitly zero elements in the sparse +// tensor will be zero (i.e., will not take up storage space), regardless of the +// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). +// +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseMul", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. +type MaxPool3DGradAttr func(optionalAttr) + +// MaxPool3DGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of max pooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool3DGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the name of the device on which `resource` has been placed. +func ExperimentalIteratorGetDevice(scope *Scope, resource tf.Output) (device tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExperimentalIteratorGetDevice", + Input: []tf.Input{ + resource, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseReduceSumAttr is an optional argument to SparseReduceSum. +type SparseReduceSumAttr func(optionalAttr) + +// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceSum", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Records the latency of producing `input_dataset` elements in a StatsAggregator. +func ExperimentalLatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalLatencyStatsDataset", + Input: []tf.Input{ + input_dataset, tag, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. +// +// This Op does not require `a_indices` be sorted in standard lexicographic order. +// +// Arguments: +// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. +// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. +// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. +// b: `ndims`-D Tensor. With shape `a_shape`. +func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseTensorDenseAdd", + Input: []tf.Input{ + a_indices, a_values, a_shape, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Split a `SparseTensor` into `num_split` tensors along one dimension. +// +// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices +// `[0 : shape[split_dim] % num_split]` gets one extra dimension. +// For example, if `split_dim = 1` and `num_split = 2` and the input is +// +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] +// +// Graphically the output tensors are: +// +// output_tensor[0] = shape = [2, 4] +// [ a ] +// [b c ] +// +// output_tensor[1] = shape = [2, 3] +// [ d e ] +// [ ] +// +// Arguments: +// split_dim: 0-D. The dimension along which to split. Must be in the range +// `[0, rank(shape))`. +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// num_split: The number of ways to split. +// +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "SparseSplit", + Input: []tf.Input{ + split_dim, indices, values, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + return output_indices, output_values, output_shape +} + +// Applies sparse addition to `input` using individual values or slices +// +// from `updates` according to indices `indices`. The updates are non-aliasing: +// `input` is only modified in-place if no other operations will use it. +// Otherwise, a copy of `input` is made. This operation has a gradient with +// respect to both `input` and `updates`. +// +// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `input`. +// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or `(P-K)`-dimensional slices +// (if `K < P`) along the `K`th dimension of `input`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ +// +// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +// elements. In Python, that addition would look like this: +// +// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) +// with tf.Session() as sess: +// print(sess.run(output)) +// +// The resulting value `output` would look like this: +// +// [1, 13, 3, 14, 14, 6, 7, 20] +// +// See `tf.scatter_nd` for more details about how to make updates to slices. +// +// Arguments: +// input: A Tensor. +// indices: A Tensor. Must be one of the following types: `int32`, `int64`. +// A tensor of indices into `input`. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to add to `input`. +// +// Returns A `Tensor` with the same shape as `input`, containing values of `input` +// updated with `updates`. +func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ScatterNdNonAliasingAdd", + Input: []tf.Input{ + input, indices, updates, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a MultiDeviceIterator resource. +// +// Arguments: +// devices: A list of devices the iterator works across. +// shared_name: If non-empty, this resource will be shared under the given name +// across multiple sessions. +// container: If non-empty, this resource is placed in the given container. +// Otherwise, a default container is used. +// output_types: The type list for the return values. +// output_shapes: The list of shapes being produced. +// +// Returns Handle to the resource created. +func MultiDeviceIterator(scope *Scope, devices []string, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"devices": devices, "shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "MultiDeviceIterator", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. +type FractionalMaxPoolAttr func(optionalAttr) + +// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. +// +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. +// If not specified, defaults to false +func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["pseudo_random"] = value + } +} + +// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. +// +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [20, 16] for fractional max pooling. +// If not specified, defaults to false +func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } +} + +// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. +// +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalMaxPool node in the computation graph. Mainly used +// in unit test to make FractionalMaxPool deterministic. +// If not specified, defaults to false +func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["deterministic"] = value + } +} + +// FractionalMaxPoolSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Performs fractional max pooling on the input. +// +// Fractional max pooling is slightly different than regular max pooling. In +// regular max pooling, you downsize an input set by taking the maximum value of +// smaller N x N subsections of the set (often 2x2), and try to reduce the set by +// a factor of N, where N is an integer. Fractional max pooling, as you might +// expect from the word "fractional", means that the overall reduction ratio N +// does not have to be an integer. +// +// The sizes of the pooling regions are generated randomly but are fairly uniform. +// For example, let's look at the height dimension, and the constraints on the +// list of rows that will be pool boundaries. +// +// First we define the following: +// +// 1. input_row_length : the number of rows from the input set +// 2. output_row_length : which will be smaller than the input +// 3. alpha = input_row_length / output_row_length : our reduction ratio +// 4. K = floor(alpha) +// 5. row_pooling_sequence : this is the result list of pool boundary rows +// +// Then, row_pooling_sequence should satisfy: +// +// 1. a[0] = 0 : the first value of the sequence is 0 +// 2. a[end] = input_row_length : the last value of the sequence is the size +// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +// 4. length(row_pooling_sequence) = output_row_length+1 +// +// For more details on fractional max pooling, see this paper: +// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) +// +// Arguments: +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. +// +// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FractionalMaxPool", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Generates sparse cross from a list of sparse and dense tensors. +// +// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +// representing features of one feature column. It outputs a 2D `SparseTensor` with +// the batchwise crosses of these features. +// +// For example, if the inputs are +// +// inputs[0]: SparseTensor with shape = [2, 2] +// [0, 0]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// inputs[1]: SparseTensor with shape = [2, 1] +// [0, 0]: "d" +// [1, 0]: "e" +// +// inputs[2]: Tensor [["f"], ["g"]] +// +// then the output will be +// +// shape = [2, 2] +// [0, 0]: "a_X_d_X_f" +// [1, 0]: "b_X_e_X_g" +// [1, 1]: "c_X_e_X_g" +// +// if hashed_output=true then the output will be +// +// shape = [2, 2] +// [0, 0]: FingerprintCat64( +// Fingerprint64("f"), FingerprintCat64( +// Fingerprint64("d"), Fingerprint64("a"))) +// [1, 0]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("b"))) +// [1, 1]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("c"))) +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// dense_inputs: 2-D. Columns represented by dense `Tensor`. +// hashed_output: If true, returns the hash of the cross instead of the string. +// This will allow us avoiding string manipulations. +// num_buckets: It is used if hashed_output is true. +// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. +// hash_key: Specify the hash_key that will be used by the `FingerprintCat64` +// function to combine the crosses fingerprints. +// +// +// +// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated or hashed +// `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. +func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, hashed_output bool, num_buckets int64, hash_key int64, out_type tf.DataType, internal_type tf.DataType) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_type": out_type, "internal_type": internal_type} + opspec := tf.OpSpec{ + Type: "SparseCross", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Inverse real-valued fast Fourier transform. +// +// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most dimension of `input`. +// +// The inner-most dimension of `input` is assumed to be the result of `RFFT`: the +// `fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If +// `fft_length` is not provided, it is computed from the size of the inner-most +// dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to +// compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller +// than the corresponding dimension of `input`, the dimension is cropped. If it is +// larger, the dimension is padded with zeros. +// +// Arguments: +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [1]. The FFT length. +// +// Returns A float32 tensor of the same rank as `input`. The inner-most +// dimension of `input` is replaced with the `fft_length` samples of its inverse +// 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.irfft +// @end_compatibility +func IRFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IRFFT", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concatenates a list of `SparseTensor` along the specified dimension. +// +// Concatenation is with respect to the dense versions of these sparse tensors. +// It is assumed that each input is a `SparseTensor` whose elements are ordered +// along increasing dimension number. +// +// All inputs' shapes must match, except for the concat dimension. The +// `indices`, `values`, and `shapes` lists must have the same length. +// +// The output shape is identical to the inputs', except along the concat +// dimension, where it is the sum of the inputs' sizes along that dimension. +// +// The output elements will be resorted to preserve the sort order along +// increasing dimension number. +// +// This op runs in `O(M log M)` time, where `M` is the total number of non-empty +// values across all inputs. This is due to the need for an internal sort in +// order to concatenate efficiently across an arbitrary dimension. +// +// For example, if `concat_dim = 1` and the inputs are +// +// sp_inputs[0]: shape = [2, 3] +// [0, 2]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// sp_inputs[1]: shape = [2, 4] +// [0, 1]: "d" +// [0, 2]: "e" +// +// then the output will be +// +// shape = [2, 7] +// [0, 2]: "a" +// [0, 4]: "d" +// [0, 5]: "e" +// [1, 0]: "b" +// [1, 1]: "c" +// +// Graphically this is equivalent to doing +// +// [ a] concat [ d e ] = [ a d e ] +// [b c ] [ ] [b c ] +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. Non-empty values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// concat_dim: Dimension to concatenate along. Must be in range [-rank, rank), +// where rank is the number of dimensions in each input `SparseTensor`. +// +// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. +func SparseConcat(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, concat_dim int64) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"concat_dim": concat_dim} + opspec := tf.OpSpec{ + Type: "SparseConcat", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Deserialize and concatenate `SparseTensors` from a serialized minibatch. +// +// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where +// `N` is the minibatch size and the rows correspond to packed outputs of +// `SerializeSparse`. The ranks of the original `SparseTensor` objects +// must all match. When the final `SparseTensor` is created, it has rank one +// higher than the ranks of the incoming `SparseTensor` objects +// (they have been concatenated along a new row dimension). +// +// The output `SparseTensor` object's shape values for all dimensions but the +// first are the max across the input `SparseTensor` objects' shape values +// for the corresponding dimensions. Its first shape value is `N`, the minibatch +// size. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: +// +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects. +// Must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "DeserializeManySparse", + Input: []tf.Input{ + serialized_sparse, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Deserialize `SparseTensor` objects. +// +// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +// the last dimension stores serialized `SparseTensor` objects and the other N +// dimensions (N >= 0) correspond to a batch. The ranks of the original +// `SparseTensor` objects must all match. When the final `SparseTensor` is +// created, its rank is the rank of the incoming `SparseTensor` objects plus N; +// the sparse tensors have been concatenated along new dimensions, one for each +// batch. +// +// The output `SparseTensor` object's shape values for the original dimensions +// are the max across the input `SparseTensor` objects' shape values for the +// corresponding dimensions. The new dimensions match the size of the batch. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: +// +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: The serialized `SparseTensor` objects. The last dimension +// must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "DeserializeSparse", + Input: []tf.Input{ + serialized_sparse, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. +type MaxPool3DGradGradAttr func(optionalAttr) + +// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool3DGradGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. +type Conv3DBackpropFilterV2Attr func(optionalAttr) + +// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the filter. +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 5-D +// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` +// tensor. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropFilterV2", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Execute a sub graph on a remote processor. +// +// The graph specifications(such as graph itself, input tensors and output names) +// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +// as serialized_remote_fused_graph_execute_info. +// The specifications will be passed to a dedicated registered +// remote fused graph executor. The executor will send the graph specifications +// to a remote processor and execute that graph. The execution results +// will be passed to consumer nodes as outputs of this node. +// +// Arguments: +// inputs: Arbitrary number of tensors with arbitrary data types +// +// serialized_remote_fused_graph_execute_info: Serialized protocol buffer +// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// +// Returns Arbitrary number of tensors with arbitrary data types +func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + opspec := tf.OpSpec{ + Type: "RemoteFusedGraphExecute", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("RemoteFusedGraphExecute", err) + return + } + return outputs +} + +// SerializeManySparseAttr is an optional argument to SerializeManySparse. +type SerializeManySparseAttr func(optionalAttr) + +// SerializeManySparseOutType sets the optional out_type attribute to value. +// +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. +// +// The `SparseTensor` must have rank `R` greater than 1, and the first dimension +// is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The serialized +// `SparseTensor` objects going into each row of `serialized_sparse` will have +// rank `R-1`. +// +// The minibatch size `N` is extracted from `sparse_shape[0]`. +// +// Arguments: +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SerializeManySparse", + Input: []tf.Input{ + sparse_indices, sparse_values, sparse_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes inverse hyperbolic cosine of x element-wise. +func Acosh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Acosh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes rectified linear 6 gradients for a Relu6 operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Relu6 operation. +// features: The features passed as input to the corresponding Relu6 operation, or +// its output; using either one produces the same result. +// +// Returns The gradients: +// `gradients * (features > 0) * (features < 6)`. +func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu6Grad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes natural logarithm of (1 + x) element-wise. +// +// I.e., \\(y = \log_e (1 + x)\\). +func Log1p(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Log1p", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeBicubicAttr is an optional argument to ResizeBicubic. +type ResizeBicubicAttr func(optionalAttr) + +// ResizeBicubicAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBicubicAlignCorners(value bool) ResizeBicubicAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeBicubicHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBicubicHalfPixelCenters(value bool) ResizeBicubicAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Resize `images` to `size` using bicubic interpolation. +// +// Input images can be of different types but output images are always float. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBicubicAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBicubic", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseTensorDenseMatMulAttr is an optional argument to SparseTensorDenseMatMul. +type SparseTensorDenseMatMulAttr func(optionalAttr) + +// SparseTensorDenseMatMulAdjointA sets the optional adjoint_a attribute to value. +// +// value: Use the adjoint of A in the matrix multiply. If A is complex, this +// is transpose(conj(A)). Otherwise it's transpose(A). +// If not specified, defaults to false +func SparseTensorDenseMatMulAdjointA(value bool) SparseTensorDenseMatMulAttr { + return func(m optionalAttr) { + m["adjoint_a"] = value + } +} + +// SparseTensorDenseMatMulAdjointB sets the optional adjoint_b attribute to value. +// +// value: Use the adjoint of B in the matrix multiply. If B is complex, this +// is transpose(conj(B)). Otherwise it's transpose(B). +// If not specified, defaults to false +func SparseTensorDenseMatMulAdjointB(value bool) SparseTensorDenseMatMulAttr { + return func(m optionalAttr) { + m["adjoint_b"] = value + } +} + +// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". +// +// No validity checking is performed on the indices of A. However, the following +// input format is recommended for optimal behavior: +// +// if adjoint_a == false: +// A should be sorted in lexicographically increasing order. Use SparseReorder +// if you're not sure. +// if adjoint_a == true: +// A should be sorted in order of increasing dimension 1 (i.e., "column major" +// order instead of "row major" order). +// +// Arguments: +// a_indices: 2-D. The `indices` of the `SparseTensor`, size `[nnz, 2]` Matrix. +// a_values: 1-D. The `values` of the `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the `SparseTensor`, size `[2]` Vector. +// b: 2-D. A dense Matrix. +func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output, optional ...SparseTensorDenseMatMulAttr) (product tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseTensorDenseMatMul", + Input: []tf.Input{ + a_indices, a_values, a_shape, b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds two `SparseTensor` objects to produce another `SparseTensor`. +// +// The input `SparseTensor` objects' indices are assumed ordered in standard +// lexicographic order. If this is not the case, before this step run +// `SparseReorder` to restore index ordering. +// +// By default, if two values sum to zero at some index, the output `SparseTensor` +// would still include that particular location in its index, storing a zero in the +// corresponding value slot. To override this, callers can specify `thresh`, +// indicating that if the sum has a magnitude strictly smaller than `thresh`, its +// corresponding value and index would then not be included. In particular, +// `thresh == 0` (default) means everything is kept and actual thresholding happens +// only for a positive value. +// +// In the following shapes, `nnz` is the count after taking `thresh` into account. +// +// Arguments: +// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. +// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. +// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. +// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. +// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. +// thresh: 0-D. The magnitude threshold that determines if an output value/index +// pair takes space. +func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseAdd", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// EnqueueTPUEmbeddingSparseTensorBatchAttr is an optional argument to EnqueueTPUEmbeddingSparseTensorBatch. +type EnqueueTPUEmbeddingSparseTensorBatchAttr func(optionalAttr) + +// EnqueueTPUEmbeddingSparseTensorBatchDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. Should be >= 0 and less than the number +// of TPU cores in the task on which the node is placed. +// If not specified, defaults to -1 +func EnqueueTPUEmbeddingSparseTensorBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingSparseTensorBatchAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// EnqueueTPUEmbeddingSparseTensorBatchCombiners sets the optional combiners attribute to value. +// +// value: A list of string scalars, one for each embedding table that specify +// how to normalize the embedding activations after weighted summation. +// Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have +// the sum of the weights be 0 for 'mean' or the sum of the squared weights be +// 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for +// all tables. +// If not specified, defaults to <> +func EnqueueTPUEmbeddingSparseTensorBatchCombiners(value []string) EnqueueTPUEmbeddingSparseTensorBatchAttr { + return func(m optionalAttr) { + m["combiners"] = value + } +} + +// Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). +// +// sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond +// to the ith feature. table_ids[i] indicates which embedding table to look up ith +// feature. +// +// The tensors at corresponding positions in the three input lists (sample_indices, +// embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1 +// with dim_size() equal to the total number of lookups into the table described by +// the corresponding feature. +// +// Arguments: +// sample_indices: A list of rank 1 Tensors specifying the training example to +// which the corresponding embedding_indices and aggregation_weights values +// belong. It corresponds to sp_ids.indices[:,0] in embedding_lookup_sparse(). +// embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. +// It corresponds to sp_ids.values in embedding_lookup_sparse(). +// aggregation_weights: A list of rank 1 Tensors containing per training example +// aggregation weights. It corresponds to sp_weights.values in +// embedding_lookup_sparse(). +// mode_override: A string input that overrides the mode specified in the +// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. +// table_ids: A list of integers specifying the identifier of the embedding table +// (offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the +// corresponding input. The ith input is looked up using table_ids[i]. The size +// of the table_ids list must be equal to that of sample_indices, +// embedding_indices and aggregation_weights. +// +// Returns the created operation. +func EnqueueTPUEmbeddingSparseTensorBatch(scope *Scope, sample_indices []tf.Output, embedding_indices []tf.Output, aggregation_weights []tf.Output, mode_override tf.Output, table_ids []int64, optional ...EnqueueTPUEmbeddingSparseTensorBatchAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"table_ids": table_ids} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EnqueueTPUEmbeddingSparseTensorBatch", + Input: []tf.Input{ + tf.OutputList(sample_indices), tf.OutputList(embedding_indices), tf.OutputList(aggregation_weights), mode_override, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// The gradient operator for the SparseAdd op. +// +// The SparseAdd op calculates A + B, where A, B, and the sum are all represented +// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. +// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty +// values of A and B. +// +// Arguments: +// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to +// the non-empty values of the sum. +// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. +// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. +// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size +// `[nnz(sum), ndims]`. +// +// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the +// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the +// non-empty values of B. +func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseAddGrad", + Input: []tf.Input{ + backprop_val_grad, a_indices, b_indices, sum_indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// DenseToSparseSetOperationAttr is an optional argument to DenseToSparseSetOperation. +type DenseToSparseSetOperationAttr func(optionalAttr) + +// DenseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToSparseSetOperationValidateIndices(value bool) DenseToSparseSetOperationAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Applies set operation along last dimension of `Tensor` and `SparseTensor`. +// +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set2` +// indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as the 1st `n-1` dimensions of `set1`, `result_shape[n]` is the +// max set size across `n-1` dimensions. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToSparseSetOperation(scope *Scope, set1 tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...DenseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"set_operation": set_operation} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DenseToSparseSetOperation", + Input: []tf.Input{ + set1, set2_indices, set2_values, set2_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// L2 Loss. +// +// Computes half the L2 norm of a tensor without the `sqrt`: +// +// output = sum(t ** 2) / 2 +// +// Arguments: +// t: Typically 2-D, but may have any dimensions. +// +// Returns 0-D. +func L2Loss(scope *Scope, t tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "L2Loss", + Input: []tf.Input{ + t, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv3DAttr is an optional argument to Conv3D. +type Conv3DAttr func(optionalAttr) + +// Conv3DDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DDataFormat(value string) Conv3DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv3DDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DDilations(value []int64) Conv3DAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 3-D convolution given 5-D `input` and `filter` tensors. +// +// In signal processing, cross-correlation is a measure of similarity of +// two waveforms as a function of a time-lag applied to one of them. This +// is also known as a sliding dot product or sliding inner-product. +// +// Our Conv3D implements a form of cross-correlation. +// +// Arguments: +// input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. +// filter: Shape `[filter_depth, filter_height, filter_width, in_channels, +// out_channels]`. `in_channels` must match between `input` and `filter`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv3DAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3D", + Input: []tf.Input{ + input, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds up a SparseTensor and a dense Tensor, using these special rules: +// +// (1) Broadcasts the dense side to have the same shape as the sparse side, if +// eligible; +// (2) Then, only the dense values pointed to by the indices of the SparseTensor +// participate in the cwise addition. +// +// By these rules, the result is a logical SparseTensor with exactly the same +// indices and shape, but possibly with different non-zero values. The output of +// this Op is the resultant non-zero values. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseAdd", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UnicodeDecodeAttr is an optional argument to UnicodeDecode. +type UnicodeDecodeAttr func(optionalAttr) + +// UnicodeDecodeErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeDecodeErrors(value string) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeDecodeReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) +// If not specified, defaults to 65533 +func UnicodeDecodeReplacementChar(value int64) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeDecodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeDecodeReplaceControlCharacters(value bool) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// Decodes each string in `input` into a sequence of Unicode code points. +// +// The character codepoints for all strings are returned using a single vector +// `char_values`, with strings expanded to characters in row-major order. +// +// The `row_splits` tensor indicates where the codepoints for +// each input string begin and end within the `char_values` tensor. +// In particular, the values for the `i`th +// string (in row-major order) are stored in the slice +// `[row_splits[i]:row_splits[i+1]]`. Thus: +// +// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th +// character in the `i`th string (in row-major order). +// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th +// string (in row-major order). +// +// Arguments: +// input: The text to be decoded. Can have any shape. Note that the output is flattened +// to a vector of char values. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// +// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints. +func UnicodeDecode(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeAttr) (row_splits tf.Output, char_values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"input_encoding": input_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeDecode", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// QuantizeV2Attr is an optional argument to QuantizeV2. +type QuantizeV2Attr func(optionalAttr) + +// QuantizeV2Mode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func QuantizeV2Mode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// QuantizeV2RoundMode sets the optional round_mode attribute to value. +// If not specified, defaults to "HALF_AWAY_FROM_ZERO" +func QuantizeV2RoundMode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["round_mode"] = value + } +} + +// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. The +// 'round_mode' attribute controls which rounding tie-breaking algorithm is used +// when rounding float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) +// if T == qint8: out[i] -= (range(T) + 1) / 2.0 +// ``` +// +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// Assume the input is type float and has a possible range of [0.0, 6.0] and the +// output type is quint8 ([0, 255]). The min_range and max_range values should be +// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each +// value of the input by 255/6 and cast to quint8. +// +// If the output type was qint8 ([-128, 127]), the operation will additionally +// subtract each value by 128 prior to casting, so that the range of values aligns +// with the range of qint8. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ``` +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = num_discrete_values / range +// quantized = round(input * range_scale) - round(range_min * range_scale) + +// numeric_limits::min() +// quantized = max(quantized, numeric_limits::min()) +// quantized = min(quantized, numeric_limits::max()) +// ``` +// +// The biggest difference between this and MIN_COMBINED is that the minimum range +// is rounded first, before it's subtracted from the rounded value. With +// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing +// and dequantizing will introduce a larger and larger error. +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// +// ```c++ +// s = (max_fixed - min_fixed) / (2 * m) +// ``` +// +// Now we can quantize the elements of our tensor: +// +// ```c++ +// result = round(input * s) +// ``` +// +// One thing to watch out for is that the operator may choose to adjust the +// requested minimum and maximum values slightly during the quantization process, +// so you should always use the output ports as the range for further calculations. +// For example, if the requested minimum and maximum values are close to equal, +// they will be separated by a small epsilon value to prevent ill-formed quantized +// buffers from being created. Otherwise, you can end up with buffers where all the +// quantized values map to the same float value, which causes problems for +// operations that have to perform further calculations on them. +// +// Arguments: +// +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +// +// +// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. +func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeV2", + Input: []tf.Input{ + input, min_range, max_range, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth. +type HistogramFixedWidthAttr func(optionalAttr) + +// HistogramFixedWidthDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_INT32 +func HistogramFixedWidthDtype(value tf.DataType) HistogramFixedWidthAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Return histogram of values. +// +// Given the tensor `values`, this operation returns a rank 1 histogram counting +// the number of entries in `values` that fall into every bin. The bins are +// equal width and determined by the arguments `value_range` and `nbins`. +// +// ```python +// # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) +// nbins = 5 +// value_range = [0.0, 5.0] +// new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] +// +// with tf.get_default_session() as sess: +// hist = tf.histogram_fixed_width(new_values, value_range, nbins=5) +// variables.global_variables_initializer().run() +// sess.run(hist) => [2, 1, 1, 0, 2] +// ``` +// +// Arguments: +// values: Numeric `Tensor`. +// value_range: Shape [2] `Tensor` of same `dtype` as `values`. +// values <= value_range[0] will be mapped to hist[0], +// values >= value_range[1] will be mapped to hist[-1]. +// nbins: Scalar `int32 Tensor`. Number of histogram bins. +// +// Returns A 1-D `Tensor` holding histogram of values. +func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, nbins tf.Output, optional ...HistogramFixedWidthAttr) (out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "HistogramFixedWidth", + Input: []tf.Input{ + values, value_range, nbins, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Serializes the tree handle to a proto +// +// Arguments: +// tree_handle: Handle to the tree resource to be serialized. +// +// Returns Serialied proto string of the tree resource. +func TensorForestTreeSerialize(scope *Scope, tree_handle tf.Output) (tree_config tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorForestTreeSerialize", + Input: []tf.Input{ + tree_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseMatMulAttr is an optional argument to SparseMatMul. +type SparseMatMulAttr func(optionalAttr) + +// SparseMatMulTransposeA sets the optional transpose_a attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeA(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// SparseMatMulTransposeB sets the optional transpose_b attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeB(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["a_is_sparse"] = value + } +} + +// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["b_is_sparse"] = value + } +} + +// Multiply matrix "a" by matrix "b". +// +// The inputs must be two-dimensional matrices and the inner dimension of "a" must +// match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not +// `SparseTensor`s. This op is optimized for the case where at least one of "a" or +// "b" is sparse, in the sense that they have a large proportion of zero values. +// The breakeven for using this versus a dense matrix multiply on one platform was +// 30% zero values in the sparse matrix. +// +// The gradient computation of this operation will only take advantage of sparsity +// in the input gradient when that gradient comes from a Relu. +func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseMatMul", + Input: []tf.Input{ + a, b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ExperimentalThreadPoolHandleAttr is an optional argument to ExperimentalThreadPoolHandle. +type ExperimentalThreadPoolHandleAttr func(optionalAttr) + +// ExperimentalThreadPoolHandleMaxIntraOpParallelism sets the optional max_intra_op_parallelism attribute to value. +// +// value: The maximum degree of parallelism to use within operations that execute on this +// threadpool. +// If not specified, defaults to 1 +func ExperimentalThreadPoolHandleMaxIntraOpParallelism(value int64) ExperimentalThreadPoolHandleAttr { + return func(m optionalAttr) { + m["max_intra_op_parallelism"] = value + } +} + +// ExperimentalThreadPoolHandleContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func ExperimentalThreadPoolHandleContainer(value string) ExperimentalThreadPoolHandleAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// ExperimentalThreadPoolHandleSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func ExperimentalThreadPoolHandleSharedName(value string) ExperimentalThreadPoolHandleAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a dataset that uses a custom thread pool to compute `input_dataset`. +// +// Arguments: +// num_threads: The number of threads in the thread pool. +// display_name: A human-readable name for the threads that may be visible in some +// visualizations. +// threadpool. +// +// Returns A resource that can be consumed by one or more ExperimentalThreadPoolDataset +// ops. +func ExperimentalThreadPoolHandle(scope *Scope, num_threads int64, display_name string, optional ...ExperimentalThreadPoolHandleAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_threads": num_threads, "display_name": display_name} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExperimentalThreadPoolHandle", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug. +type LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load proximal Adagrad embedding parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the proximal Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the proximal Adagrad optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the proximal Adagrad optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", + Input: []tf.Input{ + parameters, accumulators, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// LoadTPUEmbeddingProximalAdagradParametersAttr is an optional argument to LoadTPUEmbeddingProximalAdagradParameters. +type LoadTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingProximalAdagradParametersTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingProximalAdagradParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingProximalAdagradParametersTableName(value string) LoadTPUEmbeddingProximalAdagradParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load proximal Adagrad embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the proximal Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the proximal Adagrad optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingProximalAdagradParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingProximalAdagradParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingProximalAdagradParameters", + Input: []tf.Input{ + parameters, accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Get the current size of the TensorArray. +// +// Arguments: +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns The current size of the TensorArray. +func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySizeV3", + Input: []tf.Input{ + handle, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes gradients for the scaled exponential linear (Selu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Selu operation. +// outputs: The outputs of the corresponding Selu operation. +// +// Returns The gradients: `gradients * (outputs + scale * alpha)` +// if outputs < 0, `scale * gradients` otherwise. +func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SeluGrad", + Input: []tf.Input{ + gradients, outputs, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. +type ResourceSparseApplyFtrlV2Attr func(optionalAttr) + +// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// +// That is for rows we have grad for, we update var, accum and linear as follows: +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. +// +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyFtrlV2", + Input: []tf.Input{ + var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SumAttr is an optional argument to Sum. +type SumAttr func(optionalAttr) + +// SumKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SumKeepDims(value bool) SumAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Sum", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. +type SparseToSparseSetOperationAttr func(optionalAttr) + +// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Applies set operation along last dimension of 2 `SparseTensor` inputs. +// +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the +// order and range of `set1` and `set2` indices. +// +// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, +// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same +// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set1` +// and `set2` indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must +// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the +// max set size across `0...n-1` dimensions. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the +// max set size across `0...n-1` dimensions. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"set_operation": set_operation} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseToSparseSetOperation", + Input: []tf.Input{ + set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes softmax cross entropy cost and gradients to backpropagate. +// +// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +// a matrix of label probabilities, but rather a single label per row +// of features. This label is considered to have probability 1.0 for the +// given row. +// +// Inputs are the logits, not probabilities. +// +// Arguments: +// features: batch_size x num_classes matrix +// labels: batch_size vector with values in [0, num_classes). +// This is the label for the given minibatch entry. +// +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSoftmaxCrossEntropyWithLogits", + Input: []tf.Input{ + features, labels, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// StridedSliceGradAttr is an optional argument to StridedSliceGrad. +type StridedSliceGradAttr func(optionalAttr) + +// StridedSliceGradBeginMask sets the optional begin_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["begin_mask"] = value + } +} + +// StridedSliceGradEndMask sets the optional end_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradEndMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["end_mask"] = value + } +} + +// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} + +// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["new_axis_mask"] = value + } +} + +// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Returns the gradient of `StridedSlice`. +// +// Since `StridedSlice` cuts out pieces of its `input` which is size +// `shape`, its gradient will have the same shape (which is passed here +// as `shape`). The gradient will be zero in any element that the slice +// does not select. +// +// Arguments are the same as StridedSliceGrad with the exception that +// `dy` is the input gradient to be propagated and `shape` is the +// shape of `StridedSlice`'s `input`. +func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StridedSliceGrad", + Input: []tf.Input{ + shape, begin, end, strides, dy, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingRMSPropParameters. +type LoadTPUEmbeddingRMSPropParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingRMSPropParametersTableId(value int64) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingRMSPropParametersTableName(value string) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load RMSProp embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the RMSProp optimization algorithm. +// ms: Value of ms used in the RMSProp optimization algorithm. +// mom: Value of mom used in the RMSProp optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingRMSPropParameters", + Input: []tf.Input{ + parameters, ms, mom, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes the gradient for the inverse of `x` wrt its input. +// +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReciprocalGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EuclideanNormAttr is an optional argument to EuclideanNorm. +type EuclideanNormAttr func(optionalAttr) + +// EuclideanNormKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func EuclideanNormKeepDims(value bool) EuclideanNormAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the euclidean norm of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func EuclideanNorm(scope *Scope, input tf.Output, axis tf.Output, optional ...EuclideanNormAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EuclideanNorm", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise min of two SparseTensors. +// +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// +// Arguments: +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSparseMinimum", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. +type ResourceSparseApplyAdagradDAAttr func(optionalAttr) + +// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. +// +// Arguments: +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdagradDA", + Input: []tf.Input{ + var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) + +// EncodeJpegFormat sets the optional format attribute to value. +// +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["format"] = value + } +} + +// EncodeJpegQuality sets the optional quality attribute to value. +// +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["quality"] = value + } +} + +// EncodeJpegProgressive sets the optional progressive attribute to value. +// +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["progressive"] = value + } +} + +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value + } +} + +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value + } +} + +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["density_unit"] = value + } +} + +// EncodeJpegXDensity sets the optional x_density attribute to value. +// +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value + } +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} + +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. +// +// Arguments: +// image: 3-D with shape `[height, width, channels]`. +// +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeJpeg", + Input: []tf.Input{ + image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) + +// MultinomialSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func MultinomialSeed(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// MultinomialSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func MultinomialSeed2(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value + } +} + +// Draws samples from a multinomial distribution. +// +// Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. +// +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Multinomial", + Input: []tf.Input{ + logits, num_samples, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingRMSPropParameters. +type RetrieveTPUEmbeddingRMSPropParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingRMSPropParametersTableName(value string) RetrieveTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve RMSProp embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the RMSProp optimization algorithm.Parameter ms updated by the RMSProp optimization algorithm.Parameter mom updated by the RMSProp optimization algorithm. +func RetrieveTPUEmbeddingRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingRMSPropParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. +type QuantizedRelu6Attr func(optionalAttr) + +// QuantizedRelu6OutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` +// +// Arguments: +// +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedRelu6", + Input: []tf.Input{ + features, min_features, max_features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// BatchMatMulAttr is an optional argument to BatchMatMul. +type BatchMatMulAttr func(optionalAttr) + +// BatchMatMulAdjX sets the optional adj_x attribute to value. +// +// value: If `True`, adjoint the slices of `x`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjX(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_x"] = value + } +} + +// BatchMatMulAdjY sets the optional adj_y attribute to value. +// +// value: If `True`, adjoint the slices of `y`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjY(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_y"] = value + } +} + +// Multiplies slices of two tensors in batches. +// +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be adjointed (to adjoint a matrix +// means to transpose and conjugate it) before multiplication by setting +// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if adj_x else r_x +// c_o = r_y if adj_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// +// Arguments: +// x: 2-D or higher with shape `[..., r_x, c_x]`. +// y: 2-D or higher with shape `[..., r_y, c_y]`. +// +// Returns 3-D or higher with shape `[..., r_o, c_o]` +func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BatchMatMul", + Input: []tf.Input{ + x, y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ParseSequenceExampleAttr is an optional argument to ParseSequenceExample. +type ParseSequenceExampleAttr func(optionalAttr) + +// ParseSequenceExampleNcontextSparse sets the optional Ncontext_sparse attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func ParseSequenceExampleNcontextSparse(value int64) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["Ncontext_sparse"] = value + } +} + +// ParseSequenceExampleNcontextDense sets the optional Ncontext_dense attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func ParseSequenceExampleNcontextDense(value int64) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["Ncontext_dense"] = value + } +} + +// ParseSequenceExampleNfeatureListSparse sets the optional Nfeature_list_sparse attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func ParseSequenceExampleNfeatureListSparse(value int64) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["Nfeature_list_sparse"] = value + } +} + +// ParseSequenceExampleNfeatureListDense sets the optional Nfeature_list_dense attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func ParseSequenceExampleNfeatureListDense(value int64) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["Nfeature_list_dense"] = value + } +} + +// ParseSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. +// +// value: A list of Ncontext_sparse types; the data types of data in +// each context Feature given in context_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSequenceExampleContextSparseTypes(value []tf.DataType) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["context_sparse_types"] = value + } +} + +// ParseSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_types"] = value + } +} + +// ParseSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. +// +// value: A list of Ncontext_dense shapes; the shapes of data in +// each context Feature given in context_dense_keys. +// The number of elements in the Feature corresponding to context_dense_key[j] +// must always equal context_dense_shapes[j].NumEntries(). +// The shape of context_dense_values[j] will match context_dense_shapes[j]. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSequenceExampleContextDenseShapes(value []tf.Shape) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["context_dense_shapes"] = value + } +} + +// ParseSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. +// +// value: A list of Nfeature_list_sparse types; the data types +// of data in each FeatureList given in feature_list_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_sparse_types"] = value + } +} + +// ParseSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. +// +// value: A list of Nfeature_list_dense shapes; the shapes of +// data in each FeatureList given in feature_list_dense_keys. +// The shape of each Feature in the FeatureList corresponding to +// feature_list_dense_key[j] must always equal +// feature_list_dense_shapes[j].NumEntries(). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_shapes"] = value + } +} + +// Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors. +// +// Arguments: +// serialized: A vector containing binary serialized SequenceExample protos. +// debug_name: A vector containing the names of the serialized protos. +// May contain, for example, table key (descriptive) name for the +// corresponding serialized proto. This is purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty vector if no name is available. +// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). +// context_dense_defaults[j] provides default values +// when the SequenceExample's context map lacks context_dense_key[j]. +// If an empty Tensor is provided for context_dense_defaults[j], +// then the Feature context_dense_keys[j] is required. +// The input type is inferred from context_dense_defaults[j], even when it's +// empty. If context_dense_defaults[j] is not empty, its shape must match +// context_dense_shapes[j]. +// feature_list_dense_missing_assumed_empty: A vector listing the +// FeatureList keys which may be missing from the SequenceExamples. If the +// associated FeatureList is missing, it is treated as empty. By default, +// any FeatureList not listed in this vector must exist in the SequenceExamples. +// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). +// The keys expected in the Examples' features associated with context_sparse +// values. +// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' context features associated with +// dense values. +// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors +// (scalars). The keys expected in the FeatureLists associated with sparse +// values. +// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' feature_lists associated +// with lists of dense values. +func ParseSequenceExample(scope *Scope, serialized tf.Output, debug_name tf.Output, context_dense_defaults []tf.Output, feature_list_dense_missing_assumed_empty []string, context_sparse_keys []string, context_dense_keys []string, feature_list_sparse_keys []string, feature_list_dense_keys []string, optional ...ParseSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output, feature_list_dense_lengths []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"feature_list_dense_missing_assumed_empty": feature_list_dense_missing_assumed_empty, "context_sparse_keys": context_sparse_keys, "context_dense_keys": context_dense_keys, "feature_list_sparse_keys": feature_list_sparse_keys, "feature_list_dense_keys": feature_list_dense_keys} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ParseSequenceExample", + Input: []tf.Input{ + serialized, debug_name, tf.OutputList(context_dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + if feature_list_dense_lengths, idx, err = makeOutputList(op, idx, "feature_list_dense_lengths"); err != nil { + scope.UpdateErr("ParseSequenceExample", err) + return + } + return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths +} + +// QuantizedReluAttr is an optional argument to QuantizedRelu. +type QuantizedReluAttr func(optionalAttr) + +// QuantizedReluOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedReluOutType(value tf.DataType) QuantizedReluAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear: `max(features, 0)` +// +// Arguments: +// +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedRelu", + Input: []tf.Input{ + features, min_features, max_features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Reorders a SparseTensor into the canonical, row-major ordering. +// +// Note that by convention, all sparse ops preserve the canonical ordering along +// increasing dimension number. The only time ordering can be violated is during +// manual manipulation of the indices and values vectors to add entries. +// +// Reordering does not affect the shape of the SparseTensor. +// +// If the tensor has rank `R` and `N` non-empty values, `input_indices` has +// shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// +// Returns 2-D. `N x R` matrix with the same indices as input_indices, but +// in canonical row-major ordering.1-D. `N` non-empty values corresponding to `output_indices`. +func SparseReorder(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseReorder", + Input: []tf.Input{ + input_indices, input_values, input_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// PrelinearizeTupleAttr is an optional argument to PrelinearizeTuple. +type PrelinearizeTupleAttr func(optionalAttr) + +// PrelinearizeTupleLayouts sets the optional layouts attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence for all the +// tuple shapes in the order the shapes appear in the "shapes" input. The layout +// elements for a sub-shape can be set to -1 in which case the corresponding layout +// will be computed by the infeed operation. +// If not specified, defaults to <> +func PrelinearizeTupleLayouts(value []int64) PrelinearizeTupleAttr { + return func(m optionalAttr) { + m["layouts"] = value + } +} + +// An op which linearizes multiple Tensor values to an opaque variant tensor. +// +// Arguments: +// inputs: A list of tensors that will be provided using the infeed mechanism. +// shapes: The shapes of each tensor in `inputs`. +func PrelinearizeTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...PrelinearizeTupleAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shapes": shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PrelinearizeTuple", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ComplexAbsAttr is an optional argument to ComplexAbs. +type ComplexAbsAttr func(optionalAttr) + +// ComplexAbsTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ComplexAbsTout(value tf.DataType) ComplexAbsAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Computes the complex absolute value of a tensor. +// +// Given a tensor `x` of complex numbers, this operation returns a tensor of type +// `float` or `double` that is the absolute value of each element in `x`. All +// elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute +// value is computed as \\( \sqrt{a^2 + b^2}\\). +func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ComplexAbs", + Input: []tf.Input{ + x, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// VariableShapeAttr is an optional argument to VariableShape. +type VariableShapeAttr func(optionalAttr) + +// VariableShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func VariableShapeOutType(value tf.DataType) VariableShapeAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Returns the shape of the variable pointed to by `resource`. +// +// This operation returns a 1-D integer tensor representing the shape of `input`. +// +// For example: +// +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] +// ``` +func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "VariableShape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CompilationResultProto indicating the status of the TPU compilation. +func TPUCompilationResult(scope *Scope) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TPUCompilationResult", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// +// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest +// layer. +func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesGetEnsembleStates", + Input: []tf.Input{ + tree_ensemble_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// Store the input tensor in the state of the current session. +// +// Arguments: +// value: The tensor to be stored. +// +// Returns The handle for the tensor stored in the session state, represented +// as a ResourceHandle object. +func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GetSessionHandleV2", + Input: []tf.Input{ + value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. +type ResourceApplyAdamAttr func(optionalAttr) + +// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, uses the nesterov update. +// If not specified, defaults to false +func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the Adam algorithm. +// +// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ +// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ +// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ +// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdam", + Input: []tf.Input{ + var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. +type SdcaOptimizerAttr func(optionalAttr) + +// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. +// +// value: Whether to use Adaptive SDCA for the inner loop. +// If not specified, defaults to true +func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { + return func(m optionalAttr) { + m["adaptative"] = value + } +} + +// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for +// +// linear models with L1 + L2 regularization. As global optimization objective is +// strongly-convex, the optimizer optimizes the dual objective at each step. The +// optimizer applies each update one example at a time. Examples are sampled +// uniformly, and the optimizer is learning rate free and enjoys linear convergence +// rate. +// +// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
+// Shai Shalev-Shwartz, Tong Zhang. 2012 +// +// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ +// +// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
+// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +// Peter Richtarik, Martin Takac. 2015 +// +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
+// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 +// +// Arguments: +// sparse_example_indices: a list of vectors which contain example indices. +// sparse_feature_indices: a list of vectors which contain feature indices. +// sparse_feature_values: a list of vectors which contains feature value +// associated with each feature group. +// dense_features: a list of matrices which contains the dense feature values. +// example_weights: a vector which contains the weight associated with each +// example. +// example_labels: a vector which contains the label/target associated with each +// example. +// sparse_indices: a list of vectors where each value is the indices which has +// corresponding weights in sparse_weights. This field maybe omitted for the +// dense approach. +// sparse_weights: a list of vectors where each value is the weight associated with +// a sparse feature group. +// dense_weights: a list of vectors where the values are the weights associated +// with a dense feature group. +// example_state_data: a list of vectors containing the example state data. +// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, +// squared and hinge losses. +// l1: Symmetric l1 regularization strength. +// l2: Symmetric l2 regularization strength. +// num_loss_partitions: Number of partitions of the global loss function. +// num_inner_iterations: Number of iterations per mini-batch. +// +// Returns a list of vectors containing the updated example state +// data.a list of vectors where each value is the delta +// weights associated with a sparse feature group.a list of vectors where the values are the delta +// weights associated with a dense feature group. +func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SdcaOptimizer", + Input: []tf.Input{ + tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + out_example_state_data = op.Output(idx) + if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights +} + +// ExperimentalParseExampleDatasetAttr is an optional argument to ExperimentalParseExampleDataset. +type ExperimentalParseExampleDatasetAttr func(optionalAttr) + +// ExperimentalParseExampleDatasetSloppy sets the optional sloppy attribute to value. +// If not specified, defaults to false +func ExperimentalParseExampleDatasetSloppy(value bool) ExperimentalParseExampleDatasetAttr { + return func(m optionalAttr) { + m["sloppy"] = value + } +} + +// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features. +// +// Arguments: +// +// +// dense_defaults: A dict mapping string keys to `Tensor`s. +// The keys of the dict must match the dense_keys of the feature. +// sparse_keys: A list of string keys in the examples features. +// The results for these keys will be returned as `SparseTensor` objects. +// dense_keys: A list of Ndense string Tensors (scalars). +// The keys expected in the Examples features associated with dense values. +// sparse_types: A list of `DTypes` of the same length as `sparse_keys`. +// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), +// and `tf.string` (`BytesList`) are supported. +// dense_shapes: List of tuples with the same length as `dense_keys`. +// The shape of the data for each dense feature referenced by `dense_keys`. +// Required for any input tensors identified by `dense_keys`. Must be +// either fully defined, or may contain an unknown first dimension. +// An unknown first dimension means the feature is treated as having +// a variable number of blocks, and the output shape along this dimension +// is considered unknown at graph build time. Padding is applied for +// minibatch elements smaller than the maximum number of blocks for the +// given feature along this dimension. +// output_types: The type list for the return values. +// output_shapes: The list of shapes being produced. +func ExperimentalParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ExperimentalParseExampleDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExperimentalParseExampleDataset", + Input: []tf.Input{ + input_dataset, num_parallel_calls, tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// PrelinearizeAttr is an optional argument to Prelinearize. +type PrelinearizeAttr func(optionalAttr) + +// PrelinearizeShape sets the optional shape attribute to value. +// +// value: The shape of the tensor. +// If not specified, defaults to <> +func PrelinearizeShape(value tf.Shape) PrelinearizeAttr { + return func(m optionalAttr) { + m["shape"] = value + } +} + +// PrelinearizeLayout sets the optional layout attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence. If a layout +// attribute is passed but its values are all -1 the layout will be computed by +// the infeed operation. +// If not specified, defaults to <> +func PrelinearizeLayout(value []int64) PrelinearizeAttr { + return func(m optionalAttr) { + m["layout"] = value + } +} + +// An op which linearizes one Tensor value to an opaque variant tensor. +// +// Arguments: +// input: A tensor that will be linearized. +func Prelinearize(scope *Scope, input tf.Output, optional ...PrelinearizeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Prelinearize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. +type QuantizedMatMulAttr func(optionalAttr) + +// QuantizedMatMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// +// value: If true, `a` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. +// +// value: If true, `b` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. +// +// value: The type of output produced by activation function +// following this operation. +// If not specified, defaults to DT_QUINT8 +func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Tactivation"] = value + } +} + +// Perform a quantized matrix multiplication of `a` by the matrix `b`. +// +// The inputs must be two-dimensional matrices and the inner dimension of +// `a` (after being transposed if `transpose_a` is non-zero) must match the +// outer dimension of `b` (after being transposed if `transposed_b` is +// non-zero). +// +// Arguments: +// a: Must be a two-dimensional tensor. +// b: Must be a two-dimensional tensor. +// min_a: The float value that the lowest quantized `a` value represents. +// max_a: The float value that the highest quantized `a` value represents. +// min_b: The float value that the lowest quantized `b` value represents. +// max_b: The float value that the highest quantized `b` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { if scope.Err() != nil { return } @@ -11750,115 +16474,122 @@ func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Outpu a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNN", + Type: "QuantizedMatMul", Input: []tf.Input{ - input, input_h, input_c, params, + a, b, min_a, max_a, min_b, max_b, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) + return op.Output(0), op.Output(1), op.Output(2) } -// DecodeCompressedAttr is an optional argument to DecodeCompressed. -type DecodeCompressedAttr func(optionalAttr) - -// DecodeCompressedCompressionType sets the optional compression_type attribute to value. +// Inverse 2D real-valued fast Fourier transform. // -// value: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// If not specified, defaults to "" -func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { - return func(m optionalAttr) { - m["compression_type"] = value - } -} - -// Decompress strings. +// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 2 dimensions of `input`. // -// This op decompresses each element of the `bytes` input `Tensor`, which -// is assumed to be compressed using the given `compression_type`. +// The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 2 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. // -// The `output` is a string `Tensor` of the same shape as `bytes`, -// each element containing the decompressed data from the corresponding -// element in `bytes`. +// Along each axis `IRFFT2D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// bytes: A Tensor of string which is compressed. +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. // -// Returns A Tensor with the same shape as input `bytes`, uncompressed -// from bytes. -func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { +// Returns A float32 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.irfft2 +// @end_compatibility +func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DecodeCompressed", + Type: "IRFFT2D", Input: []tf.Input{ - bytes, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeRawAttr is an optional argument to DecodeRaw. -type DecodeRawAttr func(optionalAttr) +// InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple. +type InfeedEnqueueTupleAttr func(optionalAttr) -// DecodeRawLittleEndian sets the optional little_endian attribute to value. +// InfeedEnqueueTupleLayouts sets the optional layouts attribute to value. // -// value: Whether the input `bytes` are in little-endian order. -// Ignored for `out_type` values that are stored in a single byte like -// `uint8`. -// If not specified, defaults to true -func DecodeRawLittleEndian(value bool) DecodeRawAttr { +// value: A vector holding the requested layout in minor-to-major sequence for +// all the tuple shapes, in the order the shapes appear in the "shapes" input. +// The layout elements for a sub-shape can be set to -1, in which case the +// corresponding layout will be computed by the infeed operation. +// If not specified, defaults to <> +func InfeedEnqueueTupleLayouts(value []int64) InfeedEnqueueTupleAttr { return func(m optionalAttr) { - m["little_endian"] = value + m["layouts"] = value } } -// Reinterpret the bytes of a string as a vector of numbers. +// InfeedEnqueueTupleDeviceOrdinal sets the optional device_ordinal attribute to value. // -// Arguments: -// bytes: All the elements must have the same length. +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func InfeedEnqueueTupleDeviceOrdinal(value int64) InfeedEnqueueTupleAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// Feeds multiple Tensor values into the computation as an XLA tuple. // +// Arguments: +// inputs: A list of tensors that will be provided using the infeed mechanism. +// shapes: The shapes of each tensor in `inputs`. // -// Returns A Tensor with one more dimension than the input `bytes`. The -// added dimension will have size equal to the length of the elements -// of `bytes` divided by the number of bytes to represent `out_type`. -func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { +// Returns the created operation. +func InfeedEnqueueTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...InfeedEnqueueTupleAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{"shapes": shapes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeRaw", + Type: "InfeedEnqueueTuple", Input: []tf.Input{ - bytes, + tf.OutputList(inputs), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Computes natural logarithm of (1 + x) element-wise. +// Returns which elements of x are finite. // -// I.e., \\(y = \log_e (1 + x)\\). -func Log1p(scope *Scope, x tf.Output) (y tf.Output) { +// @compatibility(numpy) +// Equivalent to np.isfinite +// @end_compatibility +func IsFinite(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Log1p", + Type: "IsFinite", Input: []tf.Input{ x, }, @@ -11867,55 +16598,60 @@ func Log1p(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes rectified linear 6 gradients for a Relu6 operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu6 operation. -// features: The features passed as input to the corresponding Relu6 operation, or -// its output; using either one produces the same result. -// -// Returns The gradients: -// `gradients * (features > 0) * (features < 6)`. -func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return +// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign. +type ResourceStridedSliceAssignAttr func(optionalAttr) + +// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["begin_mask"] = value } - opspec := tf.OpSpec{ - Type: "Relu6Grad", - Input: []tf.Input{ - gradients, features, - }, +} + +// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["end_mask"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ResizeBicubicAttr is an optional argument to ResizeBicubic. -type ResizeBicubicAttr func(optionalAttr) +// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} -// ResizeBicubicAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeBicubicAlignCorners(value bool) ResizeBicubicAttr { +// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["new_axis_mask"] = value } } -// Resize `images` to `size` using bicubic interpolation. +// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Assign `value` to the sliced l-value reference of `ref`. // -// Input images can be of different types but output images are always float. +// The values of `value` are assigned to the positions in the variable +// `ref` that are selected by the slice parameters. The slice parameters +// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. // -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// NOTE this op currently does not support broadcasting and so `value`'s +// shape must be exactly the shape produced by the slice of `ref`. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBicubicAttr) (resized_images tf.Output) { +// Returns the created operation. +func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -11924,350 +16660,250 @@ func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...R a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBicubic", + Type: "ResourceStridedSliceAssign", Input: []tf.Input{ - images, size, + ref, begin, end, strides, value, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Gather ragged slices from `params` axis `0` according to `indices`. -// -// Outputs a `RaggedTensor` output composed from `output_dense_values` and -// `output_nested_splits`, such that: -// -// ```python -// output.shape = indices.shape + params.shape[1:] -// output.ragged_rank = indices.shape.ndims + params.ragged_rank -// output[i...j, d0...dn] = params[indices[i...j], d0...dn] -// ``` -// -// where -// -// * `params = -// ragged.from_nested_row_splits(params_dense_values, params_nested_splits)` -// provides the values that should be gathered. -// * `indices` ia a dense tensor with dtype `int32` or `int64`, indicating which -// values should be gathered. -// * `output = -// ragged.from_nested_row_splits(output_dense_values, output_nested_splits)` -// is the output tensor. -// -// (Note: This c++ op is used to implement the higher-level python -// `tf.ragged.gather` op, which also supports ragged indices.) +// ArgMaxAttr is an optional argument to ArgMax. +type ArgMaxAttr func(optionalAttr) + +// ArgMaxOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + +// Returns the index with the largest value across dimensions of a tensor. // +// Note that in case of ties the identity of the return value is not guaranteed. // // Arguments: -// params_nested_splits: The `nested_row_splits` tensors that define the row-partitioning for the -// `params` RaggedTensor input. -// params_dense_values: The `flat_values` for the `params` RaggedTensor. There was a terminology change -// at the python level from dense_values to flat_values, so dense_values is the -// deprecated name. -// indices: Indices in the outermost dimension of `params` of the values that should be -// gathered. -// OUTPUT_RAGGED_RANK: The ragged rank of the output RaggedTensor. `output_nested_splits` will contain -// this number of `row_splits` tensors. This value should equal -// `indices.shape.ndims + params.ragged_rank - 1`. // -// Returns The `nested_row_splits` tensors that define the row-partitioning for the -// returned RaggedTensor.The `flat_values` for the returned RaggedTensor. -func RaggedGather(scope *Scope, params_nested_splits []tf.Output, params_dense_values tf.Output, indices tf.Output, OUTPUT_RAGGED_RANK int64) (output_nested_splits []tf.Output, output_dense_values tf.Output) { +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"OUTPUT_RAGGED_RANK": OUTPUT_RAGGED_RANK} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RaggedGather", + Type: "ArgMax", Input: []tf.Input{ - tf.OutputList(params_nested_splits), params_dense_values, indices, + input, dimension, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output_nested_splits, idx, err = makeOutputList(op, idx, "output_nested_splits"); err != nil { - scope.UpdateErr("RaggedGather", err) - return - } - output_dense_values = op.Output(idx) - return output_nested_splits, output_dense_values + return op.Output(0) } -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. +// Extract `patches` from `images` and put them in the "depth" output dimension. // -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: +// Arguments: +// images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`. +// ksizes: The size of the sliding window for each dimension of `images`. +// strides: 1-D of length 4. How far the centers of two consecutive patches are in +// the images. Must be: `[1, stride_rows, stride_cols, 1]`. +// rates: 1-D of length 4. Must be: `[1, rate_rows, rate_cols, 1]`. This is the +// input stride, specifying how far two consecutive patch samples are in the +// input. Equivalent to extracting patches with +// `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by +// subsampling them spatially by a factor of `rates`. This is equivalent to +// `rate` in dilated (a.k.a. Atrous) convolutions. +// padding: The type of padding algorithm to use. // -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) +// We specify the size-related attributes as: // -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. +// ```python +// ksizes = [1, ksize_rows, ksize_cols, 1] +// strides = [1, strides_rows, strides_cols, 1] +// rates = [1, rates_rows, rates_cols, 1] +// ``` // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { +// Returns 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows * +// ksize_cols * depth]` containing image patches with size +// `ksize_rows x ksize_cols x depth` vectorized in the "depth" dimension. Note +// `out_rows` and `out_cols` are the dimensions of the output patches. +func ExtractImagePatches(scope *Scope, images tf.Output, ksizes []int64, strides []int64, rates []int64, padding string) (patches tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV2", + Type: "ExtractImagePatches", Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, + images, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts a `RaggedTensor` into a `SparseTensor` with the same values. +// Computes the mean along sparse segments of a tensor. // -// input=ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) -// output=SparseTensor(indices=sparse_indices, values=sparse_values, -// dense_shape=sparse_dense_shape) +// See `tf.sparse.segment_sum` for usage examples. +// +// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. // // Arguments: -// rt_nested_splits: The `row_splits` for the `RaggedTensor`. -// rt_dense_values: The `flat_values` for the `RaggedTensor`. // -// Returns The indices for the `SparseTensor`.The values of the `SparseTensor`.`sparse_dense_shape` is a tight bounding box of the input `RaggedTensor`. -func RaggedTensorToSparse(scope *Scope, rt_nested_splits []tf.Output, rt_dense_values tf.Output) (sparse_indices tf.Output, sparse_values tf.Output, sparse_dense_shape tf.Output) { +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RaggedTensorToSparse", + Type: "SparseSegmentMean", Input: []tf.Input{ - tf.OutputList(rt_nested_splits), rt_dense_values, + data, indices, segment_ids, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Check if the input matches the regex pattern. -// -// The input is a string tensor of any shape. The pattern is a scalar -// string tensor which is applied to every element of the input tensor. -// The boolean values (True or False) of the output tensor indicate -// if the input matches the regex pattern provided. +// Deserializes a serialized tree ensemble config and replaces current tree // -// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// ensemble. // // Arguments: -// input: A string tensor of the text to be processed. -// pattern: A scalar string tensor containing the regular expression to match the input. +// tree_ensemble_handle: Handle to the tree ensemble. +// stamp_token: Token to use as the new value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the ensemble. // -// Returns A bool tensor with the same shape as `input`. -func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) { +// Returns the created operation. +func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RegexFullMatch", + Type: "BoostedTreesDeserializeEnsemble", Input: []tf.Input{ - input, pattern, + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// Transforms a tf.Example proto (as a string) into typed tensors. // // Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed precision at `k` as a `bool Tensor`. -func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { +// serialized: A vector containing a batch of binary serialized Example protos. +// dense_defaults: A list of Tensors (some may be empty), whose length matches +// the length of `dense_keys`. dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// num_sparse: The number of sparse features to be parsed from the example. This +// must match the lengths of `sparse_keys` and `sparse_types`. +// sparse_keys: A list of `num_sparse` strings. +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: The keys expected in the Examples' features associated with dense +// values. +// sparse_types: A list of `num_sparse` types; the data types of data in each +// Feature given in sparse_keys. +// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: The shapes of data in each Feature given in dense_keys. +// The length of this list must match the length of `dense_keys`. The +// number of elements in the Feature corresponding to dense_key[j] must +// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == +// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] +// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, +// ..., DN), the shape of the output Tensor dense_values[j] will be (M, +// D1, .., DN), where M is the number of blocks of elements of length +// D1 * .... * DN, in the input. +func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} opspec := tf.OpSpec{ - Type: "InTopKV2", + Type: "ParseSingleExample", Input: []tf.Input{ - predictions, targets, k, + serialized, tf.OutputList(dense_defaults), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomShuffleAttr is an optional argument to RandomShuffle. -type RandomShuffleAttr func(optionalAttr) - -// RandomShuffleSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomShuffleSeed(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleSeed2(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Randomly shuffles a tensor along its first dimension. -// -// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped -// to one and only one `output[i]`. For example, a mapping that might occur for a -// 3x2 tensor is: -// -// ``` -// [[1, 2], [[5, 6], -// [3, 4], ==> [1, 2], -// [5, 6]] [3, 4]] -// ``` -// -// Arguments: -// value: The tensor to be shuffled. -// -// Returns A tensor of same shape and type as `value`, shuffled along its first -// dimension. -func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return } - opspec := tf.OpSpec{ - Type: "RandomShuffle", - Input: []tf.Input{ - value, - }, - Attrs: attrs, + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes sigmoid of `x` element-wise. -// -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleExample", err) return } - opspec := tf.OpSpec{ - Type: "Sigmoid", - Input: []tf.Input{ - x, - }, + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return sparse_indices, sparse_values, sparse_shapes, dense_values } -// FusedBatchNormAttr is an optional argument to FusedBatchNorm. -type FusedBatchNormAttr func(optionalAttr) - -// FusedBatchNormEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} +// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. +type WholeFileReaderV2Attr func(optionalAttr) -// FusedBatchNormDataFormat sets the optional data_format attribute to value. +// WholeFileReaderV2Container sets the optional container attribute to value. // -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["container"] = value } } -// FusedBatchNormIsTraining sets the optional is_training attribute to value. +// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { - m["is_training"] = value + m["shared_name"] = value } } -// Batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// A Reader that outputs the entire contents of a file as a value. // -// Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. +// To use, enqueue filenames in a Queue. The output of ReaderRead will +// be a filename (key) and the contents of that file (value). // -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { +// Returns The handle to reference the Reader. +func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } @@ -12276,216 +16912,193 @@ func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "FusedBatchNorm", - Input: []tf.Input{ - x, scale, offset, mean, variance, - }, + Type: "WholeFileReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } -// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. -type RandomStandardNormalAttr func(optionalAttr) +// LoadTPUEmbeddingADAMParametersAttr is an optional argument to LoadTPUEmbeddingADAMParameters. +type LoadTPUEmbeddingADAMParametersAttr func(optionalAttr) -// RandomStandardNormalSeed sets the optional seed attribute to value. +// LoadTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingADAMParametersTableId(value int64) LoadTPUEmbeddingADAMParametersAttr { return func(m optionalAttr) { - m["seed"] = value + m["table_id"] = value } } -// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { +// LoadTPUEmbeddingADAMParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingADAMParametersTableName(value string) LoadTPUEmbeddingADAMParametersAttr { return func(m optionalAttr) { - m["seed2"] = value + m["table_name"] = value } } -// Outputs random values from a normal distribution. +// Load ADAM embedding parameters. // -// The generated values will have mean 0 and standard deviation 1. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// parameters: Value of parameters used in the ADAM optimization algorithm. +// momenta: Value of momenta used in the ADAM optimization algorithm. +// velocities: Value of velocities used in the ADAM optimization algorithm. // -// Returns A tensor of the specified shape filled with random normal values. -func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { +// +// +// Returns the created operation. +func LoadTPUEmbeddingADAMParameters(scope *Scope, parameters tf.Output, momenta tf.Output, velocities tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingADAMParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomStandardNormal", + Type: "LoadTPUEmbeddingADAMParameters", Input: []tf.Input{ - shape, + parameters, momenta, velocities, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. -type FusedResizeAndPadConv2DAttr func(optionalAttr) +// InfeedEnqueuePrelinearizedBufferAttr is an optional argument to InfeedEnqueuePrelinearizedBuffer. +type InfeedEnqueuePrelinearizedBufferAttr func(optionalAttr) -// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. +// InfeedEnqueuePrelinearizedBufferDeviceOrdinal sets the optional device_ordinal attribute to value. // -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { +// value: The TPU device to use. This should be -1 when the Op is running on a TPU device +// and = 0 when the Op is running on the CPU device. +// If not specified, defaults to -1 +func InfeedEnqueuePrelinearizedBufferDeviceOrdinal(value int64) InfeedEnqueuePrelinearizedBufferAttr { return func(m optionalAttr) { - m["resize_align_corners"] = value + m["device_ordinal"] = value } } -// Performs a resize and padding as a preprocess during a convolution. -// -// It's often possible to do spatial transformations more efficiently as part of -// the packing stage of a convolution, so this op allows for an optimized -// implementation where these stages are fused together. This prevents the need to -// write out the intermediate results as whole tensors, reducing memory pressure, -// and we can get some latency gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and defaults to -// 'NHWC' order. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. +// An op which enqueues prelinearized buffer into TPU infeed. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. +// input: A variant tensor representing linearized output. // -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { +// Returns the created operation. +func InfeedEnqueuePrelinearizedBuffer(scope *Scope, input tf.Output, optional ...InfeedEnqueuePrelinearizedBufferAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FusedResizeAndPadConv2D", + Type: "InfeedEnqueuePrelinearizedBuffer", Input: []tf.Input{ - input, size, paddings, filter, + input, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RandomUniformAttr is an optional argument to RandomUniform. -type RandomUniformAttr func(optionalAttr) - -// RandomUniformSeed sets the optional seed attribute to value. +// Fetches multiple values from infeed as an XLA tuple. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformSeed(value int64) RandomUniformAttr { - return func(m optionalAttr) { - m["seed"] = value +// Arguments: +// dtypes: The element types of each element in `outputs`. +// shapes: The shapes of each tensor in `outputs`. +// +// Returns A list of tensors that will be provided using the infeed mechanism. +func InfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape) (outputs []tf.Output) { + if scope.Err() != nil { + return } -} + attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes} + opspec := tf.OpSpec{ + Type: "InfeedDequeueTuple", -// RandomUniformSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformSeed2(value int64) RandomUniformAttr { - return func(m optionalAttr) { - m["seed2"] = value + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("InfeedDequeueTuple", err) + return } + return outputs } -// Outputs random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// Enqueue multiple Tensor values on the computation outfeed. // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// inputs: A list of tensors that will be inserted into the outfeed queue as an +// XLA tuple. // -// Returns A tensor of the specified shape filled with uniform random values. -func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { +// Returns the created operation. +func OutfeedEnqueueTuple(scope *Scope, inputs []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomUniform", + Type: "OutfeedEnqueueTuple", Input: []tf.Input{ - shape, + tf.OutputList(inputs), }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. -type ResourceApplyFtrlAttr func(optionalAttr) +// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. +type ResourceApplyAdagradAttr func(optionalAttr) -// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. // // value: If `True`, updating of the var and accum tensors will be protected // by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { +func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the Ftrl-proximal scheme. +// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update '*var' according to the adagrad scheme. // -// accum_new = accum + grad * grad -// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) // // Arguments: // var_: Should be from a Variable(). // accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. // lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 regulariation. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// grad: The gradient. // // Returns the created operation. -func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { +func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -12494,252 +17107,122 @@ func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf. a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyFtrl", + Type: "ResourceApplyAdagrad", Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, lr_power, + var_, accum, lr, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Computes exponential of x element-wise. \\(y = e^x\\). -func Exp(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Exp", - Input: []tf.Input{ - x, - }, +// CudnnRNNV3Attr is an optional argument to CudnnRNNV3. +type CudnnRNNV3Attr func(optionalAttr) + +// CudnnRNNV3RnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNV3RnnMode(value string) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["rnn_mode"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// NthElementAttr is an optional argument to NthElement. -type NthElementAttr func(optionalAttr) - -// NthElementReverse sets the optional reverse attribute to value. -// -// value: When set to True, find the nth-largest value in the vector and vice -// versa. -// If not specified, defaults to false -func NthElementReverse(value bool) NthElementAttr { +// CudnnRNNV3InputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNV3InputMode(value string) CudnnRNNV3Attr { return func(m optionalAttr) { - m["reverse"] = value + m["input_mode"] = value } } -// Finds values of the `n`-th order statistic for the last dimension. -// -// If the input is a vector (rank-1), finds the entries which is the nth-smallest -// value in the vector and outputs their values as scalar tensor. -// -// For matrices (resp. higher rank input), computes the entries which is the -// nth-smallest value in each row (resp. vector along the last dimension). Thus, -// -// values.shape = input.shape[:-1] -// -// Arguments: -// input: 1-D or higher with last dimension at least `n+1`. -// n: 0-D. Position of sorted vector to select along the last dimension (along -// each row for matrices). Valid range of n is `[0, input.shape[:-1])` -// -// Returns The `n`-th order statistic along each last dimensional slice. -func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "NthElement", - Input: []tf.Input{ - input, n, - }, - Attrs: attrs, +// CudnnRNNV3Direction sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNV3Direction(value string) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["direction"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the maximum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the maximum such that: -// -// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such -// that `segment_ids[j...] == i`. -// -// If the maximum is empty for a given segment ID `i`, it outputs the smallest -// possible value for the specific numeric type, -// `output[i] = numeric_limits::lowest()`. -// -// If the given segment ID `i` is negative, then the corresponding value is -// dropped, and will not be included in the result. -// -//
-// -//
-// -// Arguments: -// -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// -// -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "UnsortedSegmentMax", - Input: []tf.Input{ - data, segment_ids, num_segments, - }, +// CudnnRNNV3Dropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNV3Dropout(value float32) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["dropout"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Transforms a vector of brain.Example protos (as strings) into typed tensors. -// -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// names: A vector containing the names of the serialized protos. -// May contain, for example, table key (descriptive) names for the -// corresponding serialized protos. These are purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty vector if no names are available. -// If non-empty, this vector must be the same length as "serialized". -// sparse_keys: A list of Nsparse string Tensors (scalars). -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: A list of Ndense string Tensors (scalars). -// The keys expected in the Examples' features associated with dense values. -// dense_defaults: A list of Ndense Tensors (some may be empty). -// dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// sparse_types: A list of Nsparse types; the data types of data in each Feature -// given in sparse_keys. -// Currently the ParseExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature -// given in dense_keys. -// The number of elements in the Feature corresponding to dense_key[j] -// must always equal dense_shapes[j].NumEntries(). -// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output -// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): -// The dense outputs are just the inputs row-stacked by batch. -// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case -// the shape of the output Tensor dense_values[j] will be -// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks -// of elements of length D1 * .... * DN, across all minibatch entries -// in the input. Any minibatch entry with less than M blocks of elements of -// length D1 * ... * DN will be padded with the corresponding default_value -// scalar element along the second dimension. -func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} - opspec := tf.OpSpec{ - Type: "ParseExample", - Input: []tf.Input{ - serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseExample", err) - return +// CudnnRNNV3Seed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNV3Seed(value int64) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["seed"] = value } - return sparse_indices, sparse_values, sparse_shapes, dense_values } -// Compute the pairwise cross product. -// -// `a` and `b` must be the same shape; they can either be simple 3-element vectors, -// or any shape where the innermost dimension is 3. In the latter case, each pair -// of corresponding 3-element vectors is cross-multiplied independently. -// -// Arguments: -// a: A tensor containing 3-element vectors. -// b: Another tensor, of same type and shape as `a`. -// -// Returns Pairwise cross product of the vectors in `a` and `b`. -func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cross", - Input: []tf.Input{ - a, b, - }, +// CudnnRNNV3Seed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNV3Seed2(value int64) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// StatefulStandardNormalAttr is an optional argument to StatefulStandardNormal. -type StatefulStandardNormalAttr func(optionalAttr) +// CudnnRNNV3IsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNV3IsTraining(value bool) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["is_training"] = value + } +} -// StatefulStandardNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatefulStandardNormalDtype(value tf.DataType) StatefulStandardNormalAttr { +// CudnnRNNV3TimeMajor sets the optional time_major attribute to value. +// If not specified, defaults to true +func CudnnRNNV3TimeMajor(value bool) CudnnRNNV3Attr { return func(m optionalAttr) { - m["dtype"] = value + m["time_major"] = value } } -// Outputs random values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. +// A RNN backed by cuDNN. // -// Arguments: -// resource: The handle of the resource variable that stores the state of the RNG. -// shape: The shape of the output tensor. +// Computes the RNN from the input and initial states, with respect to the params +// buffer. Accepts one extra input "sequence_lengths" than CudnnRNN. // -// Returns A tensor of the specified shape filled with random normal values. -func StatefulStandardNormal(scope *Scope, resource tf.Output, shape tf.Output, optional ...StatefulStandardNormalAttr) (output tf.Output) { +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicates whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: If time_major is true, this is a 3-D tensor with the shape of +// [seq_length, batch_size, input_size]. If time_major is false, the shape is +// [batch_size, seq_length, input_size]. +// input_h: If time_major is true, this is a 3-D tensor with the shape of +// [num_layer * dir, batch_size, num_units]. If time_major is false, the shape +// is [batch_size, num_layer * dir, num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// sequence_lengths: a vector of lengths of each input sequence. +// output: If time_major is true, this is a 3-D tensor with the shape of +// [seq_length, batch_size, dir * num_units]. If time_major is false, the +// shape is [batch_size, seq_length, dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// is_training: Indicates whether this operation is used for inferenece or +// training. +// time_major: Indicates whether the input/output format is time major or batch +// major. +// reserve_space: An opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is true. +func CudnnRNNV3(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, sequence_lengths tf.Output, optional ...CudnnRNNV3Attr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output, host_reserved tf.Output) { if scope.Err() != nil { return } @@ -12748,94 +17231,72 @@ func StatefulStandardNormal(scope *Scope, resource tf.Output, shape tf.Output, o a(attrs) } opspec := tf.OpSpec{ - Type: "StatefulStandardNormal", + Type: "CudnnRNNV3", Input: []tf.Input{ - resource, shape, + input, input_h, input_c, params, sequence_lengths, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Locks a mutex resource. The output is the lock. So long as the lock tensor -// -// is alive, any other request to use `MutexLock` with this mutex will wait. -// -// This is particularly useful for creating a critical section when used in -// conjunction with `MutexLockIdentity`: -// -// ```python -// -// mutex = mutex_v2( -// shared_name=handle_name, container=container, name=name) -// -// def execute_in_critical_section(fn, *args, **kwargs): -// lock = gen_resource_variable_ops.mutex_lock(mutex) -// -// with ops.control_dependencies([lock]): -// r = fn(*args, **kwargs) -// -// with ops.control_dependencies(nest.flatten(r)): -// with ops.colocate_with(mutex): -// ensure_lock_exists = mutex_lock_identity(lock) -// -// # Make sure that if any element of r is accessed, all of -// # them are executed together. -// r = nest.map_structure(tf.identity, r) +// Applies softmax to a batched N-D `SparseTensor`. // -// with ops.control_dependencies([ensure_lock_exists]): -// return nest.map_structure(tf.identity, r) -// ``` +// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` +// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. // -// While `fn` is running in the critical section, no other functions which wish to -// use this critical section may run. +// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost +// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly +// zero elements do not participate*. Specifically, the algorithm is equivalent +// to the following: // -// Often the use case is that two executions of the same graph, in parallel, -// wish to run `fn`; and we wish to ensure that only one of them executes -// at a time. This is especially important if `fn` modifies one or more -// variables at a time. +// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix +// with shape `[B, C]`, along the size-C dimension; +// (2) Masks out the original implicitly-zero locations; +// (3) Renormalizes the remaining elements. // -// It is also useful if two separate functions must share a resource, but we -// wish to ensure the usage is exclusive. +// Hence, the `SparseTensor` result has exactly the same non-zero indices and +// shape. // // Arguments: -// mutex: The mutex resource to lock. +// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a +// SparseTensor, in canonical ordering. +// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. // -// Returns A tensor that keeps a shared pointer to a lock on the mutex; -// when the Tensor is destroyed, the use count on the shared pointer is decreased -// by 1. When it reaches 0, the lock is released. -func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { +// Returns 1-D. The `NNZ` values for the result `SparseTensor`. +func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MutexLock", + Type: "SparseSoftmax", Input: []tf.Input{ - mutex, + sp_indices, sp_values, sp_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// Creates a Tensor by indexing into the TensorList. // -// Arguments: -// serialized: A scalar string containing a serialized TensorProto proto. -// out_type: The type of the serialized tensor. The provided type must match the -// type of the serialized tensor and no implicit conversion will take place. +// Each row in the produced Tensor corresponds to the element in the TensorList +// specified by the given index (see `tf.gather`). // -// Returns A Tensor of type `out_type`. -func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { +// input_handle: The input tensor list. +// indices: The indices used to index into the list. +// values: The tensor. +func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_shape tf.Output, element_dtype tf.DataType) (values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "ParseTensor", + Type: "TensorListGather", Input: []tf.Input{ - serialized, + input_handle, indices, element_shape, }, Attrs: attrs, } @@ -12843,477 +17304,473 @@ func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (outp return op.Output(0) } -// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. -type MaxPoolWithArgmaxAttr func(optionalAttr) +// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. +type FixedLengthRecordReaderV2Attr func(optionalAttr) -// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. -// If not specified, defaults to DT_INT64 -func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { +// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. +// +// value: Number of bytes in the header, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["Targmax"] = value + m["header_bytes"] = value } } -// Performs max pooling on the input and outputs both max values and indices. -// -// The indices in `argmax` are flattened, so that a maximum value at position -// `[b, y, x, c]` becomes flattened index -// `((b * height + y) * width + x) * channels + c`. -// -// The indices returned are always in `[0, height) x [0, width)` before flattening, -// even if padding is involved and the mathematically correct answer is outside -// (either negative or too large). This is a bug, but fixing it is difficult to do -// in a safe backwards compatible way, especially due to flattening. -// -// Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. // -// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. -func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolWithArgmax", - Input: []tf.Input{ - input, - }, - Attrs: attrs, +// value: Number of bytes in the footer, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["footer_bytes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) } -// Creates a TensorList which, when stacked, has the value of `tensor`. -// -// Each tensor in the result list corresponds to one row of the input tensor. +// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. // -// tensor: The input tensor. -// output_handle: The list. -func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorListFromTensor", - Input: []tf.Input{ - tensor, element_shape, - }, +// value: Number of bytes to hop before each read. Default of 0 means using +// record_bytes. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["hop_bytes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Assigns sparse updates to the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] = updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] = updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] -// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// FixedLengthRecordReaderV2Container sets the optional container attribute to value. // -// Returns the created operation. -func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterUpdate", - Input: []tf.Input{ - resource, indices, updates, - }, +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value } - return scope.AddOperation(opspec) } -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) +// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} -// MaxPoolDataFormat sets the optional data_format attribute to value. +// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { +// value: The type of encoding for the file. Currently ZLIB and GZIP +// are supported. Defaults to none. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["encoding"] = value } } -// Performs max pooling on the input. +// A Reader that outputs fixed-length records from a file. // // Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// record_bytes: Number of bytes in the record. // -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { +// Returns The handle to reference the Reader. +func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"record_bytes": record_bytes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool", - Input: []tf.Input{ - input, - }, + Type: "FixedLengthRecordReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Multiplies sparse updates into the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] *= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] *= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. +// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. +type ResourceApplyPowerSignAttr func(optionalAttr) + +// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. // -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AddSign update. // -//
-// -//
+// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g +// variable <- variable - lr_t * update // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// logbase: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. // // Returns the created operation. -func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceScatterMul", + Type: "ResourceApplyPowerSign", Input: []tf.Input{ - resource, indices, updates, + var_, m, lr, logbase, sign_decay, beta, grad, }, + Attrs: attrs, } return scope.AddOperation(opspec) } -// Subtracts sparse updates from the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] -= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] -= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions add. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// Connects outputs of an N-way replicated computation to N outputs. +func TPUReplicatedOutput(scope *Scope, input tf.Output, num_replicas int64) (outputs []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_replicas": num_replicas} opspec := tf.OpSpec{ - Type: "ResourceScatterSub", + Type: "TPUReplicatedOutput", Input: []tf.Input{ - resource, indices, updates, + input, }, + Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Adds sparse updates to the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] += updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] += updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions add. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + op := scope.AddOperation(opspec) if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "ResourceScatterAdd", - Input: []tf.Input{ - resource, indices, updates, - }, + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("TPUReplicatedOutput", err) + return } - return scope.AddOperation(opspec) + return outputs } -// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. -type ResourceSparseApplyFtrlV2Attr func(optionalAttr) +// LoadTPUEmbeddingFTRLParametersAttr is an optional argument to LoadTPUEmbeddingFTRLParameters. +type LoadTPUEmbeddingFTRLParametersAttr func(optionalAttr) -// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// LoadTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingFTRLParametersTableId(value int64) LoadTPUEmbeddingFTRLParametersAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["table_id"] = value } } -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// LoadTPUEmbeddingFTRLParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingFTRLParametersTableName(value string) LoadTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load FTRL embedding parameters. // -// That is for rows we have grad for, we update var, accum and linear as follows: -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. +// parameters: Value of parameters used in the FTRL optimization algorithm. +// accumulators: Value of accumulators used in the FTRL optimization algorithm. +// linears: Value of linears used in the FTRL optimization algorithm. +// // -// lr_power: Scaling factor. Must be a scalar. // // Returns the created operation. -func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { +func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, linears tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingFTRLParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrlV2", + Type: "LoadTPUEmbeddingFTRLParameters", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, + parameters, accumulators, linears, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Calculates gains for each feature and returns the best possible split information for the feature. -// -// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. -// -// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. -// -// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). -// -// The length of output lists are all of the same length, `num_features`. -// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. -// -// Arguments: -// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). -// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. -// l1: l1 regularization factor on leaf weights, per instance based. -// l2: l2 regularization factor on leaf weights, per instance based. -// tree_complexity: adjustment to the gain, per leaf based. -// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. -// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. +// Returns (x - y)(x - y) element-wise. // -// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. -func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) { +// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"max_splits": max_splits} opspec := tf.OpSpec{ - Type: "BoostedTreesCalculateBestGainsPerFeature", + Type: "SquaredDifference", Input: []tf.Input{ - node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Push an element onto the tensor_array. +// +// Arguments: +// handle: The handle to a TensorArray. +// index: The position to write to inside the TensorArray. +// value: The tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return + opspec := tf.OpSpec{ + Type: "TensorArrayWriteV3", + Input: []tf.Input{ + handle, index, value, flow_in, + }, } - return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodePngAttr is an optional argument to EncodePng. -type EncodePngAttr func(optionalAttr) +// RetrieveTPUEmbeddingAdagradParametersAttr is an optional argument to RetrieveTPUEmbeddingAdagradParameters. +type RetrieveTPUEmbeddingAdagradParametersAttr func(optionalAttr) -// EncodePngCompression sets the optional compression attribute to value. -// -// value: Compression level. +// RetrieveTPUEmbeddingAdagradParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -func EncodePngCompression(value int64) EncodePngAttr { +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdagradParametersTableId(value int64) RetrieveTPUEmbeddingAdagradParametersAttr { return func(m optionalAttr) { - m["compression"] = value + m["table_id"] = value } } -// PNG-encode an image. +// RetrieveTPUEmbeddingAdagradParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingAdagradParametersTableName(value string) RetrieveTPUEmbeddingAdagradParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve Adagrad embedding parameters. // -// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` -// where `channels` is: +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// * 1: for grayscale. -// * 2: for grayscale + alpha. -// * 3: for RGB. -// * 4: for RGBA. +// Returns Parameter parameters updated by the Adagrad optimization algorithm.Parameter accumulators updated by the Adagrad optimization algorithm. +func RetrieveTPUEmbeddingAdagradParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdagradParametersAttr) (parameters tf.Output, accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingAdagradParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. // -// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder -// default or a value from 0 to 9. 9 is the highest compression level, generating -// the smallest output, but is slower. +// Each comparison returns a boolean `true` (if `input_value > threshold`) +// or and `false` otherwise. +// +// This operation is useful for Locality-Sensitive-Hashing (LSH) and other +// algorithms that use hashing approximations of cosine and `L2` distances; +// codes can be generated from an input via: +// +// ```python +// codebook_size = 50 +// codebook_bits = codebook_size * 32 +// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], +// dtype=x.dtype, +// initializer=tf.orthogonal_initializer()) +// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +// # now codes has shape x.shape[:-1] + [codebook_size] +// ``` +// +// **NOTE**: Currently, the innermost dimension of the tensor must be divisible +// by 8. +// +// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. // // Arguments: -// image: 3-D with shape `[height, width, channels]`. +// input: Values to compare against `threshold` and bitpack. +// threshold: Threshold to compare against. // -// Returns 0-D. PNG-encoded image. -func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { +// Returns The bitpacked comparisons. +func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "EncodePng", + Type: "CompareAndBitpack", Input: []tf.Input{ - image, + input, threshold, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. -type DataFormatVecPermuteAttr func(optionalAttr) +// QuantizeAndDequantizeV2Attr is an optional argument to QuantizeAndDequantizeV2. +type QuantizeAndDequantizeV2Attr func(optionalAttr) -// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. +// QuantizeAndDequantizeV2SignedInput sets the optional signed_input attribute to value. // -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { +// value: Whether the quantization is signed or unsigned. (actually this parameter should +// have been called `signed_output`) +// If not specified, defaults to true +func QuantizeAndDequantizeV2SignedInput(value bool) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { - m["src_format"] = value + m["signed_input"] = value } } -// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. +// QuantizeAndDequantizeV2NumBits sets the optional num_bits attribute to value. // -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { +// value: The bitwidth of the quantization. +// If not specified, defaults to 8 +func QuantizeAndDequantizeV2NumBits(value int64) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { - m["dst_format"] = value + m["num_bits"] = value } } -// Returns the permuted vector/tensor in the destination data format given the +// QuantizeAndDequantizeV2RangeGiven sets the optional range_given attribute to value. // -// one in the source data format. +// value: Whether the range is given or should be determined from the `input` tensor. +// If not specified, defaults to false +func QuantizeAndDequantizeV2RangeGiven(value bool) QuantizeAndDequantizeV2Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// QuantizeAndDequantizeV2RoundMode sets the optional round_mode attribute to value. // -// Arguments: -// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. +// value: The 'round_mode' attribute controls which rounding tie-breaking algorithm is +// used when rounding float values to their quantized equivalents. The following +// rounding modes are currently supported: // -// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. -func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { +// * HALF_TO_EVEN: this is the default round_mode. +// * HALF_UP: round towards positive. In this mode 7.5 rounds up to 8 and -7.5 +// rounds up to -7. +// +// If not specified, defaults to "HALF_TO_EVEN" +func QuantizeAndDequantizeV2RoundMode(value string) QuantizeAndDequantizeV2Attr { + return func(m optionalAttr) { + m["round_mode"] = value + } +} + +// Quantizes then dequantizes a tensor. +// +// This op simulates the precision loss from the quantized forward pass by: +// +// 1. Quantizing the tensor to fixed point numbers, which should match the target +// quantization method when it is used in inference. +// 2. Dequantizing it back to floating point numbers for the following ops, most +// likely matmul. +// +// There are different ways to quantize. This version uses only scaling, so 0.0 +// maps to 0. +// +// From the specified 'num_bits' in the quantized output type, it determines +// minimum and maximum representable quantized values. +// +// e.g. +// +// * [-128, 127] for signed, num_bits = 8, or +// * [0, 255] for unsigned, num_bits = 8. +// +// If range_given == False, the initial input_min, input_max will be determined +// automatically as the minimum and maximum values in the input tensor, otherwise +// the specified values of input_min, input_max are used. +// +// Note: If the input_min, input_max are specified, they do not need to equal the +// actual minimum and maximum values in the tensor. e.g. in some cases it may be +// beneficial to specify these values such that the low probability extremes of the +// input distribution are clipped. +// +// This op determines the maximum scale_factor that would map the initial +// [input_min, input_max] range to a range that lies within the representable +// quantized range. +// +// It determines the scale from one of input_min and input_max, then updates the +// other one to maximize the respresentable range. +// +// e.g. +// +// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, +// 5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it +// would update input_max to be 127 / 12.8 = 9.921875 +// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, +// 10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it +// would update input_min to be 128.0 / 12.7 = -10.07874 +// * if the output is unsigned, input_min is forced to be 0, and only the +// specified input_max is used. +// +// After determining the scale_factor and updating the input range, it applies the +// following to each value in the 'input' tensor. +// +// output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor. +// +// The above round function rounds the value based on the given round_mode. +// +// +// Arguments: +// input: Tensor to quantize and then dequantize. +// input_min: If `range_given == True`, this specifies the minimum input value that needs to +// be represented, otherwise it is determined from the min value of the `input` +// tensor. +// input_max: If `range_given == True`, this specifies the maximum input value that needs to +// be represented, otherwise it is determined from the max value of the `input` +// tensor. +func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, optional ...QuantizeAndDequantizeV2Attr) (output tf.Output) { if scope.Err() != nil { return } @@ -13322,9 +17779,9 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe a(attrs) } opspec := tf.OpSpec{ - Type: "DataFormatVecPermute", + Type: "QuantizeAndDequantizeV2", Input: []tf.Input{ - x, + input, input_min, input_max, }, Attrs: attrs, } @@ -13332,235 +17789,327 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe return op.Output(0) } -// Initializes the multi device iterator with the given dataset. +// A TPU core selector Op. // -// Arguments: -// dataset: Dataset to be iterated upon. -// multi_device_iterator: A MultiDeviceIteratorResource. -// max_buffer_size: The maximum size of the host side per device buffer to keep. +// This Op produces a set of TPU cores (for warm-up) or a single TPU core +// (for regular inference) to execute the TPU program on. The output is +// consumed by TPUPartitionedCall. // -// Returns An int64 indicating which incarnation of the MultiDeviceIterator -// is running. -func MultiDeviceIteratorInit(scope *Scope, dataset tf.Output, multi_device_iterator tf.Output, max_buffer_size tf.Output) (incarnation_id tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MultiDeviceIteratorInit", - Input: []tf.Input{ - dataset, multi_device_iterator, max_buffer_size, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient of `igamma(a, x)` wrt `a`. -func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { +// Returns A vector 1 or more TPU cores. +func TPUOrdinalSelector(scope *Scope) (device_ordinals tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IgammaGradA", - Input: []tf.Input{ - a, x, - }, + Type: "TPUOrdinalSelector", } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. +// Looks up keys in a table, outputs the corresponding values. // -// The hash function is deterministic on the content of the string within the -// process. +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. // -// Note that the hash function may change from time to time. -// This functionality will be deprecated and it's recommended to use -// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. // // Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. // -// num_buckets: The number of buckets. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "StringToHashBucket", + Type: "LookupTableFindV2", Input: []tf.Input{ - string_tensor, + table_handle, keys, default_value, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace. -type StaticRegexReplaceAttr func(optionalAttr) +// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. +type ResourceSparseApplyRMSPropAttr func(optionalAttr) -// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value. +// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If True, the replacement is global, otherwise the replacement -// is done only on the first match. -// If not specified, defaults to true -func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr { +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { return func(m optionalAttr) { - m["replace_global"] = value + m["use_locking"] = value } } -// Replaces the match of pattern in input with rewrite. +// Update '*var' according to the RMSProp algorithm. // -// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // // Arguments: -// input: The text to be processed. -// pattern: The regular expression to match the input. -// rewrite: The rewrite to be applied to the matched expression. +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. // -// Returns The text after applying pattern and rewrite. -func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) { +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. +// +// Returns the created operation. +func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StaticRegexReplace", + Type: "ResourceSparseApplyRMSProp", Input: []tf.Input{ - input, + var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Computes gradients for the exponential linear (Elu) operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Elu operation. -// outputs: The outputs of the corresponding Elu operation. +// Returns the truth value of (x > y) element-wise. // -// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, -// `gradients` otherwise. -func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// *NOTE*: `Greater` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "EluGrad", + Type: "Greater", Input: []tf.Input{ - gradients, outputs, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that contains `count` elements from the `input_dataset`. -// -// Arguments: -// -// count: A scalar representing the number of elements from the `input_dataset` -// that should be taken. A value of `-1` indicates that all of `input_dataset` -// is taken. +// Creates a TensorList by indexing into a Tensor. // +// Each member of the TensorList corresponds to one row of the input tensor, +// specified by the given index (see `tf.gather`). // -func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// tensor: The input tensor. +// indices: The indices used to index into the list. +// element_shape: The shape of the elements in the list (can be less specified than +// the shape of the tensor). +// num_elements: The size of the output list. Must be large enough to accommodate +// the largest index in indices. If -1, the list is just large enough to include +// the largest index in indices. +// output_handle: The TensorList. +func TensorListScatterV2(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output, num_elements tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TakeDataset", + Type: "TensorListScatterV2", Input: []tf.Input{ - input_dataset, count, + tensor, indices, element_shape, num_elements, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reads the value of a variable. +// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. +type SampleDistortedBoundingBoxAttr func(optionalAttr) + +// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. // -// The tensor returned by this operation is immutable. +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. // -// The value returned by this operation is guaranteed to be influenced by all the -// writes on which this operation depends directly or indirectly, and to not be -// influenced by any of the writes which depend directly or indirectly on this -// operation. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. // -// Arguments: -// resource: handle to the resource in which to store the variable. -// dtype: the dtype of the value. -func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { - if scope.Err() != nil { - return +// value: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. +// If not specified, defaults to 0.1 +func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["min_object_covered"] = value } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "ReadVariableOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, +} + +// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. +// +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["aspect_ratio_range"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// This op consumes a lock created by `MutexLock`. +// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. // -// This op exists to consume a tensor created by `MutexLock` (other than -// direct control dependencies). It should be the only that consumes the tensor, -// and will raise an error if it is not. Its only purpose is to keep the -// mutex lock tensor alive until it is consumed by this op. +// value: The cropped area of the image must contain a fraction of the +// supplied image within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["area_range"] = value + } +} + +// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. // -// **NOTE**: This operation must run on the same device as its input. This may -// be enforced via the `colocate_with` mechanism. +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["max_attempts"] = value + } +} + +// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value + } +} + +// Generate a single randomly distorted bounding box for an image. +// +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. +// +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. +// +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, +// +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) +// +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) +// +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` +// +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. // // Arguments: -// mutex_lock: A tensor returned by `MutexLock`. +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. // -// Returns the created operation. -func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) { +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ConsumeMutexLock", + Type: "SampleDistortedBoundingBox", Input: []tf.Input{ - mutex_lock, + image_size, bounding_boxes, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd. -type ResourceScatterNdAddAttr func(optionalAttr) +// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. +type ResourceScatterNdUpdateAttr func(optionalAttr) -// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value. +// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. // // value: An optional bool. Defaults to True. If True, the assignment will // be protected by a lock; otherwise the behavior is undefined, // but may exhibit less contention. // If not specified, defaults to true -func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { +func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Applies sparse addition to individual values or slices in a Variable. +// Applies sparse `updates` to individual values or slices within a given +// +// variable according to `indices`. // // `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. // @@ -13574,24 +18123,24 @@ func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { // `updates` is `Tensor` of rank `Q-1+P-K` with shape: // // ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. // ``` // -// For example, say we want to add 4 scattered elements to a rank-1 tensor to -// 8 elements. In Python, that addition would look like this: +// For example, say we want to update 4 scattered elements to a rank-1 tensor to +// 8 elements. In Python, that update would look like this: // // ```python -// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// add = tf.scatter_nd_add(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(add) +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1] ,[7]]) +// updates = tf.constant([9, 10, 11, 12]) +// update = tf.scatter_nd_update(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(update) // ``` // // The resulting update to ref would look like this: // -// [1, 13, 3, 14, 14, 6, 7, 20] +// [1, 11, 3, 10, 9, 6, 7, 12] // // See `tf.scatter_nd` for more details about how to make updates to // slices. @@ -13600,11 +18149,11 @@ func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { // ref: A resource handle. Must be from a VarHandleOp. // indices: A Tensor. Must be one of the following types: int32, int64. // A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of +// updates: A Tensor. Must have the same type as ref. A tensor of updated // values to add to ref. // // Returns the created operation. -func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) { +func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -13613,7 +18162,7 @@ func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, update a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceScatterNdAdd", + Type: "ResourceScatterNdUpdate", Input: []tf.Input{ ref, indices, updates, }, @@ -13622,90 +18171,183 @@ func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, update return scope.AddOperation(opspec) } -// Updates the tree ensemble by either adding a layer to the last tree being grown +// UnicodeDecodeWithOffsetsAttr is an optional argument to UnicodeDecodeWithOffsets. +type UnicodeDecodeWithOffsetsAttr func(optionalAttr) + +// UnicodeDecodeWithOffsetsErrors sets the optional errors attribute to value. // -// or by starting a new tree. +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeDecodeWithOffsetsErrors(value string) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeDecodeWithOffsetsReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) +// If not specified, defaults to 65533 +func UnicodeDecodeWithOffsetsReplacementChar(value int64) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeDecodeWithOffsetsReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeDecodeWithOffsetsReplaceControlCharacters(value bool) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// Decodes each string in `input` into a sequence of Unicode code points. +// +// The character codepoints for all strings are returned using a single vector +// `char_values`, with strings expanded to characters in row-major order. +// Similarly, the character start byte offsets are returned using a single vector +// `char_to_byte_starts`, with strings expanded in row-major order. +// +// The `row_splits` tensor indicates where the codepoints and start offsets for +// each input string begin and end within the `char_values` and +// `char_to_byte_starts` tensors. In particular, the values for the `i`th +// string (in row-major order) are stored in the slice +// `[row_splits[i]:row_splits[i+1]]`. Thus: +// +// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th +// character in the `i`th string (in row-major order). +// * `char_to_bytes_starts[row_splits[i]+j]` is the start byte offset for the `j`th +// character in the `i`th string (in row-major order). +// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th +// string (in row-major order). // // Arguments: -// tree_ensemble_handle: Handle to the ensemble variable. -// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of -// the feature that will be used in the split. -// node_ids: List of rank 1 tensors representing the nodes for which this feature -// has a split. -// gains: List of rank 1 tensors representing the gains for each of the feature's -// split. -// thresholds: List of rank 1 tensors representing the thesholds for each of the -// feature's split. -// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of -// the feature's splits. Will be added to the previous node values to constitute -// the values of the left nodes. -// right_node_contribs: List of rank 2 tensors with right leaf contribs for each -// of the feature's splits. Will be added to the previous node values to constitute -// the values of the right nodes. -// max_depth: Max depth of the tree to build. -// learning_rate: shrinkage const for each new tree. -// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. +// input: The text to be decoded. Can have any shape. Note that the output is flattened +// to a vector of char values. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. // -// Returns the created operation. -func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { +// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints.A 1D int32 Tensor containing the byte index in the input string where each +// character in `char_values` starts. +func UnicodeDecodeWithOffsets(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeWithOffsetsAttr) (row_splits tf.Output, char_values tf.Output, char_to_byte_starts tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pruning_mode": pruning_mode} + attrs := map[string]interface{}{"input_encoding": input_encoding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesUpdateEnsemble", + Type: "UnicodeDecodeWithOffsets", Input: []tf.Input{ - tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, + input, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes tan of x element-wise. -func Tan(scope *Scope, x tf.Output) (y tf.Output) { +// Returns x - y element-wise. +// +// *NOTE*: `Subtract` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Tan", + Type: "Sub", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Bucketizes 'input' based on 'boundaries'. +// LRNAttr is an optional argument to LRN. +type LRNAttr func(optionalAttr) + +// LRNDepthRadius sets the optional depth_radius attribute to value. // -// For example, if the inputs are -// boundaries = [0, 10, 100] -// input = [[-5, 10000] -// [150, 10] -// [5, 100]] +// value: 0-D. Half-width of the 1-D normalization window. +// If not specified, defaults to 5 +func LRNDepthRadius(value int64) LRNAttr { + return func(m optionalAttr) { + m["depth_radius"] = value + } +} + +// LRNBias sets the optional bias attribute to value. // -// then the output will be -// output = [[0, 3] -// [3, 2] -// [1, 3]] +// value: An offset (usually positive to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNBias(value float32) LRNAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNAlpha sets the optional alpha attribute to value. // -// Arguments: -// input: Any shape of Tensor contains with int or float type. -// boundaries: A sorted list of floats gives the boundary of the buckets. +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNAlpha(value float32) LRNAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNBeta sets the optional beta attribute to value. // -// Returns Same shape with 'input', each value of input replaced with bucket index. +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNBeta(value float32) LRNAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Local Response Normalization. // -// @compatibility(numpy) -// Equivalent to np.digitize. -// @end_compatibility -func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { +// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last +// dimension), and each vector is normalized independently. Within a given vector, +// each component is divided by the weighted, squared sum of inputs within +// `depth_radius`. In detail, +// +// sqr_sum[a, b, c, d] = +// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) +// output = input / (bias + alpha * sqr_sum) ** beta +// +// For details, see [Krizhevsky et al., ImageNet classification with deep +// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// +// Arguments: +// input: 4-D. +func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"boundaries": boundaries} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Bucketize", + Type: "LRN", Input: []tf.Input{ input, }, @@ -13715,33 +18357,90 @@ func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.O return op.Output(0) } -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug. +type RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["dtype"] = value + m["table_id"] = value } } -// Outputs deterministic pseudorandom values from a truncated normal distribution. +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve proximal Adagrad embedding parameters with debug support. // -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// The outputs are a deterministic function of `shape` and `seed`. +// Returns Parameter parameters updated by the proximal Adagrad optimization algorithm.Parameter accumulators updated by the proximal Adagrad optimization algorithm.Parameter gradient_accumulators updated by the proximal Adagrad optimization algorithm. +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. +type ResourceSparseApplyAdagradAttr func(optionalAttr) + +// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { +// Returns the created operation. +func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -13750,89 +18449,81 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", + Type: "ResourceSparseApplyAdagrad", Input: []tf.Input{ - shape, seed, + var_, accum, lr, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RestoreSliceAttr is an optional argument to RestoreSlice. -type RestoreSliceAttr func(optionalAttr) +// LoadTPUEmbeddingMomentumParametersAttr is an optional argument to LoadTPUEmbeddingMomentumParameters. +type LoadTPUEmbeddingMomentumParametersAttr func(optionalAttr) -// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. -// -// value: Index of file to open first if multiple files match -// `file_pattern`. See the documentation for `Restore`. +// LoadTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingMomentumParametersTableId(value int64) LoadTPUEmbeddingMomentumParametersAttr { return func(m optionalAttr) { - m["preferred_shard"] = value + m["table_id"] = value } } -// Restores a tensor from checkpoint files. -// -// This is like `Restore` except that restored tensor can be listed as filling -// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the -// larger tensor and the slice that the restored tensor covers. +// LoadTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMomentumParametersTableName(value string) LoadTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Momentum embedding parameters. // -// The `shape_and_slice` input has the same format as the -// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// shape_and_slice: Scalar. The shapes and slice specifications to use when -// restoring a tensors. -// dt: The type of the tensor to be restored. +// parameters: Value of parameters used in the Momentum optimization algorithm. +// momenta: Value of momenta used in the Momentum optimization algorithm. // -// Returns The restored tensor. -func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { +// +// +// Returns the created operation. +func LoadTPUEmbeddingMomentumParameters(scope *Scope, parameters tf.Output, momenta tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMomentumParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dt": dt} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RestoreSlice", + Type: "LoadTPUEmbeddingMomentumParameters", Input: []tf.Input{ - file_pattern, tensor_name, shape_and_slice, + parameters, momenta, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Divides sparse updates into the variable referenced by `resource`. +// Assigns sparse updates to the variable referenced by `resource`. // // This operation computes // // # Scalar indices -// ref[indices, ...] /= updates[...] +// ref[indices, ...] = updates[...] // // # Vector indices (for each i) -// ref[indices[i], ...] /= updates[i, ...] +// ref[indices[i], ...] = updates[i, ...] // // # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
+// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] // // Arguments: // resource: Should be from a `Variable` node. @@ -13840,12 +18531,12 @@ func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, s // updates: A tensor of updated values to add to `ref`. // // Returns the created operation. -func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterDiv", + Type: "ResourceScatterUpdate", Input: []tf.Input{ resource, indices, updates, }, @@ -13853,253 +18544,123 @@ func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. -type StatelessRandomNormalAttr func(optionalAttr) - -// StatelessRandomNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// The outputs are a deterministic function of `shape` and `seed`. +// Elementwise computes the bitwise right-shift of `x` and `y`. // -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// Performs a logical shift for unsigned integer types, and an arithmetic shift +// for signed integer types. // -// Returns Random values with specified shape. -func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { +// If `y` is negative, or greater than or equal to than the width of `x` in bits +// the result is implementation defined. +func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StatelessRandomNormal", + Type: "RightShift", Input: []tf.Input{ - shape, seed, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnicodeDecodeAttr is an optional argument to UnicodeDecode. -type UnicodeDecodeAttr func(optionalAttr) - -// UnicodeDecodeErrors sets the optional errors attribute to value. -// -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeDecodeErrors(value string) UnicodeDecodeAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} - -// UnicodeDecodeReplacementChar sets the optional replacement_char attribute to value. -// -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD or U+65533.) -// If not specified, defaults to 65533 -func UnicodeDecodeReplacementChar(value int64) UnicodeDecodeAttr { - return func(m optionalAttr) { - m["replacement_char"] = value - } -} +// TensorListStackAttr is an optional argument to TensorListStack. +type TensorListStackAttr func(optionalAttr) -// UnicodeDecodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. -// -// value: Whether to replace the C0 control characters (00-1F) with the -// `replacement_char`. Default is false. -// If not specified, defaults to false -func UnicodeDecodeReplaceControlCharacters(value bool) UnicodeDecodeAttr { +// TensorListStackNumElements sets the optional num_elements attribute to value. +// If not specified, defaults to -1 +func TensorListStackNumElements(value int64) TensorListStackAttr { return func(m optionalAttr) { - m["replace_control_characters"] = value - } -} - -// Decodes each string in `input` into a sequence of Unicode code points. -// -// The character codepoints for all strings are returned using a single vector -// `char_values`, with strings expanded to characters in row-major order. -// -// The `row_splits` tensor indicates where the codepoints for -// each input string begin and end within the `char_values` tensor. -// In particular, the values for the `i`th -// string (in row-major order) are stored in the slice -// `[row_splits[i]:row_splits[i+1]]`. Thus: -// -// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th -// character in the `i`th string (in row-major order). -// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th -// string (in row-major order). -// -// Arguments: -// input: The text to be decoded. Can have any shape. Note that the output is flattened -// to a vector of char values. -// input_encoding: Text encoding of the input strings. This is any of the encodings supported -// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. -// -// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints. -func UnicodeDecode(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeAttr) (row_splits tf.Output, char_values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"input_encoding": input_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeDecode", - Input: []tf.Input{ - input, - }, - Attrs: attrs, + m["num_elements"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) } -// Adds up a SparseTensor and a dense Tensor, using these special rules: -// -// (1) Broadcasts the dense side to have the same shape as the sparse side, if -// eligible; -// (2) Then, only the dense values pointed to by the indices of the SparseTensor -// participate in the cwise addition. -// -// By these rules, the result is a logical SparseTensor with exactly the same -// indices and shape, but possibly with different non-zero values. The output of -// this Op is the resultant non-zero values. +// Stacks all tensors in the list. // -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// Requires that all tensors have the same shape. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { +// input_handle: the input list +// tensor: the gathered result +// num_elements: optional. If not -1, the number of elements in the list. +// +func TensorListStack(scope *Scope, input_handle tf.Output, element_shape tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseDenseCwiseAdd", + Type: "TensorListStack", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, + input_handle, element_shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the complementary error function of `x` element-wise. -func Erfc(scope *Scope, x tf.Output) (y tf.Output) { +// A placeholder op for a value that will be fed into the computation. +// +// Arguments: +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. +// +// Returns A tensor that will be provided using the infeed mechanism. +func InfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} opspec := tf.OpSpec{ - Type: "Erfc", - Input: []tf.Input{ - x, - }, + Type: "InfeedDequeue", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnicodeEncodeAttr is an optional argument to UnicodeEncode. -type UnicodeEncodeAttr func(optionalAttr) - -// UnicodeEncodeErrors sets the optional errors attribute to value. -// -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeEncodeErrors(value string) UnicodeEncodeAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} +// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. +type StatelessRandomUniformAttr func(optionalAttr) -// UnicodeEncodeReplacementChar sets the optional replacement_char attribute to value. +// StatelessRandomUniformDtype sets the optional dtype attribute to value. // -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD (U+65533). -// If not specified, defaults to 65533 -func UnicodeEncodeReplacementChar(value int64) UnicodeEncodeAttr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { return func(m optionalAttr) { - m["replacement_char"] = value + m["dtype"] = value } } -// Encode a tensor of ints into unicode strings. -// -// Returns a vector of strings, where `output[i]` is constructed by encoding the -// Unicode codepoints in `input_values[input_splits[i]:input_splits[i+1]]` -// using `output_encoding`. -// -// --- -// -// Example: +// Outputs deterministic pseudorandom random values from a uniform distribution. // -// ``` -// input_values = [72, 101, 108, 108, 111, 87, 111, 114, 108, 100] -// input_splits = [0, 5, 10] -// output_encoding = 'UTF-8' +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. // -// output = ['Hello', 'World'] -// ``` +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// input_values: A 1D tensor containing the unicode codepoints that should be encoded. -// input_splits: A 1D tensor specifying how the unicode codepoints should be split into strings. -// In particular, `output[i]` is constructed by encoding the codepoints in the -// slice `input_values[input_splits[i]:input_splits[i+1]]`. -// output_encoding: Unicode encoding of the output strings. Valid encodings are: `"UTF-8", -// "UTF-16-BE", and "UTF-32-BE"`. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns The 1-D Tensor of strings encoded from the provided unicode codepoints. -func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, output_encoding string, optional ...UnicodeEncodeAttr) (output tf.Output) { +// Returns Random values with specified shape. +func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_encoding": output_encoding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "UnicodeEncode", + Type: "StatelessRandomUniform", Input: []tf.Input{ - input_values, input_splits, + shape, seed, }, Attrs: attrs, } @@ -14107,343 +18668,303 @@ func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, return op.Output(0) } -// Returns the number of tensors in the input tensor list. +// Makes its input available to the next iteration. // -// input_handle: the input list -// length: the number of tensors in the list -func TensorListLength(scope *Scope, input_handle tf.Output) (length tf.Output) { +// Arguments: +// data: The tensor to be made available to the next iteration. +// +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListLength", + Type: "NextIteration", Input: []tf.Input{ - input_handle, + data, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Determine the script codes of a given tensor of Unicode integer code points. -// -// This operation converts Unicode code points to script codes corresponding to -// each code point. Script codes correspond to International Components for -// Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html. -// Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will -// match input shape. -// -// Arguments: -// input: A Tensor of int32 Unicode code points. -// -// Returns A Tensor of int32 script codes corresponding to each input code point. -func UnicodeScript(scope *Scope, input tf.Output) (output tf.Output) { +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "UnicodeScript", - Input: []tf.Input{ - input, - }, + Type: "Fact", } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a sequence of numbers. -// -// This operation creates a sequence of numbers that begins at `start` and -// extends by increments of `delta` up to but not including `limit`. -// -// For example: -// -// ``` -// # 'start' is 3 -// # 'limit' is 18 -// # 'delta' is 3 -// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] -// ``` -// -// Arguments: -// start: 0-D (scalar). First entry in the sequence. -// limit: 0-D (scalar). Upper limit of sequence, exclusive. -// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +// Returns the truth value of x AND y element-wise. // -// Returns 1-D. -func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { +// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Range", + Type: "LogicalAnd", Input: []tf.Input{ - start, limit, delta, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// OrderedMapUnstageNoKeyAttr is an optional argument to OrderedMapUnstageNoKey. -type OrderedMapUnstageNoKeyAttr func(optionalAttr) +// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. +type GenerateVocabRemappingAttr func(optionalAttr) -// OrderedMapUnstageNoKeyCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapUnstageNoKeyCapacity(value int64) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// value: Number of entries in the old vocab file to consider. If -1, +// use the entire old vocabulary. +// If not specified, defaults to -1 // -// REQUIRES: value >= 0 -func OrderedMapUnstageNoKeyMemoryLimit(value int64) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapUnstageNoKeyContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageNoKeyContainer(value string) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapUnstageNoKeySharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageNoKeySharedName(value string) OrderedMapUnstageNoKeyAttr { +// REQUIRES: value >= -1 +func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["old_vocab_size"] = value } } -// Op removes and returns the (key, value) element with the smallest +// Given a path to new and old vocabulary files, returns a remapping Tensor of // -// key from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { +// length `num_new_vocab`, where `remapping[i]` contains the row number in the old +// vocabulary that corresponds to row `i` in the new vocabulary (starting at line +// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` +// in the new vocabulary is not in the old vocabulary. The old vocabulary is +// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the +// default value of -1. +// +// `num_vocab_offset` enables +// use in the partitioned variable case, and should generally be set through +// examining partitioning info. The format of the files should be a text file, +// with each line containing a single entity within the vocabulary. +// +// For example, with `new_vocab_file` a text file containing each of the following +// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], +// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be +// `[0, -1, 2]`. +// +// The op also returns a count of how many entries in the new vocabulary +// were present in the old vocabulary, which is used to calculate the number of +// values to initialize in a weight matrix remapping +// +// This functionality can be used to remap both row vocabularies (typically, +// features) and column vocabularies (typically, classes) from TensorFlow +// checkpoints. Note that the partitioning logic relies on contiguous vocabularies +// corresponding to div-partitioned variables. Moreover, the underlying remapping +// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should +// use the corresponding index_table_from_file() as the FeatureColumn framework +// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// +// Arguments: +// new_vocab_file: Path to the new vocab file. +// old_vocab_file: Path to the old vocab file. +// new_vocab_offset: How many entries into the new vocab file to start reading. +// num_new_vocab: Number of entries in the new vocab file to remap. +// +// Returns A Tensor of length num_new_vocab where the element at index i +// is equal to the old ID that maps to the new ID i. This element is -1 for any +// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. +func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapUnstageNoKey", + Type: "GenerateVocabRemapping", Input: []tf.Input{ - indices, + new_vocab_file, old_vocab_file, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - key = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapUnstageNoKey", err) - return - } - return key, values + return op.Output(0), op.Output(1) } -// Returns element-wise integer closest to x. +// Says whether the targets are in the top `K` predictions. // -// If the result is midway between two representable values, -// the even representable is chosen. -// For example: +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. // -// ``` -// rint(-1.5) ==> -2.0 -// rint(0.5000001) ==> 1.0 -// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] -// ``` -func Rint(scope *Scope, x tf.Output) (y tf.Output) { +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed Precision at `k` as a `bool Tensor`. +func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"k": k} opspec := tf.OpSpec{ - Type: "Rint", + Type: "InTopK", Input: []tf.Input{ - x, + predictions, targets, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. -type ResourceApplyMomentumAttr func(optionalAttr) +// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingAdagradParametersGradAccumDebug. +type RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) -// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. +// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["table_id"] = value } } -// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. -// If not specified, defaults to false -func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { +// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["table_name"] = value } } -// Update '*var' according to the momentum scheme. Set use_nesterov = True if you +// Retrieve Adagrad embedding parameters with debug support. // -// want to use Nesterov momentum. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// accum = accum * momentum + grad -// var -= lr * accum +// Returns Parameter parameters updated by the Adagrad optimization algorithm.Parameter accumulators updated by the Adagrad optimization algorithm.Parameter gradient_accumulators updated by the Adagrad optimization algorithm. +func RetrieveTPUEmbeddingAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// This op consumes a lock created by `MutexLock`. +// +// This op exists to consume a tensor created by `MutexLock` (other than +// direct control dependencies). It should be the only that consumes the tensor, +// and will raise an error if it is not. Its only purpose is to keep the +// mutex lock tensor alive until it is consumed by this op. +// +// **NOTE**: This operation must run on the same device as its input. This may +// be enforced via the `colocate_with` mechanism. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// momentum: Momentum. Must be a scalar. +// mutex_lock: A tensor returned by `MutexLock`. // // Returns the created operation. -func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { +func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyMomentum", + Type: "ConsumeMutexLock", Input: []tf.Input{ - var_, accum, lr, grad, momentum, + mutex_lock, }, - Attrs: attrs, } return scope.AddOperation(opspec) } -// SubstrAttr is an optional argument to Substr. -type SubstrAttr func(optionalAttr) +// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd. +type ResourceScatterNdAddAttr func(optionalAttr) -// SubstrUnit sets the optional unit attribute to value. +// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value. // -// value: The unit that is used to create the substring. One of: `"BYTE"` (for -// defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 -// encoded Unicode code points). The default is `"BYTE"`. Results are undefined if -// `unit=UTF8_CHAR` and the `input` strings do not contain structurally valid -// UTF-8. -// If not specified, defaults to "BYTE" -func SubstrUnit(value string) SubstrAttr { +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { return func(m optionalAttr) { - m["unit"] = value + m["use_locking"] = value } } -// Return substrings from `Tensor` of strings. -// -// For each string in the input `Tensor`, creates a substring starting at index -// `pos` with a total length of `len`. -// -// If `len` defines a substring that would extend beyond the length of the input -// string, then as many characters as possible are used. -// -// A negative `pos` indicates distance within the string backwards from the end. -// -// If `pos` specifies an index which is out of range for any of the input strings, -// then an `InvalidArgumentError` is thrown. -// -// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on -// Op creation. -// -// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about -// broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +// Applies sparse addition to individual values or slices in a Variable. // -// --- +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. // -// Examples +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. // -// Using scalar `pos` and `len`: +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. // -// ```python -// input = [b'Hello', b'World'] -// position = 1 -// length = 3 +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: // -// output = [b'ell', b'orl'] // ``` -// -// Using `pos` and `len` with same shape as `input`: -// -// ```python -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen']] -// position = [[1, 2, 3], -// [1, 2, 3], -// [1, 2, 3]] -// length = [[2, 3, 4], -// [4, 3, 2], -// [5, 5, 5]] -// -// output = [[b'en', b'eve', b'lve'], -// [b'hirt', b'urt', b'te'], -// [b'ixtee', b'vente', b'hteen']] +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] // ``` // -// Broadcasting `pos` and `len` onto `input`: -// -// ``` -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen'], -// [b'nineteen', b'twenty', b'twentyone']] -// position = [1, 2, 3] -// length = [1, 2, 3] +// For example, say we want to add 4 scattered elements to a rank-1 tensor to +// 8 elements. In Python, that addition would look like this: // -// output = [[b'e', b'ev', b'lve'], -// [b'h', b'ur', b'tee'], -// [b'i', b've', b'hte'], -// [b'i', b'en', b'nty']] +// ```python +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// add = tf.scatter_nd_add(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(add) // ``` // -// Broadcasting `input` onto `pos` and `len`: +// The resulting update to ref would look like this: // -// ``` -// input = b'thirteen' -// position = [1, 5, 7] -// length = [3, 2, 1] +// [1, 13, 3, 14, 14, 6, 7, 20] // -// output = [b'hir', b'ee', b'n'] -// ``` +// See `tf.scatter_nd` for more details about how to make updates to +// slices. // // Arguments: -// input: Tensor of strings -// pos: Scalar defining the position of first character in each substring -// len: Scalar defining the number of characters to include in each substring +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of +// values to add to ref. // -// Returns Tensor of substrings -func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output, optional ...SubstrAttr) (output tf.Output) { +// Returns the created operation. +func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -14452,240 +18973,293 @@ func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output, optiona a(attrs) } opspec := tf.OpSpec{ - Type: "Substr", + Type: "ResourceScatterNdAdd", Input: []tf.Input{ - input, pos, len, + ref, indices, updates, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) +} + +// Replaces the contents of the table with the specified keys and values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableImportV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) } -// Exits the current frame to its parent frame. +// Worker heartbeat op. // -// Exit makes its input `data` available to the parent frame. +// Heartbeats may be sent periodically to indicate the coordinator is still active, +// to retrieve the current worker status and to expedite shutdown when necessary. // // Arguments: -// data: The tensor to be made available to the parent frame. +// request: A string tensor containing a serialized WorkerHeartbeatRequest // -// Returns The same tensor as `data`. -func Exit(scope *Scope, data tf.Output) (output tf.Output) { +// Returns A string tensor containing a serialized WorkerHeartbeatResponse +func WorkerHeartbeat(scope *Scope, request tf.Output) (response tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Exit", + Type: "WorkerHeartbeat", Input: []tf.Input{ - data, + request, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Produce a string tensor that encodes the state of a Reader. -// -// Not all Readers support being serialized, so this can produce an -// Unimplemented error. +// Returns the truth value of (x <= y) element-wise. // -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { +// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderSerializeStateV2", + Type: "LessEqual", Input: []tf.Input{ - reader_handle, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Concatenates quantized tensors along one dimension. +// EnqueueTPUEmbeddingIntegerBatchAttr is an optional argument to EnqueueTPUEmbeddingIntegerBatch. +type EnqueueTPUEmbeddingIntegerBatchAttr func(optionalAttr) + +// EnqueueTPUEmbeddingIntegerBatchDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. Should be >= 0 and less than the number +// of TPU cores in the task on which the node is placed. +// If not specified, defaults to -1 +func EnqueueTPUEmbeddingIntegerBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingIntegerBatchAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// An op that enqueues a list of input batch tensors to TPUEmbedding. // // Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// input_mins: The minimum scalar values for each of the input tensors. -// input_maxes: The maximum scalar values for each of the input tensors. +// batch: A list of 1D tensors, one for each embedding table, containing the +// indices into the tables. +// mode_override: A string input that overrides the mode specified in the +// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns the created operation. +func EnqueueTPUEmbeddingIntegerBatch(scope *Scope, batch []tf.Output, mode_override tf.Output, optional ...EnqueueTPUEmbeddingIntegerBatchAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QuantizedConcat", + Type: "EnqueueTPUEmbeddingIntegerBatch", Input: []tf.Input{ - concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + tf.OutputList(batch), mode_override, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Slice a `SparseTensor` based on the `start` and `size`. -// -// For example, if the input is -// -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] -// -// Graphically the output tensors are: -// -// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] -// [ a ] -// [b c ] +// An op that receives embedding activations on the TPU. // -// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] -// [ d e ] -// [ ] +// The TPU system performs the embedding lookups and aggregations specified by +// the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The +// results of these aggregations are visible to the Tensorflow Graph as the +// outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing +// one Tensor of activations per table specified in the model. There can be at +// most one RecvTPUEmbeddingActivations op in the TPU graph. // // Arguments: -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// start: 1-D. tensor represents the start of the slice. -// size: 1-D. tensor represents the size of the slice. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. +// num_outputs: The number of output activation tensors, equal to the number of +// embedding tables in the model. +// config: Serialized TPUEmbeddingConfiguration proto. // -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Returns A TensorList of embedding activations containing one Tensor per +// embedding table in the model. +func RecvTPUEmbeddingActivations(scope *Scope, num_outputs int64, config string) (outputs []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_outputs": num_outputs, "config": config} opspec := tf.OpSpec{ - Type: "SparseSlice", - Input: []tf.Input{ - indices, values, shape, start, size, - }, + Type: "RecvTPUEmbeddingActivations", + + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("RecvTPUEmbeddingActivations", err) + return + } + return outputs } -// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. +// Selects elements from `x` or `y`, depending on `condition`. // -// This operation computes +// The `x`, and `y` tensors must all have the same shape, and the +// output will also have that shape. // -// # Scalar indices -// ref[indices, ...] = min(ref[indices, ...], updates[...]) +// The `condition` tensor must be a scalar if `x` and `y` are scalars. +// If `x` and `y` are vectors or higher rank, then `condition` must be either a +// scalar, a vector with size matching the first dimension of `x`, or must have +// the same shape as `x`. // -// # Vector indices (for each i) -// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) +// The `condition` tensor acts as a mask that chooses, based on the value at each +// element, whether the corresponding element / row in the output should be +// taken from `x` (if true) or `y` (if false). // -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// If `condition` is a vector and `x` and `y` are higher rank matrices, then +// it chooses which row (outer dimension) to copy from `x` and `y`. +// If `condition` has the same shape as `x` and `y`, then it chooses which +// element to copy from `x` and `y`. // -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. +// For example: // -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// ```python +// # 'condition' tensor is [[True, False] +// # [False, True]] +// # 't' is [[1, 2], +// # [3, 4]] +// # 'e' is [[5, 6], +// # [7, 8]] +// select(condition, t, e) # => [[1, 6], [7, 4]] // -//
-// -//
+// +// # 'condition' tensor is [True, False] +// # 't' is [[1, 2], +// # [3, 4]] +// # 'e' is [[5, 6], +// # [7, 8]] +// select(condition, t, e) ==> [[1, 2], +// [7, 8]] +// +// ``` // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. // -// Returns the created operation. -func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// x: = A `Tensor` which may have the same shape as `condition`. +// If `condition` is rank 1, `x` may have higher rank, +// but its first dimension must match the size of `condition`. +// y: = A `Tensor` with the same type and shape as `x`. +// +// Returns = A `Tensor` with the same type and shape as `x` and `y`. +func Select(scope *Scope, condition tf.Output, x tf.Output, y tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterMin", + Type: "Select", Input: []tf.Input{ - resource, indices, updates, + condition, x, y, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Reshapes a quantized tensor as per the Reshape op. +// Returns the set of files matching one or more glob patterns. // -// ``` +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. +// Note also that the order of filenames returned can be non-deterministic. // // Arguments: +// pattern: Shell wildcard pattern(s). Scalar or vector of type string. // -// shape: Defines the shape of the output tensor. -// input_min: The minimum value of the input. -// input_max: The maximum value of the input. -// -// Returns This value is copied from input_min.This value is copied from input_max. -func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns A vector of matching filenames. +func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QuantizedReshape", + Type: "MatchingFiles", Input: []tf.Input{ - tensor, shape, input_min, input_max, + pattern, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// StringSplitAttr is an optional argument to StringSplit. -type StringSplitAttr func(optionalAttr) +// SqueezeAttr is an optional argument to Squeeze. +type SqueezeAttr func(optionalAttr) -// StringSplitSkipEmpty sets the optional skip_empty attribute to value. +// SqueezeAxis sets the optional axis attribute to value. // -// value: A `bool`. If `True`, skip the empty strings from the result. -// If not specified, defaults to true -func StringSplitSkipEmpty(value bool) StringSplitAttr { +// value: If specified, only squeezes the dimensions listed. The dimension +// index starts at 0. It is an error to squeeze a dimension that is not 1. Must +// be in the range `[-rank(input), rank(input))`. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func SqueezeAxis(value []int64) SqueezeAttr { return func(m optionalAttr) { - m["skip_empty"] = value + m["squeeze_dims"] = value } } -// Split elements of `input` based on `delimiter` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `input` based on `delimiter` and return a `SparseTensor` -// containing the splitted tokens. Empty tokens are ignored. +// Removes dimensions of size 1 from the shape of a tensor. // -// `delimiter` can be empty, or a string of split characters. If `delimiter` is an -// empty string, each element of `input` is split into individual single-byte -// character strings, including splitting of UTF-8 multibyte sequences. Otherwise -// every character of `delimiter` is a potential split point. +// Given a tensor `input`, this operation returns a tensor of the same type with +// all dimensions of size 1 removed. If you don't want to remove all size 1 +// dimensions, you can remove specific size 1 dimensions by specifying +// `axis`. // // For example: -// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output -// will be // -// indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// shape = [2, 3] -// values = ['hello', 'world', 'a', 'b', 'c'] +// ``` +// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +// shape(squeeze(t)) ==> [2, 3] +// ``` +// +// Or, to remove specific size 1 dimensions: +// +// ``` +// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] +// ``` // // Arguments: -// input: 1-D. Strings to split. -// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// input: The `input` to squeeze. // -// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse -// tensor, where the first value is N and the second value is the maximum number -// of tokens in a single input entry. -func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { +// Returns Contains the same data as `input`, but has one or more dimensions of +// size 1 removed. +func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -14694,62 +19268,48 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional .. a(attrs) } opspec := tf.OpSpec{ - Type: "StringSplit", + Type: "Squeeze", Input: []tf.Input{ - input, delimiter, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. -type ResourceSparseApplyMomentumAttr func(optionalAttr) +// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. +type ResourceApplyAdadeltaAttr func(optionalAttr) -// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. +// value: If True, updating of the var, accum and update_accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. // If not specified, defaults to false -func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { +func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. -// If not specified, defaults to false -func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update relevant entries in '*var' and '*accum' according to the momentum scheme. -// -// Set use_nesterov = True if you want to use Nesterov momentum. -// -// That is for rows we have grad for, we update var and accum as follows: +// Update '*var' according to the adadelta scheme. // -// accum = accum * momentum + grad -// var -= lr * accum +// accum = rho() * accum + (1 - rho()) * grad.square(); +// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; +// update_accum = rho() * update_accum + (1 - rho()) * update.square(); +// var -= update; // // Arguments: // var_: Should be from a Variable(). // accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. +// accum_update: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// momentum: Momentum. Must be a scalar. // // Returns the created operation. -func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { +func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -14758,137 +19318,170 @@ func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyMomentum", + Type: "ResourceApplyAdadelta", Input: []tf.Input{ - var_, accum, lr, grad, indices, momentum, + var_, accum, accum_update, lr, rho, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns the complex conjugate of a complex number. +// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. +type NonMaxSuppressionAttr func(optionalAttr) + +// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// complex numbers that are the complex conjugate of each element in `input`. The -// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the -// real part and *b* is the imaginary part. +// value: A float representing the threshold for deciding whether boxes +// overlap too much with respect to IOU. +// If not specified, defaults to 0.5 +func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { + return func(m optionalAttr) { + m["iou_threshold"] = value + } +} + +// Greedily selects a subset of bounding boxes in descending order of score, // -// The complex conjugate returned by this operation is of the form \\(a - bj\\). +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // -// For example: +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] -// ``` -func Conj(scope *Scope, input tf.Output) (output tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Conj", + Type: "NonMaxSuppression", Input: []tf.Input{ - input, + boxes, scores, max_output_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CudnnRNNBackpropAttr is an optional argument to CudnnRNNBackprop. -type CudnnRNNBackpropAttr func(optionalAttr) - -// CudnnRNNBackpropRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNBackpropRnnMode(value string) CudnnRNNBackpropAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value +// Creates a dataset that emits `components` as a tuple of tensors once. +func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return } -} - -// CudnnRNNBackpropInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNBackpropInputMode(value string) CudnnRNNBackpropAttr { - return func(m optionalAttr) { - m["input_mode"] = value + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TensorDataset", + Input: []tf.Input{ + tf.OutputList(components), + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// CudnnRNNBackpropDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNBackpropDirection(value string) CudnnRNNBackpropAttr { - return func(m optionalAttr) { - m["direction"] = value +// 2D real-valued fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 2 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfft2 +// @end_compatibility +func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// CudnnRNNBackpropDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNBackpropDropout(value float32) CudnnRNNBackpropAttr { - return func(m optionalAttr) { - m["dropout"] = value + opspec := tf.OpSpec{ + Type: "RFFT2D", + Input: []tf.Input{ + input, fft_length, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// CudnnRNNBackpropSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNBackpropSeed(value int64) CudnnRNNBackpropAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) -// CudnnRNNBackpropSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNBackpropSeed2(value int64) CudnnRNNBackpropAttr { +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_locking"] = value } } -// Backprop step of CudnnRNN. +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. // -// Compute the backprop of both data and weights in a RNN. +// That is for rows we have grad for, we update var, accum and linear as follows: +// accum_new = accum + grad * grad +// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// the actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. Should be -// "unidirectional" or "bidirectional". -// dropout: Dropout probability. When set to 0., dropout is disabled. -// seed: The 1st part of a seed to initialize dropout. -// seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: A 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: The same shape has input_h. -// output_c: The same shape as input_c for LSTM. An empty tensor for other models. -// output_backprop: A 3-D tensor with the same shape as output in the forward pass. -// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward -// pass. -// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward -// pass. -// reserve_space: The same reserve_space produced in for forward operation. -// input_backprop: The backprop to input in the forward pass. Has the same shape -// as input. -// input_h_backprop: The backprop to input_h in the forward pass. Has the same -// shape as input_h. -// input_c_backprop: The backprop to input_c in the forward pass. Has the same -// shape as input_c. -// params_backprop: The backprop to the params buffer in the forward pass. Has the -// same shape as params. -func CudnnRNNBackprop(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, output tf.Output, output_h tf.Output, output_c tf.Output, output_backprop tf.Output, output_h_backprop tf.Output, output_c_backprop tf.Output, reserve_space tf.Output, optional ...CudnnRNNBackpropAttr) (input_backprop tf.Output, input_h_backprop tf.Output, input_c_backprop tf.Output, params_backprop tf.Output) { +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -14897,135 +19490,222 @@ func CudnnRNNBackprop(scope *Scope, input tf.Output, input_h tf.Output, input_c a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNBackprop", + Type: "ResourceSparseApplyFtrl", Input: []tf.Input{ - input, input_h, input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Returns which elements of x are Inf. +// +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsInf", + Input: []tf.Input{ + x, + }, + } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) + return op.Output(0) } -// Encode audio data using the WAV file format. +// Gather ragged slices from `params` axis `0` according to `indices`. // -// This operation will generate a string suitable to be saved out to create a .wav -// audio file. It will be encoded in the 16-bit PCM format. It takes in float -// values in the range -1.0f to 1.0f, and any outside that value will be clamped to -// that range. +// Outputs a `RaggedTensor` output composed from `output_dense_values` and +// `output_nested_splits`, such that: +// +// ```python +// output.shape = indices.shape + params.shape[1:] +// output.ragged_rank = indices.shape.ndims + params.ragged_rank +// output[i...j, d0...dn] = params[indices[i...j], d0...dn] +// ``` +// +// where +// +// * `params = +// ragged.from_nested_row_splits(params_dense_values, params_nested_splits)` +// provides the values that should be gathered. +// * `indices` ia a dense tensor with dtype `int32` or `int64`, indicating which +// values should be gathered. +// * `output = +// ragged.from_nested_row_splits(output_dense_values, output_nested_splits)` +// is the output tensor. +// +// (Note: This c++ op is used to implement the higher-level python +// `tf.ragged.gather` op, which also supports ragged indices.) // -// `audio` is a 2-D float Tensor of shape `[length, channels]`. -// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // // Arguments: -// audio: 2-D with shape `[length, channels]`. -// sample_rate: Scalar containing the sample frequency. +// params_nested_splits: The `nested_row_splits` tensors that define the row-partitioning for the +// `params` RaggedTensor input. +// params_dense_values: The `flat_values` for the `params` RaggedTensor. There was a terminology change +// at the python level from dense_values to flat_values, so dense_values is the +// deprecated name. +// indices: Indices in the outermost dimension of `params` of the values that should be +// gathered. +// OUTPUT_RAGGED_RANK: The ragged rank of the output RaggedTensor. `output_nested_splits` will contain +// this number of `row_splits` tensors. This value should equal +// `indices.shape.ndims + params.ragged_rank - 1`. // -// Returns 0-D. WAV-encoded file contents. -func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { +// Returns The `nested_row_splits` tensors that define the row-partitioning for the +// returned RaggedTensor.The `flat_values` for the returned RaggedTensor. +func RaggedGather(scope *Scope, params_nested_splits []tf.Output, params_dense_values tf.Output, indices tf.Output, OUTPUT_RAGGED_RANK int64) (output_nested_splits []tf.Output, output_dense_values tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"OUTPUT_RAGGED_RANK": OUTPUT_RAGGED_RANK} opspec := tf.OpSpec{ - Type: "EncodeWav", + Type: "RaggedGather", Input: []tf.Input{ - audio, sample_rate, + tf.OutputList(params_nested_splits), params_dense_values, indices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if output_nested_splits, idx, err = makeOutputList(op, idx, "output_nested_splits"); err != nil { + scope.UpdateErr("RaggedGather", err) + return + } + output_dense_values = op.Output(idx) + return output_nested_splits, output_dense_values } -// Computes atan of x element-wise. -func Atan(scope *Scope, x tf.Output) (y tf.Output) { +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Atan", + Type: "NonMaxSuppressionV2", Input: []tf.Input{ - x, + boxes, scores, max_output_size, iou_threshold, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax. -type ResourceApplyAdaMaxAttr func(optionalAttr) +// TruncatedNormalAttr is an optional argument to TruncatedNormal. +type TruncatedNormalAttr func(optionalAttr) -// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value. +// TruncatedNormalSeed sets the optional seed attribute to value. // -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr { +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func TruncatedNormalSeed(value int64) TruncatedNormalAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["seed"] = value } } -// Update '*var' according to the AdaMax algorithm. +// TruncatedNormalSeed2 sets the optional seed2 attribute to value. // -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// v_t <- max(beta2 * v_{t-1}, abs(g)) -// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func TruncatedNormalSeed2(value int64) TruncatedNormalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// shape: The shape of the output tensor. +// dtype: The type of the output. // -// Returns the created operation. -func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) { +// Returns A tensor of the specified shape filled with random truncated normal +// values. +func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdaMax", + Type: "TruncatedNormal", Input: []tf.Input{ - var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, + shape, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// AssertAttr is an optional argument to Assert. -type AssertAttr func(optionalAttr) +// StringToNumberAttr is an optional argument to StringToNumber. +type StringToNumberAttr func(optionalAttr) -// AssertSummarize sets the optional summarize attribute to value. +// StringToNumberOutType sets the optional out_type attribute to value. // -// value: Print this many entries of each tensor. -// If not specified, defaults to 3 -func AssertSummarize(value int64) AssertAttr { +// value: The numeric type to interpret each string in `string_tensor` as. +// If not specified, defaults to DT_FLOAT +func StringToNumberOutType(value tf.DataType) StringToNumberAttr { return func(m optionalAttr) { - m["summarize"] = value + m["out_type"] = value } } -// Asserts that the given condition is true. -// -// If `condition` evaluates to false, print the list of tensors in `data`. -// `summarize` determines how many entries of the tensors to print. +// Converts each string in the input Tensor to the specified numeric type. // -// Arguments: -// condition: The condition to evaluate. -// data: The tensors to print out when condition is false. +// (Note that int32 overflow results in an error while float overflow +// results in a rounded value.) // -// Returns the created operation. -func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -15034,245 +19714,244 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass a(attrs) } opspec := tf.OpSpec{ - Type: "Assert", + Type: "StringToNumber", Input: []tf.Input{ - condition, tf.OutputList(data), + string_tensor, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Split a `SparseTensor` into `num_split` tensors along one dimension. -// -// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices -// `[0 : shape[split_dim] % num_split]` gets one extra dimension. -// For example, if `split_dim = 1` and `num_split = 2` and the input is -// -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] -// -// Graphically the output tensors are: +// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2. +type ResourceApplyFtrlV2Attr func(optionalAttr) + +// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value. // -// output_tensor[0] = shape = [2, 4] -// [ a ] -// [b c ] +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Ftrl-proximal scheme. // -// output_tensor[1] = shape = [2, 3] -// [ d e ] -// [ ] +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(shape))`. -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. -// num_split: The number of ways to split. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. // -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_split": num_split} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSplit", + Type: "ResourceApplyFtrlV2", Input: []tf.Input{ - split_dim, indices, values, shape, + var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { - scope.UpdateErr("SparseSplit", err) - return + return scope.AddOperation(opspec) +} + +// SkipgramAttr is an optional argument to Skipgram. +type SkipgramAttr func(optionalAttr) + +// SkipgramWindowSize sets the optional window_size attribute to value. +// +// value: The number of words to predict to the left and right of the target. +// If not specified, defaults to 5 +func SkipgramWindowSize(value int64) SkipgramAttr { + return func(m optionalAttr) { + m["window_size"] = value } - return output_indices, output_values, output_shape } -// Computes numerical negative value element-wise. +// SkipgramMinCount sets the optional min_count attribute to value. // -// I.e., \\(y = -x\\). -func Neg(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return +// value: The minimum number of word occurrences for it to be included in the +// vocabulary. +// If not specified, defaults to 5 +func SkipgramMinCount(value int64) SkipgramAttr { + return func(m optionalAttr) { + m["min_count"] = value } - opspec := tf.OpSpec{ - Type: "Neg", - Input: []tf.Input{ - x, - }, +} + +// SkipgramSubsample sets the optional subsample attribute to value. +// +// value: Threshold for word occurrence. Words that appear with higher +// frequency will be randomly down-sampled. Set to 0 to disable. +// If not specified, defaults to 0.001 +func SkipgramSubsample(value float32) SkipgramAttr { + return func(m optionalAttr) { + m["subsample"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Execute a sub graph on a remote processor. +// Parses a text file and creates a batch of examples. // -// The graph specifications(such as graph itself, input tensors and output names) -// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo -// as serialized_remote_fused_graph_execute_info. -// The specifications will be passed to a dedicated registered -// remote fused graph executor. The executor will send the graph specifications -// to a remote processor and execute that graph. The execution results -// will be passed to consumer nodes as outputs of this node. +// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result // // Arguments: -// inputs: Arbitrary number of tensors with arbitrary data types -// -// serialized_remote_fused_graph_execute_info: Serialized protocol buffer -// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// filename: The corpus's text file name. +// batch_size: The size of produced batch. // -// Returns Arbitrary number of tensors with arbitrary data types -func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { +// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids. +func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RemoteFusedGraphExecute", - Input: []tf.Input{ - tf.OutputList(inputs), - }, + Type: "Skipgram", + Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RemoteFusedGraphExecute", err) - return - } - return outputs + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) } -// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. -type MaxPool3DGradGradAttr func(optionalAttr) +// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. +type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) -// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// ThreadUnsafeUnigramCandidateSamplerSeed sets the optional seed attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ThreadUnsafeUnigramCandidateSamplerSeed(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["data_format"] = value + m["seed"] = value } } -// Computes second-order gradients of the maxpooling function. +// ThreadUnsafeUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func ThreadUnsafeUnigramCandidateSamplerSeed2(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. +// +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func ThreadUnsafeUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...ThreadUnsafeUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3DGradGrad", + Type: "ThreadUnsafeUnigramCandidateSampler", Input: []tf.Input{ - orig_input, orig_output, grad, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. -type Conv3DBackpropFilterV2Attr func(optionalAttr) - -// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} +// MaxPoolV2Attr is an optional argument to MaxPoolV2. +type MaxPoolV2Attr func(optionalAttr) -// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// MaxPoolV2DataFormat sets the optional data_format attribute to value. // -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolV2DataFormat(value string) MaxPoolV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["data_format"] = value } } -// Computes the gradients of 3-D convolution with respect to the filter. +// Performs max pooling on the input. // // Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 5-D -// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` -// tensor. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. // padding: The type of padding algorithm to use. -func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { +// +// Returns The max pooled output tensor. +func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilterV2", + Type: "MaxPoolV2", Input: []tf.Input{ - input, filter_sizes, out_backprop, + input, ksize, strides, }, Attrs: attrs, } @@ -15280,49 +19959,71 @@ func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Outpu return op.Output(0) } -// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. -type FakeQuantWithMinMaxVarsAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["num_bits"] = value +// Does nothing. Serves as a control trigger for scheduling. +// +// Only useful as a placeholder for control edges. +// +// Returns the created operation. +func ControlTrigger(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ControlTrigger", } + return scope.AddOperation(opspec) } -// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["narrow_range"] = value +// Deprecated. Use TensorArrayReadV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 +func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "TensorArrayReadV2", + Input: []tf.Input{ + handle, index, flow_in, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` +// Batch normalization. // -// and `max` to 'outputs' tensor of same shape as `inputs`. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. +// This op is deprecated. Prefer `tf.nn.batch_normalization`. // -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVars", + Type: "BatchNormWithGlobalNormalization", Input: []tf.Input{ - inputs, min, max, + t, m, v, beta, gamma, }, Attrs: attrs, } @@ -15330,68 +20031,65 @@ func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max return op.Output(0) } -// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. -type ResourceScatterNdUpdateAttr func(optionalAttr) +// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. +type AddManySparseToTensorsMapAttr func(optionalAttr) -// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. +// AddManySparseToTensorsMapContainer sets the optional container attribute to value. // -// value: An optional bool. Defaults to True. If True, the assignment will -// be protected by a lock; otherwise the behavior is undefined, -// but may exhibit less contention. -// If not specified, defaults to true -func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddManySparseToTensorsMapContainer(value string) AddManySparseToTensorsMapAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["container"] = value } } -// Applies sparse `updates` to individual values or slices within a given -// -// variable according to `indices`. -// -// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `ref`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -// dimension of `ref`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// AddManySparseToTensorsMapSharedName sets the optional shared_name attribute to value. // -// ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. -// ``` +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddManySparseToTensorsMapSharedName(value string) AddManySparseToTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. // -// For example, say we want to update 4 scattered elements to a rank-1 tensor to -// 8 elements. In Python, that update would look like this: +// A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`, where // -// ```python -// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1] ,[7]]) -// updates = tf.constant([9, 10, 11, 12]) -// update = tf.scatter_nd_update(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(update) -// ``` +// ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` // -// The resulting update to ref would look like this: +// An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` +// having a first `sparse_indices` column taking values between `[0, N)`, where +// the minibatch size `N == sparse_shape[0]`. // -// [1, 11, 3, 10, 9, 6, 7, 12] +// The input `SparseTensor` must have rank `R` greater than 1, and the first +// dimension is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The stored +// `SparseTensor` objects pointed to by each row of the output `sparse_handles` +// will have rank `R-1`. // -// See `tf.scatter_nd` for more details about how to make updates to -// slices. +// The `SparseTensor` values can then be read out as part of a minibatch by passing +// the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddManySparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. // // Arguments: -// ref: A resource handle. Must be from a VarHandleOp. -// indices: A Tensor. Must be one of the following types: int32, int64. -// A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of updated -// values to add to ref. +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// `sparse_indices[:, 0]` must be ordered values in `[0, N)`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +// The minibatch size `N == sparse_shape[0]`. // -// Returns the created operation. -func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) { +// Returns 1-D. The handles of the `SparseTensor` now stored in the +// `SparseTensorsMap`. Shape: `[N]`. +func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddManySparseToTensorsMapAttr) (sparse_handles tf.Output) { if scope.Err() != nil { return } @@ -15400,360 +20098,343 @@ func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, upd a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceScatterNdUpdate", + Type: "AddManySparseToTensorsMap", Input: []tf.Input{ - ref, indices, updates, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Produces a string handle for the given MultiDeviceIterator. -// -// Arguments: -// multi_device_iterator: A MultiDeviceIterator resource. +// TPUReplicateMetadataAttr is an optional argument to TPUReplicateMetadata. +type TPUReplicateMetadataAttr func(optionalAttr) + +// TPUReplicateMetadataNumCoresPerReplica sets the optional num_cores_per_replica attribute to value. // -// Returns A string representing the resource. -func MultiDeviceIteratorToStringHandle(scope *Scope, multi_device_iterator tf.Output) (string_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MultiDeviceIteratorToStringHandle", - Input: []tf.Input{ - multi_device_iterator, - }, +// value: Number of cores per replica. Used for model parallelism. +// If not specified, defaults to 1 +func TPUReplicateMetadataNumCoresPerReplica(value int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["num_cores_per_replica"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// CudnnRNNV3Attr is an optional argument to CudnnRNNV3. -type CudnnRNNV3Attr func(optionalAttr) - -// CudnnRNNV3RnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNV3RnnMode(value string) CudnnRNNV3Attr { +// TPUReplicateMetadataTopology sets the optional topology attribute to value. +// +// value: TopologyProto indicating the topology of the TPU pod slice. +// If not specified, defaults to "" +func TPUReplicateMetadataTopology(value string) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["rnn_mode"] = value + m["topology"] = value } } -// CudnnRNNV3InputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNV3InputMode(value string) CudnnRNNV3Attr { +// TPUReplicateMetadataUseTpu sets the optional use_tpu attribute to value. +// +// value: Whether to place the computation on the TPU. +// If not specified, defaults to true +func TPUReplicateMetadataUseTpu(value bool) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["input_mode"] = value + m["use_tpu"] = value } } -// CudnnRNNV3Direction sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNV3Direction(value string) CudnnRNNV3Attr { +// TPUReplicateMetadataDeviceAssignment sets the optional device_assignment attribute to value. +// +// value: The assignment of devices for the computation. +// If not specified, defaults to <> +func TPUReplicateMetadataDeviceAssignment(value []int64) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["direction"] = value + m["device_assignment"] = value } } -// CudnnRNNV3Dropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNV3Dropout(value float32) CudnnRNNV3Attr { +// TPUReplicateMetadataComputationShape sets the optional computation_shape attribute to value. +// +// value: DEPRECATED. Use num_cores_per_replica instead. +// If not specified, defaults to <> +func TPUReplicateMetadataComputationShape(value []int64) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["dropout"] = value + m["computation_shape"] = value } } -// CudnnRNNV3Seed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNV3Seed(value int64) CudnnRNNV3Attr { +// TPUReplicateMetadataHostComputeCore sets the optional host_compute_core attribute to value. +// If not specified, defaults to <> +func TPUReplicateMetadataHostComputeCore(value []string) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["seed"] = value + m["host_compute_core"] = value } } -// CudnnRNNV3Seed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNV3Seed2(value int64) CudnnRNNV3Attr { +// TPUReplicateMetadataPaddingMap sets the optional padding_map attribute to value. +// If not specified, defaults to <> +func TPUReplicateMetadataPaddingMap(value []string) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["seed2"] = value + m["padding_map"] = value } } -// CudnnRNNV3IsTraining sets the optional is_training attribute to value. -// If not specified, defaults to true -func CudnnRNNV3IsTraining(value bool) CudnnRNNV3Attr { +// TPUReplicateMetadataStepMarkerLocation sets the optional step_marker_location attribute to value. +// If not specified, defaults to "STEP_MARK_AT_ENTRY" +func TPUReplicateMetadataStepMarkerLocation(value string) TPUReplicateMetadataAttr { return func(m optionalAttr) { - m["is_training"] = value + m["step_marker_location"] = value } } -// A RNN backed by cuDNN. +// Metadata indicaitng how the TPU computation should be replicated. // -// Computes the RNN from the input and initial states, with respect to the params -// buffer. Accepts one extra input "sequence_lengths" than CudnnRNN. +// Arguments: +// num_replicas: Number of replicas of the computation // -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicates whether there is a linear projection between the input and -// the actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. Should be -// "unidirectional" or "bidirectional". -// dropout: Dropout probability. When set to 0., dropout is disabled. -// seed: The 1st part of a seed to initialize dropout. -// seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: A 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// sequence_lengths: a vector of lengths of each input sequence. -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: The same shape has input_h. -// output_c: The same shape as input_c for LSTM. An empty tensor for other models. -// is_training: Indicates whether this operation is used for inferenece or -// training. -// reserve_space: An opaque tensor that can be used in backprop calculation. It -// is only produced if is_training is true. -func CudnnRNNV3(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, sequence_lengths tf.Output, optional ...CudnnRNNV3Attr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output, host_reserved tf.Output) { +// Returns the created operation. +func TPUReplicateMetadata(scope *Scope, num_replicas int64, optional ...TPUReplicateMetadataAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_replicas": num_replicas} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNV3", - Input: []tf.Input{ - input, input_h, input_c, params, sequence_lengths, - }, + Type: "TPUReplicateMetadata", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return scope.AddOperation(opspec) } -// Applies softmax to a batched N-D `SparseTensor`. +// LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingFTRLParametersGradAccumDebug. +type LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` -// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. +// REQUIRES: value >= -1 +func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load FTRL embedding parameters with debug support. // -// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost -// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly -// zero elements do not participate*. Specifically, the algorithm is equivalent -// to the following: +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // -// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix -// with shape `[B, C]`, along the size-C dimension; -// (2) Masks out the original implicitly-zero locations; -// (3) Renormalizes the remaining elements. +// Arguments: +// parameters: Value of parameters used in the FTRL optimization algorithm. +// accumulators: Value of accumulators used in the FTRL optimization algorithm. +// linears: Value of linears used in the FTRL optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the FTRL optimization algorithm. // -// Hence, the `SparseTensor` result has exactly the same non-zero indices and -// shape. // -// Arguments: -// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a -// SparseTensor, in canonical ordering. -// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. // -// Returns 1-D. The `NNZ` values for the result `SparseTensor`. -func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { +// Returns the created operation. +func LoadTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSoftmax", + Type: "LoadTPUEmbeddingFTRLParametersGradAccumDebug", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, + parameters, accumulators, linears, gradient_accumulators, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Partitions `data` into `num_partitions` tensors using indices from `partitions`. -// -// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` -// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` -// are placed in `outputs[i]` in lexicographic order of `js`, and the first -// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. -// In detail, -// -// ```python -// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] -// -// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) -// ``` -// -// `data.shape` must start with `partitions.shape`. -// -// For example: -// -// ```python -// # Scalar partitions. -// partitions = 1 -// num_partitions = 2 -// data = [10, 20] -// outputs[0] = [] # Empty with shape [0, 2] -// outputs[1] = [[10, 20]] -// -// # Vector partitions. -// partitions = [0, 0, 1, 1, 0] -// num_partitions = 2 -// data = [10, 20, 30, 40, 50] -// outputs[0] = [10, 20, 50] -// outputs[1] = [30, 40] -// ``` -// -// See `dynamic_stitch` for an example on how to merge partitions back. -// -//
-// -//
+// Concatenates tensors along one dimension. // // Arguments: +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). // -// partitions: Any shape. Indices in the range `[0, num_partitions)`. -// num_partitions: The number of partitions to output. -func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_partitions": num_partitions} opspec := tf.OpSpec{ - Type: "DynamicPartition", + Type: "ConcatV2", Input: []tf.Input{ - data, partitions, + tf.OutputList(values), axis, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("DynamicPartition", err) - return + opspec := tf.OpSpec{ + Type: "ReadFile", + Input: []tf.Input{ + filename, + }, } - return outputs + op := scope.AddOperation(opspec) + return op.Output(0) } -// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. -type ResourceApplyAdagradAttr func(optionalAttr) +// AvgPoolGradAttr is an optional argument to AvgPoolGrad. +type AvgPoolGradAttr func(optionalAttr) -// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. +// AvgPoolGradDataFormat sets the optional data_format attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { return func(m optionalAttr) { - m["update_slots"] = value + m["data_format"] = value } } -// Update '*var' according to the adagrad scheme. -// -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// Computes gradients of the average pooling function. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. +// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. +// the output of `avg_pool`. +// ksize: The size of the sliding window for each dimension of the input. +// strides: The stride of the sliding window for each dimension of the input. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { +// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. +func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdagrad", + Type: "AvgPoolGrad", Input: []tf.Input{ - var_, accum, lr, grad, + orig_input_shape, grad, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high overlaps +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. N-by-n overlap values are supplied as square matrix, +// which allows for defining a custom overlap criterium (eg. intersection over union, +// intersection over area, etc.). +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_with_overlaps( +// overlaps, scores, max_output_size, overlap_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. +// overlaps: A 2-D float tensor of shape `[num_boxes, num_boxes]` representing +// the n-by-n box overlap values. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// overlap_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. // -// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest -// layer. -func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionWithOverlaps(scope *Scope, overlaps tf.Output, scores tf.Output, max_output_size tf.Output, overlap_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesGetEnsembleStates", + Type: "NonMaxSuppressionWithOverlaps", Input: []tf.Input{ - tree_ensemble_handle, + overlaps, scores, max_output_size, overlap_threshold, score_threshold, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } -// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. -type ResourceApplyPowerSignAttr func(optionalAttr) +// FractionalAvgPoolGradAttr is an optional argument to FractionalAvgPoolGrad. +type FractionalAvgPoolGradAttr func(optionalAttr) -// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. +// FractionalAvgPoolGradOverlapping sets the optional overlapping attribute to value. // -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [41/3, 26/3] for fractional avg pooling. // If not specified, defaults to false -func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { +func FractionalAvgPoolGradOverlapping(value bool) FractionalAvgPoolGradAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["overlapping"] = value } } -// Update '*var' according to the AddSign update. +// Computes gradient of the FractionalAvgPool function. // -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g -// variable <- variable - lr_t * update +// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for +// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of +// out_backprop to those indices that form the same pooling cell. Therefore, we +// just need to know the shape of original input tensor, instead of the whole +// tensor. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// logbase: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. +// orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` +// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients +// w.r.t. the output of `fractional_avg_pool`. +// row_pooling_sequence: row pooling sequence, form pooling region with +// col_pooling_sequence. +// col_pooling_sequence: column pooling sequence, form pooling region with +// row_pooling sequence. // -// Returns the created operation. -func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { +// Returns 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. +func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalAvgPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -15762,68 +20443,52 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyPowerSign", + Type: "FractionalAvgPoolGrad", Input: []tf.Input{ - var_, m, lr, logbase, sign_decay, beta, grad, + orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// StringFormatAttr is an optional argument to StringFormat. -type StringFormatAttr func(optionalAttr) - -// StringFormatTemplate sets the optional template attribute to value. -// -// value: A string, the template to format tensor summaries into. -// If not specified, defaults to "%s" -func StringFormatTemplate(value string) StringFormatAttr { - return func(m optionalAttr) { - m["template"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// StringFormatPlaceholder sets the optional placeholder attribute to value. -// -// value: A string, at each placeholder in the template a subsequent tensor summary will be inserted. -// If not specified, defaults to "%s" -func StringFormatPlaceholder(value string) StringFormatAttr { - return func(m optionalAttr) { - m["placeholder"] = value - } -} +// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace. +type StaticRegexReplaceAttr func(optionalAttr) -// StringFormatSummarize sets the optional summarize attribute to value. +// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value. // -// value: When formatting the tensor summaries print the first and last summarize entries of each tensor dimension. -// If not specified, defaults to 3 -func StringFormatSummarize(value int64) StringFormatAttr { +// value: If True, the replacement is global, otherwise the replacement +// is done only on the first match. +// If not specified, defaults to true +func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr { return func(m optionalAttr) { - m["summarize"] = value + m["replace_global"] = value } } -// Formats a string template using a list of tensors. +// Replaces the match of pattern in input with rewrite. // -// Formats a string template using a list of tensors, pretty-printing tensor summaries. +// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) // // Arguments: -// inputs: The list of tensors to format into the placeholder string. +// input: The text to be processed. +// pattern: The regular expression to match the input. +// rewrite: The rewrite to be applied to the matched expression. // -// Returns = The resulting string scalar. -func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr) (output tf.Output) { +// Returns The text after applying pattern and rewrite. +func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StringFormat", + Type: "StaticRegexReplace", Input: []tf.Input{ - tf.OutputList(inputs), + input, }, Attrs: attrs, } @@ -15831,284 +20496,155 @@ func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr return op.Output(0) } -// ShapeAttr is an optional argument to Shape. -type ShapeAttr func(optionalAttr) - -// ShapeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeOutType(value tf.DataType) ShapeAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Returns the shape of a tensor. -// -// This operation returns a 1-D integer tensor representing the shape of `input`. +// Computes gradients for the exponential linear (Elu) operation. // -// For example: +// Arguments: +// gradients: The backpropagated gradients to the corresponding Elu operation. +// outputs: The outputs of the corresponding Elu operation. // -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] -// ``` -func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { +// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, +// `gradients` otherwise. +func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Shape", + Type: "EluGrad", Input: []tf.Input{ - input, + gradients, outputs, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the power of one value to another. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for -// corresponding elements in `x` and `y`. For example: +// The hash function is deterministic on the content of the string within the +// process. // -// ``` -// # tensor 'x' is [[2, 2]], [3, 3]] -// # tensor 'y' is [[8, 16], [2, 3]] -// tf.pow(x, y) ==> [[256, 65536], [9, 27]] -// ``` -func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Note that the hash function may change from time to time. +// This functionality will be deprecated and it's recommended to use +// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. +// +// Arguments: +// +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "Pow", + Type: "StringToHashBucket", Input: []tf.Input{ - x, y, + string_tensor, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes fingerprints of the input strings. +// Creates a dataset that batches `batch_size` elements from `input_dataset`. // // Arguments: -// input: vector of strings to compute fingerprints on. // -// Returns a (N,2) shaped matrix where N is the number of elements in the input -// vector. Each row contains the low and high parts of the fingerprint. -func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) { +// batch_size: A scalar representing the number of elements to accumulate in a batch. +// drop_remainder: A scalar representing whether the last batch should be dropped in case its size +// is smaller than desired. +// +// +func BatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SdcaFprint", + Type: "BatchDatasetV2", Input: []tf.Input{ - input, + input_dataset, batch_size, drop_remainder, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. -type ResourceSparseApplyRMSPropAttr func(optionalAttr) - -// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value +// Computes the gradient of `igamma(a, x)` wrt `a`. +func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IgammaGradA", + Input: []tf.Input{ + a, x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Update '*var' according to the RMSProp algorithm. -// -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) -// -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// Creates a dataset that contains `count` elements from the `input_dataset`. // // Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. // -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. +// count: A scalar representing the number of elements from the `input_dataset` +// that should be taken. A value of `-1` indicates that all of `input_dataset` +// is taken. // -// Returns the created operation. -func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { +// +func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ResourceSparseApplyRMSProp", + Type: "TakeDataset", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + input_dataset, count, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. -type SampleDistortedBoundingBoxAttr func(optionalAttr) - -// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. -// -// value: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. -// If not specified, defaults to 0.1 -func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["min_object_covered"] = value - } -} - -// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. -// -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["aspect_ratio_range"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. -// -// value: The cropped area of the image must contain a fraction of the -// supplied image within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["area_range"] = value - } -} +// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. +type FakeQuantWithMinMaxVarsAttr func(optionalAttr) -// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. -// -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { +// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { return func(m optionalAttr) { - m["max_attempts"] = value + m["num_bits"] = value } } -// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. -// -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. // If not specified, defaults to false -func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value + m["narrow_range"] = value } } -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, -// -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) -// -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) -// -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. +// and `max` to 'outputs' tensor of same shape as `inputs`. // -// Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. // -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return } @@ -16117,149 +20653,115 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBox", + Type: "FakeQuantWithMinMaxVars", Input: []tf.Input{ - image_size, bounding_boxes, + inputs, min, max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// LRNAttr is an optional argument to LRN. -type LRNAttr func(optionalAttr) - -// LRNDepthRadius sets the optional depth_radius attribute to value. -// -// value: 0-D. Half-width of the 1-D normalization window. -// If not specified, defaults to 5 -func LRNDepthRadius(value int64) LRNAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } + return op.Output(0) } -// LRNBias sets the optional bias attribute to value. -// -// value: An offset (usually positive to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNBias(value float32) LRNAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} +// RetrieveTPUEmbeddingMomentumParametersAttr is an optional argument to RetrieveTPUEmbeddingMomentumParameters. +type RetrieveTPUEmbeddingMomentumParametersAttr func(optionalAttr) -// LRNAlpha sets the optional alpha attribute to value. +// RetrieveTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNAlpha(value float32) LRNAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingMomentumParametersTableId(value int64) RetrieveTPUEmbeddingMomentumParametersAttr { return func(m optionalAttr) { - m["alpha"] = value + m["table_id"] = value } } -// LRNBeta sets the optional beta attribute to value. -// -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNBeta(value float32) LRNAttr { +// RetrieveTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersTableName(value string) RetrieveTPUEmbeddingMomentumParametersAttr { return func(m optionalAttr) { - m["beta"] = value + m["table_name"] = value } } -// Local Response Normalization. -// -// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last -// dimension), and each vector is normalized independently. Within a given vector, -// each component is divided by the weighted, squared sum of inputs within -// `depth_radius`. In detail, -// -// sqr_sum[a, b, c, d] = -// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) -// output = input / (bias + alpha * sqr_sum) ** beta +// Retrieve Momentum embedding parameters. // -// For details, see [Krizhevsky et al., ImageNet classification with deep -// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Arguments: -// input: 4-D. -func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { +// Returns Parameter parameters updated by the Momentum optimization algorithm.Parameter momenta updated by the Momentum optimization algorithm. +func RetrieveTPUEmbeddingMomentumParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersAttr) (parameters tf.Output, momenta tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LRN", - Input: []tf.Input{ - input, - }, + Type: "RetrieveTPUEmbeddingMomentumParameters", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Creates a dataset that zips together `input_datasets`. -func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Forwards the value of an available tensor from `inputs` to `output`. +// +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. +// +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. +// +// Arguments: +// inputs: The input tensors, exactly one of which will become available. +// +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ZipDataset", + Type: "Merge", Input: []tf.Input{ - tf.OutputList(input_datasets), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. -type ResourceSparseApplyAdagradAttr func(optionalAttr) - -// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value + tf.OutputList(inputs), + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) + +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. +// +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. +// If not specified, defaults to false +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { return func(m optionalAttr) { - m["update_slots"] = value + m["cancel_pending_enqueues"] = value } } -// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// Closes the given queue. // -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// handle: The handle to a queue. // // Returns the created operation. -func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -16268,67 +20770,78 @@ func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, l a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagrad", + Type: "QueueCloseV2", Input: []tf.Input{ - var_, accum, lr, grad, indices, + handle, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Elementwise computes the bitwise right-shift of `x` and `y`. +// Writes the given dataset to the given file using the TFRecord format. // -// Performs a logical shift for unsigned integer types, and an arithmetic shift -// for signed integer types. +// Arguments: +// input_dataset: A variant tensor representing the dataset to write. +// filename: A scalar string tensor representing the filename to use. +// compression_type: A scalar string tensor containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". // -// If `y` is negative, or greater than or equal to than the width of `x` in bits -// the result is implementation defined. -func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns the created operation. +func ExperimentalDatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RightShift", + Type: "ExperimentalDatasetToTFRecord", Input: []tf.Input{ - x, y, + input_dataset, filename, compression_type, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// TensorListStackAttr is an optional argument to TensorListStack. -type TensorListStackAttr func(optionalAttr) +// BiasAddGradAttr is an optional argument to BiasAddGrad. +type BiasAddGradAttr func(optionalAttr) -// TensorListStackNumElements sets the optional num_elements attribute to value. -// If not specified, defaults to -1 -func TensorListStackNumElements(value int64) TensorListStackAttr { +// BiasAddGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. +// If not specified, defaults to "NHWC" +func BiasAddGradDataFormat(value string) BiasAddGradAttr { return func(m optionalAttr) { - m["num_elements"] = value + m["data_format"] = value } } -// Stacks all tensors in the list. +// The backward operation for "BiasAdd" on the "bias" tensor. // -// Requires that all tensors have the same shape. +// It accumulates all the values from out_backprop into the feature dimension. +// For NHWC data format, the feature dimension is the last. For NCHW data format, +// the feature dimension is the third-to-last. // -// input_handle: the input list -// tensor: the gathered result -// num_elements: optional. If not -1, the number of elements in the list. +// Arguments: +// out_backprop: Any number of dimensions. // -func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { +// Returns 1-D with size the feature dimension of `out_backprop`. +func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorListStack", + Type: "BiasAddGrad", Input: []tf.Input{ - input_handle, + out_backprop, }, Attrs: attrs, } @@ -16336,43 +20849,25 @@ func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.Data return op.Output(0) } -// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. -type StatelessRandomUniformAttr func(optionalAttr) - -// StatelessRandomUniformDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// Reduces `input` from `num_devices` using `reduction` to a single device. // -// The outputs are a deterministic function of `shape` and `seed`. +// Reduces `input` from `num_devices` using `reduction` to a single device. // -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// The graph should be constructed so that all inputs have a valid device +// assignment, and the op itself is assigned one of these devices. // -// Returns Random values with specified shape. -func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { +// input: The input to the reduction. +// data: the value of the reduction across all `num_devices` devices. +// reduction: the reduction operation to perform. +func NcclReduce(scope *Scope, input []tf.Output, reduction string) (data tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"reduction": reduction} opspec := tf.OpSpec{ - Type: "StatelessRandomUniform", + Type: "NcclReduce", Input: []tf.Input{ - shape, seed, + tf.OutputList(input), }, Attrs: attrs, } @@ -16380,256 +20875,275 @@ func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optio return op.Output(0) } -// Makes its input available to the next iteration. +// Computes the gradient of morphological 2-D dilation with respect to the input. // // Arguments: -// data: The tensor to be made available to the next iteration. +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. // -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { +// Returns 4-D with shape `[batch, in_height, in_width, depth]`. +func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (in_backprop tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "NextIteration", + Type: "Dilation2DBackpropInput", Input: []tf.Input{ - data, + input, filter, out_backprop, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { +// An Op to sum inputs across replicated TPU instances. +// +// Each instance supplies its own input. +// +// For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`. +// Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, +// and `B, D, F, H` as group 1. Thus we get the outputs: +// `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. +// +// Arguments: +// input: The local input to the sum. +// group_assignment: An int32 tensor with shape +// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the +// replica ids in the ith subgroup. +// +// Returns The sum of all the distributed inputs. +func CrossReplicaSum(scope *Scope, input tf.Output, group_assignment tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Fact", + Type: "CrossReplicaSum", + Input: []tf.Input{ + input, group_assignment, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deserialize `SparseTensor` objects. -// -// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where -// the last dimension stores serialized `SparseTensor` objects and the other N -// dimensions (N >= 0) correspond to a batch. The ranks of the original -// `SparseTensor` objects must all match. When the final `SparseTensor` is -// created, its rank is the rank of the incoming `SparseTensor` objects plus N; -// the sparse tensors have been concatenated along new dimensions, one for each -// batch. -// -// The output `SparseTensor` object's shape values for the original dimensions -// are the max across the input `SparseTensor` objects' shape values for the -// corresponding dimensions. The new dimensions match the size of the batch. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: +// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. +type ResourceSparseApplyMomentumAttr func(optionalAttr) + +// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. // -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// and +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the momentum scheme. // -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] +// Set use_nesterov = True if you want to use Nesterov momentum. // -// then the final deserialized `SparseTensor` will be: +// That is for rows we have grad for, we update var and accum as follows: // -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] +// accum = accum * momentum + grad +// var -= lr * accum // // Arguments: -// serialized_sparse: The serialized `SparseTensor` objects. The last dimension -// must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// momentum: Momentum. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DeserializeSparse", + Type: "ResourceSparseApplyMomentum", Input: []tf.Input{ - serialized_sparse, + var_, accum, lr, grad, indices, momentum, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// SqueezeAttr is an optional argument to Squeeze. -type SqueezeAttr func(optionalAttr) - -// SqueezeAxis sets the optional axis attribute to value. +// An Op to permute tensors across replicated TPU instances. // -// value: If specified, only squeezes the dimensions listed. The dimension -// index starts at 0. It is an error to squeeze a dimension that is not 1. Must -// be in the range `[-rank(input), rank(input))`. -// If not specified, defaults to <> +// Each instance supplies its own input. // -// REQUIRES: len(value) >= 0 -func SqueezeAxis(value []int64) SqueezeAttr { - return func(m optionalAttr) { - m["squeeze_dims"] = value +// For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +// source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: +// `[D, A, B, C]`. +// +// Arguments: +// input: The local input to be permuted. Currently only supports float and +// bfloat16. +// source_target_pairs: A tensor with shape [num_pairs, 2]. +// +// Returns The permuted input. +func CollectivePermute(scope *Scope, input tf.Output, source_target_pairs tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CollectivePermute", + Input: []tf.Input{ + input, source_target_pairs, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Removes dimensions of size 1 from the shape of a tensor. -// -// Given a tensor `input`, this operation returns a tensor of the same type with -// all dimensions of size 1 removed. If you don't want to remove all size 1 -// dimensions, you can remove specific size 1 dimensions by specifying -// `axis`. +// Returns the complex conjugate of a complex number. // -// For example: +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// complex numbers that are the complex conjugate of each element in `input`. The +// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the +// real part and *b* is the imaginary part. // -// ``` -// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -// shape(squeeze(t)) ==> [2, 3] -// ``` +// The complex conjugate returned by this operation is of the form \\(a - bj\\). // -// Or, to remove specific size 1 dimensions: +// For example: // // ``` -// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] // ``` -// -// Arguments: -// input: The `input` to squeeze. -// -// Returns Contains the same data as `input`, but has one or more dimensions of -// size 1 removed. -func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf.Output) { +func Conj(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Squeeze", + Type: "Conj", Input: []tf.Input{ input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. -type ResourceApplyAdadeltaAttr func(optionalAttr) +// RetrieveTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingCenteredRMSPropParameters. +type RetrieveTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) -// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// RetrieveTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If True, updating of the var, accum and update_accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingCenteredRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["table_id"] = value } } -// Update '*var' according to the adadelta scheme. -// -// accum = rho() * accum + (1 - rho()) * grad.square(); -// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; -// update_accum = rho() * update_accum + (1 - rho()) * update.square(); -// var -= update; +// RetrieveTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingCenteredRMSPropParametersTableName(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve centered RMSProp embedding parameters. // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// accum_update: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns the created operation. -func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { +// Returns Parameter parameters updated by the centered RMSProp optimization algorithm.Parameter ms updated by the centered RMSProp optimization algorithm.Parameter mom updated by the centered RMSProp optimization algorithm.Parameter mg updated by the centered RMSProp optimization algorithm. +func RetrieveTPUEmbeddingCenteredRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingCenteredRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdadelta", - Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, - }, + Type: "RetrieveTPUEmbeddingCenteredRMSPropParameters", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. -type NonMaxSuppressionAttr func(optionalAttr) +// StringSplitAttr is an optional argument to StringSplit. +type StringSplitAttr func(optionalAttr) -// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. +// StringSplitSkipEmpty sets the optional skip_empty attribute to value. // -// value: A float representing the threshold for deciding whether boxes -// overlap too much with respect to IOU. -// If not specified, defaults to 0.5 -func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { +// value: A `bool`. If `True`, skip the empty strings from the result. +// If not specified, defaults to true +func StringSplitSkipEmpty(value bool) StringSplitAttr { return func(m optionalAttr) { - m["iou_threshold"] = value + m["skip_empty"] = value } } -// Greedily selects a subset of bounding boxes in descending order of score, +// Split elements of `input` based on `delimiter` into a `SparseTensor`. // -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) +// Let N be the size of source (typically N will be the batch size). Split each +// element of `input` based on `delimiter` and return a `SparseTensor` +// containing the splitted tokens. Empty tokens are ignored. +// +// `delimiter` can be empty, or a string of split characters. If `delimiter` is an +// empty string, each element of `input` is split into individual single-byte +// character strings, including splitting of UTF-8 multibyte sequences. Otherwise +// every character of `delimiter` is a potential split point. +// +// For example: +// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output +// will be +// +// indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// shape = [2, 3] +// values = ['hello', 'world', 'a', 'b', 'c'] // // Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. +// input: 1-D. Strings to split. +// delimiter: 0-D. Delimiter characters (bytes), or empty string. // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { +// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse +// tensor, where the first value is N and the second value is the maximum number +// of tokens in a single input entry. +func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { if scope.Err() != nil { return } @@ -16638,201 +21152,211 @@ func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_outp a(attrs) } opspec := tf.OpSpec{ - Type: "NonMaxSuppression", + Type: "StringSplit", Input: []tf.Input{ - boxes, scores, max_output_size, + input, delimiter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Creates a dataset that emits `components` as a tuple of tensors once. -func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return +// RetrieveTPUEmbeddingFTRLParametersAttr is an optional argument to RetrieveTPUEmbeddingFTRLParameters. +type RetrieveTPUEmbeddingFTRLParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingFTRLParametersTableId(value int64) RetrieveTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TensorDataset", - Input: []tf.Input{ - tf.OutputList(components), - }, - Attrs: attrs, +} + +// RetrieveTPUEmbeddingFTRLParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersTableName(value string) RetrieveTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Component-wise multiplies a SparseTensor by a dense Tensor. -// -// The output locations corresponding to the implicitly zero elements in the sparse -// tensor will be zero (i.e., will not take up storage space), regardless of the -// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). -// -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. +// Retrieve FTRL embedding parameters. // -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { +// Returns Parameter parameters updated by the FTRL optimization algorithm.Parameter accumulators updated by the FTRL optimization algorithm.Parameter linears updated by the FTRL optimization algorithm. +func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseDenseCwiseMul", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, - }, + Type: "RetrieveTPUEmbeddingFTRLParameters", + + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// 2D real-valued fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 2 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. +// MaxPool3DAttr is an optional argument to MaxPool3D. +type MaxPool3DAttr func(optionalAttr) + +// MaxPool3DDataFormat sets the optional data_format attribute to value. // -// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DDataFormat(value string) MaxPool3DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs 3D max pooling on the input. // // Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. +// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// @compatibility(numpy) -// Equivalent to np.fft.rfft2 -// @end_compatibility -func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Returns The max pooled output tensor. +func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RFFT2D", + Type: "MaxPool3D", Input: []tf.Input{ - input, fft_length, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Pads a tensor with zeros. -// -// This operation pads a `input` with zeros according to the `paddings` you -// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the -// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many zeros to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` -// in that dimension. -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// Convert JSON-encoded Example records to binary protocol buffer strings. // -// For example: +// This op translates a tensor containing Example records, encoded using +// the [standard JSON +// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), +// into a tensor containing the same records encoded as binary protocol +// buffers. The resulting tensor can then be fed to any of the other +// Example-parsing ops. // -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` +// Arguments: +// json_examples: Each string is a JSON object serialized according to the JSON +// mapping of the Example proto. // -func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { +// Returns Each string is a binary Example protocol buffer corresponding +// to the respective element of `json_examples`. +func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Pad", + Type: "DecodeJSONExample", Input: []tf.Input{ - input, paddings, + json_examples, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Checks whether a resource handle-based variable has been initialized. +// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. +type QueueEnqueueManyV2Attr func(optionalAttr) + +// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue is too full, this operation will block for up +// to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Enqueues zero or more tuples of one or more tensors in the given queue. +// +// This operation slices each component tensor along the 0th dimension to +// make multiple queue elements. All of the tuple components must have the +// same size in the 0th dimension. +// +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. +// +// N.B. If the queue is full, this operation will block until the given +// elements have been enqueued (or 'timeout_ms' elapses, if specified). // // Arguments: -// resource: the input resource handle. +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should +// be taken. // -// Returns a scalar boolean which is true if the variable has been -// initialized. -func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { +// Returns the created operation. +func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "VarIsInitializedOp", + Type: "QueueEnqueueManyV2", Input: []tf.Input{ - resource, + handle, tf.OutputList(components), }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) +// PrintV2Attr is an optional argument to PrintV2. +type PrintV2Attr func(optionalAttr) -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// PrintV2OutputStream sets the optional output_stream attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { +// value: A string specifying the output stream or logging level to print to. +// If not specified, defaults to "stderr" +func PrintV2OutputStream(value string) PrintV2Attr { return func(m optionalAttr) { - m["use_locking"] = value + m["output_stream"] = value } } -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// Prints a string scalar. // -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// Prints a string scalar to the desired output_stream. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// input: The string scalar to print. // // Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { +func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -16841,171 +21365,187 @@ func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, line a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", + Type: "PrintV2", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, + input, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns which elements of x are Inf. +// The gradient operator for the SparseSlice op. // -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { +// This op takes in the upstream gradient w.r.t. non-empty values of +// the sliced `SparseTensor`, and outputs the gradients w.r.t. +// the non-empty values of input `SparseTensor`. +// +// Arguments: +// backprop_val_grad: 1-D. The gradient with respect to +// the non-empty values of the sliced `SparseTensor`. +// input_indices: 2-D. The `indices` of the input `SparseTensor`. +// input_start: 1-D. tensor represents the start of the slice. +// output_indices: 2-D. The `indices` of the sliced `SparseTensor`. +// +// Returns 1-D. The gradient with respect to the non-empty values of input `SparseTensor`. +func SparseSliceGrad(scope *Scope, backprop_val_grad tf.Output, input_indices tf.Output, input_start tf.Output, output_indices tf.Output) (val_grad tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsInf", + Type: "SparseSliceGrad", Input: []tf.Input{ - x, + backprop_val_grad, input_indices, input_start, output_indices, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TruncatedNormalAttr is an optional argument to TruncatedNormal. -type TruncatedNormalAttr func(optionalAttr) - -// TruncatedNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func TruncatedNormalSeed(value int64) TruncatedNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) -// TruncatedNormalSeed2 sets the optional seed2 attribute to value. +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func TruncatedNormalSeed2(value int64) TruncatedNormalAttr { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_locking"] = value } } -// Outputs random values from a truncated normal distribution. +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. // -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. // -// Returns A tensor of the specified shape filled with random truncated normal -// values. -func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) { +// Returns the created operation. +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TruncatedNormal", + Type: "ResourceApplyProximalAdagrad", Input: []tf.Input{ - shape, + var_, accum, lr, l1, l2, grad, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// SkipgramAttr is an optional argument to Skipgram. -type SkipgramAttr func(optionalAttr) +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) -// SkipgramWindowSize sets the optional window_size attribute to value. +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. // -// value: The number of words to predict to the left and right of the target. -// If not specified, defaults to 5 -func SkipgramWindowSize(value int64) SkipgramAttr { +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { return func(m optionalAttr) { - m["window_size"] = value + m["container"] = value } } -// SkipgramMinCount sets the optional min_count attribute to value. +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. // -// value: The minimum number of word occurrences for it to be included in the -// vocabulary. -// If not specified, defaults to 5 -func SkipgramMinCount(value int64) SkipgramAttr { +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { return func(m optionalAttr) { - m["min_count"] = value + m["shared_name"] = value } } -// SkipgramSubsample sets the optional subsample attribute to value. -// -// value: Threshold for word occurrence. Words that appear with higher -// frequency will be randomly down-sampled. Set to 0 to disable. -// If not specified, defaults to 0.001 -func SkipgramSubsample(value float32) SkipgramAttr { +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { return func(m optionalAttr) { - m["subsample"] = value + m["use_node_name_sharing"] = value } } -// Parses a text file and creates a batch of examples. +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. // -// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. // // Arguments: -// filename: The corpus's text file name. -// batch_size: The size of produced batch. +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. // -// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids. -func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) { +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size} + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Skipgram", + Type: "MutableHashTableOfTensorsV2", Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) + return op.Output(0) } -// StringToNumberAttr is an optional argument to StringToNumber. -type StringToNumberAttr func(optionalAttr) +// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. +type ResourceApplyProximalGradientDescentAttr func(optionalAttr) -// StringToNumberOutType sets the optional out_type attribute to value. +// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. // -// value: The numeric type to interpret each string in `string_tensor` as. -// If not specified, defaults to DT_FLOAT -func StringToNumberOutType(value tf.DataType) StringToNumberAttr { +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { return func(m optionalAttr) { - m["out_type"] = value + m["use_locking"] = value } } -// Converts each string in the input Tensor to the specified numeric type. +// Update '*var' as FOBOS algorithm with fixed learning rate. // -// (Note that int32 overflow results in an error while float overflow -// results in a rounded value.) +// prox_v = var - alpha * delta +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) { +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// delta: The change. +// +// Returns the created operation. +func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -17014,200 +21554,228 @@ func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToN a(attrs) } opspec := tf.OpSpec{ - Type: "StringToNumber", + Type: "ResourceApplyProximalGradientDescent", Input: []tf.Input{ - string_tensor, + var_, alpha, l1, l2, delta, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2. -type ResourceApplyFtrlV2Attr func(optionalAttr) - -// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// Returns 0 if the denominator is zero. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr { - return func(m optionalAttr) { - m["use_locking"] = value +// +// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DivNoNan", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Update '*var' according to the Ftrl-proximal scheme. +// Subtracts a value from the current value of a variable. // -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// Any ReadVariableOp with a control dependency on this op is guaranteed to +// see the decremented value or a subsequent newer one. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. -// -// lr_power: Scaling factor. Must be a scalar. +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. // // Returns the created operation. -func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) { +func AssignSubVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyFtrlV2", + Type: "AssignSubVariableOp", Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power, + resource, value, }, - Attrs: attrs, } return scope.AddOperation(opspec) } -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) +// RestoreAttr is an optional argument to Restore. +type RestoreAttr func(optionalAttr) -// EncodeJpegFormat sets the optional format attribute to value. +// RestorePreferredShard sets the optional preferred_shard attribute to value. // -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { +// value: Index of file to open first if multiple files match +// `file_pattern`. +// If not specified, defaults to -1 +func RestorePreferredShard(value int64) RestoreAttr { return func(m optionalAttr) { - m["format"] = value + m["preferred_shard"] = value } } -// EncodeJpegQuality sets the optional quality attribute to value. +// Restores a tensor from checkpoint files. // -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value +// Reads a tensor stored in one or several files. If there are several files (for +// instance because a tensor was saved as slices), `file_pattern` may contain +// wildcard symbols (`*` and `?`) in the filename portion only, not in the +// directory portion. +// +// If a `file_pattern` matches several files, `preferred_shard` can be used to hint +// in which file the requested tensor is likely to be found. This op will first +// open the file at index `preferred_shard` in the list of matching files and try +// to restore tensors from that file. Only if some tensors or tensor slices are +// not found in that first file, then the Op opens all the files. Setting +// `preferred_shard` to match the value passed as the `shard` input +// of a matching `Save` Op may speed up Restore. This attribute only affects +// performance, not correctness. The default value -1 means files are processed in +// order. +// +// See also `RestoreSlice`. +// +// Arguments: +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// dt: The type of the tensor to be restored. +// +// Returns The restored tensor. +func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf.DataType, optional ...RestoreAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Restore", + Input: []tf.Input{ + file_pattern, tensor_name, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeJpegProgressive sets the optional progressive attribute to value. +// QuantizedResizeBilinearAttr is an optional argument to QuantizedResizeBilinear. +type QuantizedResizeBilinearAttr func(optionalAttr) + +// QuantizedResizeBilinearAlignCorners sets the optional align_corners attribute to value. // -// value: If True, create a JPEG that loads progressively (coarse to fine). +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. // If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { +func QuantizedResizeBilinearAlignCorners(value bool) QuantizedResizeBilinearAttr { return func(m optionalAttr) { - m["progressive"] = value + m["align_corners"] = value } } -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. -// -// value: If True, spend CPU/RAM to reduce size with no quality change. +// QuantizedResizeBilinearHalfPixelCenters sets the optional half_pixel_centers attribute to value. // If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { +func QuantizedResizeBilinearHalfPixelCenters(value bool) QuantizedResizeBilinearAttr { return func(m optionalAttr) { - m["optimize_size"] = value + m["half_pixel_centers"] = value + } +} + +// Resize quantized `images` to `size` using quantized bilinear interpolation. +// +// Input images and output images must be quantized types. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min tf.Output, max tf.Output, optional ...QuantizedResizeBilinearAttr) (resized_images tf.Output, out_min tf.Output, out_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedResizeBilinear", + Input: []tf.Input{ + images, size, min, max, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// Creates a dataset that uses a custom thread pool to compute `input_dataset`. // -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} - -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// Arguments: // -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value - } -} - -// EncodeJpegXDensity sets the optional x_density attribute to value. +// num_threads: Identifies the number of threads to use for the private threadpool. // -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["x_density"] = value - } -} - -// EncodeJpegYDensity sets the optional y_density attribute to value. // -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value +func ExperimentalPrivateThreadPoolDataset(scope *Scope, input_dataset tf.Output, num_threads tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalPrivateThreadPoolDataset", + Input: []tf.Input{ + input_dataset, num_threads, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// StackV2Attr is an optional argument to StackV2. +type StackV2Attr func(optionalAttr) + +// StackV2StackName sets the optional stack_name attribute to value. // -// value: If not empty, embed this XMP metadata in the image header. +// value: Overrides the name used for the temporary stack resource. Default +// value is the name of the 'Stack' op (which is guaranteed unique). // If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { +func StackV2StackName(value string) StackV2Attr { return func(m optionalAttr) { - m["xmp_metadata"] = value + m["stack_name"] = value } } -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. +// A stack that produces elements in first-in last-out order. // // Arguments: -// image: 3-D with shape `[height, width, channels]`. +// max_size: The maximum size of the stack if non-negative. If negative, the stack +// size is unlimited. +// elem_type: The type of the elements on the stack. // -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { +// Returns The handle to the stack. +func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"elem_type": elem_type} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "EncodeJpeg", + Type: "StackV2", Input: []tf.Input{ - image, + max_size, }, Attrs: attrs, } @@ -17215,48 +21783,99 @@ func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (cont return op.Output(0) } -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) +// CudnnRNNBackpropAttr is an optional argument to CudnnRNNBackprop. +type CudnnRNNBackpropAttr func(optionalAttr) -// MultinomialSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. +// CudnnRNNBackpropRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNBackpropRnnMode(value string) CudnnRNNBackpropAttr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNBackpropInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNBackpropInputMode(value string) CudnnRNNBackpropAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNBackpropDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNBackpropDirection(value string) CudnnRNNBackpropAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNBackpropDropout sets the optional dropout attribute to value. // If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { +func CudnnRNNBackpropDropout(value float32) CudnnRNNBackpropAttr { return func(m optionalAttr) { - m["seed"] = value + m["dropout"] = value } } -// MultinomialSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. +// CudnnRNNBackpropSeed sets the optional seed attribute to value. // If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { +func CudnnRNNBackpropSeed(value int64) CudnnRNNBackpropAttr { return func(m optionalAttr) { - m["seed2"] = value + m["seed"] = value } } -// MultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { +// CudnnRNNBackpropSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropSeed2(value int64) CudnnRNNBackpropAttr { return func(m optionalAttr) { - m["output_dtype"] = value + m["seed2"] = value } } -// Draws samples from a multinomial distribution. +// Backprop step of CudnnRNN. // -// Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. +// Compute the backprop of both data and weights in a RNN. // -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// output_backprop: A 3-D tensor with the same shape as output in the forward pass. +// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward +// pass. +// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward +// pass. +// reserve_space: The same reserve_space produced in for forward operation. +// input_backprop: The backprop to input in the forward pass. Has the same shape +// as input. +// input_h_backprop: The backprop to input_h in the forward pass. Has the same +// shape as input_h. +// input_c_backprop: The backprop to input_c in the forward pass. Has the same +// shape as input_c. +// params_backprop: The backprop to the params buffer in the forward pass. Has the +// same shape as params. +func CudnnRNNBackprop(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, output tf.Output, output_h tf.Output, output_c tf.Output, output_backprop tf.Output, output_h_backprop tf.Output, output_c_backprop tf.Output, reserve_space tf.Output, optional ...CudnnRNNBackpropAttr) (input_backprop tf.Output, input_h_backprop tf.Output, input_c_backprop tf.Output, params_backprop tf.Output) { if scope.Err() != nil { return } @@ -17265,45 +21884,60 @@ func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional a(attrs) } opspec := tf.OpSpec{ - Type: "Multinomial", + Type: "CudnnRNNBackprop", Input: []tf.Input{ - logits, num_samples, + input, input_h, input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. -type ResourceSparseApplyAdagradDAAttr func(optionalAttr) +// InfeedEnqueueAttr is an optional argument to InfeedEnqueue. +type InfeedEnqueueAttr func(optionalAttr) -// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// InfeedEnqueueShape sets the optional shape attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { +// value: The shape of the tensor. +// If not specified, defaults to <> +func InfeedEnqueueShape(value tf.Shape) InfeedEnqueueAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["shape"] = value } } -// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. +// InfeedEnqueueLayout sets the optional layout attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence. +// If a layout attribute is passed, but its values are all -1, the layout will +// be computed by the infeed operation. +// If not specified, defaults to <> +func InfeedEnqueueLayout(value []int64) InfeedEnqueueAttr { + return func(m optionalAttr) { + m["layout"] = value + } +} + +// InfeedEnqueueDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func InfeedEnqueueDeviceOrdinal(value int64) InfeedEnqueueAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// An op which feeds a single Tensor value into the computation. // // Arguments: -// var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. +// input: A tensor that will be provided using the infeed mechanism. // // Returns the created operation. -func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { +func InfeedEnqueue(scope *Scope, input tf.Output, optional ...InfeedEnqueueAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -17312,221 +21946,354 @@ func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumul a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagradDA", + Type: "InfeedEnqueue", Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, + input, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. +// Computes softmax cross entropy cost and gradients to backpropagate. // -// The hash function is deterministic on the content of the string within the -// process and will never change. However, it is not suitable for cryptography. -// This function may be used when CPU time is scarce and inputs are trusted or -// unimportant. There is a risk of adversaries constructing inputs that all hash -// to the same bucket. To prevent this problem, use a strong hash function with -// `tf.string_to_hash_bucket_strong`. +// Inputs are the logits, not probabilities. // // Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. +// features: batch_size x num_classes matrix +// labels: batch_size x num_classes matrix +// The caller must ensure that each batch of labels represents a valid +// probability distribution. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "StringToHashBucketFast", + Type: "SoftmaxCrossEntropyWithLogits", Input: []tf.Input{ - input, + features, labels, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Returns the last element of the input list as well as a list with all but that element. +// ReduceJoinAttr is an optional argument to ReduceJoin. +type ReduceJoinAttr func(optionalAttr) + +// ReduceJoinKeepDims sets the optional keep_dims attribute to value. // -// Fails if the list is empty. +// value: If `True`, retain reduced dimensions with length `1`. +// If not specified, defaults to false +func ReduceJoinKeepDims(value bool) ReduceJoinAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// ReduceJoinSeparator sets the optional separator attribute to value. // -// input_handle: the input list -// tensor: the withdrawn last element of the list -// element_dtype: the type of elements in the list -// element_shape: the shape of the output tensor -func TensorListPopBack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) { +// value: The separator to use when joining. +// If not specified, defaults to "" +func ReduceJoinSeparator(value string) ReduceJoinAttr { + return func(m optionalAttr) { + m["separator"] = value + } +} + +// Joins a string Tensor across the given dimensions. +// +// Computes the string join across dimensions in the given string Tensor of shape +// `[\\(d_0, d_1, ..., d_{n-1}\\)]`. Returns a new Tensor created by joining the input +// strings with the given separator (default: empty string). Negative indices are +// counted backwards from the end, with `-1` being equivalent to `n - 1`. If +// indices are not specified, joins across all dimensions beginning from `n - 1` +// through `0`. +// +// For example: +// +// ```python +// # tensor `a` is [["a", "b"], ["c", "d"]] +// tf.reduce_join(a, 0) ==> ["ac", "bd"] +// tf.reduce_join(a, 1) ==> ["ab", "cd"] +// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"] +// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"] +// tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] +// tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] +// tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] +// tf.reduce_join(a, [0, 1]) ==> "acbd" +// tf.reduce_join(a, [1, 0]) ==> "abcd" +// tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]] +// tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd" +// ``` +// +// Arguments: +// inputs: The input to be joined. All reduced indices must have non-zero size. +// reduction_indices: The dimensions to reduce over. Dimensions are reduced in the +// order specified. Omitting `reduction_indices` is equivalent to passing +// `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported. +// +// Returns Has shape equal to that of the input with reduced dimensions removed or +// set to `1` depending on `keep_dims`. +func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, optional ...ReduceJoinAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorListPopBack", + Type: "ReduceJoin", Input: []tf.Input{ - input_handle, + inputs, reduction_indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. -type MaxPoolGradGradAttr func(optionalAttr) +// TopKAttr is an optional argument to TopK. +type TopKAttr func(optionalAttr) -// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. +// TopKSorted sets the optional sorted attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { +// value: If true the resulting `k` elements will be sorted by the values in +// descending order. +// If not specified, defaults to true +func TopKSorted(value bool) TopKAttr { return func(m optionalAttr) { - m["data_format"] = value + m["sorted"] = value } } -// Computes second-order gradients of the maxpooling function. +// Finds values and indices of the `k` largest elements for the last dimension. +// +// DEPRECATED at GraphDef version 7: Use TopKV2 instead +// +// If the input is a vector (rank-1), finds the `k` largest entries in the vector +// and outputs their values and indices as vectors. Thus `values[j]` is the +// `j`-th largest entry in `input`, and its index is `indices[j]`. +// +// For matrices (resp. higher rank input), computes the top `k` entries in each +// row (resp. vector along the last dimension). Thus, +// +// values.shape = indices.shape = input.shape[:-1] + [k] +// +// If two elements are equal, the lower-index element appears first. +// +// If `k` varies dynamically, use `TopKV2` below. // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// input: 1-D or higher with last dimension at least `k`. +// k: Number of top elements to look for along the last dimension (along each +// row for matrices). // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { +// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. +func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values tf.Output, indices tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"k": k} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolGradGrad", + Type: "TopK", Input: []tf.Input{ - orig_input, orig_output, grad, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) - -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// BatchToSpace for N-D tensors of type T. // -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Gather specific elements from the TensorArray into output `value`. +// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape +// `block_shape + [batch]`, interleaves these blocks back into the grid defined by +// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as +// the input. The spatial dimensions of this intermediate result are then +// optionally cropped according to `crops` to produce the output. This is the +// reverse of SpaceToBatch. See below for a precise description. +// +// Arguments: +// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, +// where spatial_shape has M dimensions. +// block_shape: 1-D with shape `[M]`, all values must be >= 1. +// crops: 2-D with shape `[M, 2]`, all values must be >= 0. +// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input +// dimension `i + 1`, which corresponds to spatial dimension `i`. It is +// required that +// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. +// +// This operation is equivalent to the following steps: +// +// 1. Reshape `input` to `reshaped` of shape: +// [block_shape[0], ..., block_shape[M-1], +// batch / prod(block_shape), +// input_shape[1], ..., input_shape[N-1]] +// +// 2. Permute dimensions of `reshaped` to produce `permuted` of shape +// [batch / prod(block_shape), +// +// input_shape[1], block_shape[0], +// ..., +// input_shape[M], block_shape[M-1], +// +// input_shape[M+1], ..., input_shape[N-1]] +// +// 3. Reshape `permuted` to produce `reshaped_permuted` of shape +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0], +// ..., +// input_shape[M] * block_shape[M-1], +// +// input_shape[M+1], +// ..., +// input_shape[N-1]] +// +// 4. Crop the start and end of dimensions `[1, ..., M]` of +// `reshaped_permuted` according to `crops` to produce the output of shape: +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], +// ..., +// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], +// +// input_shape[M+1], ..., input_shape[N-1]] +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]] +// ``` // -// All elements selected by `indices` must have the same shape. +// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [2, 0]]`: // -// Arguments: -// handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// ``` +// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], +// [[[0], [2], [4]]], [[[0], [10], [12]]], +// [[[0], [5], [7]]], [[[0], [13], [15]]], +// [[[0], [6], [8]]], [[[0], [14], [16]]]] +// ``` // -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]]], +// [[[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", + Type: "BatchToSpaceND", Input: []tf.Input{ - handle, indices, flow_in, + input, block_shape, crops, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x / y element-wise for integer types. -// -// Truncation designates that negative numbers will round fractional quantities -// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different -// than Python semantics. See `FloorDiv` for a division function that matches -// Python Semantics. +// UnpackAttr is an optional argument to Unpack. +type UnpackAttr func(optionalAttr) + +// UnpackAxis sets the optional axis attribute to value. // -// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TruncateDiv", - Input: []tf.Input{ - x, y, - }, +// value: Dimension along which to unpack. Negative values wrap around, so the +// valid range is `[-R, R)`. +// If not specified, defaults to 0 +func UnpackAxis(value int64) UnpackAttr { + return func(m optionalAttr) { + m["axis"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Restores tensors from a V2 checkpoint. +// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. // -// For backward compatibility with the V1 format, this Op currently allows -// restoring from a V1 checkpoint as well: -// - This Op first attempts to find the V2 index file pointed to by "prefix", and -// if found proceed to read it as a V2 checkpoint; -// - Otherwise the V1 read path is invoked. -// Relying on this behavior is not recommended, as the ability to fall back to read -// V1 might be deprecated and eventually removed. +// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. +// For example, given a tensor of shape `(A, B, C, D)`; // -// By default, restores the named tensors in full. If the caller wishes to restore -// specific slices of stored tensors, "shape_and_slices" should be non-empty -// strings and correspondingly well-formed. +// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` +// and each tensor in `output` will have shape `(B, C, D)`. (Note that the +// dimension unpacked along is gone, unlike `split`). // -// Callers must ensure all the named tensors are indeed stored in the checkpoint. +// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` +// and each tensor in `output` will have shape `(A, C, D)`. +// Etc. +// +// This is the opposite of `pack`. // // Arguments: -// prefix: Must have a single element. The prefix of a V2 checkpoint. -// tensor_names: shape {N}. The names of the tensors to be restored. -// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. -// Empty strings indicate that they are non-partitioned tensors. -// dtypes: shape {N}. The list of expected dtype for the tensors. Must match -// those stored in the checkpoint. +// value: 1-D or higher, with `axis` dimension size equal to `num`. // -// Returns shape {N}. The restored tensors, whose shapes are read from the -// checkpoint directly. -func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { +// +// Returns The list of tensors unpacked from `value`. +func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"num": num} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RestoreV2", + Type: "Unpack", Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, + value, }, Attrs: attrs, } @@ -17536,200 +22303,141 @@ func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and } var idx int var err error - if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { - scope.UpdateErr("RestoreV2", err) + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Unpack", err) return } - return tensors + return output } -// Receives a tensor value broadcast from another device. -func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { +// Delete the stack from its resource container. +// +// Arguments: +// handle: The handle to a stack. +// +// Returns the created operation. +func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} opspec := tf.OpSpec{ - Type: "CollectiveBcastRecv", - - Attrs: attrs, + Type: "StackCloseV2", + Input: []tf.Input{ + handle, + }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Decode web-safe base64-encoded strings. -// -// Input may or may not have padding at the end. See EncodeBase64 for padding. -// Web-safe means that input must use - and _ instead of + and /. +// Increments variable pointed to by 'resource' until it reaches 'limit'. // // Arguments: -// input: Base64 strings to decode. +// resource: Should be from a scalar `Variable` node. +// limit: If incrementing ref would bring it above limit, instead generates an +// 'OutOfRange' error. // -// Returns Decoded strings. -func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { +// +// Returns A copy of the input before increment. If nothing else modifies the +// input, the values produced will all be distinct. +func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"limit": limit, "T": T} opspec := tf.OpSpec{ - Type: "DecodeBase64", + Type: "ResourceCountUpTo", Input: []tf.Input{ - input, + resource, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. -type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) - -// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. -// -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// prox_v = var -// prox_v -= lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// Computes softsign gradients for a softsign operation. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// gradients: The backpropagated gradients to the corresponding softsign operation. +// features: The features passed as input to the corresponding softsign operation. // -// Returns the created operation. -func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { +// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. +func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalAdagrad", + Type: "SoftsignGrad", Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, indices, + gradients, features, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. -type MaxPool3DGradAttr func(optionalAttr) - -// MaxPool3DGradDataFormat sets the optional data_format attribute to value. +// Provides the time since epoch in seconds. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of max pooling function. +// Returns the timestamp as a `float64` for seconds since the Unix epoch. // -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { +// Note: the timestamp is computed when the op is executed, not when it is added +// to the graph. +func Timestamp(scope *Scope) (ts tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MaxPool3DGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, + Type: "Timestamp", } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the name of the device on which `resource` has been placed. -func ExperimentalIteratorGetDevice(scope *Scope, resource tf.Output) (device tf.Output) { +// Returns immutable tensor from memory region. +// +// The current implementation memmaps the tensor from a file. +// +// Arguments: +// dtype: Type of the returned tensor. +// shape: Shape of the returned tensor. +// memory_region_name: Name of readonly memory region used by the tensor, see +// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. +func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} opspec := tf.OpSpec{ - Type: "ExperimentalIteratorGetDevice", - Input: []tf.Input{ - resource, - }, + Type: "ImmutableConst", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseReduceSumAttr is an optional argument to SparseReduceSum. -type SparseReduceSumAttr func(optionalAttr) +// StringJoinAttr is an optional argument to StringJoin. +type StringJoinAttr func(optionalAttr) -// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. +// StringJoinSeparator sets the optional separator attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { +// value: string, an optional join separator. +// If not specified, defaults to "" +func StringJoinSeparator(value string) StringJoinAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["separator"] = value } } -// Computes the sum of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// Joins the strings in the given list of string tensors into one tensor; // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// with the given separator (default is an empty separator). // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -// -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { +// inputs: A list of string tensors. The tensors must all have the same shape, +// or be scalars. Scalars may be mixed in; these will be broadcast to the shape +// of non-scalar inputs. +func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -17738,9 +22446,9 @@ func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "SparseReduceSum", + Type: "StringJoin", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + tf.OutputList(inputs), }, Attrs: attrs, } @@ -17748,16 +22456,23 @@ func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Outp return op.Output(0) } -// Records the latency of producing `input_dataset` elements in a StatsAggregator. -func ExperimentalLatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Creates and returns an empty tensor list. +// +// All list elements must be tensors of dtype element_dtype and shape compatible +// with element_shape. +// +// handle: an empty tensor list. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func EmptyTensorList(scope *Scope, element_shape tf.Output, max_num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "ExperimentalLatencyStatsDataset", + Type: "EmptyTensorList", Input: []tf.Input{ - input_dataset, tag, + element_shape, max_num_elements, }, Attrs: attrs, } @@ -17765,97 +22480,85 @@ func ExperimentalLatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag return op.Output(0) } -// SparseTensorDenseMatMulAttr is an optional argument to SparseTensorDenseMatMul. -type SparseTensorDenseMatMulAttr func(optionalAttr) - -// SparseTensorDenseMatMulAdjointA sets the optional adjoint_a attribute to value. -// -// value: Use the adjoint of A in the matrix multiply. If A is complex, this -// is transpose(conj(A)). Otherwise it's transpose(A). -// If not specified, defaults to false -func SparseTensorDenseMatMulAdjointA(value bool) SparseTensorDenseMatMulAttr { - return func(m optionalAttr) { - m["adjoint_a"] = value - } -} - -// SparseTensorDenseMatMulAdjointB sets the optional adjoint_b attribute to value. +// Returns a list of tensors with the same shapes and contents as the input // -// value: Use the adjoint of B in the matrix multiply. If B is complex, this -// is transpose(conj(B)). Otherwise it's transpose(B). -// If not specified, defaults to false -func SparseTensorDenseMatMulAdjointB(value bool) SparseTensorDenseMatMulAttr { - return func(m optionalAttr) { - m["adjoint_b"] = value - } -} - -// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". +// tensors. // -// No validity checking is performed on the indices of A. However, the following -// input format is recommended for optimal behavior: +// This op can be used to override the gradient for complicated functions. For +// example, suppose y = f(x) and we wish to apply a custom function g for backprop +// such that dx = g(dy). In Python, // -// if adjoint_a == false: -// A should be sorted in lexicographically increasing order. Use SparseReorder -// if you're not sure. -// if adjoint_a == true: -// A should be sorted in order of increasing dimension 1 (i.e., "column major" -// order instead of "row major" order). +// ```python +// with tf.get_default_graph().gradient_override_map( +// {'IdentityN': 'OverrideGradientWithG'}): +// y, _ = identity_n([f(x), x]) // -// Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, size `[nnz, 2]` Matrix. -// a_values: 1-D. The `values` of the `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the `SparseTensor`, size `[2]` Vector. -// b: 2-D. A dense Matrix. -func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output, optional ...SparseTensorDenseMatMulAttr) (product tf.Output) { +// @tf.RegisterGradient('OverrideGradientWithG') +// def ApplyG(op, dy, _): +// return [None, g(dy)] # Do not backprop to f(x). +// ``` +func IdentityN(scope *Scope, input []tf.Output) (output []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseTensorDenseMatMul", + Type: "IdentityN", Input: []tf.Input{ - a_indices, a_values, a_shape, b, + tf.OutputList(input), }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("IdentityN", err) + return + } + return output } -// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. -type ResourceApplyRMSPropAttr func(optionalAttr) +// ResourceApplyCenteredRMSPropAttr is an optional argument to ResourceApplyCenteredRMSProp. +type ResourceApplyCenteredRMSPropAttr func(optionalAttr) -// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { +func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the RMSProp algorithm. +// Update '*var' according to the centered RMSProp algorithm. // -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. +// +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, // and mom will not update in iterations during which the grad is zero. // // mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// mean_grad = decay * mean_grad + (1-decay) * gradient +// +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) // +// mg <- rho * mg_{t-1} + (1-rho) * grad // ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) // var <- var - mom // // Arguments: // var_: Should be from a Variable(). +// mg: Should be from a Variable(). // ms: Should be from a Variable(). // mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. @@ -17865,7 +22568,7 @@ func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { // grad: The gradient. // // Returns the created operation. -func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { +func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyCenteredRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -17874,82 +22577,63 @@ func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyRMSProp", + Type: "ResourceApplyCenteredRMSProp", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Store the input tensor in the state of the current session. -// -// Arguments: -// value: The tensor to be stored. -// -// Returns The handle for the tensor stored in the session state, represented -// as a ResourceHandle object. -func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "GetSessionHandleV2", - Input: []tf.Input{ - value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. -type ResourceApplyAdamAttr func(optionalAttr) +// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. +type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) -// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { +func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// Update '*var' according to the centered RMSProp algorithm. // -// value: If `True`, uses the nesterov update. -// If not specified, defaults to false -func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the Adam algorithm. +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. // -// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ -// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ -// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ -// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, +// and mom will not update in iterations during which the grad is zero. // -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// mean_grad = decay * mean_grad + (1-decay) * gradient +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// mg: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// // epsilon: Ridge term. Must be a scalar. // grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { +func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -17958,127 +22642,160 @@ func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, b a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdam", + Type: "ResourceSparseApplyCenteredRMSProp", Input: []tf.Input{ - var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// SerializeManySparseAttr is an optional argument to SerializeManySparse. -type SerializeManySparseAttr func(optionalAttr) +// Creates a dataset that batches `batch_size` elements from `input_dataset`. +// +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// +// +func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "BatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// SerializeManySparseOutType sets the optional out_type attribute to value. +// LoadTPUEmbeddingAdadeltaParametersAttr is an optional argument to LoadTPUEmbeddingAdadeltaParameters. +type LoadTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdadeltaParametersTableId(value int64) LoadTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { - m["out_type"] = value + m["table_id"] = value } } -// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. -// -// The `SparseTensor` must have rank `R` greater than 1, and the first dimension -// is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The serialized -// `SparseTensor` objects going into each row of `serialized_sparse` will have -// rank `R-1`. +// LoadTPUEmbeddingAdadeltaParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingAdadeltaParametersTableName(value string) LoadTPUEmbeddingAdadeltaParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Adadelta embedding parameters. // -// The minibatch size `N` is extracted from `sparse_shape[0]`. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { +// parameters: Value of parameters used in the Adadelta optimization algorithm. +// accumulators: Value of accumulators used in the Adadelta optimization algorithm. +// updates: Value of updates used in the Adadelta optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingAdadeltaParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, updates tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdadeltaParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SerializeManySparse", + Type: "LoadTPUEmbeddingAdadeltaParameters", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + parameters, accumulators, updates, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Computes inverse hyperbolic cosine of x element-wise. -func Acosh(scope *Scope, x tf.Output) (y tf.Output) { +// Converts each string in the input Tensor to its hash mod by a number of buckets. +// +// The hash function is deterministic on the content of the string within the +// process and will never change. However, it is not suitable for cryptography. +// This function may be used when CPU time is scarce and inputs are trusted or +// unimportant. There is a risk of adversaries constructing inputs that all hash +// to the same bucket. To prevent this problem, use a strong hash function with +// `tf.string_to_hash_bucket_strong`. +// +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "Acosh", + Type: "StringToHashBucketFast", Input: []tf.Input{ - x, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorArrayV2Attr is an optional argument to TensorArrayV2. -type TensorArrayV2Attr func(optionalAttr) - -// TensorArrayV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayV2ElementShape(value tf.Shape) TensorArrayV2Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// TensorArrayV2DynamicSize sets the optional dynamic_size attribute to value. -// If not specified, defaults to false -func TensorArrayV2DynamicSize(value bool) TensorArrayV2Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV2ClearAfterRead sets the optional clear_after_read attribute to value. -// If not specified, defaults to true -func TensorArrayV2ClearAfterRead(value bool) TensorArrayV2Attr { - return func(m optionalAttr) { - m["clear_after_read"] = value - } -} +// RealAttr is an optional argument to Real. +type RealAttr func(optionalAttr) -// TensorArrayV2TensorArrayName sets the optional tensor_array_name attribute to value. -// If not specified, defaults to "" -func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr { +// RealTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func RealTout(value tf.DataType) RealAttr { return func(m optionalAttr) { - m["tensor_array_name"] = value + m["Tout"] = value } } -// Deprecated. Use TensorArrayV3 +// Returns the real part of a complex number. // -// DEPRECATED at GraphDef version 26: Use TensorArrayV3 -func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) { +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the real part of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real +// part returned by this operation and *b* is the imaginary part. +// +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.real(input) ==> [-2.25, 3.25] +// ``` +func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayV2", + Type: "Real", Input: []tf.Input{ - size, + input, }, Attrs: attrs, } @@ -18086,389 +22803,410 @@ func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ... return op.Output(0) } -// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. -type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) +// AudioSummaryAttr is an optional argument to AudioSummary. +type AudioSummaryAttr func(optionalAttr) -// ThreadUnsafeUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ThreadUnsafeUnigramCandidateSamplerSeed(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// ThreadUnsafeUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// value: Max number of batch elements to generate audio for. +// If not specified, defaults to 3 // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ThreadUnsafeUnigramCandidateSamplerSeed2(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { +// REQUIRES: value >= 1 +func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr { return func(m optionalAttr) { - m["seed2"] = value + m["max_outputs"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. +// Outputs a `Summary` protocol buffer with audio. // -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. +// DEPRECATED at GraphDef version 15: Use AudioSummaryV2. // -// For each batch, this op picks a single set of sampled candidate labels. +// The summary has up to `max_outputs` summary values containing audio. The +// audio is built from `tensor` which must be 3-D with shape `[batch_size, +// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: +// +// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +// * If `max_outputs` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 2-D of shape `[batch_size, frames]`. +// sample_rate: The sample rate of the signal in hertz. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func ThreadUnsafeUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...ThreadUnsafeUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{"sample_rate": sample_rate} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ThreadUnsafeUnigramCandidateSampler", + Type: "AudioSummary", Input: []tf.Input{ - true_classes, + tag, tensor, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// MaxPoolV2Attr is an optional argument to MaxPoolV2. -type MaxPoolV2Attr func(optionalAttr) +// QrAttr is an optional argument to Qr. +type QrAttr func(optionalAttr) -// MaxPoolV2DataFormat sets the optional data_format attribute to value. +// QrFullMatrices sets the optional full_matrices attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolV2DataFormat(value string) MaxPoolV2Attr { +// value: If true, compute full-sized `q` and `r`. If false +// (the default), compute only the leading `P` columns of `q`. +// If not specified, defaults to false +func QrFullMatrices(value bool) QrAttr { return func(m optionalAttr) { - m["data_format"] = value + m["full_matrices"] = value } } -// Performs max pooling on the input. +// Computes the QR decompositions of one or more matrices. +// +// Computes the QR decomposition of each inner matrix in `tensor` such that +// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` +// +// ```python +// # a is a tensor. +// # q is a tensor of orthonormal matrices. +// # r is a tensor of upper triangular matrices. +// q, r = qr(a) +// q_full, r_full = qr(a, full_matrices=True) +// ``` // // Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. // -// Returns The max pooled output tensor. -func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolV2Attr) (output tf.Output) { +// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then +// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is +// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is +// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`. +func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolV2", + Type: "Qr", Input: []tf.Input{ - input, ksize, strides, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. -type AddManySparseToTensorsMapAttr func(optionalAttr) +// TensorArrayV3Attr is an optional argument to TensorArrayV3. +type TensorArrayV3Attr func(optionalAttr) -// AddManySparseToTensorsMapContainer sets the optional container attribute to value. +// TensorArrayV3ElementShape sets the optional element_shape attribute to value. // -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddManySparseToTensorsMapContainer(value string) AddManySparseToTensorsMapAttr { +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { return func(m optionalAttr) { - m["container"] = value + m["element_shape"] = value } } -// AddManySparseToTensorsMapSharedName sets the optional shared_name attribute to value. +// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. // -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. -// If not specified, defaults to "" -func AddManySparseToTensorsMapSharedName(value string) AddManySparseToTensorsMapAttr { +// value: A boolean that determines whether writes to the TensorArray +// are allowed to grow the size. By default, this is not allowed. +// If not specified, defaults to false +func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["dynamic_size"] = value } } -// Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. -// -// A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`, where +// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. // -// ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` +// value: If true (default), Tensors in the TensorArray are cleared +// after being read. This disables multiple read semantics but allows early +// release of memory. +// If not specified, defaults to true +func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["clear_after_read"] = value + } +} + +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. // -// An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` -// having a first `sparse_indices` column taking values between `[0, N)`, where -// the minibatch size `N == sparse_shape[0]`. +// value: If true (default is false), then all +// elements in the TensorArray will be expected to have have identical shapes. +// This allows certain behaviors, like dynamically checking for +// consistent shapes on write, and being able to fill in properly +// shaped zero tensors on stack -- even if the element_shape attribute +// is not fully defined. +// If not specified, defaults to false +func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["identical_element_shapes"] = value + } +} + +// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. // -// The input `SparseTensor` must have rank `R` greater than 1, and the first -// dimension is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The stored -// `SparseTensor` objects pointed to by each row of the output `sparse_handles` -// will have rank `R-1`. +// value: Overrides the name used for the temporary tensor_array +// resource. Default value is the name of the 'TensorArray' op (which +// is guaranteed unique). +// If not specified, defaults to "" +func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { + return func(m optionalAttr) { + m["tensor_array_name"] = value + } +} + +// An array of Tensors of given size. // -// The `SparseTensor` values can then be read out as part of a minibatch by passing -// the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddManySparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// Write data via Write and read via Read or Pack. // // Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// `sparse_indices[:, 0]` must be ordered values in `[0, N)`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -// The minibatch size `N == sparse_shape[0]`. +// size: The size of the array. +// dtype: The type of the elements on the tensor_array. // -// Returns 1-D. The handles of the `SparseTensor` now stored in the -// `SparseTensorsMap`. Shape: `[N]`. -func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddManySparseToTensorsMapAttr) (sparse_handles tf.Output) { +// Returns The handle to the TensorArray.A scalar used to control gradient flow. +func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AddManySparseToTensorsMap", + Type: "TensorArrayV3", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Concatenates tensors along one dimension. -// -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatV2", - Input: []tf.Input{ - tf.OutputList(values), axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { +// Returns the truth value of NOT x element-wise. +func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReadFile", + Type: "LogicalNot", Input: []tf.Input{ - filename, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns immutable tensor from memory region. -// -// The current implementation memmaps the tensor from a file. +// 3D real-valued fast Fourier transform. // -// Arguments: -// dtype: Type of the returned tensor. -// shape: Shape of the returned tensor. -// memory_region_name: Name of readonly memory region used by the tensor, see -// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. -func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} - opspec := tf.OpSpec{ - Type: "ImmutableConst", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StringJoinAttr is an optional argument to StringJoin. -type StringJoinAttr func(optionalAttr) - -// StringJoinSeparator sets the optional separator attribute to value. +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 3 dimensions of `input`. // -// value: string, an optional join separator. -// If not specified, defaults to "" -func StringJoinSeparator(value string) StringJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } -} - -// Joins the strings in the given list of string tensors into one tensor; +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. // -// with the given separator (default is an empty separator). +// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// inputs: A list of string tensors. The tensors must all have the same shape, -// or be scalars. Scalars may be mixed in; these will be broadcast to the shape -// of non-scalar inputs. -func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the their 3D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfftn with 3 dimensions. +// @end_compatibility +func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StringJoin", + Type: "RFFT3D", Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates and returns an empty tensor list. -// -// All list elements must be tensors of dtype element_dtype and shape compatible -// with element_shape. -// -// handle: an empty tensor list. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func EmptyTensorList(scope *Scope, element_shape tf.Output, max_num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes rectified linear: `max(features, 0)`. +func Relu(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "EmptyTensorList", + Type: "Relu", Input: []tf.Input{ - element_shape, max_num_elements, + features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softsign gradients for a softsign operation. +// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. +type ResourceApplyAddSignAttr func(optionalAttr) + +// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AddSign update. +// +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- (alpha + sign_decay * sign(g) *sign(m)) * g +// variable <- variable - lr_t * update // // Arguments: -// gradients: The backpropagated gradients to the corresponding softsign operation. -// features: The features passed as input to the corresponding softsign operation. +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// alpha: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. // -// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. -func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// Returns the created operation. +func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SoftsignGrad", + Type: "ResourceApplyAddSign", Input: []tf.Input{ - gradients, features, + var_, m, lr, alpha, sign_decay, beta, grad, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Provides the time since epoch in seconds. +// Divides sparse updates into the variable referenced by `resource`. // -// Returns the timestamp as a `float64` for seconds since the Unix epoch. +// This operation computes // -// Note: the timestamp is computed when the op is executed, not when it is added -// to the graph. -func Timestamp(scope *Scope) (ts tf.Output) { +// # Scalar indices +// ref[indices, ...] /= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] /= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions multiply. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Timestamp", + Type: "ResourceScatterDiv", + Input: []tf.Input{ + resource, indices, updates, + }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// VariableShapeAttr is an optional argument to VariableShape. -type VariableShapeAttr func(optionalAttr) +// ListDiffAttr is an optional argument to ListDiff. +type ListDiffAttr func(optionalAttr) -// VariableShapeOutType sets the optional out_type attribute to value. +// ListDiffOutIdx sets the optional out_idx attribute to value. // If not specified, defaults to DT_INT32 -func VariableShapeOutType(value tf.DataType) VariableShapeAttr { +func ListDiffOutIdx(value tf.DataType) ListDiffAttr { return func(m optionalAttr) { - m["out_type"] = value + m["out_idx"] = value } } -// Returns the shape of the variable pointed to by `resource`. +// Computes the difference between two lists of numbers or strings. // -// This operation returns a 1-D integer tensor representing the shape of `input`. +// Given a list `x` and a list `y`, this operation returns a list `out` that +// represents all values that are in `x` but not in `y`. The returned list `out` +// is sorted in the same order that the numbers appear in `x` (duplicates are +// preserved). This operation also returns a list `idx` that represents the +// position of each `out` element in `x`. In other words: // -// For example: +// `out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` +// +// For example, given this input: // // ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] +// x = [1, 2, 3, 4, 5, 6] +// y = [1, 3, 5] // ``` -func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { +// +// This operation would return: +// +// ``` +// out ==> [2, 4, 6] +// idx ==> [1, 3, 5] +// ``` +// +// Arguments: +// x: 1-D. Values to keep. +// y: 1-D. Values to remove. +// +// Returns 1-D. Values present in `x` but not in `y`.1-D. Positions of `x` values preserved in `out`. +func ListDiff(scope *Scope, x tf.Output, y tf.Output, optional ...ListDiffAttr) (out tf.Output, idx tf.Output) { if scope.Err() != nil { return } @@ -18477,495 +23215,391 @@ func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) a(attrs) } opspec := tf.OpSpec{ - Type: "VariableShape", + Type: "ListDiff", Input: []tf.Input{ - input, + x, y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// AvgPoolGradAttr is an optional argument to AvgPoolGrad. -type AvgPoolGradAttr func(optionalAttr) +// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingAdadeltaParametersGradAccumDebug. +type LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) -// AvgPoolGradDataFormat sets the optional data_format attribute to value. +// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["data_format"] = value + m["table_id"] = value } } -// Computes gradients of the average pooling function. +// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Adadelta parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. -// the output of `avg_pool`. -// ksize: The size of the sliding window for each dimension of the input. -// strides: The stride of the sliding window for each dimension of the input. -// padding: The type of padding algorithm to use. +// parameters: Value of parameters used in the Adadelta optimization algorithm. +// accumulators: Value of accumulators used in the Adadelta optimization algorithm. +// updates: Value of updates used in the Adadelta optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Adadelta optimization algorithm. // -// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. -func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { +// +// +// Returns the created operation. +func LoadTPUEmbeddingAdadeltaParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, updates tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPoolGrad", + Type: "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", Input: []tf.Input{ - orig_input_shape, grad, + parameters, accumulators, updates, gradient_accumulators, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high overlaps -// with previously selected boxes. Bounding boxes with score less than -// `score_threshold` are removed. N-by-n overlap values are supplied as square matrix, -// which allows for defining a custom overlap criterium (eg. intersection over union, -// intersection over area, etc.). -// -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: +// Returns a serialized GraphDef representing `input_dataset`. // -// selected_indices = tf.image.non_max_suppression_with_overlaps( -// overlaps, scores, max_output_size, overlap_threshold, score_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) +// Returns a graph representation for `input_dataset`. // // Arguments: -// overlaps: A 2-D float tensor of shape `[num_boxes, num_boxes]` representing -// the n-by-n box overlap values. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// overlap_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too. -// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove -// boxes based on score. +// input_dataset: A variant tensor representing the dataset to return the graph representation for. // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionWithOverlaps(scope *Scope, overlaps tf.Output, scores tf.Output, max_output_size tf.Output, overlap_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { +// Returns The graph representation of the dataset (as serialized GraphDef). +func DatasetToGraph(scope *Scope, input_dataset tf.Output) (graph tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NonMaxSuppressionWithOverlaps", + Type: "DatasetToGraph", Input: []tf.Input{ - overlaps, scores, max_output_size, overlap_threshold, score_threshold, + input_dataset, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softmax cross entropy cost and gradients to backpropagate. -// -// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept -// a matrix of label probabilities, but rather a single label per row -// of features. This label is considered to have probability 1.0 for the -// given row. -// -// Inputs are the logits, not probabilities. -// -// Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size vector with values in [0, num_classes). -// This is the label for the given minibatch entry. -// -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSoftmaxCrossEntropyWithLogits", - Input: []tf.Input{ - features, labels, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} +// MatrixSolveAttr is an optional argument to MatrixSolve. +type MatrixSolveAttr func(optionalAttr) -// Returns the truth value of NOT x element-wise. -func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalNot", - Input: []tf.Input{ - x, - }, +// MatrixSolveAdjoint sets the optional adjoint attribute to value. +// +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. +// If not specified, defaults to false +func MatrixSolveAdjoint(value bool) MatrixSolveAttr { + return func(m optionalAttr) { + m["adjoint"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// 3D real-valued fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 3 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. +// Solves systems of linear equations. // -// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `True` then each output matrix satisfies +// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. // // Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. // -// @compatibility(numpy) -// Equivalent to np.fft.rfftn with 3 dimensions. -// @end_compatibility -func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Returns Shape is `[..., M, K]`. +func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RFFT3D", + Type: "MatrixSolve", Input: []tf.Input{ - input, fft_length, + matrix, rhs, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorArrayV3Attr is an optional argument to TensorArrayV3. -type TensorArrayV3Attr func(optionalAttr) - -// TensorArrayV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} +// ResourceApplyKerasMomentumAttr is an optional argument to ResourceApplyKerasMomentum. +type ResourceApplyKerasMomentumAttr func(optionalAttr) -// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. +// ResourceApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. // -// value: A boolean that determines whether writes to the TensorArray -// are allowed to grow the size. By default, this is not allowed. +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. -// -// value: If true (default), Tensors in the TensorArray are cleared -// after being read. This disables multiple read semantics but allows early -// release of memory. -// If not specified, defaults to true -func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { +func ResourceApplyKerasMomentumUseLocking(value bool) ResourceApplyKerasMomentumAttr { return func(m optionalAttr) { - m["clear_after_read"] = value + m["use_locking"] = value } } -// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// ResourceApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// value: If true (default is false), then all -// elements in the TensorArray will be expected to have have identical shapes. -// This allows certain behaviors, like dynamically checking for -// consistent shapes on write, and being able to fill in properly -// shaped zero tensors on stack -- even if the element_shape attribute -// is not fully defined. +// value: If `True`, the tensor passed to compute grad will be +// var + momentum * accum, so in the end, the var you get is actually +// var + momentum * accum. // If not specified, defaults to false -func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { +func ResourceApplyKerasMomentumUseNesterov(value bool) ResourceApplyKerasMomentumAttr { return func(m optionalAttr) { - m["identical_element_shapes"] = value + m["use_nesterov"] = value } } -// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. +// Update '*var' according to the momentum scheme. Set use_nesterov = True if you // -// value: Overrides the name used for the temporary tensor_array -// resource. Default value is the name of the 'TensorArray' op (which -// is guaranteed unique). -// If not specified, defaults to "" -func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { - return func(m optionalAttr) { - m["tensor_array_name"] = value - } -} - -// An array of Tensors of given size. +// want to use Nesterov momentum. // -// Write data via Write and read via Read or Pack. +// accum = accum * momentum - lr * grad +// var += accum // // Arguments: -// size: The size of the array. -// dtype: The type of the elements on the tensor_array. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// momentum: Momentum. Must be a scalar. // -// Returns The handle to the TensorArray.A scalar used to control gradient flow. -func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { +// Returns the created operation. +func ResourceApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyKerasMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayV3", + Type: "ResourceApplyKerasMomentum", Input: []tf.Input{ - size, + var_, accum, lr, grad, momentum, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Runs multiple additive regression ensemble predictors on input instances and -// -// computes the logits. It is designed to be used during prediction. -// It traverses all the trees and calculates the final score for each instance. -// -// Arguments: -// -// bucketized_features: A list of rank 1 Tensors containing bucket id for each -// feature. -// logits_dimension: scalar, dimension of the logits, to be used for partial logits -// shape. -// -// Returns Output rank 2 Tensor containing logits for each example. -func BoostedTreesPredict(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (logits tf.Output) { +// Return a tensor with the same shape and contents as the input tensor or value. +func Identity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"logits_dimension": logits_dimension} opspec := tf.OpSpec{ - Type: "BoostedTreesPredict", + Type: "Identity", Input: []tf.Input{ - tree_ensemble_handle, tf.OutputList(bucketized_features), + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Elementwise computes the bitwise OR of `x` and `y`. +// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. // -// The result will have those bits set, that are set in `x`, `y` or both. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// This is the angle \( \theta \in [-\pi, \pi] \) such that +// \[ x = r \cos(\theta) \] +// and +// \[ y = r \sin(\theta) \] +// where \(r = \sqrt(x^2 + y^2) \). +func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BitwiseOr", + Type: "Atan2", Input: []tf.Input{ - x, y, + y, x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) - -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { - return func(m optionalAttr) { - m["fast"] = value - } -} - -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same -// type as `matrix` and shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations -// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` -// in the least squares sense. -// -// We use the following notation for (complex) matrix and right-hand sides -// in the batch: -// -// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), -// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), -// `output`=\\(X \in \mathbb{C}^{n \times k}\\), -// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). -// -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). -// If \\(m \lt n\\) then `output` is computed as -// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), -// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable -// when \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is -// sufficiently large. +// Updates specified rows with values in `v`. // -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. +// Computes `x[i, :] = v; return x`. // // Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. -// -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility +// x: A tensor of type `T`. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. // -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MatrixSolveLs", + Type: "InplaceUpdate", Input: []tf.Input{ - matrix, rhs, l2_regularizer, + x, i, v, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MaxPool3DAttr is an optional argument to MaxPool3D. -type MaxPool3DAttr func(optionalAttr) +// OutfeedDequeueTupleAttr is an optional argument to OutfeedDequeueTuple. +type OutfeedDequeueTupleAttr func(optionalAttr) -// MaxPool3DDataFormat sets the optional data_format attribute to value. +// OutfeedDequeueTupleDeviceOrdinal sets the optional device_ordinal attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DDataFormat(value string) MaxPool3DAttr { +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func OutfeedDequeueTupleDeviceOrdinal(value int64) OutfeedDequeueTupleAttr { return func(m optionalAttr) { - m["data_format"] = value + m["device_ordinal"] = value } } -// Performs 3D max pooling on the input. +// Retrieve multiple values from the computation outfeed. +// +// This operation will block indefinitely until data is available. Output `i` +// corresponds to XLA tuple element `i`. // // Arguments: -// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// dtypes: The element types of each element in `outputs`. +// shapes: The shapes of each tensor in `outputs`. // -// Returns The max pooled output tensor. -func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DAttr) (output tf.Output) { +// Returns A list of tensors that will be read from the outfeed. +func OutfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape, optional ...OutfeedDequeueTupleAttr) (outputs []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3D", + Type: "OutfeedDequeueTuple", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("OutfeedDequeueTuple", err) + return + } + return outputs +} + +// Identity op for gradient debugging. +// +// This op is hidden from public in Python. It is used by TensorFlow Debugger to +// register gradient tensors for gradient debugging. +// This op operates on non-reference-type tensors. +func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DebugGradientIdentity", Input: []tf.Input{ input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. -type Conv3DBackpropInputAttr func(optionalAttr) +// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. +type ResourceSparseApplyAdadeltaAttr func(optionalAttr) -// Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to -func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { +// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { return func(m optionalAttr) { - m["dilations"] = value + m["use_locking"] = value } } -// Computes the gradients of 3-D convolution with respect to the input. -// -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// var: Should be from a Variable(). // // Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { +// +// accum: Should be from a Variable(). +// accum_update: : Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropInput", + Type: "ResourceSparseApplyAdadelta", Input: []tf.Input{ - input, filter, out_backprop, + var_, accum, accum_update, lr, rho, epsilon, grad, indices, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Returns which elements of x are NaN. +// +// @compatibility(numpy) +// Equivalent to np.isnan +// @end_compatibility +func IsNan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsNan", + Input: []tf.Input{ + x, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } @@ -19040,80 +23674,206 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. The hash function is a keyed hash function, where attribute `key` -// defines the key of the hash function. `key` is an array of 2 elements. +// MapUnstageAttr is an optional argument to MapUnstage. +type MapUnstageAttr func(optionalAttr) + +// MapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// A strong hash is important when inputs may be malicious, e.g. URLs with -// additional components. Adversaries could try to make their inputs hash to the -// same bucket for a denial-of-service attack or to skew the results. A strong -// hash prevents this by making it difficult, if not infeasible, to compute inputs -// that hash to the same bucket. This comes at a cost of roughly 4x higher compute -// time than `tf.string_to_hash_bucket_fast`. +// REQUIRES: value >= 0 +func MapUnstageCapacity(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// key: The key for the keyed hash function passed as a list of two uint64 -// elements. +// REQUIRES: value >= 0 +func MapUnstageMemoryLimit(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageContainer(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageSharedName(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the values associated with the key // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "StringToHashBucketStrong", + Type: "MapUnstage", Input: []tf.Input{ - input, + key, indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstage", err) + return + } + return values } -// StringLengthAttr is an optional argument to StringLength. -type StringLengthAttr func(optionalAttr) - -// StringLengthUnit sets the optional unit attribute to value. +// An op enabling differentiation of TPU Embeddings. // -// value: The unit that is counted to compute string length. One of: `"BYTE"` (for -// the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8 -// encoded Unicode code points in each string). Results are undefined -// if `unit=UTF8_CHAR` and the `input` strings do not contain structurally -// valid UTF-8. -// If not specified, defaults to "BYTE" -func StringLengthUnit(value string) StringLengthAttr { - return func(m optionalAttr) { - m["unit"] = value +// This op simply returns its first input, which is assumed to have been sliced +// from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of +// this op, and its first argument being a trainable Variable, enables automatic +// differentiation of graphs containing embeddings via the TPU Embedding Python +// libraries. +// +// Arguments: +// embedding_variable: A trainable variable, enabling optimizers to find this op. +// sliced_activations: The embedding activations Tensor to return. +// table_id: The id of the table in the embedding layer configuration from which +// these activations were computed. +// lookup_id: Identifier of the set of embedding indices which produced these +// activations. +func TPUEmbeddingActivations(scope *Scope, embedding_variable tf.Output, sliced_activations tf.Output, table_id int64, lookup_id int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"table_id": table_id, "lookup_id": lookup_id} + opspec := tf.OpSpec{ + Type: "TPUEmbeddingActivations", + Input: []tf.Input{ + embedding_variable, sliced_activations, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// String lengths of `input`. +// BatchToSpace for 4-D tensors of type T. // -// Computes the length of each string given in the input tensor. +// This is a legacy version of the more general BatchToSpaceND. +// +// Rearranges (permutes) data from batch into blocks of spatial data, followed by +// cropping. This is the reverse transformation of SpaceToBatch. More specifically, +// this op outputs a copy of the input tensor where values from the `batch` +// dimension are moved in spatial blocks to the `height` and `width` dimensions, +// followed by cropping along the `height` and `width` dimensions. // // Arguments: -// input: The string for which to compute the length. +// input: 4-D tensor with shape +// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, +// depth]`. Note that the batch size of the input tensor must be divisible by +// `block_size * block_size`. +// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies +// how many elements to crop from the intermediate result across the spatial +// dimensions as follows: // -// Returns Integer tensor that has the same shape as `input`. The output contains the -// element-wise string lengths of `input`. -func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) { +// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] +// +// +// Returns 4-D with shape `[batch, height, width, depth]`, where: +// +// height = height_pad - crop_top - crop_bottom +// width = width_pad - crop_left - crop_right +// +// The attr `block_size` must be greater than one. It indicates the block size. +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]] +// ``` +// +// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], +// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] +// ``` +// +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[5], [7]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"block_size": block_size} opspec := tf.OpSpec{ - Type: "StringLength", + Type: "BatchToSpace", Input: []tf.Input{ - input, + input, crops, }, Attrs: attrs, } @@ -19121,212 +23881,197 @@ func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) ( return op.Output(0) } -// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. -type ResourceApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value +// Produces a summary of any statistics recorded by the given statistics manager. +func ExperimentalStatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExperimentalStatsAggregatorSummary", + Input: []tf.Input{ + iterator, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Update '*var' as FOBOS algorithm with fixed learning rate. -// -// prox_v = var - alpha * delta -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// Makes a new iterator from the given `dataset` and stores it in `iterator`. // -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// delta: The change. +// This operation may be executed multiple times. Each execution will reset the +// iterator in `iterator` to the first element of `dataset`. // // Returns the created operation. -func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { +func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyProximalGradientDescent", + Type: "MakeIterator", Input: []tf.Input{ - var_, alpha, l1, l2, delta, + dataset, iterator, }, - Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns 0 if the denominator is zero. +// Component-wise divides a SparseTensor by a dense Tensor. // +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. // -// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DivNoNan", + Type: "SparseDenseCwiseDiv", Input: []tf.Input{ - x, y, + sp_indices, sp_values, sp_shape, dense, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnicodeDecodeWithOffsetsAttr is an optional argument to UnicodeDecodeWithOffsets. -type UnicodeDecodeWithOffsetsAttr func(optionalAttr) - -// UnicodeDecodeWithOffsetsErrors sets the optional errors attribute to value. +// Creates a dataset that batches and pads `batch_size` elements from the input. // -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeDecodeWithOffsetsErrors(value string) UnicodeDecodeWithOffsetsAttr { - return func(m optionalAttr) { - m["errors"] = value +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// +func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "PaddedBatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// UnicodeDecodeWithOffsetsReplacementChar sets the optional replacement_char attribute to value. +// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. +type ResourceApplyMomentumAttr func(optionalAttr) + +// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. // -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD or U+65533.) -// If not specified, defaults to 65533 -func UnicodeDecodeWithOffsetsReplacementChar(value int64) UnicodeDecodeWithOffsetsAttr { +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { return func(m optionalAttr) { - m["replacement_char"] = value + m["use_locking"] = value } } -// UnicodeDecodeWithOffsetsReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// value: Whether to replace the C0 control characters (00-1F) with the -// `replacement_char`. Default is false. +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. // If not specified, defaults to false -func UnicodeDecodeWithOffsetsReplaceControlCharacters(value bool) UnicodeDecodeWithOffsetsAttr { +func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { return func(m optionalAttr) { - m["replace_control_characters"] = value + m["use_nesterov"] = value } } -// Decodes each string in `input` into a sequence of Unicode code points. -// -// The character codepoints for all strings are returned using a single vector -// `char_values`, with strings expanded to characters in row-major order. -// Similarly, the character start byte offsets are returned using a single vector -// `char_to_byte_starts`, with strings expanded in row-major order. +// Update '*var' according to the momentum scheme. Set use_nesterov = True if you // -// The `row_splits` tensor indicates where the codepoints and start offsets for -// each input string begin and end within the `char_values` and -// `char_to_byte_starts` tensors. In particular, the values for the `i`th -// string (in row-major order) are stored in the slice -// `[row_splits[i]:row_splits[i+1]]`. Thus: +// want to use Nesterov momentum. // -// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th -// character in the `i`th string (in row-major order). -// * `char_to_bytes_starts[row_splits[i]+j]` is the start byte offset for the `j`th -// character in the `i`th string (in row-major order). -// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th -// string (in row-major order). +// accum = accum * momentum + grad +// var -= lr * accum // // Arguments: -// input: The text to be decoded. Can have any shape. Note that the output is flattened -// to a vector of char values. -// input_encoding: Text encoding of the input strings. This is any of the encodings supported -// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// momentum: Momentum. Must be a scalar. // -// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints.A 1D int32 Tensor containing the byte index in the input string where each -// character in `char_values` starts. -func UnicodeDecodeWithOffsets(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeWithOffsetsAttr) (row_splits tf.Output, char_values tf.Output, char_to_byte_starts tf.Output) { +// Returns the created operation. +func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"input_encoding": input_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeDecodeWithOffsets", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns x - y element-wise. -// -// *NOTE*: `Subtract` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "Sub", + Type: "ResourceApplyMomentum", Input: []tf.Input{ - x, y, + var_, accum, lr, grad, momentum, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RandomPoissonAttr is an optional argument to RandomPoisson. -type RandomPoissonAttr func(optionalAttr) - -// RandomPoissonSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. +type MaxPoolGradGradAttr func(optionalAttr) -// RandomPoissonSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed2(value int64) RandomPoissonAttr { +// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { return func(m optionalAttr) { - m["seed2"] = value + m["data_format"] = value } } -// Use RandomPoissonV2 instead. +// Computes second-order gradients of the maxpooling function. // -// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 -func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomPoisson", + Type: "MaxPoolGradGrad", Input: []tf.Input{ - shape, rate, + orig_input, orig_output, grad, }, Attrs: attrs, } @@ -19334,479 +24079,397 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra return op.Output(0) } -// Returns the max of x and y (i.e. x > y ? x : y) element-wise. +// Returns the last element of the input list as well as a list with all but that element. // -// *NOTE*: `Maximum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Fails if the list is empty. +// +// input_handle: the input list +// tensor: the withdrawn last element of the list +// element_dtype: the type of elements in the list +// element_shape: the shape of the output tensor +func TensorListPopBack(scope *Scope, input_handle tf.Output, element_shape tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "Maximum", + Type: "TensorListPopBack", Input: []tf.Input{ - x, y, + input_handle, element_shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Computes softmax cross entropy cost and gradients to backpropagate. +// Determine the script codes of a given tensor of Unicode integer code points. // -// Inputs are the logits, not probabilities. +// This operation converts Unicode code points to script codes corresponding to +// each code point. Script codes correspond to International Components for +// Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html. +// Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will +// match input shape. // // Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size x num_classes matrix -// The caller must ensure that each batch of labels represents a valid -// probability distribution. +// input: A Tensor of int32 Unicode code points. // -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { +// Returns A Tensor of int32 script codes corresponding to each input code point. +func UnicodeScript(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SoftmaxCrossEntropyWithLogits", + Type: "UnicodeScript", Input: []tf.Input{ - features, labels, + input, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ReduceJoinAttr is an optional argument to ReduceJoin. -type ReduceJoinAttr func(optionalAttr) - -// ReduceJoinKeepDims sets the optional keep_dims attribute to value. -// -// value: If `True`, retain reduced dimensions with length `1`. -// If not specified, defaults to false -func ReduceJoinKeepDims(value bool) ReduceJoinAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// ReduceJoinSeparator sets the optional separator attribute to value. -// -// value: The separator to use when joining. -// If not specified, defaults to "" -func ReduceJoinSeparator(value string) ReduceJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } + return op.Output(0) } -// Joins a string Tensor across the given dimensions. +// Creates a sequence of numbers. // -// Computes the string join across dimensions in the given string Tensor of shape -// `[\\(d_0, d_1, ..., d_{n-1}\\)]`. Returns a new Tensor created by joining the input -// strings with the given separator (default: empty string). Negative indices are -// counted backwards from the end, with `-1` being equivalent to `n - 1`. If -// indices are not specified, joins across all dimensions beginning from `n - 1` -// through `0`. +// This operation creates a sequence of numbers that begins at `start` and +// extends by increments of `delta` up to but not including `limit`. // // For example: // -// ```python -// # tensor `a` is [["a", "b"], ["c", "d"]] -// tf.reduce_join(a, 0) ==> ["ac", "bd"] -// tf.reduce_join(a, 1) ==> ["ab", "cd"] -// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"] -// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"] -// tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] -// tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] -// tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] -// tf.reduce_join(a, [0, 1]) ==> "acbd" -// tf.reduce_join(a, [1, 0]) ==> "abcd" -// tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]] -// tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd" +// ``` +// # 'start' is 3 +// # 'limit' is 18 +// # 'delta' is 3 +// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] // ``` // // Arguments: -// inputs: The input to be joined. All reduced indices must have non-zero size. -// reduction_indices: The dimensions to reduce over. Dimensions are reduced in the -// order specified. Omitting `reduction_indices` is equivalent to passing -// `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported. +// start: 0-D (scalar). First entry in the sequence. +// limit: 0-D (scalar). Upper limit of sequence, exclusive. +// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. // -// Returns Has shape equal to that of the input with reduced dimensions removed or -// set to `1` depending on `keep_dims`. -func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, optional ...ReduceJoinAttr) (output tf.Output) { +// Returns 1-D. +func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ReduceJoin", + Type: "Range", Input: []tf.Input{ - inputs, reduction_indices, + start, limit, delta, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DenseToSparseSetOperationAttr is an optional argument to DenseToSparseSetOperation. -type DenseToSparseSetOperationAttr func(optionalAttr) +// MaxPoolGradGradWithArgmaxAttr is an optional argument to MaxPoolGradGradWithArgmax. +type MaxPoolGradGradWithArgmaxAttr func(optionalAttr) -// DenseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToSparseSetOperationValidateIndices(value bool) DenseToSparseSetOperationAttr { +// MaxPoolGradGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolGradGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradGradWithArgmaxAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["include_batch_in_index"] = value } } -// Applies set operation along last dimension of `Tensor` and `SparseTensor`. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. -// -// If `validate_indices` is `True`, this op validates the order and range of `set2` -// indices. -// -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// Computes second-order gradients of the maxpooling function. // // Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as the 1st `n-1` dimensions of `set1`, `result_shape[n]` is the -// max set size across `n-1` dimensions. -// +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// input of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToSparseSetOperation(scope *Scope, set1 tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...DenseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// Returns Gradients of gradients w.r.t. the input of `max_pool`. +func MaxPoolGradGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradWithArgmaxAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DenseToSparseSetOperation", + Type: "MaxPoolGradGradWithArgmax", Input: []tf.Input{ - set1, set2_indices, set2_values, set2_shape, + input, grad, argmax, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// L2 Loss. +// Return a slice from 'input'. // -// Computes half the L2 norm of a tensor without the `sqrt`: +// The output tensor is a tensor with dimensions described by 'size' +// whose values are extracted from 'input' starting at the offsets in +// 'begin'. // -// output = sum(t ** 2) / 2 +// *Requirements*: +// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) // // Arguments: -// t: Typically 2-D, but may have any dimensions. // -// Returns 0-D. -func L2Loss(scope *Scope, t tf.Output) (output tf.Output) { +// begin: begin[i] specifies the offset into the 'i'th dimension of +// 'input' to slice from. +// size: size[i] specifies the number of elements of the 'i'th dimension +// of 'input' to slice. If size[i] is -1, all remaining elements in dimension +// i are included in the slice (i.e. this is equivalent to setting +// size[i] = input.dim_size(i) - begin[i]). +func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "L2Loss", + Type: "Slice", Input: []tf.Input{ - t, + input, begin, size, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes cos of x element-wise. -func Cos(scope *Scope, x tf.Output) (y tf.Output) { +// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). +// +// The Hurwitz zeta function is defined as: +// +// +// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) +func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cos", + Type: "Zeta", Input: []tf.Input{ - x, + x, q, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// FusedBatchNormGradAttr is an optional argument to FusedBatchNormGrad. -type FusedBatchNormGradAttr func(optionalAttr) - -// FusedBatchNormGradEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradEpsilon(value float32) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradDataFormat(value string) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormGradIsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradIsTraining(value bool) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Gradient for batch normalization. +// Returns the cardinality of `input_dataset`. // -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// Returns the cardinality of `input_dataset`. // // Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. +// input_dataset: A variant tensor representing the dataset to return cardinality for. // -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradAttr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { +// Returns The cardinality of `input_dataset`. Named constants are used to represent +// infinite and unknown cardinality. +func ExperimentalDatasetCardinality(scope *Scope, input_dataset tf.Output) (cardinality tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "FusedBatchNormGrad", + Type: "ExperimentalDatasetCardinality", Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, + input_dataset, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } -// TopKAttr is an optional argument to TopK. -type TopKAttr func(optionalAttr) +// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap. +type TakeManySparseFromTensorsMapAttr func(optionalAttr) -// TopKSorted sets the optional sorted attribute to value. +// TakeManySparseFromTensorsMapContainer sets the optional container attribute to value. // -// value: If true the resulting `k` elements will be sorted by the values in -// descending order. -// If not specified, defaults to true -func TopKSorted(value bool) TopKAttr { +// value: The container name for the `SparseTensorsMap` read by this op. +// If not specified, defaults to "" +func TakeManySparseFromTensorsMapContainer(value string) TakeManySparseFromTensorsMapAttr { return func(m optionalAttr) { - m["sorted"] = value + m["container"] = value } } -// Finds values and indices of the `k` largest elements for the last dimension. +// TakeManySparseFromTensorsMapSharedName sets the optional shared_name attribute to value. // -// DEPRECATED at GraphDef version 7: Use TopKV2 instead +// value: The shared name for the `SparseTensorsMap` read by this op. +// It should not be blank; rather the `shared_name` or unique Operation name +// of the Op that created the original `SparseTensorsMap` should be used. +// If not specified, defaults to "" +func TakeManySparseFromTensorsMapSharedName(value string) TakeManySparseFromTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. // -// If the input is a vector (rank-1), finds the `k` largest entries in the vector -// and outputs their values and indices as vectors. Thus `values[j]` is the -// `j`-th largest entry in `input`, and its index is `indices[j]`. +// The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where +// `N` is the minibatch size and the rows correspond to the output handles of +// `AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the +// original `SparseTensor` objects that went into the given input ops must all +// match. When the final `SparseTensor` is created, it has rank one +// higher than the ranks of the incoming `SparseTensor` objects +// (they have been concatenated along a new row dimension on the left). // -// For matrices (resp. higher rank input), computes the top `k` entries in each -// row (resp. vector along the last dimension). Thus, +// The output `SparseTensor` object's shape values for all dimensions but the +// first are the max across the input `SparseTensor` objects' shape values +// for the corresponding dimensions. Its first shape value is `N`, the minibatch +// size. // -// values.shape = indices.shape = input.shape[:-1] + [k] +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. // -// If two elements are equal, the lower-index element appears first. +// For example, if the handles represent an input, which is a `[2, 3]` matrix +// representing two original `SparseTensor` objects: // -// If `k` varies dynamically, use `TopKV2` below. +// ``` +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// ``` +// +// and +// +// ``` +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// ``` +// +// then the final `SparseTensor` will be: +// +// ``` +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// ``` // // Arguments: -// input: 1-D or higher with last dimension at least `k`. -// k: Number of top elements to look for along the last dimension (along each -// row for matrices). +// sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. +// Shape: `[N]`. +// dtype: The `dtype` of the `SparseTensor` objects stored in the +// `SparseTensorsMap`. // -// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. -func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values tf.Output, indices tf.Output) { +// Returns 2-D. The `indices` of the minibatch `SparseTensor`.1-D. The `values` of the minibatch `SparseTensor`.1-D. The `shape` of the minibatch `SparseTensor`. +func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype tf.DataType, optional ...TakeManySparseFromTensorsMapAttr) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"k": k} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TopK", + Type: "TakeManySparseFromTensorsMap", Input: []tf.Input{ - input, + sparse_handles, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// The gradient operator for the SparseAdd op. -// -// The SparseAdd op calculates A + B, where A, B, and the sum are all represented -// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. -// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty -// values of A and B. -// -// Arguments: -// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to -// the non-empty values of the sum. -// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. -// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. -// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size -// `[nnz(sum), ndims]`. +// NonDeterministicIntsAttr is an optional argument to NonDeterministicInts. +type NonDeterministicIntsAttr func(optionalAttr) + +// NonDeterministicIntsDtype sets the optional dtype attribute to value. // -// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the -// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the -// non-empty values of B. -func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseAddGrad", - Input: []tf.Input{ - backprop_val_grad, a_indices, b_indices, sum_indices, - }, +// value: The type of the output. +// If not specified, defaults to DT_INT64 +func NonDeterministicIntsDtype(value tf.DataType) NonDeterministicIntsAttr { + return func(m optionalAttr) { + m["dtype"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) } -// Returns a list of tensors with the same shapes and contents as the input +// Non-deterministically generates some integers. // -// tensors. -// -// This op can be used to override the gradient for complicated functions. For -// example, suppose y = f(x) and we wish to apply a custom function g for backprop -// such that dx = g(dy). In Python, +// This op may use some OS-provided source of non-determinism (e.g. an RNG), so each execution will give different results. // -// ```python -// with tf.get_default_graph().gradient_override_map( -// {'IdentityN': 'OverrideGradientWithG'}): -// y, _ = identity_n([f(x), x]) +// Arguments: +// shape: The shape of the output tensor. // -// @tf.RegisterGradient('OverrideGradientWithG') -// def ApplyG(op, dy, _): -// return [None, g(dy)] # Do not backprop to f(x). -// ``` -func IdentityN(scope *Scope, input []tf.Output) (output []tf.Output) { +// Returns Non-deterministic integer values with specified shape. +func NonDeterministicInts(scope *Scope, shape tf.Output, optional ...NonDeterministicIntsAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IdentityN", + Type: "NonDeterministicInts", Input: []tf.Input{ - tf.OutputList(input), + shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("IdentityN", err) - return - } - return output + return op.Output(0) } -// ResourceApplyCenteredRMSPropAttr is an optional argument to ResourceApplyCenteredRMSProp. -type ResourceApplyCenteredRMSPropAttr func(optionalAttr) +// ResourceSparseApplyKerasMomentumAttr is an optional argument to ResourceSparseApplyKerasMomentum. +type ResourceSparseApplyKerasMomentumAttr func(optionalAttr) -// ResourceApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMSPropAttr { +func ResourceSparseApplyKerasMomentumUseLocking(value bool) ResourceSparseApplyKerasMomentumAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. +// ResourceSparseApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. +// value: If `True`, the tensor passed to compute grad will be +// var + momentum * accum, so in the end, the var you get is actually +// var + momentum * accum. +// If not specified, defaults to false +func ResourceSparseApplyKerasMomentumUseNesterov(value bool) ResourceSparseApplyKerasMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the momentum scheme. // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient +// Set use_nesterov = True if you want to use Nesterov momentum. // -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// That is for rows we have grad for, we update var and accum as follows: // -// mg <- rho * mg_{t-1} + (1-rho) * grad -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) -// var <- var - mom +// accum = accum * momentum - lr * grad +// var += accum // // Arguments: // var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. // grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// momentum: Momentum. Must be a scalar. // // Returns the created operation. -func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyCenteredRMSPropAttr) (o *tf.Operation) { +func ResourceSparseApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyKerasMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -19815,63 +24478,53 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyCenteredRMSProp", + Type: "ResourceSparseApplyKerasMomentum", Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, + var_, accum, lr, grad, indices, momentum, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. -type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) +// ResourceApplyAdamWithAmsgradAttr is an optional argument to ResourceApplyAdamWithAmsgrad. +type ResourceApplyAdamWithAmsgradAttr func(optionalAttr) -// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdamWithAmsgradUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { +func ResourceApplyAdamWithAmsgradUseLocking(value bool) ResourceApplyAdamWithAmsgradAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. -// -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// Update '*var' according to the Adam algorithm. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ +// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ +// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ +// $$vhat_t := max{vhat_{t-1}, v_t}$$ +// $$variable := variable - lr_t * m_t / (\sqrt{vhat_t} + \epsilon)$$ // // Arguments: // var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// vhat: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. // lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. // epsilon: Ridge term. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { +func ResourceApplyAdamWithAmsgrad(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, vhat tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamWithAmsgradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -19880,221 +24533,336 @@ func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyCenteredRMSProp", + Type: "ResourceApplyAdamWithAmsgrad", Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, + var_, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Creates a dataset that batches `batch_size` elements from `input_dataset`. -// -// Arguments: +// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. +type MapUnstageNoKeyAttr func(optionalAttr) + +// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. +// REQUIRES: value >= 0 +func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // +// REQUIRES: value >= 0 +func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapUnstageNoKeyContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns a random (key, value) // -func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BatchDataset", + Type: "MapUnstageNoKey", Input: []tf.Input{ - input_dataset, batch_size, + indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + key = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstageNoKey", err) + return + } + return key, values } -// RandomPoissonV2Attr is an optional argument to RandomPoissonV2. -type RandomPoissonV2Attr func(optionalAttr) +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) -// RandomPoissonV2Seed sets the optional seed attribute to value. +// HashTableV2Container sets the optional container attribute to value. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr { +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func HashTableV2Container(value string) HashTableV2Attr { return func(m optionalAttr) { - m["seed"] = value + m["container"] = value } } -// RandomPoissonV2Seed2 sets the optional seed2 attribute to value. +// HashTableV2SharedName sets the optional shared_name attribute to value. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr { +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { return func(m optionalAttr) { - m["seed2"] = value + m["shared_name"] = value } } -// RandomPoissonV2Dtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_INT64 -func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr { +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { return func(m optionalAttr) { - m["dtype"] = value + m["use_node_name_sharing"] = value } } -// Outputs random values from the Poisson distribution(s) described by rate. -// -// This op uses two algorithms, depending on rate. If rate >= 10, then -// the algorithm by Hormann is used to acquire samples via -// transformation-rejection. -// See http://www.sciencedirect.com/science/article/pii/0167668793909974. +// Creates a non-initialized hash table. // -// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform -// random variables. -// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer -// Programming, Volume 2. Addison Wesley +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. // // Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in rate. -// rate: A tensor in which each scalar is a "rate" parameter describing the -// associated poisson distribution. +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. // -// Returns A tensor with shape `shape + shape(rate)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `rate[i0, i1, ...iN]`. -func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) { +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomPoissonV2", - Input: []tf.Input{ - shape, rate, - }, + Type: "HashTableV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. -type DecodeAndCropJpegAttr func(optionalAttr) +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingMomentumParametersGradAccumDebug. +type RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) -// DecodeAndCropJpegChannels sets the optional channels attribute to value. +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["channels"] = value + m["table_id"] = value } } -// DecodeAndCropJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["ratio"] = value + m["table_name"] = value } } -// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// Retrieve Momentum embedding parameters with debug support. // -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the Momentum optimization algorithm.Parameter momenta updated by the Momentum optimization algorithm.Parameter gradient_accumulators updated by the Momentum optimization algorithm. +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", + + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// Enqueue a Tensor on the computation outfeed. // -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value +// Arguments: +// input: A tensor that will be inserted into the outfeed queue. +// +// Returns the created operation. +func OutfeedEnqueue(scope *Scope, input tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OutfeedEnqueue", + Input: []tf.Input{ + input, + }, } + return scope.AddOperation(opspec) } -// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// Outputs a `Summary` protocol buffer with a histogram. // -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value +// The generated +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// has one summary value containing a histogram for `values`. +// +// This op reports an `InvalidArgument` error if any value is not finite. +// +// Arguments: +// tag: Scalar. Tag to use for the `Summary.Value`. +// values: Any shape. Values to use to build the histogram. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HistogramSummary", + Input: []tf.Input{ + tag, values, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. +// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. +type MutableDenseHashTableV2Attr func(optionalAttr) + +// MutableDenseHashTableV2Container sets the optional container attribute to value. // -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { +func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { return func(m optionalAttr) { - m["dct_method"] = value + m["container"] = value } } -// Decode and Crop a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. +// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. // -// Accepted values are: +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. // -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. +// value: The shape of each value. +// If not specified, defaults to <> +func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. // -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. +// value: The initial number of hash table buckets. Must be a power +// to 2. +// If not specified, defaults to 131072 +func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["initial_num_buckets"] = value + } +} + +// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. // -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. +// value: The maximum ratio between number of entries and number of +// buckets before growing the table. Must be between 0 and 1. +// If not specified, defaults to 0.8 +func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["max_load_factor"] = value + } +} + +// Creates an empty hash table that uses tensors as the backing store. // +// It uses "open addressing" with quadratic reprobing to resolve +// collisions. // -// It is equivalent to a combination of decode and crop, but much faster by only -// decoding partial jpeg image. +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. // // Arguments: -// contents: 0-D. The JPEG-encoded image. -// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// empty_key: The key used to represent empty key buckets internally. Must not +// be used in insert or lookup operations. // -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, deleted_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeAndCropJpeg", + Type: "MutableDenseHashTableV2", Input: []tf.Input{ - contents, crop_window, + empty_key, deleted_key, }, Attrs: attrs, } @@ -20102,265 +24870,297 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, return op.Output(0) } -// Adds two `SparseTensor` objects to produce another `SparseTensor`. +// Deprecated. Use TensorArraySplitV3 // -// The input `SparseTensor` objects' indices are assumed ordered in standard -// lexicographic order. If this is not the case, before this step run -// `SparseReorder` to restore index ordering. +// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 +func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySplitV2", + Input: []tf.Input{ + handle, value, lengths, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reshapes a SparseTensor to represent values in a new dense shape. // -// By default, if two values sum to zero at some index, the output `SparseTensor` -// would still include that particular location in its index, storing a zero in the -// corresponding value slot. To override this, callers can specify `thresh`, -// indicating that if the sum has a magnitude strictly smaller than `thresh`, its -// corresponding value and index would then not be included. In particular, -// `thresh == 0` (default) means everything is kept and actual thresholding happens -// only for a positive value. +// This operation has the same semantics as reshape on the represented dense +// tensor. The `input_indices` are recomputed based on the requested `new_shape`. +// +// If one component of `new_shape` is the special value -1, the size of that +// dimension is computed so that the total dense size remains constant. At +// most one component of `new_shape` can be -1. The number of dense elements +// implied by `new_shape` must be the same as the number of dense elements +// originally implied by `input_shape`. +// +// Reshaping does not affect the order of values in the SparseTensor. +// +// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` +// has length `R_out`, then `input_indices` has shape `[N, R_in]`, +// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and +// `output_shape` has length `R_out`. +// +// Arguments: +// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a +// SparseTensor. +// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. +// new_shape: 1-D. `R_out` vector with the requested new dense shape. +// +// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty +// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output +// SparseTensor. This is the same as `new_shape` but with any -1 dimensions +// filled in. +func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseReshape", + Input: []tf.Input{ + input_indices, input_shape, new_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Computes the product along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output_i = \prod_j data_j\\) where the product is over `j` such +// that `segment_ids[j] == i`. +// +// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_prod(c, tf.constant([0, 0, 1])) +// # ==> [[4, 6, 6, 4], +// # [5, 6, 7, 8]] +// ``` // -// In the following shapes, `nnz` is the count after taking `thresh` into account. // // Arguments: -// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. -// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. -// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. -// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. -// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. -// thresh: 0-D. The magnitude threshold that determines if an output value/index -// pair takes space. -func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAdd", + Type: "SegmentProd", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, + data, segment_ids, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. -type QuantizedRelu6Attr func(optionalAttr) +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. +type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) -// QuantizedRelu6OutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["out_type"] = value + m["table_id"] = value } } -// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` -// -// Arguments: +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve ADAM embedding parameters with debug support. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// Returns Parameter parameters updated by the ADAM optimization algorithm.Parameter momenta updated by the ADAM optimization algorithm.Parameter velocities updated by the ADAM optimization algorithm.Parameter gradient_accumulators updated by the ADAM optimization algorithm. +func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedRelu6", - Input: []tf.Input{ - features, min_features, max_features, - }, + Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. -type FixedLengthRecordReaderV2Attr func(optionalAttr) +// CudnnRNNAttr is an optional argument to CudnnRNN. +type CudnnRNNAttr func(optionalAttr) -// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. -// -// value: Number of bytes in the header, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { +// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNRnnMode(value string) CudnnRNNAttr { return func(m optionalAttr) { - m["header_bytes"] = value + m["rnn_mode"] = value } } -// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. -// -// value: Number of bytes in the footer, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { +// CudnnRNNInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNInputMode(value string) CudnnRNNAttr { return func(m optionalAttr) { - m["footer_bytes"] = value + m["input_mode"] = value } } -// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. -// -// value: Number of bytes to hop before each read. Default of 0 means using -// record_bytes. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { +// CudnnRNNDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNDirection(value string) CudnnRNNAttr { return func(m optionalAttr) { - m["hop_bytes"] = value + m["direction"] = value } } -// FixedLengthRecordReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { +// CudnnRNNDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNDropout(value float32) CudnnRNNAttr { return func(m optionalAttr) { - m["container"] = value + m["dropout"] = value } } -// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { +// CudnnRNNSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed"] = value } } -// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. -// -// value: The type of encoding for the file. Currently ZLIB and GZIP -// are supported. Defaults to none. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { +// CudnnRNNSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed2(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["encoding"] = value - } -} - -// A Reader that outputs fixed-length records from a file. -// -// Arguments: -// record_bytes: Number of bytes in the record. -// -// Returns The handle to reference the Reader. -func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"record_bytes": record_bytes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FixedLengthRecordReaderV2", - - Attrs: attrs, + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// AudioSummaryAttr is an optional argument to AudioSummary. -type AudioSummaryAttr func(optionalAttr) - -// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value. -// -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 -// -// REQUIRES: value >= 1 -func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr { +// CudnnRNNIsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNIsTraining(value bool) CudnnRNNAttr { return func(m optionalAttr) { - m["max_outputs"] = value + m["is_training"] = value } } -// Outputs a `Summary` protocol buffer with audio. -// -// DEPRECATED at GraphDef version 15: Use AudioSummaryV2. -// -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: -// -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. +// A RNN backed by cuDNN. // -// Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. +// Computes the RNN from the input and initial states, with respect to the params +// buffer. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) { +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// is_training: Indicates whether this operation is used for inferenece or +// training. +// reserve_space: An opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is false. +func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"sample_rate": sample_rate} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AudioSummary", + Type: "CudnnRNN", Input: []tf.Input{ - tag, tensor, + input, input_h, input_c, params, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// QrAttr is an optional argument to Qr. -type QrAttr func(optionalAttr) +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) -// QrFullMatrices sets the optional full_matrices attribute to value. +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. // -// value: If true, compute full-sized `q` and `r`. If false -// (the default), compute only the leading `P` columns of `q`. -// If not specified, defaults to false -func QrFullMatrices(value bool) QrAttr { +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// If not specified, defaults to "" +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { return func(m optionalAttr) { - m["full_matrices"] = value + m["compression_type"] = value } } -// Computes the QR decompositions of one or more matrices. +// Decompress strings. // -// Computes the QR decomposition of each inner matrix in `tensor` such that -// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. // -// ```python -// # a is a tensor. -// # q is a tensor of orthonormal matrices. -// # r is a tensor of upper triangular matrices. -// q, r = qr(a) -// q_full, r_full = qr(a, full_matrices=True) -// ``` +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. // // Arguments: -// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. +// bytes: A Tensor of string which is compressed. // -// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then -// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is -// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is -// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`. -func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) { +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -20369,331 +25169,224 @@ func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Ou a(attrs) } opspec := tf.OpSpec{ - Type: "Qr", + Type: "DecodeCompressed", Input: []tf.Input{ - input, + bytes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Check if the input matches the regex pattern. -// -// The input is a string tensor of any shape. The pattern is the -// regular expression to be matched with every element of the input tensor. -// The boolean values (True or False) of the output tensor indicate -// if the input matches the regex pattern provided. -// -// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) -// -// Arguments: -// input: A string tensor of the text to be processed. -// pattern: The regular expression to match the input. +// RetrieveTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to RetrieveTPUEmbeddingMDLAdagradLightParameters. +type RetrieveTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// Returns A bool tensor with the same shape as `input`. -func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"pattern": pattern} - opspec := tf.OpSpec{ - Type: "StaticRegexFullMatch", - Input: []tf.Input{ - input, - }, - Attrs: attrs, +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingMDLAdagradLightParametersTableId(value int64) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. -type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { +// RetrieveTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMDLAdagradLightParametersTableName(value string) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["table_name"] = value } } -// Sparse update '*var' as FOBOS algorithm with fixed learning rate. -// -// That is for rows we have grad for, we update var as follows: -// prox_v = var - alpha * grad -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// Retrieve MDL Adagrad Light embedding parameters. // -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns the created operation. -func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { +// Returns Parameter parameters updated by the MDL Adagrad Light optimization algorithm.Parameter accumulators updated by the MDL Adagrad Light optimization algorithm.Parameter weights updated by the MDL Adagrad Light optimization algorithm.Parameter benefits updated by the MDL Adagrad Light optimization algorithm. +func RetrieveTPUEmbeddingMDLAdagradLightParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMDLAdagradLightParametersAttr) (parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalGradientDescent", - Input: []tf.Input{ - var_, alpha, l1, l2, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} + Type: "RetrieveTPUEmbeddingMDLAdagradLightParameters", -// Real-valued fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most dimension of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the -// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, -// followed by the `fft_length / 2` positive-frequency terms. -// -// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [1]. The FFT length. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most -// dimension of `input` is replaced with the `fft_length / 2 + 1` unique -// frequency components of its 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfft -// @end_compatibility -func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT", - Input: []tf.Input{ - input, fft_length, - }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Adds a value to the current value of a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to -// see the incremented value or a subsequent newer one. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. +// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug. +type RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// Returns the created operation. -func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignAddVariableOp", - Input: []tf.Input{ - resource, value, - }, +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value } - return scope.AddOperation(opspec) } -// QuantizedReluAttr is an optional argument to QuantizedRelu. -type QuantizedReluAttr func(optionalAttr) - -// QuantizedReluOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedReluOutType(value tf.DataType) QuantizedReluAttr { +// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["out_type"] = value + m["table_name"] = value } } -// Computes Quantized Rectified Linear: `max(features, 0)` +// Retrieve Adadelta embedding parameters with debug support. // -// Arguments: -// -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// Returns Parameter parameters updated by the Adadelta optimization algorithm.Parameter accumulators updated by the Adadelta optimization algorithm.Parameter updates updated by the Adadelta optimization algorithm.Parameter gradient_accumulators updated by the Adadelta optimization algorithm. +func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, updates tf.Output, gradient_accumulators tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedRelu", - Input: []tf.Input{ - features, min_features, max_features, - }, + Type: "RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Reshapes a SparseTensor to represent values in a new dense shape. -// -// This operation has the same semantics as reshape on the represented dense -// tensor. The `input_indices` are recomputed based on the requested `new_shape`. -// -// If one component of `new_shape` is the special value -1, the size of that -// dimension is computed so that the total dense size remains constant. At -// most one component of `new_shape` can be -1. The number of dense elements -// implied by `new_shape` must be the same as the number of dense elements -// originally implied by `input_shape`. -// -// Reshaping does not affect the order of values in the SparseTensor. -// -// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` -// has length `R_out`, then `input_indices` has shape `[N, R_in]`, -// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and -// `output_shape` has length `R_out`. -// -// Arguments: -// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a -// SparseTensor. -// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. -// new_shape: 1-D. `R_out` vector with the requested new dense shape. +// MapClearAttr is an optional argument to MapClear. +type MapClearAttr func(optionalAttr) + +// MapClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty -// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output -// SparseTensor. This is the same as `new_shape` but with any -1 dimensions -// filled in. -func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseReshape", - Input: []tf.Input{ - input_indices, input_shape, new_shape, - }, +// REQUIRES: value >= 0 +func MapClearCapacity(value int64) MapClearAttr { + return func(m optionalAttr) { + m["capacity"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) } -// Deprecated. Use TensorArraySplitV3 +// MapClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 -func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func MapClearMemoryLimit(value int64) MapClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - opspec := tf.OpSpec{ - Type: "TensorArraySplitV2", - Input: []tf.Input{ - handle, value, lengths, flow_in, - }, +} + +// MapClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapClearContainer(value string) MapClearAttr { + return func(m optionalAttr) { + m["container"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Reorders a SparseTensor into the canonical, row-major ordering. -// -// Note that by convention, all sparse ops preserve the canonical ordering along -// increasing dimension number. The only time ordering can be violated is during -// manual manipulation of the indices and values vectors to add entries. -// -// Reordering does not affect the shape of the SparseTensor. -// -// If the tensor has rank `R` and `N` non-empty values, `input_indices` has -// shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. +// MapClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapClearSharedName(value string) MapClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. // -// Returns 2-D. `N x R` matrix with the same indices as input_indices, but -// in canonical row-major ordering.1-D. `N` non-empty values corresponding to `output_indices`. -func SparseReorder(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { +// Returns the created operation. +func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseReorder", - Input: []tf.Input{ - input_indices, input_values, input_shape, - }, + Type: "MapClear", + + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Computes rectified linear: `max(features, 0)`. -func Relu(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu", - Input: []tf.Input{ - features, - }, +// DecodeCSVAttr is an optional argument to DecodeCSV. +type DecodeCSVAttr func(optionalAttr) + +// DecodeCSVFieldDelim sets the optional field_delim attribute to value. +// +// value: char delimiter to separate fields in a record. +// If not specified, defaults to "," +func DecodeCSVFieldDelim(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["field_delim"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. -type ResourceApplyAddSignAttr func(optionalAttr) +// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. +// +// value: If false, treats double quotation marks as regular +// characters inside of the string fields (ignoring RFC 4180, Section 2, +// Bullet 5). +// If not specified, defaults to true +func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { + return func(m optionalAttr) { + m["use_quote_delim"] = value + } +} -// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. +// DecodeCSVNaValue sets the optional na_value attribute to value. // -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { +// value: Additional string to recognize as NA/NaN. +// If not specified, defaults to "" +func DecodeCSVNaValue(value string) DecodeCSVAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["na_value"] = value } } -// Update '*var' according to the AddSign update. +// DecodeCSVSelectCols sets the optional select_cols attribute to value. +// If not specified, defaults to <> +func DecodeCSVSelectCols(value []int64) DecodeCSVAttr { + return func(m optionalAttr) { + m["select_cols"] = value + } +} + +// Convert CSV records to tensors. Each column maps to one tensor. // -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- (alpha + sign_decay * sign(g) *sign(m)) * g -// variable <- variable - lr_t * update +// RFC 4180 format is expected for the CSV records. +// (https://tools.ietf.org/html/rfc4180) +// Note that we allow leading and trailing spaces with int or float field. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// alpha: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. +// records: Each string is a record/row in the csv and all records should have +// the same format. +// record_defaults: One tensor per column of the input record, with either a +// scalar default value for that column or an empty vector if the column is +// required. // -// Returns the created operation. -func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { +// Returns Each tensor will have the same shape as records. +func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { if scope.Err() != nil { return } @@ -20702,82 +25395,97 @@ func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Outpu a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAddSign", + Type: "DecodeCSV", Input: []tf.Input{ - var_, m, lr, alpha, sign_decay, beta, grad, + records, tf.OutputList(record_defaults), }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("DecodeCSV", err) + return + } + return output } -// Component-wise divides a SparseTensor by a dense Tensor. -// -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. +// Produces the max pool of the input tensor for quantized types. // // Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// input: The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// ksize: The size of the window for each dimension of the input tensor. +// The length must be 4 to match the number of dimensions of the input. +// strides: The stride of the sliding window for each dimension of the input +// tensor. The length must be 4 to match the number of dimensions of the input. +// padding: The type of padding algorithm to use. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMaxPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "SparseDenseCwiseDiv", + Type: "QuantizedMaxPool", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, + input, min_input, max_input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// FractionalAvgPoolGradAttr is an optional argument to FractionalAvgPoolGrad. -type FractionalAvgPoolGradAttr func(optionalAttr) +// RandomShuffleAttr is an optional argument to RandomShuffle. +type RandomShuffleAttr func(optionalAttr) -// FractionalAvgPoolGradOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` +// RandomShuffleSeed sets the optional seed attribute to value. // -// `value 20 5 16 3 7` +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomShuffleSeed(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleSeed2 sets the optional seed2 attribute to value. // -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [41/3, 26/3] for fractional avg pooling. -// If not specified, defaults to false -func FractionalAvgPoolGradOverlapping(value bool) FractionalAvgPoolGradAttr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleSeed2(value int64) RandomShuffleAttr { return func(m optionalAttr) { - m["overlapping"] = value + m["seed2"] = value } } -// Computes gradient of the FractionalAvgPool function. +// Randomly shuffles a tensor along its first dimension. // -// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for -// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of -// out_backprop to those indices that form the same pooling cell. Therefore, we -// just need to know the shape of original input tensor, instead of the whole -// tensor. +// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped +// to one and only one `output[i]`. For example, a mapping that might occur for a +// 3x2 tensor is: +// +// ``` +// [[1, 2], [[5, 6], +// [3, 4], ==> [1, 2], +// [5, 6]] [3, 4]] +// ``` // // Arguments: -// orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` -// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients -// w.r.t. the output of `fractional_avg_pool`. -// row_pooling_sequence: row pooling sequence, form pooling region with -// col_pooling_sequence. -// col_pooling_sequence: column pooling sequence, form pooling region with -// row_pooling sequence. +// value: The tensor to be shuffled. // -// Returns 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. -func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalAvgPoolGradAttr) (output tf.Output) { +// Returns A tensor of same shape and type as `value`, shuffled along its first +// dimension. +func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -20786,9 +25494,9 @@ func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_ a(attrs) } opspec := tf.OpSpec{ - Type: "FractionalAvgPoolGrad", + Type: "RandomShuffle", Input: []tf.Input{ - orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, + value, }, Attrs: attrs, } @@ -20796,107 +25504,116 @@ func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_ return op.Output(0) } -// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. -type QuantizedConv2DAttr func(optionalAttr) +// EnqueueTPUEmbeddingSparseBatchAttr is an optional argument to EnqueueTPUEmbeddingSparseBatch. +type EnqueueTPUEmbeddingSparseBatchAttr func(optionalAttr) -// QuantizedConv2DOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { +// EnqueueTPUEmbeddingSparseBatchDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. Should be >= 0 and less than the number +// of TPU cores in the task on which the node is placed. +// If not specified, defaults to -1 +func EnqueueTPUEmbeddingSparseBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingSparseBatchAttr { return func(m optionalAttr) { - m["out_type"] = value + m["device_ordinal"] = value } } -// QuantizedConv2DDilations sets the optional dilations attribute to value. +// EnqueueTPUEmbeddingSparseBatchCombiners sets the optional combiners attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { +// value: A list of string scalars, one for each embedding table that specify +// how to normalize the embedding activations after weighted summation. +// Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have +// the sum of the weights be 0 for 'mean' or the sum of the squared weights be +// 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for +// all tables. +// If not specified, defaults to <> +func EnqueueTPUEmbeddingSparseBatchCombiners(value []string) EnqueueTPUEmbeddingSparseBatchAttr { return func(m optionalAttr) { - m["dilations"] = value + m["combiners"] = value } } -// Computes a 2D convolution given quantized 4D input and filter tensors. +// An op that enqueues TPUEmbedding input indices from a SparseTensor. // -// The inputs are quantized tensors where the lowest value represents the real -// number of the associated minimum, and the highest represents the maximum. -// This means that you can only interpret the quantized output in the same way, by -// taking the returned minimum and maximum values into account. +// This Op eases the porting of code that uses embedding_lookup_sparse(), +// although some Python preprocessing of the SparseTensor arguments to +// embedding_lookup_sparse() is required to produce the arguments to this Op, +// since only a single EnqueueTPUEmbeddingSparseBatch Op is allowed per training +// step. // -// Arguments: +// The tensors at corresponding positions in the three input lists +// must have the same shape, i.e. rank 1 with dim_size() equal to the total +// number of lookups into the table described by the corresponding table_id. // -// filter: filter's input_depth dimension must match input's depth dimensions. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// min_filter: The float value that the lowest quantized filter value represents. -// max_filter: The float value that the highest quantized filter value represents. -// strides: The stride of the sliding window for each dimension of the input -// tensor. -// padding: The type of padding algorithm to use. +// Arguments: +// sample_indices: A list of rank 1 Tensors specifying the training example and +// feature to which the corresponding embedding_indices and aggregation_weights +// values belong. sample_indices[i] must equal b * nf + f, where nf is the +// number of features from the corresponding table, f is in [0, nf), and +// b is in [0, batch size). +// embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. +// aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e. per +// (training example, feature) -- aggregation weights. +// mode_override: A string input that overrides the mode specified in the +// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { +// Returns the created operation. +func EnqueueTPUEmbeddingSparseBatch(scope *Scope, sample_indices []tf.Output, embedding_indices []tf.Output, aggregation_weights []tf.Output, mode_override tf.Output, optional ...EnqueueTPUEmbeddingSparseBatchAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedConv2D", + Type: "EnqueueTPUEmbeddingSparseBatch", Input: []tf.Input{ - input, filter, min_input, max_input, min_filter, max_filter, + tf.OutputList(sample_indices), tf.OutputList(embedding_indices), tf.OutputList(aggregation_weights), mode_override, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// ResourceGatherAttr is an optional argument to ResourceGather. -type ResourceGatherAttr func(optionalAttr) +// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. +type StatelessRandomNormalAttr func(optionalAttr) -// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { +// StatelessRandomNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["dtype"] = value } } -// Gather slices from the variable pointed to by `resource` according to `indices`. +// Outputs deterministic pseudorandom values from a normal distribution. // -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: +// The generated values will have mean 0 and standard deviation 1. // -// ```python -// # Scalar indices -// output[:, ..., :] = params[indices, :, ... :] +// The outputs are a deterministic function of `shape` and `seed`. // -// # Vector indices -// output[i, :, ..., :] = params[indices[i], :, ... :] +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// # Higher rank indices -// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] -// ``` -func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype tf.DataType, optional ...ResourceGatherAttr) (output tf.Output) { +// Returns Random values with specified shape. +func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceGather", + Type: "StatelessRandomNormal", Input: []tf.Input{ - resource, indices, + shape, seed, }, Attrs: attrs, } @@ -20904,61 +25621,45 @@ func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype t return op.Output(0) } -// Delete the TensorArray from its resource container. +// An Op to exchange data across TPU replicas. // -// This enables the user to close and release the resource in the middle -// of a step/run. +// On each replica, the input is split into `split_count` blocks along +// `split_dimension` and send to the other replicas given group_assignment. After +// receiving `split_count` - 1 blocks from other replicas, we concatenate the +// blocks along `concat_dimension` as the output. // -// Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// For example, suppose there are 2 TPU replicas: +// replica 0 receives input: `[[A, B]]` +// replica 1 receives input: `[[C, D]]` // -// Returns the created operation. -func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayCloseV3", - Input: []tf.Input{ - handle, - }, - } - return scope.AddOperation(opspec) -} - -// StatelessMultinomialAttr is an optional argument to StatelessMultinomial. -type StatelessMultinomialAttr func(optionalAttr) - -// StatelessMultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func StatelessMultinomialOutputDtype(value tf.DataType) StatelessMultinomialAttr { - return func(m optionalAttr) { - m["output_dtype"] = value - } -} - -// Draws samples from a multinomial distribution. +// group_assignment=`[[0, 1]]` +// concat_dimension=0 +// split_dimension=1 +// split_count=2 +// +// replica 0's output: `[[A], [C]]` +// replica 1's output: `[[B], [D]]` // // Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. -// seed: 2 seeds (shape [2]). +// input: The local input to the sum. +// group_assignment: An int32 tensor with shape +// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the +// replica ids in the ith subgroup. +// concat_dimension: The dimension number to concatenate. +// split_dimension: The dimension number to split. +// split_count: The number of splits, this number must equal to the sub-group +// size(group_assignment.get_shape()[1]) // -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func StatelessMultinomial(scope *Scope, logits tf.Output, num_samples tf.Output, seed tf.Output, optional ...StatelessMultinomialAttr) (output tf.Output) { +// Returns The exchanged result. +func AllToAll(scope *Scope, input tf.Output, group_assignment tf.Output, concat_dimension int64, split_dimension int64, split_count int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"concat_dimension": concat_dimension, "split_dimension": split_dimension, "split_count": split_count} opspec := tf.OpSpec{ - Type: "StatelessMultinomial", + Type: "AllToAll", Input: []tf.Input{ - logits, num_samples, seed, + input, group_assignment, }, Attrs: attrs, } @@ -20966,203 +25667,111 @@ func StatelessMultinomial(scope *Scope, logits tf.Output, num_samples tf.Output, return op.Output(0) } -// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. +// Adds a value to the current value of a variable. // -// This Op does not require `a_indices` be sorted in standard lexicographic order. +// Any ReadVariableOp with a control dependency on this op is guaranteed to +// see the incremented value or a subsequent newer one. // // Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. -// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. -// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. -// b: `ndims`-D Tensor. With shape `a_shape`. -func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. +// +// Returns the created operation. +func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseTensorDenseAdd", + Type: "AssignAddVariableOp", Input: []tf.Input{ - a_indices, a_values, a_shape, b, + resource, value, }, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. -type SparseToSparseSetOperationAttr func(optionalAttr) - -// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } + return scope.AddOperation(opspec) } -// Applies set operation along last dimension of 2 `SparseTensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the -// order and range of `set1` and `set2` indices. -// -// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, -// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same -// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. +// Real-valued fast Fourier transform. // -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. +// Computes the 1-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most dimension of `input`. // -// If `validate_indices` is `True`, this op validates the order and range of `set1` -// and `set2` indices. +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the +// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, +// followed by the `fft_length / 2` positive-frequency terms. // -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must -// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the -// max set size across `0...n-1` dimensions. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the -// max set size across `0...n-1` dimensions. +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [1]. The FFT length. // +// Returns A complex64 tensor of the same rank as `input`. The inner-most +// dimension of `input` is replaced with the `fft_length / 2 + 1` unique +// frequency components of its 1D Fourier transform. // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// @compatibility(numpy) +// Equivalent to np.fft.rfft +// @end_compatibility +func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseToSparseSetOperation", + Type: "RFFT", Input: []tf.Input{ - set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. -type MutableDenseHashTableV2Attr func(optionalAttr) +// RetrieveTPUEmbeddingAdadeltaParametersAttr is an optional argument to RetrieveTPUEmbeddingAdadeltaParameters. +type RetrieveTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) -// MutableDenseHashTableV2Container sets the optional container attribute to value. +// RetrieveTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdadeltaParametersTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { - m["container"] = value + m["table_id"] = value } } -// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. +// RetrieveTPUEmbeddingAdadeltaParametersTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. -// -// value: The shape of each value. -// If not specified, defaults to <> -func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. -// -// value: The initial number of hash table buckets. Must be a power -// to 2. -// If not specified, defaults to 131072 -func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["initial_num_buckets"] = value - } -} - -// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. -// -// value: The maximum ratio between number of entries and number of -// buckets before growing the table. Must be between 0 and 1. -// If not specified, defaults to 0.8 -func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { +func RetrieveTPUEmbeddingAdadeltaParametersTableName(value string) RetrieveTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { - m["max_load_factor"] = value + m["table_name"] = value } -} - -// Creates an empty hash table that uses tensors as the backing store. -// -// It uses "open addressing" with quadratic reprobing to resolve -// collisions. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// empty_key: The key used to represent empty key buckets internally. Must not -// be used in insert or lookup operations. +} + +// Retrieve Adadelta embedding parameters. // -// value_dtype: Type of the table values. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns Handle to a table. -func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, deleted_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { +// Returns Parameter parameters updated by the Adadelta optimization algorithm.Parameter accumulators updated by the Adadelta optimization algorithm.Parameter updates updated by the Adadelta optimization algorithm. +func RetrieveTPUEmbeddingAdadeltaParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdadeltaParametersAttr) (parameters tf.Output, accumulators tf.Output, updates tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"value_dtype": value_dtype} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableDenseHashTableV2", - Input: []tf.Input{ - empty_key, deleted_key, - }, + Type: "RetrieveTPUEmbeddingAdadeltaParameters", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } // UpperBoundAttr is an optional argument to UpperBound. @@ -21276,51 +25885,6 @@ func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Ou return op.Output(0) } -// ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. -type ResourceApplyAdagradDAAttr func(optionalAttr) - -// ResourceApplyAdagradDAUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyAdagradDAUseLocking(value bool) ResourceApplyAdagradDAAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the proximal adagrad scheme. -// -// Arguments: -// var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceApplyAdagradDAAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdagradDA", - Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, lr, l1, l2, global_step, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // SparseReduceMaxSparseAttr is an optional argument to SparseReduceMaxSparse. type SparseReduceMaxSparseAttr func(optionalAttr) @@ -21374,6 +25938,78 @@ func SparseReduceMaxSparse(scope *Scope, input_indices tf.Output, input_values t return op.Output(0), op.Output(1), op.Output(2) } +// Convert one or more images from HSV to RGB. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// See `rgb_to_hsv` for a description of the HSV encoding. +// +// Arguments: +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HSVToRGB", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient of the sigmoid of `x` wrt its input. +// +// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +// `dy` is the corresponding input gradient. +func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SigmoidGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that changes the batch size. +// +// Creates a dataset that changes the batch size of the dataset to current batch +// size // num_workers. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// num_workers: A scalar representing the number of workers to distribute this batch across. As +// a result of this transformation the current batch size would end up being +// divided by this parameter. +// +// +func ExperimentalRebatchDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalRebatchDataset", + Input: []tf.Input{ + input_dataset, num_workers, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a dataset that emits the outputs of `input_dataset` `count` times. // // Arguments: @@ -21398,236 +26034,415 @@ func RepeatDataset(scope *Scope, input_dataset tf.Output, count tf.Output, outpu return op.Output(0) } -// Computes the gradient for the inverse of `x` wrt its input. +// ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. +type ResourceApplyAdagradDAAttr func(optionalAttr) + +// ResourceApplyAdagradDAUseLocking sets the optional use_locking attribute to value. // -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyAdagradDAUseLocking(value bool) ResourceApplyAdagradDAAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the proximal adagrad scheme. +// +// Arguments: +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceApplyAdagradDAAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ReciprocalGrad", + Type: "ResourceApplyAdagradDA", Input: []tf.Input{ - y, dy, + var_, gradient_accumulator, gradient_squared_accumulator, grad, lr, l1, l2, global_step, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// Creates a TensorList which, when stacked, has the value of `tensor`. // -// *NOTE*: `Minimum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Each tensor in the result list corresponds to one row of the input tensor. +// +// tensor: The input tensor. +// output_handle: The list. +func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Minimum", + Type: "TensorListFromTensor", Input: []tf.Input{ - x, y, + tensor, element_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MfccAttr is an optional argument to Mfcc. -type MfccAttr func(optionalAttr) +// ConfigureDistributedTPUAttr is an optional argument to ConfigureDistributedTPU. +type ConfigureDistributedTPUAttr func(optionalAttr) -// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// ConfigureDistributedTPUEmbeddingConfig sets the optional embedding_config attribute to value. // -// value: The highest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 4000 -func MfccUpperFrequencyLimit(value float32) MfccAttr { +// value: Reserved. Do not use. +// If not specified, defaults to "" +func ConfigureDistributedTPUEmbeddingConfig(value string) ConfigureDistributedTPUAttr { return func(m optionalAttr) { - m["upper_frequency_limit"] = value + m["embedding_config"] = value } } -// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// ConfigureDistributedTPUTpuEmbeddingConfig sets the optional tpu_embedding_config attribute to value. // -// value: The lowest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 20 -func MfccLowerFrequencyLimit(value float32) MfccAttr { +// value: Serialized tensorflow.tpu.TPUEmbeddingConfiguration that +// describes the embedding lookups of the program. +// If not specified, defaults to "" +func ConfigureDistributedTPUTpuEmbeddingConfig(value string) ConfigureDistributedTPUAttr { return func(m optionalAttr) { - m["lower_frequency_limit"] = value + m["tpu_embedding_config"] = value } } -// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// ConfigureDistributedTPUIsGlobalInit sets the optional is_global_init attribute to value. // -// value: Resolution of the Mel bank used internally. -// If not specified, defaults to 40 -func MfccFilterbankChannelCount(value int64) MfccAttr { +// value: Reserved. Do not use. +// If not specified, defaults to false +func ConfigureDistributedTPUIsGlobalInit(value bool) ConfigureDistributedTPUAttr { return func(m optionalAttr) { - m["filterbank_channel_count"] = value + m["is_global_init"] = value } } -// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. +// Sets up the centralized structures for a distributed TPU system. // -// value: How many output channels to produce per time slice. -// If not specified, defaults to 13 -func MfccDctCoefficientCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["dct_coefficient_count"] = value +// Returns A serialized tensorflow.tpu.TopologyProto that describes the TPU +// topology. +func ConfigureDistributedTPU(scope *Scope, optional ...ConfigureDistributedTPUAttr) (topology tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ConfigureDistributedTPU", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Transforms a spectrogram into a form that's useful for speech recognition. +// Reshapes a quantized tensor as per the Reshape op. // -// Mel Frequency Cepstral Coefficients are a way of representing audio data that's -// been effective as an input feature for machine learning. They are created by -// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the -// higher frequencies that are less significant to the human ear. They have a long -// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum -// is a good resource to learn more. +// ``` // // Arguments: -// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared -// set to true. -// sample_rate: How many samples per second the source audio used. -func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { +// +// shape: Defines the shape of the output tensor. +// input_min: The minimum value of the input. +// input_max: The maximum value of the input. +// +// Returns This value is copied from input_min.This value is copied from input_max. +func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Mfcc", + Type: "QuantizedReshape", Input: []tf.Input{ - spectrogram, sample_rate, + tensor, shape, input_min, input_max, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Returns the element-wise sum of a list of tensors. +// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. +type PriorityQueueV2Attr func(optionalAttr) + +// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. // -// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not -// wait for all of its inputs to be ready before beginning to sum. This can -// save memory if inputs are ready at different times, since minimum temporary -// storage is proportional to the output size rather than the inputs size. +// value: The type of each component in a value. +// If not specified, defaults to <> // -// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// REQUIRES: len(value) >= 0 +func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["component_types"] = value + } +} + +// PriorityQueueV2Capacity sets the optional capacity attribute to value. // -// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// PriorityQueueV2Container sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PriorityQueueV2Container(value string) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// PriorityQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that produces elements sorted by the first component value. +// +// Note that the PriorityQueue requires the first component of any element +// to be a scalar int64, in addition to the other elements declared by +// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue +// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra +// entry in their input (resp. output) lists. // // Arguments: -// inputs: A list of `Tensor` objects, each with same shape and type. -// shape: Shape of elements of `inputs`. -func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { +// shapes: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// +// Returns The handle to the queue. +func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} + attrs := map[string]interface{}{"shapes": shapes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "AccumulateNV2", - Input: []tf.Input{ - tf.OutputList(inputs), - }, + Type: "PriorityQueueV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Outputs deterministic pseudorandom random integers from a uniform distribution. +// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. +type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) + +// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. // -// The generated values follow a uniform distribution in the range `[minval, maxval)`. +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update '*var' as FOBOS algorithm with fixed learning rate. // -// The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +// That is for rows we have grad for, we update var as follows: +// prox_v = var - alpha * grad +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// minval: Minimum value (inclusive, scalar). -// maxval: Maximum value (exclusive, scalar). +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // -// Returns Random values with specified shape. -func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { +// Returns the created operation. +func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "StatelessRandomUniformInt", + Type: "ResourceSparseApplyProximalGradientDescent", Input: []tf.Input{ - shape, seed, minval, maxval, + var_, alpha, l1, l2, grad, indices, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). +// Check if the input matches the regex pattern. // -// The Hurwitz zeta function is defined as: +// The input is a string tensor of any shape. The pattern is the +// regular expression to be matched with every element of the input tensor. +// The boolean values (True or False) of the output tensor indicate +// if the input matches the regex pattern provided. // +// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) // -// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) -func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { +// Arguments: +// input: A string tensor of the text to be processed. +// pattern: The regular expression to match the input. +// +// Returns A bool tensor with the same shape as `input`. +func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"pattern": pattern} opspec := tf.OpSpec{ - Type: "Zeta", + Type: "StaticRegexFullMatch", Input: []tf.Input{ - x, q, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the cardinality of `input_dataset`. +// OutfeedDequeueAttr is an optional argument to OutfeedDequeue. +type OutfeedDequeueAttr func(optionalAttr) + +// OutfeedDequeueDeviceOrdinal sets the optional device_ordinal attribute to value. // -// Returns the cardinality of `input_dataset`. +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func OutfeedDequeueDeviceOrdinal(value int64) OutfeedDequeueAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// Retrieves a single tensor from the computation outfeed. +// +// This operation will block indefinitely until data is available. // // Arguments: -// input_dataset: A variant tensor representing the dataset to return cardinality for. +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. // -// Returns The cardinality of `input_dataset`. Named constants are used to represent -// infinite and unknown cardinality. -func ExperimentalDatasetCardinality(scope *Scope, input_dataset tf.Output) (cardinality tf.Output) { +// Returns A tensor that will be read from the device outfeed. +func OutfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...OutfeedDequeueAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ExperimentalDatasetCardinality", - Input: []tf.Input{ - input_dataset, - }, + Type: "OutfeedDequeue", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that executes a SQL query and emits rows of the result set. +// RandomPoissonV2Attr is an optional argument to RandomPoissonV2. +type RandomPoissonV2Attr func(optionalAttr) + +// RandomPoissonV2Seed sets the optional seed attribute to value. // -// Arguments: -// driver_name: The database type. Currently, the only supported type is 'sqlite'. -// data_source_name: A connection string to connect to the database. -// query: A SQL query to execute. +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomPoissonV2Seed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// RandomPoissonV2Dtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_INT64 +func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from the Poisson distribution(s) described by rate. // +// This op uses two algorithms, depending on rate. If rate >= 10, then +// the algorithm by Hormann is used to acquire samples via +// transformation-rejection. +// See http://www.sciencedirect.com/science/article/pii/0167668793909974. // -func ExperimentalSqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform +// random variables. +// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer +// Programming, Volume 2. Addison Wesley +// +// Arguments: +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in rate. +// rate: A tensor in which each scalar is a "rate" parameter describing the +// associated poisson distribution. +// +// Returns A tensor with shape `shape + shape(rate)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `rate[i0, i1, ...iN]`. +func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ExperimentalSqlDataset", + Type: "RandomPoissonV2", Input: []tf.Input{ - driver_name, data_source_name, query, + shape, rate, }, Attrs: attrs, } @@ -21635,586 +26450,531 @@ func ExperimentalSqlDataset(scope *Scope, driver_name tf.Output, data_source_nam return op.Output(0) } -// Inverse fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform over the -// inner-most dimension of `input`. +// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug. +type RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// Arguments: -// input: A complex tensor. +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve RMSProp embedding parameters with debug support. // -// Returns A complex tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier transform. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// @compatibility(numpy) -// Equivalent to np.fft.ifft -// @end_compatibility -func IFFT(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Parameter parameters updated by the RMSProp optimization algorithm.Parameter ms updated by the RMSProp optimization algorithm.Parameter mom updated by the RMSProp optimization algorithm.Parameter gradient_accumulators updated by the RMSProp optimization algorithm. +func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IFFT", - Input: []tf.Input{ - input, - }, + Type: "RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", + + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// 2D fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. +// Computes the gradient for the rsqrt of `x` wrt its input. // -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { +// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` +// is the corresponding input gradient. +func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FFT2D", + Type: "RsqrtGrad", Input: []tf.Input{ - input, + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Inverse 2D fast Fourier transform. +// Encode audio data using the WAV file format. // -// Computes the inverse 2-dimensional discrete Fourier transform over the -// inner-most 2 dimensions of `input`. +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. // -// Arguments: -// input: A complex tensor. +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // -// Returns A complex tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// Arguments: +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. // -// @compatibility(numpy) -// Equivalent to np.fft.ifft2 -// @end_compatibility -func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IFFT2D", + Type: "EncodeWav", Input: []tf.Input{ - input, + audio, sample_rate, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns element-wise remainder of division. This emulates C semantics in that +// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax. +type ResourceApplyAdaMaxAttr func(optionalAttr) + +// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value. // -// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * -// y + truncate_mod(x, y) = x`. +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AdaMax algorithm. // -// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// v_t <- max(beta2 * v_{t-1}, abs(g)) +// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TruncateMod", + Type: "ResourceApplyAdaMax", Input: []tf.Input{ - x, y, + var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Inverse 2D real-valued fast Fourier transform. -// -// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 2 dimensions of `input`. -// -// The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 2 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along each axis `IRFFT2D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.irfft2 -// @end_compatibility -func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Computes atan of x element-wise. +func Atan(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IRFFT2D", + Type: "Atan", Input: []tf.Input{ - input, fft_length, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeJpegAttr is an optional argument to DecodeJpeg. -type DecodeJpegAttr func(optionalAttr) +// AssertAttr is an optional argument to Assert. +type AssertAttr func(optionalAttr) -// DecodeJpegChannels sets the optional channels attribute to value. +// AssertSummarize sets the optional summarize attribute to value. // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeJpegChannels(value int64) DecodeJpegAttr { +// value: Print this many entries of each tensor. +// If not specified, defaults to 3 +func AssertSummarize(value int64) AssertAttr { return func(m optionalAttr) { - m["channels"] = value + m["summarize"] = value } } -// DecodeJpegRatio sets the optional ratio attribute to value. +// Asserts that the given condition is true. // -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeJpegRatio(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// If `condition` evaluates to false, print the list of tensors in `data`. +// `summarize` determines how many entries of the tensors to print. // -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// Arguments: +// condition: The condition to evaluate. +// data: The tensors to print out when condition is false. // -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value +// Returns the created operation. +func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Assert", + Input: []tf.Input{ + condition, tf.OutputList(data), + }, + Attrs: attrs, } + return scope.AddOperation(opspec) } -// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingAdagradParametersGradAccumDebug. +type LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["acceptable_fraction"] = value + m["table_id"] = value } } -// DecodeJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) +// LoadTPUEmbeddingAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func DecodeJpegDctMethod(value string) DecodeJpegAttr { +func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["dct_method"] = value + m["table_name"] = value } } -// Decode a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. +// Load Adagrad embedding parameters with debug support. // -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // +// Arguments: +// parameters: Value of parameters used in the Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the Adagrad optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Adagrad optimization algorithm. // -// This op also supports decoding PNGs and non-animated GIFs since the interface is -// the same, though it is cleaner to use `tf.image.decode_image`. // -// Arguments: -// contents: 0-D. The JPEG-encoded image. // -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { +// Returns the created operation. +func LoadTPUEmbeddingAdagradParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeJpeg", + Type: "LoadTPUEmbeddingAdagradParametersGradAccumDebug", Input: []tf.Input{ - contents, + parameters, accumulators, gradient_accumulators, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Inverse 3D real-valued fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 3 dimensions of `input`. -// -// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 3 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along each axis `IRFFT3D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingFTRLParametersGradAccumDebug. +type RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve FTRL embedding parameters with debug support. // -// Returns A float32 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 3D real Fourier transform. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// @compatibility(numpy) -// Equivalent to np.irfftn with 3 dimensions. -// @end_compatibility -func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Returns Parameter parameters updated by the FTRL optimization algorithm.Parameter accumulators updated by the FTRL optimization algorithm.Parameter linears updated by the FTRL optimization algorithm.Parameter gradient_accumulators updated by the FTRL optimization algorithm. +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IRFFT3D", - Input: []tf.Input{ - input, fft_length, - }, + Type: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", + + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Returns the truth value of (x != y) element-wise. -// -// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// A dataset that splits the elements of its input into multiple elements. +func ExperimentalUnbatchDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "NotEqual", + Type: "ExperimentalUnbatchDataset", Input: []tf.Input{ - x, y, + input_dataset, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Produces the max pool of the input tensor for quantized types. +// StringFormatAttr is an optional argument to StringFormat. +type StringFormatAttr func(optionalAttr) + +// StringFormatTemplate sets the optional template attribute to value. // -// Arguments: -// input: The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// ksize: The size of the window for each dimension of the input tensor. -// The length must be 4 to match the number of dimensions of the input. -// strides: The stride of the sliding window for each dimension of the input -// tensor. The length must be 4 to match the number of dimensions of the input. -// padding: The type of padding algorithm to use. +// value: A string, the template to format tensor summaries into. +// If not specified, defaults to "%s" +func StringFormatTemplate(value string) StringFormatAttr { + return func(m optionalAttr) { + m["template"] = value + } +} + +// StringFormatPlaceholder sets the optional placeholder attribute to value. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMaxPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { - if scope.Err() != nil { - return +// value: A string, at each placeholder in the template a subsequent tensor summary will be inserted. +// If not specified, defaults to "%s" +func StringFormatPlaceholder(value string) StringFormatAttr { + return func(m optionalAttr) { + m["placeholder"] = value } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "QuantizedMaxPool", - Input: []tf.Input{ - input, min_input, max_input, - }, - Attrs: attrs, +} + +// StringFormatSummarize sets the optional summarize attribute to value. +// +// value: When formatting the tensor summaries print the first and last summarize entries of each tensor dimension. +// If not specified, defaults to 3 +func StringFormatSummarize(value int64) StringFormatAttr { + return func(m optionalAttr) { + m["summarize"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) } -// Computes softplus: `log(exp(features) + 1)`. -func Softplus(scope *Scope, features tf.Output) (activations tf.Output) { +// Formats a string template using a list of tensors. +// +// Formats a string template using a list of tensors, pretty-printing tensor summaries. +// +// Arguments: +// inputs: The list of tensors to format into the placeholder string. +// +// Returns = The resulting string scalar. +func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Softplus", + Type: "StringFormat", Input: []tf.Input{ - features, + tf.OutputList(inputs), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes exponential of x - 1 element-wise. +// Returns true if queue is closed. // -// I.e., \\(y = (\exp x) - 1\\). -func Expm1(scope *Scope, x tf.Output) (y tf.Output) { +// This operation returns true if the queue is closed and false if the queue +// is open. +// +// Arguments: +// handle: The handle to a queue. +func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Expm1", + Type: "QueueIsClosedV2", Input: []tf.Input{ - x, + handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the number of records this Reader has produced. -// -// This is the same as the number of ReaderRead executions that have -// succeeded. -// -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { +// Computes inverse hyperbolic tangent of x element-wise. +func Atanh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderNumRecordsProducedV2", + Type: "Atanh", Input: []tf.Input{ - reader_handle, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorListConcatAttr is an optional argument to TensorListConcat. -type TensorListConcatAttr func(optionalAttr) - -// TensorListConcatElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorListConcatElementShape(value tf.Shape) TensorListConcatAttr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Concats all tensors in the list along the 0th dimension. +// Computes the reverse mode backpropagated gradient of the Cholesky algorithm. // -// Requires that all tensors have the same shape except the first dimension. +// For an explanation see "Differentiation of the Cholesky algorithm" by +// Iain Murray http://arxiv.org/abs/1602.07527. // -// input_handle: The input list. -// tensor: The concated result. -// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. +// Arguments: +// l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`. +// Algorithm depends only on lower triangular part of the innermost matrices of +// this tensor. +// grad: df/dl where f is some scalar function. Shape is `[..., M, M]`. +// Algorithm depends only on lower triangular part of the innermost matrices of +// this tensor. // -func TensorListConcat(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListConcatAttr) (tensor tf.Output, lengths tf.Output) { +// Returns Symmetrized version of df/dA . Shape is `[..., M, M]` +func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TensorListConcat", + Type: "CholeskyGrad", Input: []tf.Input{ - input_handle, + l, grad, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Returns the set of files matching one or more glob patterns. +// Assigns a new value to a variable. // -// Note that this routine only supports wildcard characters in the -// basename portion of the pattern, not in the directory portion. -// Note also that the order of filenames returned can be non-deterministic. +// Any ReadVariableOp with a control dependency on this op is guaranteed to return +// this value or a subsequent newer value of the variable. // // Arguments: -// pattern: Shell wildcard pattern(s). Scalar or vector of type string. +// resource: handle to the resource in which to store the variable. +// value: the value to set the new tensor to use. // -// Returns A vector of matching filenames. -func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { +// Returns the created operation. +func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatchingFiles", + Type: "AssignVariableOp", Input: []tf.Input{ - pattern, + resource, value, }, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth. -type HistogramFixedWidthAttr func(optionalAttr) - -// HistogramFixedWidthDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_INT32 -func HistogramFixedWidthDtype(value tf.DataType) HistogramFixedWidthAttr { - return func(m optionalAttr) { - m["dtype"] = value - } + return scope.AddOperation(opspec) } -// Return histogram of values. -// -// Given the tensor `values`, this operation returns a rank 1 histogram counting -// the number of entries in `values` that fall into every bin. The bins are -// equal width and determined by the arguments `value_range` and `nbins`. -// -// ```python -// # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) -// nbins = 5 -// value_range = [0.0, 5.0] -// new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] -// -// with tf.get_default_session() as sess: -// hist = tf.histogram_fixed_width(new_values, value_range, nbins=5) -// variables.global_variables_initializer().run() -// sess.run(hist) => [2, 1, 1, 0, 2] -// ``` -// -// Arguments: -// values: Numeric `Tensor`. -// value_range: Shape [2] `Tensor` of same `dtype` as `values`. -// values <= value_range[0] will be mapped to hist[0], -// values >= value_range[1] will be mapped to hist[-1]. -// nbins: Scalar `int32 Tensor`. Number of histogram bins. +// Returns a tensor of ones with the same shape and type as x. // -// Returns A 1-D `Tensor` holding histogram of values. -func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, nbins tf.Output, optional ...HistogramFixedWidthAttr) (out tf.Output) { +// Arguments: +// x: a tensor of type T. +// +// Returns a tensor of the same shape and type as x but filled with ones. +func OnesLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "HistogramFixedWidth", + Type: "OnesLike", Input: []tf.Input{ - values, value_range, nbins, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv3DAttr is an optional argument to Conv3D. -type Conv3DAttr func(optionalAttr) - -// Conv3DDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DDataFormat(value string) Conv3DAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv3DDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DDilations(value []int64) Conv3DAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 3-D convolution given 5-D `input` and `filter` tensors. +// The gradient of SparseFillEmptyRows. // -// In signal processing, cross-correlation is a measure of similarity of -// two waveforms as a function of a time-lag applied to one of them. This -// is also known as a sliding dot product or sliding inner-product. +// Takes vectors reverse_index_map, shaped `[N]`, and grad_values, +// shaped `[N_full]`, where `N_full >= N` and copies data into either +// `d_values` or `d_default_value`. Here `d_values` is shaped `[N]` and +// `d_default_value` is a scalar. // -// Our Conv3D implements a form of cross-correlation. +// d_values[j] = grad_values[reverse_index_map[j]] +// d_default_value = sum_{k : 0 .. N_full - 1} ( +// grad_values[k] * 1{k not in reverse_index_map}) // // Arguments: -// input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. -// filter: Shape `[filter_depth, filter_height, filter_width, in_channels, -// out_channels]`. `in_channels` must match between `input` and `filter`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv3DAttr) (output tf.Output) { +// reverse_index_map: 1-D. The reverse index map from SparseFillEmptyRows. +// grad_values: 1-D. The gradients from backprop. +// +// Returns 1-D. The backprop into values.0-D. The backprop into default_value. +func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_values tf.Output) (d_values tf.Output, d_default_value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "SparseFillEmptyRowsGrad", + Input: []tf.Input{ + reverse_index_map, grad_values, + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Creates a dataset that zips together `input_datasets`. +func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Conv3D", + Type: "ZipDataset", Input: []tf.Input{ - input, filter, + tf.OutputList(input_datasets), }, Attrs: attrs, } @@ -22222,191 +26982,143 @@ func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, pa return op.Output(0) } -// QuantizeV2Attr is an optional argument to QuantizeV2. -type QuantizeV2Attr func(optionalAttr) +// LoadTPUEmbeddingAdagradParametersAttr is an optional argument to LoadTPUEmbeddingAdagradParameters. +type LoadTPUEmbeddingAdagradParametersAttr func(optionalAttr) -// QuantizeV2Mode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func QuantizeV2Mode(value string) QuantizeV2Attr { +// LoadTPUEmbeddingAdagradParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdagradParametersTableId(value int64) LoadTPUEmbeddingAdagradParametersAttr { return func(m optionalAttr) { - m["mode"] = value + m["table_id"] = value } } -// QuantizeV2RoundMode sets the optional round_mode attribute to value. -// If not specified, defaults to "HALF_AWAY_FROM_ZERO" -func QuantizeV2RoundMode(value string) QuantizeV2Attr { +// LoadTPUEmbeddingAdagradParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingAdagradParametersTableName(value string) LoadTPUEmbeddingAdagradParametersAttr { return func(m optionalAttr) { - m["round_mode"] = value + m["table_name"] = value } } -// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. The -// 'round_mode' attribute controls which rounding tie-breaking algorithm is used -// when rounding float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) -// if T == qint8: out[i] -= (range(T) + 1) / 2.0 -// ``` -// -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// Assume the input is type float and has a possible range of [0.0, 6.0] and the -// output type is quint8 ([0, 255]). The min_range and max_range values should be -// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each -// value of the input by 255/6 and cast to quint8. -// -// If the output type was qint8 ([-128, 127]), the operation will additionally -// subtract each value by 128 prior to casting, so that the range of values aligns -// with the range of qint8. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ``` -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = num_discrete_values / range -// quantized = round(input * range_scale) - round(range_min * range_scale) + -// numeric_limits::min() -// quantized = max(quantized, numeric_limits::min()) -// quantized = min(quantized, numeric_limits::max()) -// ``` -// -// The biggest difference between this and MIN_COMBINED is that the minimum range -// is rounded first, before it's subtracted from the rounded value. With -// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing -// and dequantizing will introduce a larger and larger error. -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` -// -// Otherwise, if T is unsigned, the fixed-point range is -// -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` -// -// From this we compute our scaling factor, s: -// -// ```c++ -// s = (max_fixed - min_fixed) / (2 * m) -// ``` -// -// Now we can quantize the elements of our tensor: -// -// ```c++ -// result = round(input * s) -// ``` +// Load Adagrad embedding parameters. // -// One thing to watch out for is that the operator may choose to adjust the -// requested minimum and maximum values slightly during the quantization process, -// so you should always use the output ports as the range for further calculations. -// For example, if the requested minimum and maximum values are close to equal, -// they will be separated by a small epsilon value to prevent ill-formed quantized -// buffers from being created. Otherwise, you can end up with buffers where all the -// quantized values map to the same float value, which causes problems for -// operations that have to perform further calculations on them. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: +// parameters: Value of parameters used in the Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the Adagrad optimization algorithm. // -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. // // -// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. -func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns the created operation. +func LoadTPUEmbeddingAdagradParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdagradParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeV2", + Type: "LoadTPUEmbeddingAdagradParameters", Input: []tf.Input{ - input, min_range, max_range, + parameters, accumulators, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Strip leading and trailing whitespaces from the Tensor. +// +// Arguments: +// input: A string `Tensor` of any shape. +// +// Returns A string `Tensor` of the same shape as the input. +func StringStrip(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StringStrip", + Input: []tf.Input{ + input, + }, + } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns the truth value of (x < y) element-wise. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// *NOTE*: `Less` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// The hash function is deterministic on the content of the string within the +// process. The hash function is a keyed hash function, where attribute `key` +// defines the key of the hash function. `key` is an array of 2 elements. +// +// A strong hash is important when inputs may be malicious, e.g. URLs with +// additional components. Adversaries could try to make their inputs hash to the +// same bucket for a denial-of-service attack or to skew the results. A strong +// hash prevents this by making it difficult, if not infeasible, to compute inputs +// that hash to the same bucket. This comes at a cost of roughly 4x higher compute +// time than `tf.string_to_hash_bucket_fast`. +// +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// key: The key for the keyed hash function passed as a list of two uint64 +// elements. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} opspec := tf.OpSpec{ - Type: "Less", + Type: "StringToHashBucketStrong", Input: []tf.Input{ - x, y, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedReluXAttr is an optional argument to QuantizedReluX. -type QuantizedReluXAttr func(optionalAttr) +// StringLengthAttr is an optional argument to StringLength. +type StringLengthAttr func(optionalAttr) -// QuantizedReluXOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedReluXOutType(value tf.DataType) QuantizedReluXAttr { +// StringLengthUnit sets the optional unit attribute to value. +// +// value: The unit that is counted to compute string length. One of: `"BYTE"` (for +// the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8 +// encoded Unicode code points in each string). Results are undefined +// if `unit=UTF8_CHAR` and the `input` strings do not contain structurally +// valid UTF-8. +// If not specified, defaults to "BYTE" +func StringLengthUnit(value string) StringLengthAttr { return func(m optionalAttr) { - m["out_type"] = value + m["unit"] = value } } -// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` -// -// Arguments: +// String lengths of `input`. // +// Computes the length of each string given in the input tensor. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// Arguments: +// input: The string for which to compute the length. // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluXAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// Returns Integer tensor that has the same shape as `input`. The output contains the +// element-wise string lengths of `input`. +func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -22415,111 +27127,217 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedReluX", + Type: "StringLength", Input: []tf.Input{ - features, max_value, min_features, max_features, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Creates a dataset that batches `batch_size` elements from `input_dataset`. +// Performs gradient updates of embedding tables. // // Arguments: +// inputs: A TensorList of gradients with which to update embedding tables. +// This argument has the same length and shapes as the return value of +// RecvTPUEmbeddingActivations, but contains gradients of the model's loss +// with respect to the embedding activations. The embedding tables are updated +// from these gradients via the optimizer specified in the TPU embedding +// configuration given to tpu.initialize_system. +// learning_rates: A TensorList of float32 scalars, one for each dynamic learning +// rate tag: see the comments in +// //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto. +// Multiple tables can share the same dynamic learning rate tag as specified +// in the configuration. If the learning rates for all tables are constant, +// this list should be empty. +// config: Serialized TPUEmbeddingConfiguration proto. // -// batch_size: A scalar representing the number of elements to accumulate in a batch. -// drop_remainder: A scalar representing whether the last batch should be dropped in case its size -// is smaller than desired. -// -// -func BatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns the created operation. +func SendTPUEmbeddingGradients(scope *Scope, inputs []tf.Output, learning_rates []tf.Output, config string) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"config": config} opspec := tf.OpSpec{ - Type: "BatchDatasetV2", + Type: "SendTPUEmbeddingGradients", Input: []tf.Input{ - input_dataset, batch_size, drop_remainder, + tf.OutputList(inputs), tf.OutputList(learning_rates), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Saves the input tensors to disk. -// -// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` -// is written to `filename` with name `tensor_names[i]`. -// -// See also `SaveSlices`. -// -// Arguments: -// filename: Must have a single element. The name of the file to which we write -// the tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// data: `N` tensors to save. +// Computes numerical negative value element-wise. // -// Returns the created operation. -func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) { +// I.e., \\(y = -x\\). +func Neg(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Save", + Type: "Neg", Input: []tf.Input{ - filename, tensor_names, tf.OutputList(data), + x, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns element-wise remainder of division. When `x < 0` xor `y < 0` is +// Receives a tensor value broadcast from another device. +func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} + opspec := tf.OpSpec{ + Type: "CollectiveBcastRecv", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Decode web-safe base64-encoded strings. // -// true, this follows Python semantics in that the result here is consistent -// with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. +// Input may or may not have padding at the end. See EncodeBase64 for padding. +// Web-safe means that input must use - and _ instead of + and /. // -// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func FloorMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// input: Base64 strings to decode. +// +// Returns Decoded strings. +func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FloorMod", + Type: "DecodeBase64", Input: []tf.Input{ - x, y, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient of morphological 2-D dilation with respect to the filter. +// SubstrAttr is an optional argument to Substr. +type SubstrAttr func(optionalAttr) + +// SubstrUnit sets the optional unit attribute to value. +// +// value: The unit that is used to create the substring. One of: `"BYTE"` (for +// defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 +// encoded Unicode code points). The default is `"BYTE"`. Results are undefined if +// `unit=UTF8_CHAR` and the `input` strings do not contain structurally valid +// UTF-8. +// If not specified, defaults to "BYTE" +func SubstrUnit(value string) SubstrAttr { + return func(m optionalAttr) { + m["unit"] = value + } +} + +// Return substrings from `Tensor` of strings. +// +// For each string in the input `Tensor`, creates a substring starting at index +// `pos` with a total length of `len`. +// +// If `len` defines a substring that would extend beyond the length of the input +// string, then as many characters as possible are used. +// +// A negative `pos` indicates distance within the string backwards from the end. +// +// If `pos` specifies an index which is out of range for any of the input strings, +// then an `InvalidArgumentError` is thrown. +// +// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on +// Op creation. +// +// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about +// broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +// +// --- +// +// Examples +// +// Using scalar `pos` and `len`: +// +// ```python +// input = [b'Hello', b'World'] +// position = 1 +// length = 3 +// +// output = [b'ell', b'orl'] +// ``` +// +// Using `pos` and `len` with same shape as `input`: +// +// ```python +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen']] +// position = [[1, 2, 3], +// [1, 2, 3], +// [1, 2, 3]] +// length = [[2, 3, 4], +// [4, 3, 2], +// [5, 5, 5]] +// +// output = [[b'en', b'eve', b'lve'], +// [b'hirt', b'urt', b'te'], +// [b'ixtee', b'vente', b'hteen']] +// ``` +// +// Broadcasting `pos` and `len` onto `input`: +// +// ``` +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen'], +// [b'nineteen', b'twenty', b'twentyone']] +// position = [1, 2, 3] +// length = [1, 2, 3] +// +// output = [[b'e', b'ev', b'lve'], +// [b'h', b'ur', b'tee'], +// [b'i', b've', b'hte'], +// [b'i', b'en', b'nty']] +// ``` +// +// Broadcasting `input` onto `pos` and `len`: +// +// ``` +// input = b'thirteen' +// position = [1, 5, 7] +// length = [3, 2, 1] +// +// output = [b'hir', b'ee', b'n'] +// ``` // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// input: Tensor of strings +// pos: Scalar defining the position of first character in each substring +// len: Scalar defining the number of characters to include in each substring // -// Returns 3-D with shape `[filter_height, filter_width, depth]`. -func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { +// Returns Tensor of substrings +func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output, optional ...SubstrAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Dilation2DBackpropFilter", + Type: "Substr", Input: []tf.Input{ - input, filter, out_backprop, + input, pos, len, }, Attrs: attrs, } @@ -22527,386 +27345,294 @@ func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, o return op.Output(0) } -// Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. +// Exits the current frame to its parent frame. // -// tensor: The tensor to put on the list. -// input_handle: The old list. -// output_handle: A list with the elements of the old list followed by tensor. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func TensorListPushBack(scope *Scope, input_handle tf.Output, tensor tf.Output) (output_handle tf.Output) { +// Exit makes its input `data` available to the parent frame. +// +// Arguments: +// data: The tensor to be made available to the parent frame. +// +// Returns The same tensor as `data`. +func Exit(scope *Scope, data tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListPushBack", + Type: "Exit", Input: []tf.Input{ - input_handle, tensor, + data, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. -type AddSparseToTensorsMapAttr func(optionalAttr) +// RetrieveTPUEmbeddingProximalAdagradParametersAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParameters. +type RetrieveTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) -// AddSparseToTensorsMapContainer sets the optional container attribute to value. +// RetrieveTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingProximalAdagradParametersTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersAttr { return func(m optionalAttr) { - m["container"] = value + m["table_id"] = value } } -// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. +// RetrieveTPUEmbeddingProximalAdagradParametersTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { +func RetrieveTPUEmbeddingProximalAdagradParametersTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["table_name"] = value } } -// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. -// -// A `SparseTensor` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`. -// -// This operator takes the given `SparseTensor` and adds it to a container -// object (a `SparseTensorsMap`). A unique key within this container is generated -// in the form of an `int64`, and this is the value that is returned. -// -// The `SparseTensor` can then be read out as part of a minibatch by passing -// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddSparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// Retrieve proximal Adagrad embedding parameters. // -// Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// Returns 0-D. The handle of the `SparseTensor` now stored in the -// `SparseTensorsMap`. -func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { +// Returns Parameter parameters updated by the proximal Adagrad optimization algorithm.Parameter accumulators updated by the proximal Adagrad optimization algorithm. +func RetrieveTPUEmbeddingProximalAdagradParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersAttr) (parameters tf.Output, accumulators tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AddSparseToTensorsMap", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, + Type: "RetrieveTPUEmbeddingProximalAdagradParameters", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Deserialize and concatenate `SparseTensors` from a serialized minibatch. -// -// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where -// `N` is the minibatch size and the rows correspond to packed outputs of -// `SerializeSparse`. The ranks of the original `SparseTensor` objects -// must all match. When the final `SparseTensor` is created, it has rank one -// higher than the ranks of the incoming `SparseTensor` objects -// (they have been concatenated along a new row dimension). -// -// The output `SparseTensor` object's shape values for all dimensions but the -// first are the max across the input `SparseTensor` objects' shape values -// for the corresponding dimensions. Its first shape value is `N`, the minibatch -// size. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: -// -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// -// and -// -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// -// then the final deserialized `SparseTensor` will be: +// Produce a string tensor that encodes the state of a Reader. // -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] +// Not all Readers support being serialized, so this can produce an +// Unimplemented error. // // Arguments: -// serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects. -// Must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// reader_handle: Handle to a Reader. +func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "DeserializeManySparse", + Type: "ReaderSerializeStateV2", Input: []tf.Input{ - serialized_sparse, + reader_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Elementwise computes the bitwise AND of `x` and `y`. +// Returns the number of tensors in the input tensor list. // -// The result will have those bits set, that are set in both `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// input_handle: the input list +// length: the number of tensors in the list +func TensorListLength(scope *Scope, input_handle tf.Output) (length tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BitwiseAnd", + Type: "TensorListLength", Input: []tf.Input{ - x, y, + input_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Inverse real-valued fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most dimension of `input`. -// -// The inner-most dimension of `input` is assumed to be the result of `RFFT`: the -// `fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If -// `fft_length` is not provided, it is computed from the size of the inner-most -// dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to -// compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller -// than the corresponding dimension of `input`, the dimension is cropped. If it is -// larger, the dimension is padded with zeros. +// Creates a dataset with a range of values. Corresponds to python's xrange. // // Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [1]. The FFT length. +// start: corresponds to start in python's xrange(). +// stop: corresponds to stop in python's xrange(). +// step: corresponds to step in python's xrange(). // -// Returns A float32 tensor of the same rank as `input`. The inner-most -// dimension of `input` is replaced with the `fft_length` samples of its inverse -// 1D Fourier transform. // -// @compatibility(numpy) -// Equivalent to np.fft.irfft -// @end_compatibility -func IRFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IRFFT", + Type: "RangeDataset", Input: []tf.Input{ - input, fft_length, + start, stop, step, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Concatenates a list of `SparseTensor` along the specified dimension. -// -// Concatenation is with respect to the dense versions of these sparse tensors. -// It is assumed that each input is a `SparseTensor` whose elements are ordered -// along increasing dimension number. -// -// All inputs' shapes must match, except for the concat dimension. The -// `indices`, `values`, and `shapes` lists must have the same length. -// -// The output shape is identical to the inputs', except along the concat -// dimension, where it is the sum of the inputs' sizes along that dimension. -// -// The output elements will be resorted to preserve the sort order along -// increasing dimension number. -// -// This op runs in `O(M log M)` time, where `M` is the total number of non-empty -// values across all inputs. This is due to the need for an internal sort in -// order to concatenate efficiently across an arbitrary dimension. -// -// For example, if `concat_dim = 1` and the inputs are -// -// sp_inputs[0]: shape = [2, 3] -// [0, 2]: "a" -// [1, 0]: "b" -// [1, 1]: "c" -// -// sp_inputs[1]: shape = [2, 4] -// [0, 1]: "d" -// [0, 2]: "e" -// -// then the output will be -// -// shape = [2, 7] -// [0, 2]: "a" -// [0, 4]: "d" -// [0, 5]: "e" -// [1, 0]: "b" -// [1, 1]: "c" -// -// Graphically this is equivalent to doing -// -// [ a] concat [ d e ] = [ a d e ] -// [b c ] [ ] [b c ] -// -// Arguments: -// indices: 2-D. Indices of each input `SparseTensor`. -// values: 1-D. Non-empty values of each `SparseTensor`. -// shapes: 1-D. Shapes of each `SparseTensor`. -// concat_dim: Dimension to concatenate along. Must be in range [-rank, rank), -// where rank is the number of dimensions in each input `SparseTensor`. -// -// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. -func SparseConcat(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, concat_dim int64) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Computes inverse hyperbolic sine of x element-wise. +func Asinh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"concat_dim": concat_dim} opspec := tf.OpSpec{ - Type: "SparseConcat", + Type: "Asinh", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Generates sparse cross from a list of sparse and dense tensors. -// -// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each -// representing features of one feature column. It outputs a 2D `SparseTensor` with -// the batchwise crosses of these features. +// UnicodeTranscodeAttr is an optional argument to UnicodeTranscode. +type UnicodeTranscodeAttr func(optionalAttr) + +// UnicodeTranscodeErrors sets the optional errors attribute to value. // -// For example, if the inputs are +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeTranscodeErrors(value string) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeTranscodeReplacementChar sets the optional replacement_char attribute to value. // -// inputs[0]: SparseTensor with shape = [2, 2] -// [0, 0]: "a" -// [1, 0]: "b" -// [1, 1]: "c" +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) // -// inputs[1]: SparseTensor with shape = [2, 1] -// [0, 0]: "d" -// [1, 0]: "e" +// Note that for UTF-8, passing a replacement character expressible in 1 byte, such +// as ' ', will preserve string alignment to the source since invalid bytes will be +// replaced with a 1-byte replacement. For UTF-16-BE and UTF-16-LE, any 1 or 2 byte +// replacement character will preserve byte alignment to the source. +// If not specified, defaults to 65533 +func UnicodeTranscodeReplacementChar(value int64) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeTranscodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. // -// inputs[2]: Tensor [["f"], ["g"]] +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeTranscodeReplaceControlCharacters(value bool) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// Transcode the input text from a source encoding to a destination encoding. // -// then the output will be +// The input is a string tensor of any shape. The output is a string tensor of +// the same shape containing the transcoded strings. Output strings are always +// valid unicode. If the input contains invalid encoding positions, the +// `errors` attribute sets the policy for how to deal with them. If the default +// error-handling policy is used, invalid formatting will be substituted in the +// output by the `replacement_char`. If the errors policy is to `ignore`, any +// invalid encoding positions in the input are skipped and not included in the +// output. If it set to `strict` then any invalid formatting will result in an +// InvalidArgument error. // -// shape = [2, 2] -// [0, 0]: "a_X_d_X_f" -// [1, 0]: "b_X_e_X_g" -// [1, 1]: "c_X_e_X_g" +// This operation can be used with `output_encoding = input_encoding` to enforce +// correct formatting for inputs even if they are already in the desired encoding. // -// if hashed_output=true then the output will be +// If the input is prefixed by a Byte Order Mark needed to determine encoding +// (e.g. if the encoding is UTF-16 and the BOM indicates big-endian), then that +// BOM will be consumed and not emitted into the output. If the input encoding +// is marked with an explicit endianness (e.g. UTF-16-BE), then the BOM is +// interpreted as a non-breaking-space and is preserved in the output (including +// always for UTF-8). // -// shape = [2, 2] -// [0, 0]: FingerprintCat64( -// Fingerprint64("f"), FingerprintCat64( -// Fingerprint64("d"), Fingerprint64("a"))) -// [1, 0]: FingerprintCat64( -// Fingerprint64("g"), FingerprintCat64( -// Fingerprint64("e"), Fingerprint64("b"))) -// [1, 1]: FingerprintCat64( -// Fingerprint64("g"), FingerprintCat64( -// Fingerprint64("e"), Fingerprint64("c"))) +// The end result is that if the input is marked as an explicit endianness the +// transcoding is faithful to all codepoints in the source. If it is not marked +// with an explicit endianness, the BOM is not considered part of the string itself +// but as metadata, and so is not preserved in the output. // // Arguments: -// indices: 2-D. Indices of each input `SparseTensor`. -// values: 1-D. values of each `SparseTensor`. -// shapes: 1-D. Shapes of each `SparseTensor`. -// dense_inputs: 2-D. Columns represented by dense `Tensor`. -// hashed_output: If true, returns the hash of the cross instead of the string. -// This will allow us avoiding string manipulations. -// num_buckets: It is used if hashed_output is true. -// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. -// hash_key: Specify the hash_key that will be used by the `FingerprintCat64` -// function to combine the crosses fingerprints. -// -// +// input: The text to be processed. Can have any shape. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// output_encoding: The unicode encoding to use in the output. Must be one of +// `"UTF-8", "UTF-16-BE", "UTF-32-BE"`. Multi-byte encodings will be big-endian. // -// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated or hashed -// `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. -func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, hashed_output bool, num_buckets int64, hash_key int64, out_type tf.DataType, internal_type tf.DataType) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Returns A string tensor containing unicode text encoded using `output_encoding`. +func UnicodeTranscode(scope *Scope, input tf.Output, input_encoding string, output_encoding string, optional ...UnicodeTranscodeAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_type": out_type, "internal_type": internal_type} + attrs := map[string]interface{}{"input_encoding": input_encoding, "output_encoding": output_encoding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseCross", + Type: "UnicodeTranscode", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) +// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. +type ResourceApplyRMSPropAttr func(optionalAttr) -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { +func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. +// Update '*var' according to the RMSProp algorithm. // -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // // Arguments: // var_: Should be from a Variable(). -// accum: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. // grad: The gradient. // // Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { +func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -22915,247 +27641,279 @@ func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", + Type: "ResourceApplyRMSProp", Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, + var_, ms, mom, lr, rho, momentum, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. // -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { return func(m optionalAttr) { - m["container"] = value + m["dtype"] = value } } -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// Outputs deterministic pseudorandom values from a truncated normal distribution. // -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. -// If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// Creates an empty hash table. +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. // -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// The gradient operator for the SparseSlice op. -// -// This op takes in the upstream gradient w.r.t. non-empty values of -// the sliced `SparseTensor`, and outputs the gradients w.r.t. -// the non-empty values of input `SparseTensor`. -// -// Arguments: -// backprop_val_grad: 1-D. The gradient with respect to -// the non-empty values of the sliced `SparseTensor`. -// input_indices: 2-D. The `indices` of the input `SparseTensor`. -// input_start: 1-D. tensor represents the start of the slice. -// output_indices: 2-D. The `indices` of the sliced `SparseTensor`. -// -// Returns 1-D. The gradient with respect to the non-empty values of input `SparseTensor`. -func SparseSliceGrad(scope *Scope, backprop_val_grad tf.Output, input_indices tf.Output, input_start tf.Output, output_indices tf.Output) (val_grad tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSliceGrad", + Type: "StatelessTruncatedNormal", Input: []tf.Input{ - backprop_val_grad, input_indices, input_start, output_indices, + shape, seed, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient of the sigmoid of `x` wrt its input. +// RestoreSliceAttr is an optional argument to RestoreSlice. +type RestoreSliceAttr func(optionalAttr) + +// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. // -// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and -// `dy` is the corresponding input gradient. -func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SigmoidGrad", - Input: []tf.Input{ - y, dy, - }, +// value: Index of file to open first if multiple files match +// `file_pattern`. See the documentation for `Restore`. +// If not specified, defaults to -1 +func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { + return func(m optionalAttr) { + m["preferred_shard"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Convert one or more images from HSV to RGB. +// Restores a tensor from checkpoint files. // -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// This is like `Restore` except that restored tensor can be listed as filling +// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +// larger tensor and the slice that the restored tensor covers. // -// See `rgb_to_hsv` for a description of the HSV encoding. +// The `shape_and_slice` input has the same format as the +// elements of the `shapes_and_slices` input of the `SaveSlices` op. // // Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// shape_and_slice: Scalar. The shapes and slice specifications to use when +// restoring a tensors. +// dt: The type of the tensor to be restored. // -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { +// Returns The restored tensor. +func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "HSVToRGB", + Type: "RestoreSlice", Input: []tf.Input{ - images, + file_pattern, tensor_name, shape_and_slice, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset by applying optimizations to `input_dataset`. +// Convert the quantized 'input' tensor into a lower-precision 'output', using the // -// Creates a dataset by applying optimizations to `input_dataset`. +// actual distribution of the values to maximize the usage of the lower bit depth +// and adjusting the output min and max ranges accordingly. +// +// [input_min, input_max] are scalar floats that specify the range for the float +// interpretation of the 'input' data. For example, if input_min is -1.0f and +// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 +// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. +// +// This operator tries to squeeze as much precision as possible into an output with +// a lower bit depth by calculating the actual min and max values found in the +// data. For example, maybe that quint16 input has no values lower than 16,384 and +// none higher than 49,152. That means only half the range is actually needed, all +// the float interpretations are between -0.5f and 0.5f, so if we want to compress +// the data into a quint8 output, we can use that range rather than the theoretical +// -1.0f to 1.0f that is suggested by the input min and max. +// +// In practice, this is most useful for taking output from operations like +// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and +// may have large potential output ranges, but in practice have a distribution of +// input values that only uses a small fraction of the possible range. By feeding +// that output into this operator, we can reduce it from 32 bits down to 8 with +// minimal loss of accuracy. // // Arguments: -// input_dataset: A variant tensor representing the input dataset. -// optimizations: A `tf.string` vector `tf.Tensor` identifying optimizations to use. // +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// out_type: The type of the output. Should be a lower bit depth than Tinput. // -func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "OptimizeDataset", + Type: "QuantizeDownAndShrinkRange", Input: []tf.Input{ - input_dataset, optimizations, + input, input_min, input_max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Returns the element-wise min of two SparseTensors. +// RandomGammaAttr is an optional argument to RandomGamma. +type RandomGammaAttr func(optionalAttr) + +// RandomGammaSeed sets the optional seed attribute to value. // -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomGammaSeed(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomGammaSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomGammaSeed2(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from the Gamma distribution(s) described by alpha. +// +// This op uses the algorithm by Marsaglia et al. to acquire samples via +// transformation-rejection from pairs of uniform and normal random variables. +// See http://dl.acm.org/citation.cfm?id=358414 // // Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in alpha. +// alpha: A tensor in which each scalar is a "shape" parameter describing the +// associated gamma distribution. // -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { +// Returns A tensor with shape `shape + shape(alpha)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. +func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSparseMinimum", + Type: "RandomGamma", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, + shape, alpha, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// ResourceApplyAdamWithAmsgradAttr is an optional argument to ResourceApplyAdamWithAmsgrad. -type ResourceApplyAdamWithAmsgradAttr func(optionalAttr) +// ResourceScatterNdSubAttr is an optional argument to ResourceScatterNdSub. +type ResourceScatterNdSubAttr func(optionalAttr) -// ResourceApplyAdamWithAmsgradUseLocking sets the optional use_locking attribute to value. +// ResourceScatterNdSubUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamWithAmsgradUseLocking(value bool) ResourceApplyAdamWithAmsgradAttr { +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdSubUseLocking(value bool) ResourceScatterNdSubAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the Adam algorithm. +// Applies sparse subtraction to individual values or slices in a Variable. // -// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ -// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ -// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ -// $$vhat_t := max{vhat_{t-1}, v_t}$$ -// $$variable := variable - lr_t * m_t / (\sqrt{vhat_t} + \epsilon)$$ +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +// ``` +// +// For example, say we want to subtract 4 scattered elements from a rank-1 tensor +// with 8 elements. In Python, that subtraction would look like this: +// +// ```python +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// sub = tf.scatter_nd_sub(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(sub) +// ``` +// +// The resulting update to ref would look like this: +// +// [1, -9, 3, -6, -4, 6, 7, -4] +// +// See `tf.scatter_nd` for more details about how to make updates to +// slices. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// vhat: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of +// values to add to ref. // // Returns the created operation. -func ResourceApplyAdamWithAmsgrad(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, vhat tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamWithAmsgradAttr) (o *tf.Operation) { +func ResourceScatterNdSub(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdSubAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -23164,461 +27922,474 @@ func ResourceApplyAdamWithAmsgrad(scope *Scope, var_ tf.Output, m tf.Output, v t a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdamWithAmsgrad", + Type: "ResourceScatterNdSub", Input: []tf.Input{ - var_, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + ref, indices, updates, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. -type MapUnstageNoKeyAttr func(optionalAttr) - -// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Outputs deterministic pseudorandom random integers from a uniform distribution. // -// REQUIRES: value >= 0 -func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// The generated values follow a uniform distribution in the range `[minval, maxval)`. // -// REQUIRES: value >= 0 -func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// minval: Minimum value (inclusive, scalar). +// maxval: Maximum value (exclusive, scalar). +// +// Returns Random values with specified shape. +func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniformInt", + Input: []tf.Input{ + shape, seed, minval, maxval, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapUnstageNoKeyContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { +// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. +type QuantizedConv2DAttr func(optionalAttr) + +// QuantizedConv2DOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { return func(m optionalAttr) { - m["container"] = value + m["out_type"] = value } } -// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { +// QuantizedConv2DDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["dilations"] = value } } -// Op removes and returns a random (key, value) +// Computes a 2D convolution given quantized 4D input and filter tensors. // -// from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { +// The inputs are quantized tensors where the lowest value represents the real +// number of the associated minimum, and the highest represents the maximum. +// This means that you can only interpret the quantized output in the same way, by +// taking the returned minimum and maximum values into account. +// +// Arguments: +// +// filter: filter's input_depth dimension must match input's depth dimensions. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// min_filter: The float value that the lowest quantized filter value represents. +// max_filter: The float value that the highest quantized filter value represents. +// strides: The stride of the sliding window for each dimension of the input +// tensor. +// padding: The type of padding algorithm to use. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapUnstageNoKey", + Type: "QuantizedConv2D", Input: []tf.Input{ - indices, + input, filter, min_input, max_input, min_filter, max_filter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - key = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstageNoKey", err) - return - } - return key, values + return op.Output(0), op.Output(1), op.Output(2) } -// HashTableV2Attr is an optional argument to HashTableV2. -type HashTableV2Attr func(optionalAttr) +// ResourceGatherAttr is an optional argument to ResourceGather. +type ResourceGatherAttr func(optionalAttr) -// HashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func HashTableV2Container(value string) HashTableV2Attr { +// ResourceGatherBatchDims sets the optional batch_dims attribute to value. +// If not specified, defaults to 0 +func ResourceGatherBatchDims(value int64) ResourceGatherAttr { return func(m optionalAttr) { - m["container"] = value + m["batch_dims"] = value } } -// HashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func HashTableV2SharedName(value string) HashTableV2Attr { +// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["validate_indices"] = value } } -// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// Gather slices from the variable pointed to by `resource` according to `indices`. // -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates a non-initialized hash table. +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: // -// This op creates a hash table, specifying the type of its keys and values. -// Before using the table you will have to initialize it. After initialization the -// table will be immutable. +// ```python +// # Scalar indices +// output[:, ..., :] = params[indices, :, ... :] // -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// # Vector indices +// output[i, :, ..., :] = params[indices[i], :, ... :] // -// Returns Handle to a table. -func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { +// # Higher rank indices +// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] +// ``` +func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype tf.DataType, optional ...ResourceGatherAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "HashTableV2", - + Type: "ResourceGather", + Input: []tf.Input{ + resource, indices, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap. -type TakeManySparseFromTensorsMapAttr func(optionalAttr) +// StatelessMultinomialAttr is an optional argument to StatelessMultinomial. +type StatelessMultinomialAttr func(optionalAttr) -// TakeManySparseFromTensorsMapContainer sets the optional container attribute to value. -// -// value: The container name for the `SparseTensorsMap` read by this op. -// If not specified, defaults to "" -func TakeManySparseFromTensorsMapContainer(value string) TakeManySparseFromTensorsMapAttr { +// StatelessMultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func StatelessMultinomialOutputDtype(value tf.DataType) StatelessMultinomialAttr { return func(m optionalAttr) { - m["container"] = value + m["output_dtype"] = value } } -// TakeManySparseFromTensorsMapSharedName sets the optional shared_name attribute to value. +// Draws samples from a multinomial distribution. // -// value: The shared name for the `SparseTensorsMap` read by this op. -// It should not be blank; rather the `shared_name` or unique Operation name -// of the Op that created the original `SparseTensorsMap` should be used. -// If not specified, defaults to "" -func TakeManySparseFromTensorsMapSharedName(value string) TakeManySparseFromTensorsMapAttr { - return func(m optionalAttr) { - m["shared_name"] = value +// Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. +// seed: 2 seeds (shape [2]). +// +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func StatelessMultinomial(scope *Scope, logits tf.Output, num_samples tf.Output, seed tf.Output, optional ...StatelessMultinomialAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessMultinomial", + Input: []tf.Input{ + logits, num_samples, seed, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. -// -// The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where -// `N` is the minibatch size and the rows correspond to the output handles of -// `AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the -// original `SparseTensor` objects that went into the given input ops must all -// match. When the final `SparseTensor` is created, it has rank one -// higher than the ranks of the incoming `SparseTensor` objects -// (they have been concatenated along a new row dimension on the left). -// -// The output `SparseTensor` object's shape values for all dimensions but the -// first are the max across the input `SparseTensor` objects' shape values -// for the corresponding dimensions. Its first shape value is `N`, the minibatch -// size. +// Returns a batched matrix tensor with new batched diagonal values. // -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. +// Given `input` and `diagonal`, this operation returns a tensor with the +// same shape and values as `input`, except for the main diagonal of the +// innermost matrices. These will be overwritten by the values in `diagonal`. // -// For example, if the handles represent an input, which is a `[2, 3]` matrix -// representing two original `SparseTensor` objects: +// The output is computed as follows: // -// ``` -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// ``` +// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has +// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a +// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: // -// and +// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. +// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. // -// ``` -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// ``` +// Arguments: +// input: Rank `k+1`, where `k >= 1`. +// diagonal: Rank `k`, where `k >= 1`. // -// then the final `SparseTensor` will be: +// Returns Rank `k+1`, with `output.shape = input.shape`. +func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixSetDiag", + Input: []tf.Input{ + input, diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise max of two SparseTensors. // -// ``` -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] -// ``` +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. // // Arguments: -// sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. -// Shape: `[N]`. -// dtype: The `dtype` of the `SparseTensor` objects stored in the -// `SparseTensorsMap`. +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. // -// Returns 2-D. The `indices` of the minibatch `SparseTensor`.1-D. The `values` of the minibatch `SparseTensor`.1-D. The `shape` of the minibatch `SparseTensor`. -func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype tf.DataType, optional ...TakeManySparseFromTensorsMapAttr) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TakeManySparseFromTensorsMap", + Type: "SparseSparseMaximum", Input: []tf.Input{ - sparse_handles, + a_indices, a_values, a_shape, b_indices, b_values, b_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1) } -// ResourceSparseApplyKerasMomentumAttr is an optional argument to ResourceSparseApplyKerasMomentum. -type ResourceSparseApplyKerasMomentumAttr func(optionalAttr) +// LoadTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to LoadTPUEmbeddingMDLAdagradLightParameters. +type LoadTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) -// ResourceSparseApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. +// LoadTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyKerasMomentumUseLocking(value bool) ResourceSparseApplyKerasMomentumAttr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingMDLAdagradLightParametersTableId(value int64) LoadTPUEmbeddingMDLAdagradLightParametersAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["table_id"] = value } } -// ResourceSparseApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var + momentum * accum, so in the end, the var you get is actually -// var + momentum * accum. -// If not specified, defaults to false -func ResourceSparseApplyKerasMomentumUseNesterov(value bool) ResourceSparseApplyKerasMomentumAttr { +// LoadTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMDLAdagradLightParametersTableName(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["table_name"] = value } } -// Update relevant entries in '*var' and '*accum' according to the momentum scheme. +// Load MDL Adagrad Light embedding parameters. // -// Set use_nesterov = True if you want to use Nesterov momentum. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // -// That is for rows we have grad for, we update var and accum as follows: +// Arguments: +// parameters: Value of parameters used in the MDL Adagrad Light optimization algorithm. +// accumulators: Value of accumulators used in the MDL Adagrad Light optimization algorithm. +// weights: Value of weights used in the MDL Adagrad Light optimization algorithm. +// benefits: Value of benefits used in the MDL Adagrad Light optimization algorithm. // -// accum = accum * momentum - lr * grad -// var += accum // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// momentum: Momentum. Must be a scalar. // // Returns the created operation. -func ResourceSparseApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyKerasMomentumAttr) (o *tf.Operation) { +func LoadTPUEmbeddingMDLAdagradLightParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMDLAdagradLightParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyKerasMomentum", + Type: "LoadTPUEmbeddingMDLAdagradLightParameters", Input: []tf.Input{ - var_, accum, lr, grad, indices, momentum, + parameters, accumulators, weights, benefits, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Assigns a new value to a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to return -// this value or a subsequent newer value of the variable. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value to set the new tensor to use. -// -// Returns the created operation. -func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignVariableOp", - Input: []tf.Input{ - resource, value, - }, - } - return scope.AddOperation(opspec) -} - -// Strip leading and trailing whitespaces from the Tensor. -// -// Arguments: -// input: A string `Tensor` of any shape. +// List of the given size with empty elements. // -// Returns A string `Tensor` of the same shape as the input. -func StringStrip(scope *Scope, input tf.Output) (output tf.Output) { +// element_shape: the shape of the future elements of the list +// num_elements: the number of elements to reserve +// handle: the output list +// element_dtype: the desired type of elements in the list. +func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "StringStrip", + Type: "TensorListReserve", Input: []tf.Input{ - input, + element_shape, num_elements, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a tensor of ones with the same shape and type as x. -// -// Arguments: -// x: a tensor of type T. +// Computes the gradient for the inverse of `x` wrt its input. // -// Returns a tensor of the same shape and type as x but filled with ones. -func OnesLike(scope *Scope, x tf.Output) (y tf.Output) { +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "OnesLike", + Type: "InvGrad", Input: []tf.Input{ - x, + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// The gradient of SparseFillEmptyRows. +// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. // -// Takes vectors reverse_index_map, shaped `[N]`, and grad_values, -// shaped `[N_full]`, where `N_full >= N` and copies data into either -// `d_values` or `d_default_value`. Here `d_values` is shaped `[N]` and -// `d_default_value` is a scalar. +// This operation computes // -// d_values[j] = grad_values[reverse_index_map[j]] -// d_default_value = sum_{k : 0 .. N_full - 1} ( -// grad_values[k] * 1{k not in reverse_index_map}) +// # Scalar indices +// ref[indices, ...] = min(ref[indices, ...], updates[...]) +// +// # Vector indices (for each i) +// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
// // Arguments: -// reverse_index_map: 1-D. The reverse index map from SparseFillEmptyRows. -// grad_values: 1-D. The gradients from backprop. +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. // -// Returns 1-D. The backprop into values.0-D. The backprop into default_value. -func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_values tf.Output) (d_values tf.Output, d_default_value tf.Output) { +// Returns the created operation. +func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseFillEmptyRowsGrad", + Type: "ResourceScatterMin", Input: []tf.Input{ - reverse_index_map, grad_values, + resource, indices, updates, }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` -// -// if < 0, `scale * features` otherwise. -// -// To be used together with -// `initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. -// For correct dropout, use `tf.contrib.nn.alpha_dropout`. +// Elementwise computes the bitwise OR of `x` and `y`. // -// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) -func Selu(scope *Scope, features tf.Output) (activations tf.Output) { +// The result will have those bits set, that are set in `x`, `y` or both. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Selu", + Type: "BitwiseOr", Input: []tf.Input{ - features, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SetSizeAttr is an optional argument to SetSize. -type SetSizeAttr func(optionalAttr) +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) -// SetSizeValidateIndices sets the optional validate_indices attribute to value. +// MatrixSolveLsFast sets the optional fast attribute to value. // If not specified, defaults to true -func SetSizeValidateIndices(value bool) SetSizeAttr { +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["fast"] = value } } -// Number of unique elements along last dimension of input `set`. +// Solves one or more linear least-squares problems. // -// Input `set` is a `SparseTensor` represented by `set_indices`, `set_values`, -// and `set_shape`. The last dimension contains values in a set, duplicates are -// allowed but ignored. +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same +// type as `matrix` and shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations +// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` +// in the least squares sense. // -// If `validate_indices` is `True`, this op validates the order and range of `set` -// indices. +// We use the following notation for (complex) matrix and right-hand sides +// in the batch: +// +// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), +// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), +// `output`=\\(X \in \mathbb{C}^{n \times k}\\), +// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). +// If \\(m \lt n\\) then `output` is computed as +// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), +// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable +// when \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is +// sufficiently large. +// +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. // // Arguments: -// set_indices: 2D `Tensor`, indices of a `SparseTensor`. -// set_values: 1D `Tensor`, values of a `SparseTensor`. -// set_shape: 1D `Tensor`, shape of a `SparseTensor`. +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. // -// Returns For `set` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st -// `n-1` dimensions as `set`. Each value is the number of unique elements in -// the corresponding `[0...n-1]` dimension of `set`. -func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shape tf.Output, optional ...SetSizeAttr) (size tf.Output) { +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -23627,9 +28398,9 @@ func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shap a(attrs) } opspec := tf.OpSpec{ - Type: "SetSize", + Type: "MatrixSolveLs", Input: []tf.Input{ - set_indices, set_values, set_shape, + matrix, rhs, l2_regularizer, }, Attrs: attrs, } @@ -23637,264 +28408,117 @@ func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shap return op.Output(0) } -// Adds sparse `updates` to an existing tensor according to `indices`. -// -// This operation creates a new tensor by adding sparse `updates` to the passed -// in `tensor`. -// This operation is very similar to `tf.scatter_nd_add`, except that the updates -// are added onto an existing tensor (as opposed to a variable). If the memory -// for the existing tensor cannot be re-used, a copy is made and updated. -// -// `indices` is an integer tensor containing indices into a new tensor of shape -// `shape`. The last dimension of `indices` can be at most the rank of `shape`: -// -// indices.shape[-1] <= shape.rank -// -// The last dimension of `indices` corresponds to indices into elements -// (if `indices.shape[-1] = shape.rank`) or slices -// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -// `shape`. `updates` is a tensor with shape -// -// indices.shape[:-1] + shape[indices.shape[-1]:] -// -// The simplest form of tensor_scatter_add is to add individual elements to a -// tensor by index. For example, say we want to add 4 elements in a rank-1 -// tensor with 8 elements. +// Interleave the values from the `data` tensors into a single tensor. // -// In Python, this scatter add operation would look like this: +// Builds a merged tensor such that // // ```python -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// tensor = tf.ones([8], dtype=tf.int32) -// updated = tf.tensor_scatter_add(tensor, indices, updates) -// with tf.Session() as sess: -// print(sess.run(scatter)) +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] // ``` // -// The resulting tensor would look like this: -// -// [1, 12, 1, 11, 10, 1, 1, 13] -// -// We can also, insert entire slices of a higher rank tensor all at once. For -// example, if we wanted to insert two slices in the first dimension of a -// rank-3 tensor with two matrices of new values. -// -// In Python, this scatter add operation would look like this: +// For example, if each `indices[m]` is scalar or vector, we have // // ```python -// indices = tf.constant([[0], [2]]) -// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]], -// [[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]]]) -// tensor = tf.ones([4, 4, 4]) -// updated = tf.tensor_scatter_add(tensor, indices, updates) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], -// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], -// [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], -// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] -// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, the index is ignored. -// -// Arguments: -// tensor: Tensor to copy/update. -// indices: Index tensor. -// updates: Updates to scatter into output. -// -// Returns A new tensor copied from tensor and updates added according to the indices. -func TensorScatterAdd(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorScatterAdd", - Input: []tf.Input{ - tensor, indices, updates, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sign and the log of the absolute value of the determinant of -// -// one or more square matrices. -// -// The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions -// form square matrices. The outputs are two tensors containing the signs and -// absolute values of the log determinants for all N input submatrices -// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant). -// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU -// is the LU decomposition of the input and P is the corresponding -// permutation matrix. -// -// Arguments: -// input: Shape is `[N, M, M]`. -// -// Returns The signs of the log determinants of the inputs. Shape is `[N]`.The logs of the absolute values of the determinants -// of the N input matrices. Shape is `[N]`. -func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_abs_determinant tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogMatrixDeterminant", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Copy a tensor setting everything outside a central band in each innermost matrix -// -// to zero. +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] // -// The `band` part is computed as follows: -// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor with the same shape where +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` // -// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is // -// The indicator function +// merged.shape = [max(indices)] + constant // -// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && -// (num_upper < 0 || (n-m) <= num_upper)`. +// Values are merged in order, so if an index appears in both `indices[m][i]` and +// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the +// merged result. If you do not need this guarantee, ParallelDynamicStitch might +// perform better on some devices. // // For example: // -// ``` -// # if 'input' is [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [-2, -1, 0, 1] -// [-3, -2, -1, 0]], -// -// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [ 0, -1, 0, 1] -// [ 0, 0, -1, 0]], -// -// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] -// [-1, 0, 1, 0] -// [-2, -1, 0, 1] -// [ 0, -2, -1, 0]] +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] // ``` // -// Useful special cases: +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: // -// ``` -// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. -// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. -// tf.matrix_band_part(input, 0, 0) ==> Diagonal. +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. // ``` // -// Arguments: -// input: Rank `k` tensor. -// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire -// lower triangle. -// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep -// entire upper triangle. -// -// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. -func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { +//
+// +//
+func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixBandPart", + Type: "DynamicStitch", Input: []tf.Input{ - input, num_lower, num_upper, + tf.OutputList(indices), tf.OutputList(data), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Subtracts a value from the current value of a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to -// see the decremented value or a subsequent newer one. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. -// -// Returns the created operation. -func AssignSubVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignSubVariableOp", - Input: []tf.Input{ - resource, value, - }, - } - return scope.AddOperation(opspec) -} - -// RestoreAttr is an optional argument to Restore. -type RestoreAttr func(optionalAttr) - -// RestorePreferredShard sets the optional preferred_shard attribute to value. -// -// value: Index of file to open first if multiple files match -// `file_pattern`. -// If not specified, defaults to -1 -func RestorePreferredShard(value int64) RestoreAttr { - return func(m optionalAttr) { - m["preferred_shard"] = value - } -} - -// Restores a tensor from checkpoint files. -// -// Reads a tensor stored in one or several files. If there are several files (for -// instance because a tensor was saved as slices), `file_pattern` may contain -// wildcard symbols (`*` and `?`) in the filename portion only, not in the -// directory portion. -// -// If a `file_pattern` matches several files, `preferred_shard` can be used to hint -// in which file the requested tensor is likely to be found. This op will first -// open the file at index `preferred_shard` in the list of matching files and try -// to restore tensors from that file. Only if some tensors or tensor slices are -// not found in that first file, then the Op opens all the files. Setting -// `preferred_shard` to match the value passed as the `shard` input -// of a matching `Save` Op may speed up Restore. This attribute only affects -// performance, not correctness. The default value -1 means files are processed in -// order. +// Performs a padding as a preprocess during a convolution. // -// See also `RestoreSlice`. +// Similar to FusedResizeAndPadConv2d, this op allows for an optimized +// implementation where the spatial padding transformation stage is fused with the +// im2col lookup, but in this case without the bilinear filtering required for +// resizing. Fusing the padding prevents the need to write out the intermediate +// results as whole tensors, reducing memory pressure, and we can get some latency +// gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' +// order is used instead. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. // // Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// dt: The type of the tensor to be restored. +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. // -// Returns The restored tensor. -func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf.DataType, optional ...RestoreAttr) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dt": dt} - for _, a := range optional { - a(attrs) +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "Restore", + Type: "FusedPadConv2D", Input: []tf.Input{ - file_pattern, tensor_name, + input, paddings, filter, }, Attrs: attrs, } @@ -23902,68 +28526,110 @@ func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf. return op.Output(0) } -// QuantizedResizeBilinearAttr is an optional argument to QuantizedResizeBilinear. -type QuantizedResizeBilinearAttr func(optionalAttr) +// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. +type Conv2DBackpropInputAttr func(optionalAttr) -// QuantizedResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func QuantizedResizeBilinearAlignCorners(value bool) QuantizedResizeBilinearAttr { +// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["use_cudnn_on_gpu"] = value } } -// Resize quantized `images` to `size` using quantized bilinear interpolation. +// Conv2DBackpropInputExplicitPaddings sets the optional explicit_paddings attribute to value. // -// Input images and output images must be quantized types. +// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith +// dimension, the amount of padding inserted before and after the dimension is +// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If +// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// If not specified, defaults to <> +func Conv2DBackpropInputExplicitPaddings(value []int64) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["explicit_paddings"] = value + } +} + +// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. // -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropInputDilations sets the optional dilations attribute to value. // +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the input. // +// Arguments: +// input_sizes: An integer vector representing the shape of `input`, +// where `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min tf.Output, max tf.Output, optional ...QuantizedResizeBilinearAttr) (resized_images tf.Output, out_min tf.Output, out_max tf.Output) { +// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient +// w.r.t. the input of the convolution. +func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedResizeBilinear", + Type: "Conv2DBackpropInput", Input: []tf.Input{ - images, size, min, max, + input_sizes, filter, out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Creates a dataset that uses a custom thread pool to compute `input_dataset`. +// Creates a dataset that executes a SQL query and emits rows of the result set. // // Arguments: -// -// num_threads: Identifies the number of threads to use for the private threadpool. +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. // // -func ExperimentalPrivateThreadPoolDataset(scope *Scope, input_dataset tf.Output, num_threads tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +func ExperimentalSqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalPrivateThreadPoolDataset", + Type: "ExperimentalSqlDataset", Input: []tf.Input{ - input_dataset, num_threads, + driver_name, data_source_name, query, }, Attrs: attrs, } @@ -23971,209 +28637,213 @@ func ExperimentalPrivateThreadPoolDataset(scope *Scope, input_dataset tf.Output, return op.Output(0) } -// ExperimentalParseExampleDatasetAttr is an optional argument to ExperimentalParseExampleDataset. -type ExperimentalParseExampleDatasetAttr func(optionalAttr) +// LoadTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingCenteredRMSPropParameters. +type LoadTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) -// ExperimentalParseExampleDatasetSloppy sets the optional sloppy attribute to value. -// If not specified, defaults to false -func ExperimentalParseExampleDatasetSloppy(value bool) ExperimentalParseExampleDatasetAttr { +// LoadTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingCenteredRMSPropParametersTableId(value int64) LoadTPUEmbeddingCenteredRMSPropParametersAttr { return func(m optionalAttr) { - m["sloppy"] = value + m["table_id"] = value } } -// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features. +// LoadTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingCenteredRMSPropParametersTableName(value string) LoadTPUEmbeddingCenteredRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load centered RMSProp embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: +// parameters: Value of parameters used in the centered RMSProp optimization algorithm. +// ms: Value of ms used in the centered RMSProp optimization algorithm. +// mom: Value of mom used in the centered RMSProp optimization algorithm. +// mg: Value of mg used in the centered RMSProp optimization algorithm. // // -// dense_defaults: A dict mapping string keys to `Tensor`s. -// The keys of the dict must match the dense_keys of the feature. -// sparse_keys: A list of string keys in the examples features. -// The results for these keys will be returned as `SparseTensor` objects. -// dense_keys: A list of Ndense string Tensors (scalars). -// The keys expected in the Examples features associated with dense values. -// sparse_types: A list of `DTypes` of the same length as `sparse_keys`. -// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), -// and `tf.string` (`BytesList`) are supported. -// dense_shapes: List of tuples with the same length as `dense_keys`. -// The shape of the data for each dense feature referenced by `dense_keys`. -// Required for any input tensors identified by `dense_keys`. Must be -// either fully defined, or may contain an unknown first dimension. -// An unknown first dimension means the feature is treated as having -// a variable number of blocks, and the output shape along this dimension -// is considered unknown at graph build time. Padding is applied for -// minibatch elements smaller than the maximum number of blocks for the -// given feature along this dimension. -// output_types: The type list for the return values. -// output_shapes: The list of shapes being produced. -func ExperimentalParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ExperimentalParseExampleDatasetAttr) (handle tf.Output) { +// +// Returns the created operation. +func LoadTPUEmbeddingCenteredRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingCenteredRMSPropParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ExperimentalParseExampleDataset", + Type: "LoadTPUEmbeddingCenteredRMSPropParameters", Input: []tf.Input{ - input_dataset, num_parallel_calls, tf.OutputList(dense_defaults), + parameters, ms, mom, mg, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. -type SdcaOptimizerAttr func(optionalAttr) +// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. +type DataFormatVecPermuteAttr func(optionalAttr) -// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. +// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. // -// value: Whether to use Adaptive SDCA for the inner loop. -// If not specified, defaults to true -func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { return func(m optionalAttr) { - m["adaptative"] = value + m["src_format"] = value } } -// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for -// -// linear models with L1 + L2 regularization. As global optimization objective is -// strongly-convex, the optimizer optimizes the dual objective at each step. The -// optimizer applies each update one example at a time. Examples are sampled -// uniformly, and the optimizer is learning rate free and enjoys linear convergence -// rate. -// -// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
-// Shai Shalev-Shwartz, Tong Zhang. 2012 -// -// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ +// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. // -// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
-// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, -// Peter Richtarik, Martin Takac. 2015 +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { + return func(m optionalAttr) { + m["dst_format"] = value + } +} + +// Returns the permuted vector/tensor in the destination data format given the // -// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
-// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 +// one in the source data format. // // Arguments: -// sparse_example_indices: a list of vectors which contain example indices. -// sparse_feature_indices: a list of vectors which contain feature indices. -// sparse_feature_values: a list of vectors which contains feature value -// associated with each feature group. -// dense_features: a list of matrices which contains the dense feature values. -// example_weights: a vector which contains the weight associated with each -// example. -// example_labels: a vector which contains the label/target associated with each -// example. -// sparse_indices: a list of vectors where each value is the indices which has -// corresponding weights in sparse_weights. This field maybe omitted for the -// dense approach. -// sparse_weights: a list of vectors where each value is the weight associated with -// a sparse feature group. -// dense_weights: a list of vectors where the values are the weights associated -// with a dense feature group. -// example_state_data: a list of vectors containing the example state data. -// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, -// squared and hinge losses. -// l1: Symmetric l1 regularization strength. -// l2: Symmetric l2 regularization strength. -// num_loss_partitions: Number of partitions of the global loss function. -// num_inner_iterations: Number of iterations per mini-batch. +// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. // -// Returns a list of vectors containing the updated example state -// data.a list of vectors where each value is the delta -// weights associated with a sparse feature group.a list of vectors where the values are the delta -// weights associated with a dense feature group. -func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { +// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. +func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SdcaOptimizer", + Type: "DataFormatVecPermute", Input: []tf.Input{ - tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, + x, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x / y element-wise. +// +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Div(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - out_example_state_data = op.Output(idx) - if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return + opspec := tf.OpSpec{ + Type: "Div", + Input: []tf.Input{ + x, y, + }, } - if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams. +type CudnnRNNCanonicalToParamsAttr func(optionalAttr) + +// CudnnRNNCanonicalToParamsRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNCanonicalToParamsRnnMode(value string) CudnnRNNCanonicalToParamsAttr { + return func(m optionalAttr) { + m["rnn_mode"] = value } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights } -// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve. -type MatrixTriangularSolveAttr func(optionalAttr) +// CudnnRNNCanonicalToParamsInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNCanonicalToParamsInputMode(value string) CudnnRNNCanonicalToParamsAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} -// MatrixTriangularSolveLower sets the optional lower attribute to value. -// -// value: Boolean indicating whether the innermost matrices in `matrix` are -// lower or upper triangular. -// If not specified, defaults to true -func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { +// CudnnRNNCanonicalToParamsDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNCanonicalToParamsDirection(value string) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { - m["lower"] = value + m["direction"] = value } } -// MatrixTriangularSolveAdjoint sets the optional adjoint attribute to value. -// -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. -// -// @compatibility(numpy) -// Equivalent to scipy.linalg.solve_triangular -// @end_compatibility -// If not specified, defaults to false -func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { +// CudnnRNNCanonicalToParamsDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNCanonicalToParamsDropout(value float32) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { - m["adjoint"] = value + m["dropout"] = value } } -// Solves systems of linear equations with upper or lower triangular matrices by -// -// backsubstitution. -// -// `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form -// square matrices. If `lower` is `True` then the strictly upper triangular part -// of each inner-most matrix is assumed to be zero and not accessed. -// If `lower` is False then the strictly lower triangular part of each inner-most -// matrix is assumed to be zero and not accessed. -// `rhs` is a tensor of shape `[..., M, K]`. +// CudnnRNNCanonicalToParamsSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNCanonicalToParamsSeed(value int64) CudnnRNNCanonicalToParamsAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNCanonicalToParamsSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNCanonicalToParamsSeed2(value int64) CudnnRNNCanonicalToParamsAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Converts CudnnRNN params from canonical form to usable form. // -// The output is a tensor of shape `[..., M, K]`. If `adjoint` is -// `True` then the innermost matrices in `output` satisfy matrix equations -// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `False` then the strictly then the innermost matrices in -// `output` satisfy matrix equations -// `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. +// Writes a set of weights into the opaque params buffer so they can be used in +// upcoming training or inferences. // -// Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. +// Note that the params buffer may not be compatible across different GPUs. So any +// save and restoration should be converted to and from the canonical weights and +// biases. // -// Returns Shape is `[..., M, K]`. -func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixTriangularSolveAttr) (output tf.Output) { +// num_layers: Specifies the number of layers in the RNN model. +// num_units: Specifies the size of the hidden state. +// input_size: Specifies the size of the input state. +// weights: the canonical form of weights that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// biases: the canonical form of biases that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// num_params: number of parameter sets for all layers. +// Each layer may contain multiple parameter sets, with each set consisting of +// a weight matrix and a bias vector. +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +func CudnnRNNCanonicalToParams(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, weights []tf.Output, biases []tf.Output, optional ...CudnnRNNCanonicalToParamsAttr) (params tf.Output) { if scope.Err() != nil { return } @@ -24182,9 +28852,9 @@ func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, option a(attrs) } opspec := tf.OpSpec{ - Type: "MatrixTriangularSolve", + Type: "CudnnRNNCanonicalToParams", Input: []tf.Input{ - matrix, rhs, + num_layers, num_units, input_size, tf.OutputList(weights), tf.OutputList(biases), }, Attrs: attrs, } @@ -24192,129 +28862,83 @@ func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, option return op.Output(0) } -// Saves tensors in V2 checkpoint format. -// -// By default, saves the named tensors in full. If the caller wishes to save -// specific slices of full tensors, "shape_and_slices" should be non-empty strings -// and correspondingly well-formed. -// -// Arguments: -// prefix: Must have a single element. The prefix of the V2 checkpoint to which we -// write the tensors. -// tensor_names: shape {N}. The names of the tensors to be saved. -// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. -// Empty strings indicate that they are non-partitioned tensors. -// tensors: `N` tensors to save. -// -// Returns the created operation. -func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { +// Creates a dataset containing elements of first component of `input_dataset` having true in the last component. +func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SaveV2", + Type: "FilterByLastComponentDataset", Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), + input_dataset, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// UnicodeTranscodeAttr is an optional argument to UnicodeTranscode. -type UnicodeTranscodeAttr func(optionalAttr) - -// UnicodeTranscodeErrors sets the optional errors attribute to value. +// Computes the absolute value of a tensor. // -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeTranscodeErrors(value string) UnicodeTranscodeAttr { - return func(m optionalAttr) { - m["errors"] = value +// Given a tensor `x`, this operation returns a tensor containing the absolute +// value of each element in `x`. For example, if x is an input element and y is +// an output element, this operation computes \\(y = |x|\\). +func Abs(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return } -} - -// UnicodeTranscodeReplacementChar sets the optional replacement_char attribute to value. -// -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD or U+65533.) -// -// Note that for UTF-8, passing a replacement character expressible in 1 byte, such -// as ' ', will preserve string alignment to the source since invalid bytes will be -// replaced with a 1-byte replacement. For UTF-16-BE and UTF-16-LE, any 1 or 2 byte -// replacement character will preserve byte alignment to the source. -// If not specified, defaults to 65533 -func UnicodeTranscodeReplacementChar(value int64) UnicodeTranscodeAttr { - return func(m optionalAttr) { - m["replacement_char"] = value + opspec := tf.OpSpec{ + Type: "Abs", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// UnicodeTranscodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. +type MaxPoolGradV2Attr func(optionalAttr) + +// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. // -// value: Whether to replace the C0 control characters (00-1F) with the -// `replacement_char`. Default is false. -// If not specified, defaults to false -func UnicodeTranscodeReplaceControlCharacters(value bool) UnicodeTranscodeAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { return func(m optionalAttr) { - m["replace_control_characters"] = value + m["data_format"] = value } } -// Transcode the input text from a source encoding to a destination encoding. -// -// The input is a string tensor of any shape. The output is a string tensor of -// the same shape containing the transcoded strings. Output strings are always -// valid unicode. If the input contains invalid encoding positions, the -// `errors` attribute sets the policy for how to deal with them. If the default -// error-handling policy is used, invalid formatting will be substituted in the -// output by the `replacement_char`. If the errors policy is to `ignore`, any -// invalid encoding positions in the input are skipped and not included in the -// output. If it set to `strict` then any invalid formatting will result in an -// InvalidArgument error. -// -// This operation can be used with `output_encoding = input_encoding` to enforce -// correct formatting for inputs even if they are already in the desired encoding. -// -// If the input is prefixed by a Byte Order Mark needed to determine encoding -// (e.g. if the encoding is UTF-16 and the BOM indicates big-endian), then that -// BOM will be consumed and not emitted into the output. If the input encoding -// is marked with an explicit endianness (e.g. UTF-16-BE), then the BOM is -// interpreted as a non-breaking-space and is preserved in the output (including -// always for UTF-8). -// -// The end result is that if the input is marked as an explicit endianness the -// transcoding is faithful to all codepoints in the source. If it is not marked -// with an explicit endianness, the BOM is not considered part of the string itself -// but as metadata, and so is not preserved in the output. +// Computes gradients of the maxpooling function. // // Arguments: -// input: The text to be processed. Can have any shape. -// input_encoding: Text encoding of the input strings. This is any of the encodings supported -// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. -// output_encoding: The unicode encoding to use in the output. Must be one of -// `"UTF-8", "UTF-16-BE", "UTF-32-BE"`. Multi-byte encodings will be big-endian. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients w.r.t. the output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns A string tensor containing unicode text encoded using `output_encoding`. -func UnicodeTranscode(scope *Scope, input tf.Output, input_encoding string, output_encoding string, optional ...UnicodeTranscodeAttr) (output tf.Output) { +// Returns Gradients w.r.t. the input to `max_pool`. +func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"input_encoding": input_encoding, "output_encoding": output_encoding} + attrs := map[string]interface{}{"padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "UnicodeTranscode", + Type: "MaxPoolGradV2", Input: []tf.Input{ - input, + orig_input, orig_output, grad, ksize, strides, }, Attrs: attrs, } @@ -24322,72 +28946,78 @@ func UnicodeTranscode(scope *Scope, input tf.Output, input_encoding string, outp return op.Output(0) } -// Computes inverse hyperbolic sine of x element-wise. -func Asinh(scope *Scope, x tf.Output) (y tf.Output) { +// Restore a reader to a previously saved state. +// +// Not all Readers support being restored, so this can produce an +// Unimplemented error. +// +// Arguments: +// reader_handle: Handle to a Reader. +// state: Result of a ReaderSerializeState of a Reader with type +// matching reader_handle. +// +// Returns the created operation. +func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Asinh", + Type: "ReaderRestoreStateV2", Input: []tf.Input{ - x, + reader_handle, state, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset with a range of values. Corresponds to python's xrange. +// Inverse fast Fourier transform. +// +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. // // Arguments: -// start: corresponds to start in python's xrange(). -// stop: corresponds to stop in python's xrange(). -// step: corresponds to step in python's xrange(). +// input: A complex tensor. // +// Returns A complex tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its inverse 1D Fourier transform. // -func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// @compatibility(numpy) +// Equivalent to np.fft.ifft +// @end_compatibility +func IFFT(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "RangeDataset", + Type: "IFFT", Input: []tf.Input{ - start, stop, step, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Stops gradient computation. +// 2D fast Fourier transform. // -// When executed in a graph, this op outputs its input tensor as-is. +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. // -// When building ops to compute gradients, this op prevents the contribution of -// its inputs to be taken into account. Normally, the gradient generator adds ops -// to a graph to compute the derivatives of a specified 'loss' by recursively -// finding out inputs that contributed to its computation. If you insert this op -// in the graph it inputs are masked from the gradient generator. They are not -// taken into account for computing gradients. +// Arguments: +// input: A complex tensor. // -// This is useful any time you want to compute a value with TensorFlow but need -// to pretend that the value was a constant. Some examples include: +// Returns A complex tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. // -// * The *EM* algorithm where the *M-step* should not involve backpropagation -// through the output of the *E-step*. -// * Contrastive divergence training of Boltzmann machines where, when -// differentiating the energy function, the training must not backpropagate -// through the graph that generated the samples from the model. -// * Adversarial training, where no backprop should happen through the adversarial -// example generation process. -func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "StopGradient", + Type: "FFT2D", Input: []tf.Input{ input, }, @@ -24396,84 +29026,86 @@ func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Eagerly executes a python function to compute func(input)->output. The +// Inverse 2D fast Fourier transform. // -// semantics of the input, output, and attributes are the same as those for -// PyFunc. -func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) { +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft2 +// @end_compatibility +func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"token": token, "Tout": Tout} opspec := tf.OpSpec{ - Type: "EagerPyFunc", + Type: "IFFT2D", Input: []tf.Input{ - tf.OutputList(input), + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("EagerPyFunc", err) - return - } - return output + return op.Output(0) } -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. +// Inverse 3D real-valued fast Fourier transform. // -// More formally, let +// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 3 dimensions of `input`. // -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, +// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 3 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. // -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// Along each axis `IRFFT3D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. // -// Returns Computed Precision at `k` as a `bool Tensor`. -func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { +// Returns A float32 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 3D real Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.irfftn with 3 dimensions. +// @end_compatibility +func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"k": k} opspec := tf.OpSpec{ - Type: "InTopK", + Type: "IRFFT3D", Input: []tf.Input{ - predictions, targets, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns (x - y)(x - y) element-wise. +// Returns the truth value of (x != y) element-wise. // -// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SquaredDifference", + Type: "NotEqual", Input: []tf.Input{ x, y, }, @@ -24482,144 +29114,84 @@ func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// RandomGammaAttr is an optional argument to RandomGamma. -type RandomGammaAttr func(optionalAttr) +// LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingMomentumParametersGradAccumDebug. +type LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) -// RandomGammaSeed sets the optional seed attribute to value. +// LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomGammaSeed(value int64) RandomGammaAttr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["seed"] = value + m["table_id"] = value } } -// RandomGammaSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomGammaSeed2(value int64) RandomGammaAttr { +// LoadTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["seed2"] = value + m["table_name"] = value } } -// Outputs random values from the Gamma distribution(s) described by alpha. +// Load Momentum embedding parameters with debug support. // -// This op uses the algorithm by Marsaglia et al. to acquire samples via -// transformation-rejection from pairs of uniform and normal random variables. -// See http://dl.acm.org/citation.cfm?id=358414 +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in alpha. -// alpha: A tensor in which each scalar is a "shape" parameter describing the -// associated gamma distribution. +// parameters: Value of parameters used in the Momentum optimization algorithm. +// momenta: Value of momenta used in the Momentum optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Momentum optimization algorithm. // -// Returns A tensor with shape `shape + shape(alpha)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. -func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { +// +// +// Returns the created operation. +func LoadTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomGamma", + Type: "LoadTPUEmbeddingMomentumParametersGradAccumDebug", Input: []tf.Input{ - shape, alpha, + parameters, momenta, gradient_accumulators, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Convert the quantized 'input' tensor into a lower-precision 'output', using the -// -// actual distribution of the values to maximize the usage of the lower bit depth -// and adjusting the output min and max ranges accordingly. -// -// [input_min, input_max] are scalar floats that specify the range for the float -// interpretation of the 'input' data. For example, if input_min is -1.0f and -// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 -// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. -// -// This operator tries to squeeze as much precision as possible into an output with -// a lower bit depth by calculating the actual min and max values found in the -// data. For example, maybe that quint16 input has no values lower than 16,384 and -// none higher than 49,152. That means only half the range is actually needed, all -// the float interpretations are between -0.5f and 0.5f, so if we want to compress -// the data into a quint8 output, we can use that range rather than the theoretical -// -1.0f to 1.0f that is suggested by the input min and max. -// -// In practice, this is most useful for taking output from operations like -// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and -// may have large potential output ranges, but in practice have a distribution of -// input values that only uses a small fraction of the possible range. By feeding -// that output into this operator, we can reduce it from 32 bits down to 8 with -// minimal loss of accuracy. -// -// Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// out_type: The type of the output. Should be a lower bit depth than Tinput. -// -// Returns The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"out_type": out_type} - opspec := tf.OpSpec{ - Type: "QuantizeDownAndShrinkRange", - Input: []tf.Input{ - input, input_min, input_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} +// StatefulStandardNormalAttr is an optional argument to StatefulStandardNormal. +type StatefulStandardNormalAttr func(optionalAttr) -// Returns the truth value of (x >= y) element-wise. +// StatefulStandardNormalDtype sets the optional dtype attribute to value. // -// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func GreaterEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "GreaterEqual", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ApproximateEqualAttr is an optional argument to ApproximateEqual. -type ApproximateEqualAttr func(optionalAttr) - -// ApproximateEqualTolerance sets the optional tolerance attribute to value. -// If not specified, defaults to 1e-05 -func ApproximateEqualTolerance(value float32) ApproximateEqualAttr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatefulStandardNormalDtype(value tf.DataType) StatefulStandardNormalAttr { return func(m optionalAttr) { - m["tolerance"] = value + m["dtype"] = value } } -// Returns the truth value of abs(x-y) < tolerance element-wise. -func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...ApproximateEqualAttr) (z tf.Output) { +// Outputs random values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// shape: The shape of the output tensor. +// +// Returns A tensor of the specified shape filled with random normal values. +func StatefulStandardNormal(scope *Scope, resource tf.Output, shape tf.Output, optional ...StatefulStandardNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -24628,9 +29200,9 @@ func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...Approx a(attrs) } opspec := tf.OpSpec{ - Type: "ApproximateEqual", + Type: "StatefulStandardNormal", Input: []tf.Input{ - x, y, + resource, shape, }, Attrs: attrs, } @@ -24638,374 +29210,187 @@ func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...Approx return op.Output(0) } -// Returns the truth value of x OR y element-wise. -// -// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LogicalOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Computes the Gauss error function of `x` element-wise. +func Erf(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LogicalOr", + Type: "Erf", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Selects elements from `x` or `y`, depending on `condition`. -// -// The `x`, and `y` tensors must all have the same shape, and the -// output will also have that shape. -// -// The `condition` tensor must be a scalar if `x` and `y` are scalars. -// If `x` and `y` are vectors or higher rank, then `condition` must be either a -// scalar, a vector with size matching the first dimension of `x`, or must have -// the same shape as `x`. -// -// The `condition` tensor acts as a mask that chooses, based on the value at each -// element, whether the corresponding element / row in the output should be -// taken from `x` (if true) or `y` (if false). -// -// If `condition` is a vector and `x` and `y` are higher rank matrices, then -// it chooses which row (outer dimension) to copy from `x` and `y`. -// If `condition` has the same shape as `x` and `y`, then it chooses which -// element to copy from `x` and `y`. -// -// For example: -// -// ```python -// # 'condition' tensor is [[True, False] -// # [False, True]] -// # 't' is [[1, 2], -// # [3, 4]] -// # 'e' is [[5, 6], -// # [7, 8]] -// select(condition, t, e) # => [[1, 6], [7, 4]] -// -// -// # 'condition' tensor is [True, False] -// # 't' is [[1, 2], -// # [3, 4]] -// # 'e' is [[5, 6], -// # [7, 8]] -// select(condition, t, e) ==> [[1, 2], -// [7, 8]] -// -// ``` -// -// Arguments: -// -// x: = A `Tensor` which may have the same shape as `condition`. -// If `condition` is rank 1, `x` may have higher rank, -// but its first dimension must match the size of `condition`. -// y: = A `Tensor` with the same type and shape as `x`. -// -// Returns = A `Tensor` with the same type and shape as `x` and `y`. -func Select(scope *Scope, condition tf.Output, x tf.Output, y tf.Output) (output tf.Output) { +// Returns element-wise largest integer not greater than x. +func Floor(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Select", + Type: "Floor", Input: []tf.Input{ - condition, x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatMulAttr is an optional argument to MatMul. -type MatMulAttr func(optionalAttr) - -// MatMulTransposeA sets the optional transpose_a attribute to value. -// -// value: If true, "a" is transposed before multiplication. -// If not specified, defaults to false -func MatMulTransposeA(value bool) MatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// MatMulTransposeB sets the optional transpose_b attribute to value. -// -// value: If true, "b" is transposed before multiplication. -// If not specified, defaults to false -func MatMulTransposeB(value bool) MatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// Multiply the matrix "a" by the matrix "b". -// -// The inputs must be two-dimensional matrices and the inner dimension of -// "a" (after being transposed if transpose_a is true) must match the -// outer dimension of "b" (after being transposed if transposed_b is -// true). +// Returns the number of records this Reader has produced. // -// *Note*: The default kernel implementation for MatMul on GPUs uses -// cublas. -func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (product tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatMul", - Input: []tf.Input{ - a, b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Serializes the tree handle to a proto +// This is the same as the number of ReaderRead executions that have +// succeeded. // // Arguments: -// tree_handle: Handle to the tree resource to be serialized. -// -// Returns Serialied proto string of the tree resource. -func TensorForestTreeSerialize(scope *Scope, tree_handle tf.Output) (tree_config tf.Output) { +// reader_handle: Handle to a Reader. +func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorForestTreeSerialize", + Type: "ReaderNumRecordsProducedV2", Input: []tf.Input{ - tree_handle, + reader_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseMatMulAttr is an optional argument to SparseMatMul. -type SparseMatMulAttr func(optionalAttr) - -// SparseMatMulTransposeA sets the optional transpose_a attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeA(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// SparseMatMulTransposeB sets the optional transpose_b attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeB(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["a_is_sparse"] = value - } -} +// TensorListConcatAttr is an optional argument to TensorListConcat. +type TensorListConcatAttr func(optionalAttr) -// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { +// TensorListConcatElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorListConcatElementShape(value tf.Shape) TensorListConcatAttr { return func(m optionalAttr) { - m["b_is_sparse"] = value + m["element_shape"] = value } } -// Multiply matrix "a" by matrix "b". +// Concats all tensors in the list along the 0th dimension. // -// The inputs must be two-dimensional matrices and the inner dimension of "a" must -// match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not -// `SparseTensor`s. This op is optimized for the case where at least one of "a" or -// "b" is sparse, in the sense that they have a large proportion of zero values. -// The breakeven for using this versus a dense matrix multiply on one platform was -// 30% zero values in the sparse matrix. +// Requires that all tensors have the same shape except the first dimension. // -// The gradient computation of this operation will only take advantage of sparsity -// in the input gradient when that gradient comes from a Relu. -func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { +// input_handle: The input list. +// tensor: The concated result. +// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. +// +func TensorListConcat(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListConcatAttr) (tensor tf.Output, lengths tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"element_dtype": element_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SparseMatMul", + Type: "TensorListConcat", Input: []tf.Input{ - a, b, + input_handle, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// ExperimentalThreadPoolHandleAttr is an optional argument to ExperimentalThreadPoolHandle. -type ExperimentalThreadPoolHandleAttr func(optionalAttr) - -// ExperimentalThreadPoolHandleMaxIntraOpParallelism sets the optional max_intra_op_parallelism attribute to value. +// Elementwise computes the bitwise AND of `x` and `y`. // -// value: The maximum degree of parallelism to use within operations that execute on this -// threadpool. -// If not specified, defaults to 1 -func ExperimentalThreadPoolHandleMaxIntraOpParallelism(value int64) ExperimentalThreadPoolHandleAttr { - return func(m optionalAttr) { - m["max_intra_op_parallelism"] = value +// The result will have those bits set, that are set in both `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return } -} - -// ExperimentalThreadPoolHandleContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func ExperimentalThreadPoolHandleContainer(value string) ExperimentalThreadPoolHandleAttr { - return func(m optionalAttr) { - m["container"] = value + opspec := tf.OpSpec{ + Type: "BitwiseAnd", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// ExperimentalThreadPoolHandleSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func ExperimentalThreadPoolHandleSharedName(value string) ExperimentalThreadPoolHandleAttr { +// ResizeAreaAttr is an optional argument to ResizeArea. +type ResizeAreaAttr func(optionalAttr) + +// ResizeAreaAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["align_corners"] = value } } -// Creates a dataset that uses a custom thread pool to compute `input_dataset`. +// Resize `images` to `size` using area interpolation. +// +// Input images can be of different types but output images are always float. +// +// The range of pixel values for the output image might be slightly different +// from the range for the input image because of limited numerical precision. +// To guarantee an output range, for example `[0.0, 1.0]`, apply +// `tf.clip_by_value` to the output. +// +// Each output pixel is computed by first transforming the pixel's footprint into +// the input tensor and then averaging the pixels that intersect the footprint. An +// input pixel's contribution to the average is weighted by the fraction of its +// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. // // Arguments: -// num_threads: The number of threads in the thread pool. -// display_name: A human-readable name for the threads that may be visible in some -// visualizations. -// threadpool. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns A resource that can be consumed by one or more ExperimentalThreadPoolDataset -// ops. -func ExperimentalThreadPoolHandle(scope *Scope, num_threads int64, display_name string, optional ...ExperimentalThreadPoolHandleAttr) (handle tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_threads": num_threads, "display_name": display_name} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ExperimentalThreadPoolHandle", - + Type: "ResizeArea", + Input: []tf.Input{ + images, size, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams. -type CudnnRNNCanonicalToParamsAttr func(optionalAttr) - -// CudnnRNNCanonicalToParamsRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNCanonicalToParamsRnnMode(value string) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNCanonicalToParamsInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNCanonicalToParamsInputMode(value string) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNCanonicalToParamsDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNCanonicalToParamsDirection(value string) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} - -// CudnnRNNCanonicalToParamsDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNCanonicalToParamsDropout(value float32) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNCanonicalToParamsSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNCanonicalToParamsSeed(value int64) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// CudnnRNNCanonicalToParamsSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNCanonicalToParamsSeed2(value int64) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Converts CudnnRNN params from canonical form to usable form. +// Sends `input` to all devices that are connected to the output. // -// Writes a set of weights into the opaque params buffer so they can be used in -// upcoming training or inferences. +// Sends `input` to all devices that are connected to the output. // -// Note that the params buffer may not be compatible across different GPUs. So any -// save and restoration should be converted to and from the canonical weights and -// biases. +// The graph should be constructed so that all ops connected to the output have a +// valid device assignment, and the op itself is assigned one of these devices. // -// num_layers: Specifies the number of layers in the RNN model. -// num_units: Specifies the size of the hidden state. -// input_size: Specifies the size of the input state. -// weights: the canonical form of weights that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// biases: the canonical form of biases that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// num_params: number of parameter sets for all layers. -// Each layer may contain multiple parameter sets, with each set consisting of -// a weight matrix and a bias vector. -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -func CudnnRNNCanonicalToParams(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, weights []tf.Output, biases []tf.Output, optional ...CudnnRNNCanonicalToParamsAttr) (params tf.Output) { +// input: The input to the broadcast. +// output: The same as input. +// shape: The shape of the input tensor. +// +func NcclBroadcast(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "CudnnRNNCanonicalToParams", + Type: "NcclBroadcast", Input: []tf.Input{ - num_layers, num_units, input_size, tf.OutputList(weights), tf.OutputList(biases), + input, }, Attrs: attrs, } @@ -25013,16 +29398,28 @@ func CudnnRNNCanonicalToParams(scope *Scope, num_layers tf.Output, num_units tf. return op.Output(0) } -// Creates a dataset containing elements of first component of `input_dataset` having true in the last component. -func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) { +// Computes the gradient of morphological 2-D dilation with respect to the filter. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. +// +// Returns 3-D with shape `[filter_height, filter_width, depth]`. +func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "FilterByLastComponentDataset", + Type: "Dilation2DBackpropFilter", Input: []tf.Input{ - input_dataset, + input, filter, out_backprop, }, Attrs: attrs, } @@ -25030,33 +29427,55 @@ func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_ return op.Output(0) } -// SumAttr is an optional argument to Sum. -type SumAttr func(optionalAttr) +// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. +type AddSparseToTensorsMapAttr func(optionalAttr) -// SumKeepDims sets the optional keep_dims attribute to value. +// AddSparseToTensorsMapContainer sets the optional container attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SumKeepDims(value bool) SumAttr { +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["container"] = value } } -// Computes the sum of elements across dimensions of a tensor. +// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. +// +// A `SparseTensor` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`. +// +// This operator takes the given `SparseTensor` and adds it to a container +// object (a `SparseTensorsMap`). A unique key within this container is generated +// in the form of an `int64`, and this is the value that is returned. +// +// The `SparseTensor` can then be read out as part of a minibatch by passing +// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddSparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. // -// Returns The reduced tensor. -func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (output tf.Output) { +// Returns 0-D. The handle of the `SparseTensor` now stored in the +// `SparseTensorsMap`. +func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { if scope.Err() != nil { return } @@ -25065,9 +29484,9 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou a(attrs) } opspec := tf.OpSpec{ - Type: "Sum", + Type: "AddSparseToTensorsMap", Input: []tf.Input{ - input, axis, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } @@ -25075,6 +29494,27 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou return op.Output(0) } +// Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. +// +// tensor: The tensor to put on the list. +// input_handle: The old list. +// output_handle: A list with the elements of the old list followed by tensor. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func TensorListPushBack(scope *Scope, input_handle tf.Output, tensor tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListPushBack", + Input: []tf.Input{ + input_handle, tensor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // EnterAttr is an optional argument to Enter. type EnterAttr func(optionalAttr) @@ -25130,24 +29570,6 @@ func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAtt return op.Output(0) } -// Add all input tensors element wise. -// -// Arguments: -// inputs: Must all be the same size and shape. -func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AddN", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // TryRpcAttr is an optional argument to TryRpc. type TryRpcAttr func(optionalAttr) @@ -25270,259 +29692,47 @@ func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output return op.Output(0), op.Output(1), op.Output(2) } -// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. -type InitializeTableFromTextFileV2Attr func(optionalAttr) - -// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. -// -// value: Number of elements of the file, use -1 if unknown. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["vocab_size"] = value - } -} - -// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. -// -// value: Delimiter to separate fields in a line. -// If not specified, defaults to "\t" -func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["delimiter"] = value - } -} - -// Initializes a table from a text file. -// -// It inserts one key-value pair into the table for each line of the file. -// The key and value is extracted from the whole line content, elements from the -// split line based on `delimiter` or the line number (starting from zero). -// Where to extract the key and value from a line is specified by `key_index` and -// `value_index`. -// -// - A value of -1 means use the line number(starting from zero), expects `int64`. -// - A value of -2 means use the whole line content, expects `string`. -// - A value >= 0 means use the index (starting at zero) of the split line based -// on `delimiter`. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// filename: Filename of a vocabulary text file. -// key_index: Column index in a line to get the table `key` values from. -// value_index: Column index that represents information of a line to get the table -// `value` values from. -// -// Returns the created operation. -func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "InitializeTableFromTextFileV2", - Input: []tf.Input{ - table_handle, filename, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MeanAttr is an optional argument to Mean. -type MeanAttr func(optionalAttr) - -// MeanKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MeanKeepDims(value bool) MeanAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the mean of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Mean", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ProdAttr is an optional argument to Prod. -type ProdAttr func(optionalAttr) - -// ProdKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func ProdKeepDims(value bool) ProdAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the product of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// Add all input tensors element wise. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) { +// inputs: Must all be the same size and shape. +func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Prod", + Type: "AddN", Input: []tf.Input{ - input, axis, + tf.OutputList(inputs), }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) - -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize `images` to `size` using bilinear interpolation. -// -// Input images can be of different types but output images are always float. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// Returns the element-wise sum of a list of tensors. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinear", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxAttr is an optional argument to Max. -type MaxAttr func(optionalAttr) - -// MaxKeepDims sets the optional keep_dims attribute to value. +// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not +// wait for all of its inputs to be ready before beginning to sum. This can +// save memory if inputs are ready at different times, since minimum temporary +// storage is proportional to the output size rather than the inputs size. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MaxKeepDims(value bool) MaxAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the maximum of elements across dimensions of a tensor. +// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// Returns a `Tensor` of same shape and type as the elements of `inputs`. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Max", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that contains the unique elements of `input_dataset`. -func ExperimentalUniqueDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// inputs: A list of `Tensor` objects, each with same shape and type. +// shape: Shape of elements of `inputs`. +func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "ExperimentalUniqueDataset", + Type: "AccumulateNV2", Input: []tf.Input{ - input_dataset, + tf.OutputList(inputs), }, Attrs: attrs, } @@ -25530,476 +29740,323 @@ func ExperimentalUniqueDataset(scope *Scope, input_dataset tf.Output, output_typ return op.Output(0) } -// ArgMinAttr is an optional argument to ArgMin. -type ArgMinAttr func(optionalAttr) - -// ArgMinOutputType sets the optional output_type attribute to value. -// If not specified, defaults to DT_INT64 -func ArgMinOutputType(value tf.DataType) ArgMinAttr { - return func(m optionalAttr) { - m["output_type"] = value - } -} - -// Returns the index with the smallest value across dimensions of a tensor. -// -// Note that in case of ties the identity of the return value is not guaranteed. -// -// Arguments: +// Computes the reciprocal of x element-wise. // -// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. -// Describes which dimension of the input Tensor to reduce across. For vectors, -// use dimension = 0. -func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { +// I.e., \\(y = 1 / x\\). +func Inv(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ArgMin", + Type: "Inv", Input: []tf.Input{ - input, dimension, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Convert the quantized 'input' tensor into a lower-precision 'output', using the -// -// output range specified with 'requested_output_min' and 'requested_output_max'. -// -// [input_min, input_max] are scalar floats that specify the range for the float -// interpretation of the 'input' data. For example, if input_min is -1.0f and -// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 -// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. +// Creates a dataset that batches input elements into a SparseTensor. // // Arguments: +// input_dataset: A handle to an input dataset. Must have a single component. +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// row_shape: A vector representing the dense shape of each row in the produced +// SparseTensor. The shape may be partially specified, using `-1` to indicate +// that a particular dimension should use the maximum size of all batch elements. // -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// requested_output_min: The float value that the minimum quantized output value represents. -// requested_output_max: The float value that the maximum quantized output value represents. -// out_type: The type of the output. Should be a lower bit depth than Tinput. // -// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output. -func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { +func ExperimentalDenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Requantize", + Type: "ExperimentalDenseToSparseBatchDataset", Input: []tf.Input{ - input, input_min, input_max, requested_output_min, requested_output_max, + input_dataset, batch_size, row_shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Creates a dataset that emits the lines of one or more text files. +// Computes the reciprocal of x element-wise. // -// Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar containing the number of bytes to buffer. -func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { +// I.e., \\(y = 1 / x\\). +func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TextLineDataset", + Type: "Reciprocal", Input: []tf.Input{ - filenames, compression_type, buffer_size, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the sum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \sum_j data_j\\) where sum is over `j` such -// that `segment_ids[j] == i`. -// -// If the sum is empty for a given segment ID `i`, `output[i] = 0`. +// Conv3DBackpropFilterAttr is an optional argument to Conv3DBackpropFilter. +type Conv3DBackpropFilterAttr func(optionalAttr) + +// Conv3DBackpropFilterDilations sets the optional dilations attribute to value. +// If not specified, defaults to +func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the filter. // -//
-// -//
+// DEPRECATED at GraphDef version 10: Use Conv3DBackpropFilterV2 // // Arguments: -// -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SegmentSum", + Type: "Conv3DBackpropFilter", Input: []tf.Input{ - data, segment_ids, + input, filter, out_backprop, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the mean along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is -// over `j` such that `segment_ids[j] == i` and `N` is the total number of -// values summed. -// -// If the mean is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// Arguments: -// -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. +// Computes square root of x element-wise. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// I.e., \\(y = \sqrt{x} = x^{1/2}\\). +func Sqrt(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentMean", + Type: "Sqrt", Input: []tf.Input{ - data, segment_ids, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the minimum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such -// that `segment_ids[j] == i`. -// -// If the min is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
+// Get the value of the tensor specified by its handle. // // Arguments: +// handle: The handle for a tensor stored in the session state. +// dtype: The type of the output value. // -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns The tensor for the given handle. +func GetSessionTensor(scope *Scope, handle tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "SegmentMin", + Type: "GetSessionTensor", Input: []tf.Input{ - data, segment_ids, + handle, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Removes keys and its associated values from a table. -// -// The tensor `keys` must of the same type as the keys of the table. Keys not -// already in the table are silently ignored. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys of the elements to remove. +// Computes the gradient for the sqrt of `x` wrt its input. // -// Returns the created operation. -func LookupTableRemoveV2(scope *Scope, table_handle tf.Output, keys tf.Output) (o *tf.Operation) { +// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` +// is the corresponding input gradient. +func SqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LookupTableRemoveV2", + Type: "SqrtGrad", Input: []tf.Input{ - table_handle, keys, + y, dy, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the sum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. +// MatrixInverseAttr is an optional argument to MatrixInverse. +type MatrixInverseAttr func(optionalAttr) + +// MatrixInverseAdjoint sets the optional adjoint attribute to value. +// If not specified, defaults to false +func MatrixInverseAdjoint(value bool) MatrixInverseAttr { + return func(m optionalAttr) { + m["adjoint"] = value + } +} + +// Computes the inverse of one or more square invertible matrices or their // -// Computes a tensor such that -// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such -// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` -// need not be sorted and need not cover all values in the full -// range of valid values. +// adjoints (conjugate transposes). // -// If the sum is empty for a given segment ID `i`, `output[i] = 0`. -// If the given segment ID `i` is negative, the value is dropped and will not be -// added to the sum of the segment. +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the inverse for all input submatrices `[..., :, :]`. // -// `num_segments` should equal the number of distinct segment IDs. +// The op uses LU decomposition with partial pivoting to compute the inverses. // -//
-// -//
+// If a matrix is not invertible there is no guarantee what the op does. It +// may detect the condition and raise an exception or it may simply return a +// garbage result. // // Arguments: +// input: Shape is `[..., M, M]`. // -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// +// Returns Shape is `[..., M, M]`. // -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// @compatibility(numpy) +// Equivalent to np.linalg.inv +// @end_compatibility +func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "UnsortedSegmentSum", + Type: "MatrixInverse", Input: []tf.Input{ - data, segment_ids, num_segments, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the product along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation) -// for an explanation of segments. -// -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the product of all -// entries belonging to a segment such that: -// -// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples -// `j...` such that `segment_ids[j...] == i`. -// -// If there is no entry for a given segment ID `i`, it outputs 1. -// -// If the given segment ID `i` is negative, then the corresponding value is -// dropped, and will not be included in the result. -// -// Arguments: -// -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// +// Computes reciprocal of square root of x element-wise. // -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// I.e., \\(y = 1 / \sqrt{x}\\). +func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "UnsortedSegmentProd", + Type: "Rsqrt", Input: []tf.Input{ - data, segment_ids, num_segments, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the mean along sparse segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// Rounds the values of a tensor to the nearest integer, element-wise. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// Rounds half to even. Also known as bankers rounding. If you want to round +// according to the current system rounding mode use std::cint. +func Round(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentMean", + Type: "Round", Input: []tf.Input{ - data, indices, segment_ids, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deserializes a serialized tree ensemble config and replaces current tree +// Delete the TensorArray from its resource container. // -// ensemble. +// This enables the user to close and release the resource in the middle +// of a step/run. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. -// stamp_token: Token to use as the new value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the ensemble. +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). // // Returns the created operation. -func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { +func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesDeserializeEnsemble", + Type: "TensorArrayCloseV3", Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + handle, }, } return scope.AddOperation(opspec) } -// Transforms a tf.Example proto (as a string) into typed tensors. -// -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// dense_defaults: A list of Tensors (some may be empty), whose length matches -// the length of `dense_keys`. dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// num_sparse: The number of sparse features to be parsed from the example. This -// must match the lengths of `sparse_keys` and `sparse_types`. -// sparse_keys: A list of `num_sparse` strings. -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: The keys expected in the Examples' features associated with dense -// values. -// sparse_types: A list of `num_sparse` types; the data types of data in each -// Feature given in sparse_keys. -// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: The shapes of data in each Feature given in dense_keys. -// The length of this list must match the length of `dense_keys`. The -// number of elements in the Feature corresponding to dense_key[j] must -// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == -// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] -// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, -// ..., DN), the shape of the output Tensor dense_values[j] will be (M, -// D1, .., DN), where M is the number of blocks of elements of length -// D1 * .... * DN, in the input. -func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { +// Computes exponential of x element-wise. \\(y = e^x\\). +func Exp(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} opspec := tf.OpSpec{ - Type: "ParseSingleExample", + Type: "Exp", Input: []tf.Input{ - serialized, tf.OutputList(dense_defaults), + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - return sparse_indices, sparse_values, sparse_shapes, dense_values + return op.Output(0) } -// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. -type WholeFileReaderV2Attr func(optionalAttr) +// NthElementAttr is an optional argument to NthElement. +type NthElementAttr func(optionalAttr) -// WholeFileReaderV2Container sets the optional container attribute to value. +// NthElementReverse sets the optional reverse attribute to value. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { +// value: When set to True, find the nth-largest value in the vector and vice +// versa. +// If not specified, defaults to false +func NthElementReverse(value bool) NthElementAttr { return func(m optionalAttr) { - m["container"] = value + m["reverse"] = value } } -// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. +// Finds values of the `n`-th order statistic for the last dimension. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the entire contents of a file as a value. +// If the input is a vector (rank-1), finds the entries which is the nth-smallest +// value in the vector and outputs their values as scalar tensor. // -// To use, enqueue filenames in a Queue. The output of ReaderRead will -// be a filename (key) and the contents of that file (value). +// For matrices (resp. higher rank input), computes the entries which is the +// nth-smallest value in each row (resp. vector along the last dimension). Thus, // -// Returns The handle to reference the Reader. -func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { +// values.shape = input.shape[:-1] +// +// Arguments: +// input: 1-D or higher with last dimension at least `n+1`. +// n: 0-D. Position of sorted vector to select along the last dimension (along +// each row for matrices). Valid range of n is `[0, input.shape[:-1])` +// +// Returns The `n`-th order statistic along each last dimensional slice. +func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) { if scope.Err() != nil { return } @@ -26008,44 +30065,96 @@ func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_ a(attrs) } opspec := tf.OpSpec{ - Type: "WholeFileReaderV2", - + Type: "NthElement", + Input: []tf.Input{ + input, n, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Pop the element at the top of the stack. +// Computes the maximum along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the maximum such that: +// +// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such +// that `segment_ids[j...] == i`. +// +// If the maximum is empty for a given segment ID `i`, it outputs the smallest +// possible value for the specific numeric type, +// `output[i] = numeric_limits::lowest()`. +// +// If the given segment ID `i` is negative, then the corresponding value is +// dropped, and will not be included in the result. +// +//
+// +//
+// +// For example: +// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 4, 3, 3, 4], +// # [5, 6, 7, 8]] +// ``` +// // // Arguments: -// handle: The handle to a stack. -// elem_type: The type of the elem that is popped. // -// Returns The tensor that is popped from the top of the stack. -func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnsortedSegmentMax", + Input: []tf.Input{ + data, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softplus: `log(exp(features) + 1)`. +func Softplus(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} opspec := tf.OpSpec{ - Type: "StackPopV2", + Type: "Softplus", Input: []tf.Input{ - handle, + features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes hyperbolic cosine of x element-wise. -func Cosh(scope *Scope, x tf.Output) (y tf.Output) { +// Computes exponential of x - 1 element-wise. +// +// I.e., \\(y = (\exp x) - 1\\). +func Expm1(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cosh", + Type: "Expm1", Input: []tf.Input{ x, }, @@ -26054,162 +30163,97 @@ func Cosh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes the mean along sparse segments of a tensor. -// -// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. +// Computes natural logarithm of x element-wise. // -// Returns Has same shape as data, except for dimension 0 which has size -// `num_segments`. -func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// I.e., \\(y = \log_e x\\). +func Log(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentMeanWithNumSegments", + Type: "Log", Input: []tf.Input{ - data, indices, segment_ids, num_segments, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize. -type CudnnRNNParamsSizeAttr func(optionalAttr) - -// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNParamsSizeDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} - -// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNParamsSizeSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Computes size of weights that can be used by a Cudnn RNN model. +// Returns the index of a data point that should be added to the seed set. // -// Return the params size that can be used by the Cudnn RNN model. Subsequent -// weight allocation and initialization should use this size. +// Entries in distances are assumed to be squared distances of candidate points to +// the already sampled centers in the seed set. The op constructs one Markov chain +// of the k-MC^2 algorithm and returns the index of one candidate point to be added +// as an additional cluster center. // -// num_layers: Specifies the number of layers in the RNN model. -// num_units: Specifies the size of the hidden state. -// input_size: Specifies the size of the input state. -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -// params_size: The size of the params buffer that should be allocated and -// initialized for this RNN model. Note that this params buffer may not be -// compatible across GPUs. Please use CudnnRNNParamsWeights and -// CudnnRNNParamsBiases to save and restore them in a way that is compatible -// across different runs. -func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) { +// Arguments: +// distances: Vector with squared distances to the closest previously sampled cluster center +// for each candidate point. +// seed: Scalar. Seed for initializing the random number generator. +// +// Returns Scalar with the index of the sampled point. +func KMC2ChainInitialization(scope *Scope, distances tf.Output, seed tf.Output) (index tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T, "S": S} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "CudnnRNNParamsSize", + Type: "KMC2ChainInitialization", Input: []tf.Input{ - num_layers, num_units, input_size, + distances, seed, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for SparseSegmentMean. -// -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. -// -// Arguments: -// grad: gradient propagated to the SparseSegmentMean op. -// indices: indices passed to the corresponding SparseSegmentMean op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. -func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { +// Computes hyperbolic sine of x element-wise. +func Sinh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentMeanGrad", + Type: "Sinh", Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. -// -// N is the size of the segment being reduced. +// Computes the sum along sparse segments of a tensor. // // Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) // for an explanation of segments. // +// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. +// +// For example: +// +// ```python +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// +// # Select two rows, one segment. +// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) +// # => [[0 0 0 0]] +// +// # Select two rows, two segment. +// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) +// # => [[ 1 2 3 4] +// # [-1 -2 -3 -4]] +// +// # Select all rows, two segments. +// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) +// # => [[0 0 0 0] +// # [5 6 7 8]] +// +// # Which is equivalent to: +// tf.segment_sum(c, tf.constant([0, 0, 1])) +// ``` +// // Arguments: // // indices: A 1-D tensor. Has same rank as `segment_ids`. @@ -26217,12 +30261,12 @@ func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segm // // Returns Has same shape as data, except for dimension 0 which // has size `k`, the number of segments. -func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtN", + Type: "SparseSegmentSum", Input: []tf.Input{ data, indices, segment_ids, }, @@ -26231,154 +30275,123 @@ func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment return op.Output(0) } -// Compute the upper regularized incomplete Gamma function `Q(a, x)`. -// -// The upper regularized incomplete Gamma function is defined as: -// -// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) -// -// where -// -// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) -// -// is the upper incomplete Gama function. -// -// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete -// Gamma function. -func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igammac", - Input: []tf.Input{ - a, x, - }, +// CastAttr is an optional argument to Cast. +type CastAttr func(optionalAttr) + +// CastTruncate sets the optional Truncate attribute to value. +// If not specified, defaults to false +func CastTruncate(value bool) CastAttr { + return func(m optionalAttr) { + m["Truncate"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. -// -// N is the size of the segment being reduced. -// -// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// Cast x of type SrcT to y of DstT. +func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"DstT": DstT} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNWithNumSegments", + Type: "Cast", Input: []tf.Input{ - data, indices, segment_ids, num_segments, + x, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for SparseSegmentSqrtN. -// -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. -// -// Arguments: -// grad: gradient propagated to the SparseSegmentSqrtN op. -// indices: indices passed to the corresponding SparseSegmentSqrtN op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. -func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { +// Computes the log of the absolute value of `Gamma(x)` element-wise. +func Lgamma(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNGrad", + Type: "Lgamma", Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// LRNGradAttr is an optional argument to LRNGrad. -type LRNGradAttr func(optionalAttr) +// UnicodeEncodeAttr is an optional argument to UnicodeEncode. +type UnicodeEncodeAttr func(optionalAttr) -// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// UnicodeEncodeErrors sets the optional errors attribute to value. // -// value: A depth radius. -// If not specified, defaults to 5 -func LRNGradDepthRadius(value int64) LRNGradAttr { +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeEncodeErrors(value string) UnicodeEncodeAttr { return func(m optionalAttr) { - m["depth_radius"] = value + m["errors"] = value } } -// LRNGradBias sets the optional bias attribute to value. +// UnicodeEncodeReplacementChar sets the optional replacement_char attribute to value. // -// value: An offset (usually > 0 to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNGradBias(value float32) LRNGradAttr { +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD (U+65533). +// If not specified, defaults to 65533 +func UnicodeEncodeReplacementChar(value int64) UnicodeEncodeAttr { return func(m optionalAttr) { - m["bias"] = value + m["replacement_char"] = value } } -// LRNGradAlpha sets the optional alpha attribute to value. +// Encode a tensor of ints into unicode strings. // -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNGradAlpha(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// LRNGradBeta sets the optional beta attribute to value. +// Returns a vector of strings, where `output[i]` is constructed by encoding the +// Unicode codepoints in `input_values[input_splits[i]:input_splits[i+1]]` +// using `output_encoding`. // -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNGradBeta(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["beta"] = value - } -} - -// Gradients for Local Response Normalization. +// --- +// +// Example: +// +// ``` +// input_values = [72, 101, 108, 108, 111, 87, 111, 114, 108, 100] +// input_splits = [0, 5, 10] +// output_encoding = 'UTF-8' +// +// output = ['Hello', 'World'] +// ``` // // Arguments: -// input_grads: 4-D with shape `[batch, height, width, channels]`. -// input_image: 4-D with shape `[batch, height, width, channels]`. -// output_image: 4-D with shape `[batch, height, width, channels]`. +// input_values: A 1D tensor containing the unicode codepoints that should be encoded. +// input_splits: A 1D tensor specifying how the unicode codepoints should be split into strings. +// In particular, `output[i]` is constructed by encoding the codepoints in the +// slice `input_values[input_splits[i]:input_splits[i+1]]`. +// output_encoding: Unicode encoding of the output strings. Valid encodings are: `"UTF-8", +// "UTF-16-BE", and "UTF-32-BE"`. // -// Returns The gradients for LRN. -func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { +// Returns The 1-D Tensor of strings encoded from the provided unicode codepoints. +func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, output_encoding string, optional ...UnicodeEncodeAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"output_encoding": output_encoding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LRNGrad", + Type: "UnicodeEncode", Input: []tf.Input{ - input_grads, input_image, output_image, + input_values, input_splits, }, Attrs: attrs, } @@ -26386,195 +30399,110 @@ func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_ return op.Output(0) } -// AnyAttr is an optional argument to Any. -type AnyAttr func(optionalAttr) - -// AnyKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AnyKeepDims(value bool) AnyAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the "logical or" of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { +// Computes the complementary error function of `x` element-wise. +func Erfc(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Any", + Type: "Erfc", Input: []tf.Input{ - input, axis, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) - -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. -// -// value: whether to ignore the error when the resource -// doesn't exist. -// If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { - return func(m optionalAttr) { - m["ignore_lookup_error"] = value - } -} - -// Deletes the resource specified by the handle. -// -// All subsequent operations using the resource will result in a NotFound -// error status. -// -// Arguments: -// resource: handle to the resource to delete. +// Computes sigmoid of `x` element-wise. // -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DestroyResourceOp", + Type: "Sigmoid", Input: []tf.Input{ - resource, + x, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Generates values in an interval. -// -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. -// -// For example: -// -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` -// -// Arguments: -// start: 0-D tensor. First entry in the range. -// stop: 0-D tensor. Last entry in the range. -// num: 0-D tensor. Number of values to generate. -// -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { +// Computes sin of x element-wise. +func Sin(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LinSpace", + Type: "Sin", Input: []tf.Input{ - start, stop, num, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ComplexAttr is an optional argument to Complex. -type ComplexAttr func(optionalAttr) +// FusedBatchNormGradAttr is an optional argument to FusedBatchNormGrad. +type FusedBatchNormGradAttr func(optionalAttr) -// ComplexTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_COMPLEX64 -func ComplexTout(value tf.DataType) ComplexAttr { +// FusedBatchNormGradEpsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradEpsilon(value float32) FusedBatchNormGradAttr { return func(m optionalAttr) { - m["Tout"] = value + m["epsilon"] = value } } -// Converts two real numbers to a complex number. -// -// Given a tensor `real` representing the real part of a complex number, and a -// tensor `imag` representing the imaginary part of a complex number, this -// operation returns complex numbers elementwise of the form \\(a + bj\\), where -// *a* represents the `real` part and *b* represents the `imag` part. -// -// The input tensors `real` and `imag` must have the same shape. -// -// For example: +// FusedBatchNormGradDataFormat sets the optional data_format attribute to value. // -// ``` -// # tensor 'real' is [2.25, 3.25] -// # tensor `imag` is [4.75, 5.75] -// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] -// ``` -func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Complex", - Input: []tf.Input{ - real, imag, - }, - Attrs: attrs, +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormGradDataFormat(value string) FusedBatchNormGradAttr { + return func(m optionalAttr) { + m["data_format"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ImagAttr is an optional argument to Imag. -type ImagAttr func(optionalAttr) - -// ImagTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ImagTout(value tf.DataType) ImagAttr { +// FusedBatchNormGradIsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradIsTraining(value bool) FusedBatchNormGradAttr { return func(m optionalAttr) { - m["Tout"] = value + m["is_training"] = value } } -// Returns the imaginary part of a complex number. +// Gradient for batch normalization. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the imaginary part of each element in `input`. All -// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part returned by this operation. +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. // -// For example: +// Arguments: +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.imag(input) ==> [4.75, 5.75] -// ``` -func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradAttr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { if scope.Err() != nil { return } @@ -26583,23 +30511,23 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "Imag", + Type: "FusedBatchNormGrad", Input: []tf.Input{ - input, + y_backprop, x, scale, reserve_space_1, reserve_space_2, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Computes hyperbolic tangent of `x` element-wise. -func Tanh(scope *Scope, x tf.Output) (y tf.Output) { +// Computes cos of x element-wise. +func Cos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Tanh", + Type: "Cos", Input: []tf.Input{ x, }, @@ -26608,103 +30536,95 @@ func Tanh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes the maximum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such -// that `segment_ids[j] == i`. -// -// If the max is empty for a given segment ID `i`, `output[i] = 0`. +// Computes the determinant of one or more square matrices. // -//
-// -//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor containing the determinants +// for all input submatrices `[..., :, :]`. // // Arguments: +// input: Shape is `[..., M, M]`. // -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns Shape is `[...]`. +func MatrixDeterminant(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentMax", + Type: "MatrixDeterminant", Input: []tf.Input{ - data, segment_ids, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that skips `count` elements from the `input_dataset`. -// -// Arguments: +// Updates the tree ensemble by either adding a layer to the last tree being grown // -// count: A scalar representing the number of elements from the `input_dataset` -// that should be skipped. If count is -1, skips everything. +// or by starting a new tree. // +// Arguments: +// tree_ensemble_handle: Handle to the ensemble variable. +// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of +// the feature that will be used in the split. +// node_ids: List of rank 1 tensors representing the nodes for which this feature +// has a split. +// gains: List of rank 1 tensors representing the gains for each of the feature's +// split. +// thresholds: List of rank 1 tensors representing the thesholds for each of the +// feature's split. +// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of +// the feature's splits. Will be added to the previous node values to constitute +// the values of the left nodes. +// right_node_contribs: List of rank 2 tensors with right leaf contribs for each +// of the feature's splits. Will be added to the previous node values to constitute +// the values of the right nodes. +// max_depth: Max depth of the tree to build. +// learning_rate: shrinkage const for each new tree. +// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. // -func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns the created operation. +func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"pruning_mode": pruning_mode} opspec := tf.OpSpec{ - Type: "SkipDataset", + Type: "BoostedTreesUpdateEnsemble", Input: []tf.Input{ - input_dataset, count, + tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RealAttr is an optional argument to Real. -type RealAttr func(optionalAttr) - -// RealTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func RealTout(value tf.DataType) RealAttr { - return func(m optionalAttr) { - m["Tout"] = value +// Computes tan of x element-wise. +func Tan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tan", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the real part of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the real part of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real -// part returned by this operation and *b* is the imaginary part. -// -// For example: -// -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.real(input) ==> [-2.25, 3.25] -// ``` -func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { +// Creates a dataset that emits each dim-0 slice of `components` once. +func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Real", + Type: "TensorSliceDataset", Input: []tf.Input{ - input, + tf.OutputList(components), }, Attrs: attrs, } @@ -26712,541 +30632,384 @@ func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output return op.Output(0) } -// Sends `input` to all devices that are connected to the output. -// -// Sends `input` to all devices that are connected to the output. -// -// The graph should be constructed so that all ops connected to the output have a -// valid device assignment, and the op itself is assigned one of these devices. -// -// input: The input to the broadcast. -// output: The same as input. -// shape: The shape of the input tensor. -// -func NcclBroadcast(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { +// Computes acos of x element-wise. +func Acos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "NcclBroadcast", + Type: "Acos", Input: []tf.Input{ - input, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeAreaAttr is an optional argument to ResizeArea. -type ResizeAreaAttr func(optionalAttr) - -// ResizeAreaAlignCorners sets the optional align_corners attribute to value. +// Computes the Bessel i0e function of `x` element-wise. // -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { - return func(m optionalAttr) { - m["align_corners"] = value +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. +// +// This function is faster and numerically stabler than `bessel_i0(x)`. +func BesselI0e(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BesselI0e", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Resize `images` to `size` using area interpolation. -// -// Input images can be of different types but output images are always float. -// -// The range of pixel values for the output image might be slightly different -// from the range for the input image because of limited numerical precision. -// To guarantee an output range, for example `[0.0, 1.0]`, apply -// `tf.clip_by_value` to the output. -// -// Each output pixel is computed by first transforming the pixel's footprint into -// the input tensor and then averaging the pixels that intersect the footprint. An -// input pixel's contribution to the average is weighted by the fraction of its -// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// Shuffle dimensions of x according to a permutation. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResizeArea", + Type: "Transpose", Input: []tf.Input{ - images, size, + x, perm, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// VarHandleOpAttr is an optional argument to VarHandleOp. -type VarHandleOpAttr func(optionalAttr) +// MinAttr is an optional argument to Min. +type MinAttr func(optionalAttr) -// VarHandleOpContainer sets the optional container attribute to value. +// MinKeepDims sets the optional keep_dims attribute to value. // -// value: the container this variable is placed in. -// If not specified, defaults to "" -func VarHandleOpContainer(value string) VarHandleOpAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MinKeepDims(value bool) MinAttr { return func(m optionalAttr) { - m["container"] = value + m["keep_dims"] = value } } -// VarHandleOpSharedName sets the optional shared_name attribute to value. +// Computes the minimum of elements across dimensions of a tensor. // -// value: the name by which this variable is referred to. -// If not specified, defaults to "" -func VarHandleOpSharedName(value string) VarHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a Variable resource. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// dtype: the type of this variable. Must agree with the dtypes -// of all ops using this variable. -// shape: The (possibly partially specified) shape of this variable. -func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "VarHandleOp", - + Type: "Min", + Input: []tf.Input{ + input, axis, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AngleAttr is an optional argument to Angle. -type AngleAttr func(optionalAttr) - -// AngleTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func AngleTout(value tf.DataType) AngleAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Returns the argument of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the argument of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part. -// -// The argument returned by this operation is of the form \\(atan2(b, a)\\). -// -// For example: +// Computes the Bessel i1e function of `x` element-wise. // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.angle(input) ==> [2.0132, 1.056] -// ``` +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. // -// @compatibility(numpy) -// Equivalent to np.angle. -// @end_compatibility -func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { +// This function is faster and numerically stabler than `bessel_i1(x)`. +func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Angle", + Type: "BesselI1e", Input: []tf.Input{ - input, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Clips tensor values to a specified min and max. -// -// Given a tensor `t`, this operation returns a tensor of the same type and -// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. -// Any values less than `clip_value_min` are set to `clip_value_min`. Any values -// greater than `clip_value_max` are set to `clip_value_max`. +// Returns an element-wise indication of the sign of a number. // -// Arguments: -// t: A `Tensor`. -// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The minimum value to clip by. -// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The maximum value to clip by. +// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. // -// Returns A clipped `Tensor` with the same shape as input 't'. -func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { +// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. +func Sign(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ClipByValue", + Type: "Sign", Input: []tf.Input{ - t, clip_value_min, clip_value_max, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Counts the number of occurrences of each value in an integer array. +// Creates a dataset that passes a sliding window over `input_dataset`. // -// Outputs a vector with length `size` and the same dtype as `weights`. If -// `weights` are empty, then index `i` stores the number of times the value `i` is -// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of -// the value in `weights` at each index where the corresponding value in `arr` is -// `i`. +// Arguments: // -// Values in `arr` outside of the range [0, size) are ignored. +// window_size: A scalar representing the number of elements in the +// sliding window. +// window_shift: A scalar representing the steps moving the sliding window +// forward in one iteration. It must be positive. +// window_stride: A scalar representing the stride of the input elements of the sliding window. +// It must be positive. // -// Arguments: -// arr: int32 `Tensor`. -// size: non-negative int32 scalar `Tensor`. -// weights: is an int32, int64, float32, or float64 `Tensor` with the same -// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights -// equal to 1. // -// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for -// each value in the range [0, size). -func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { +func ExperimentalSlidingWindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, window_shift tf.Output, window_stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Bincount", + Type: "ExperimentalSlidingWindowDataset", Input: []tf.Input{ - arr, size, weights, + input_dataset, window_size, window_shift, window_stride, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CumsumAttr is an optional argument to Cumsum. -type CumsumAttr func(optionalAttr) +// OrderedMapUnstageNoKeyAttr is an optional argument to OrderedMapUnstageNoKey. +type OrderedMapUnstageNoKeyAttr func(optionalAttr) -// CumsumExclusive sets the optional exclusive attribute to value. +// OrderedMapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If `True`, perform exclusive cumsum. -// If not specified, defaults to false -func CumsumExclusive(value bool) CumsumAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageNoKeyCapacity(value int64) OrderedMapUnstageNoKeyAttr { return func(m optionalAttr) { - m["exclusive"] = value + m["capacity"] = value } } -// CumsumReverse sets the optional reverse attribute to value. +// OrderedMapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumsumReverse(value bool) CumsumAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageNoKeyMemoryLimit(value int64) OrderedMapUnstageNoKeyAttr { return func(m optionalAttr) { - m["reverse"] = value + m["memory_limit"] = value } } -// Compute the cumulative sum of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumsum, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is -// performed instead: -// -// ```python -// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumsum is performed in the -// opposite direction: -// -// ```python -// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] -// ``` +// OrderedMapUnstageNoKeyContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageNoKeyContainer(value string) OrderedMapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapUnstageNoKeySharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageNoKeySharedName(value string) OrderedMapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the (key, value) element with the smallest // -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { +// key from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Cumsum", + Type: "OrderedMapUnstageNoKey", Input: []tf.Input{ - x, axis, + indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Return the shape of s0 op s1 with broadcast. -// -// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the -// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. -func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "BroadcastArgs", - Input: []tf.Input{ - s0, s1, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. -type DataFormatDimMapAttr func(optionalAttr) - -// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. -// -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { - return func(m optionalAttr) { - m["src_format"] = value - } -} - -// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. -// -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { - return func(m optionalAttr) { - m["dst_format"] = value - } + var idx int + var err error + key = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapUnstageNoKey", err) + return + } + return key, values } -// Returns the dimension index in the destination data format given the one in -// -// the source data format. +// Returns element-wise integer closest to x. // -// Arguments: -// x: A Tensor with each element as a dimension index in source data format. -// Must be in the range [-4, 4). +// If the result is midway between two representable values, +// the even representable is chosen. +// For example: // -// Returns A Tensor with each element as a dimension index in destination data format. -func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { +// ``` +// rint(-1.5) ==> -2.0 +// rint(0.5000001) ==> 1.0 +// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] +// ``` +func Rint(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DataFormatDimMap", + Type: "Rint", Input: []tf.Input{ x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CumprodAttr is an optional argument to Cumprod. -type CumprodAttr func(optionalAttr) - -// CumprodExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumprod. -// If not specified, defaults to false -func CumprodExclusive(value bool) CumprodAttr { - return func(m optionalAttr) { - m["exclusive"] = value +// Computes the derivative of a Gamma random sample w.r.t. `alpha`. +func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RandomGammaGrad", + Input: []tf.Input{ + alpha, sample, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// CumprodReverse sets the optional reverse attribute to value. +// Returns x + y element-wise. // -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumprodReverse(value bool) CumprodAttr { - return func(m optionalAttr) { - m["reverse"] = value +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Add", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Compute the cumulative product of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumprod, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is -// performed instead: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumprod is performed in the -// opposite direction: -// -// ```python -// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] -// ``` +// Returns x + y element-wise. // -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Cumprod", + Type: "AddV2", Input: []tf.Input{ - x, axis, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. -type QuantizedMatMulAttr func(optionalAttr) - -// QuantizedMatMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} +// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. +type AllCandidateSamplerAttr func(optionalAttr) -// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// AllCandidateSamplerSeed sets the optional seed attribute to value. // -// value: If true, `a` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { - m["transpose_a"] = value + m["seed"] = value } } -// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. +// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// value: If true, `b` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { - m["transpose_b"] = value + m["seed2"] = value } } -// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. +// Generates labels for candidate sampling with a learned unigram distribution. // -// value: The type of output produced by activation function -// following this operation. -// If not specified, defaults to DT_QUINT8 -func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Tactivation"] = value - } -} - -// Perform a quantized matrix multiplication of `a` by the matrix `b`. +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // -// The inputs must be two-dimensional matrices and the inner dimension of -// `a` (after being transposed if `transpose_a` is non-zero) must match the -// outer dimension of `b` (after being transposed if `transposed_b` is -// non-zero). +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// a: Must be a two-dimensional tensor. -// b: Must be a two-dimensional tensor. -// min_a: The float value that the lowest quantized `a` value represents. -// max_a: The float value that the highest quantized `a` value represents. -// min_b: The float value that the lowest quantized `b` value represents. -// max_b: The float value that the highest quantized `b` value represents. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to produce. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedMatMul", + Type: "AllCandidateSampler", Input: []tf.Input{ - a, b, min_a, max_a, min_b, max_b, + true_classes, }, Attrs: attrs, } @@ -27254,148 +31017,119 @@ func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, ma return op.Output(0), op.Output(1), op.Output(2) } -// Does nothing. Serves as a control trigger for scheduling. +// Returns element-wise remainder of division. When `x < 0` xor `y < 0` is // -// Only useful as a placeholder for control edges. +// true, this follows Python semantics in that the result here is consistent +// with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. // -// Returns the created operation. -func ControlTrigger(scope *Scope) (o *tf.Operation) { +// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func FloorMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ControlTrigger", + Type: "FloorMod", + Input: []tf.Input{ + x, y, + }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Batch normalization. +// Saves the input tensors to disk. // -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` +// is written to `filename` with name `tensor_names[i]`. // -// This op is deprecated. Prefer `tf.nn.batch_normalization`. +// See also `SaveSlices`. // // Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { +// filename: Must have a single element. The name of the file to which we write +// the tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// data: `N` tensors to save. +// +// Returns the created operation. +func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalization", + Type: "Save", Input: []tf.Input{ - t, m, v, beta, gamma, + filename, tensor_names, tf.OutputList(data), }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Deprecated. Use TensorArrayReadV3 +// Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN. // -// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 -func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { +// *NOTE*: `Mul` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func MulNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "TensorArrayReadV2", + Type: "MulNoNan", Input: []tf.Input{ - handle, index, flow_in, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedMulAttr is an optional argument to QuantizedMul. -type QuantizedMulAttr func(optionalAttr) - -// QuantizedMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// Returns x * y element-wise, working on quantized buffers. -// -// Arguments: -// -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. +// Returns x / y element-wise for integer types. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// Truncation designates that negative numbers will round fractional quantities +// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different +// than Python semantics. See `FloorDiv` for a division function that matches +// Python Semantics. // -// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { +// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "QuantizedMul", + Type: "TruncateDiv", Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// QuantizedAddAttr is an optional argument to QuantizedAdd. -type QuantizedAddAttr func(optionalAttr) +// RequantizePerChannelAttr is an optional argument to RequantizePerChannel. +type RequantizePerChannelAttr func(optionalAttr) -// QuantizedAddToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { +// RequantizePerChannelOutType sets the optional out_type attribute to value. +// +// value: The quantized type of output tensor that needs to be converted. +// If not specified, defaults to DT_QUINT8 +func RequantizePerChannelOutType(value tf.DataType) RequantizePerChannelAttr { return func(m optionalAttr) { - m["Toutput"] = value + m["out_type"] = value } } -// Returns x + y element-wise, working on quantized buffers. +// Requantizes input with min and max values known per channel. // // Arguments: +// input: The original input tensor. +// input_min: The minimum value of the input tensor +// input_max: The maximum value of the input tensor. +// requested_output_min: The minimum value of the output tensor requested. +// requested_output_max: The maximum value of the output tensor requested. // -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -// -// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { +// Returns Output tensor.The minimum value of the final output tensorThe maximum value of the final output tensor. +func RequantizePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, optional ...RequantizePerChannelAttr) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } @@ -27404,9 +31138,9 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedAdd", + Type: "RequantizePerChannel", Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, + input, input_min, input_max, requested_output_min, requested_output_max, }, Attrs: attrs, } @@ -27414,501 +31148,620 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x return op.Output(0), op.Output(1), op.Output(2) } -// Given a quantized tensor described by (input, input_min, input_max), outputs a +// Restores tensors from a V2 checkpoint. +// +// For backward compatibility with the V1 format, this Op currently allows +// restoring from a V1 checkpoint as well: +// - This Op first attempts to find the V2 index file pointed to by "prefix", and +// if found proceed to read it as a V2 checkpoint; +// - Otherwise the V1 read path is invoked. +// Relying on this behavior is not recommended, as the ability to fall back to read +// V1 might be deprecated and eventually removed. // -// range that covers the actual values present in that tensor. This op is -// typically used to produce the requested_output_min and requested_output_max for -// Requantize. +// By default, restores the named tensors in full. If the caller wishes to restore +// specific slices of stored tensors, "shape_and_slices" should be non-empty +// strings and correspondingly well-formed. // -// Arguments: +// Callers must ensure all the named tensors are indeed stored in the checkpoint. // -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. +// Arguments: +// prefix: Must have a single element. The prefix of a V2 checkpoint. +// tensor_names: shape {N}. The names of the tensors to be restored. +// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. +// Empty strings indicate that they are non-partitioned tensors. +// dtypes: shape {N}. The list of expected dtype for the tensors. Must match +// those stored in the checkpoint. // -// Returns The computed min output.the computed max output. -func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { +// Returns shape {N}. The restored tensors, whose shapes are read from the +// checkpoint directly. +func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} opspec := tf.OpSpec{ - Type: "RequantizationRange", + Type: "RestoreV2", Input: []tf.Input{ - input, input_min, input_max, + prefix, tensor_names, shape_and_slices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + if scope.Err() != nil { + return + } + var idx int + var err error + if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { + scope.UpdateErr("RestoreV2", err) + return + } + return tensors } -// Rolls the elements of a tensor along an axis. +// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. +type FIFOQueueV2Attr func(optionalAttr) + +// FIFOQueueV2Shapes sets the optional shapes attribute to value. // -// The elements are shifted positively (towards larger indices) by the offset of -// `shift` along the dimension of `axis`. Negative `shift` values will shift -// elements in the opposite direction. Elements that roll passed the last position -// will wrap around to the first and vice versa. Multiple shifts along multiple -// axes may be specified. +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> // -// For example: +// REQUIRES: len(value) >= 0 +func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// FIFOQueueV2Capacity sets the optional capacity attribute to value. // -// ``` -// # 't' is [0, 1, 2, 3, 4] -// roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// FIFOQueueV2Container sets the optional container attribute to value. // -// # shifting along multiple dimensions -// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] -// roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FIFOQueueV2Container(value string) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// FIFOQueueV2SharedName sets the optional shared_name attribute to value. // -// # shifting along the same axis multiple times -// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] -// roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] -// ``` +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that produces elements in first-in first-out order. // // Arguments: +// component_types: The type of each component in a value. // -// shift: Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which -// elements are shifted positively (towards larger indices) along the dimension -// specified by `axis[i]`. Negative shifts will roll the elements in the opposite -// direction. -// axis: Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift -// `shift[i]` should occur. If the same axis is referenced more than once, the -// total shift for that axis will be the sum of all the shifts that belong to that -// axis. -// -// Returns Has the same shape and size as the input. The elements are shifted -// positively (towards larger indices) by the offsets of `shift` along the -// dimensions of `axis`. -func Roll(scope *Scope, input tf.Output, shift tf.Output, axis tf.Output) (output tf.Output) { +// Returns The handle to the queue. +func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Roll", + Type: "FIFOQueueV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains the elements of `input_dataset` ignoring errors. +func ExperimentalIgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalIgnoreErrorsDataset", Input: []tf.Input{ - input, shift, axis, + input_dataset, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Looks up keys in a table, outputs the corresponding values. -// -// The tensor `keys` must of the same type as the keys of the table. -// The output `values` is of the type of the table values. -// -// The scalar `default_value` is the value output for keys not present in the -// table. It must also be of the same type as the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// -// -// Returns Same shape as `keys`. Values found in the table, or `default_values` -// for missing keys. -func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { +// Returns 0 if x == 0, and x / y otherwise, elementwise. +func Xdivy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LookupTableFindV2", + Type: "Xdivy", Input: []tf.Input{ - table_handle, keys, default_value, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Updates the table to associates keys with values. +// Bucketizes 'input' based on 'boundaries'. // -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. +// For example, if the inputs are +// boundaries = [0, 10, 100] +// input = [[-5, 10000] +// [150, 10] +// [5, 100]] +// +// then the output will be +// output = [[0, 3] +// [3, 2] +// [1, 3]] // // Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. +// input: Any shape of Tensor contains with int or float type. +// boundaries: A sorted list of floats gives the boundary of the buckets. // -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// Returns Same shape with 'input', each value of input replaced with bucket index. +// +// @compatibility(numpy) +// Equivalent to np.digitize. +// @end_compatibility +func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"boundaries": boundaries} opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", + Type: "Bucketize", Input: []tf.Input{ - table_handle, keys, values, + input, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Creates a dataset that batches and pads `batch_size` elements from the input. +// Calculates gains for each feature and returns the best possible split information for the feature. // -// Arguments: +// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. // -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// padded_shapes: A list of int64 tensors representing the desired padded shapes -// of the corresponding output components. These shapes may be partially -// specified, using `-1` to indicate that a particular dimension should be -// padded to the maximum size of all batch elements. -// padding_values: A list of scalars containing the padding value to use for -// each of the outputs. -// drop_remainder: A scalar representing whether the last batch should be dropped in case its size -// is smaller than desired. +// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. // -func PaddedBatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, drop_remainder tf.Output, output_shapes []tf.Shape) (handle tf.Output) { +// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). +// +// The length of output lists are all of the same length, `num_features`. +// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. +// +// Arguments: +// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). +// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. +// l1: l1 regularization factor on leaf weights, per instance based. +// l2: l2 regularization factor on leaf weights, per instance based. +// tree_complexity: adjustment to the gain, per leaf based. +// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. +// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. +// +// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. +func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_shapes": output_shapes} + attrs := map[string]interface{}{"max_splits": max_splits} opspec := tf.OpSpec{ - Type: "PaddedBatchDatasetV2", + Type: "BoostedTreesCalculateBestGainsPerFeature", Input: []tf.Input{ - input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), drop_remainder, + node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns element-wise smallest integer not less than x. -func Ceil(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Ceil", - Input: []tf.Input{ - x, - }, + var idx int + var err error + if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list } -// Computes the number of elements in the given table. +// EncodePngAttr is an optional argument to EncodePng. +type EncodePngAttr func(optionalAttr) + +// EncodePngCompression sets the optional compression attribute to value. +// +// value: Compression level. +// If not specified, defaults to -1 +func EncodePngCompression(value int64) EncodePngAttr { + return func(m optionalAttr) { + m["compression"] = value + } +} + +// PNG-encode an image. +// +// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +// where `channels` is: +// +// * 1: for grayscale. +// * 2: for grayscale + alpha. +// * 3: for RGB. +// * 4: for RGBA. +// +// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +// default or a value from 0 to 9. 9 is the highest compression level, generating +// the smallest output, but is slower. // // Arguments: -// table_handle: Handle to the table. +// image: 3-D with shape `[height, width, channels]`. // -// Returns Scalar that contains number of elements in the table. -func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { +// Returns 0-D. PNG-encoded image. +func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LookupTableSizeV2", + Type: "EncodePng", Input: []tf.Input{ - table_handle, + image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. -type ResizeBilinearGradAttr func(optionalAttr) +// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. +type QueueDequeueUpToV2Attr func(optionalAttr) -// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. +// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. // -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { return func(m optionalAttr) { - m["align_corners"] = value + m["timeout_ms"] = value } } -// Computes the gradient of bilinear interpolation. +// Dequeues `n` tuples of one or more tensors from the given queue. +// +// This operation is not supported by all queues. If a queue does not support +// DequeueUpTo, then an Unimplemented error is returned. +// +// If the queue is closed and there are more than 0 but less than `n` +// elements remaining, then instead of returning an OutOfRange error like +// QueueDequeueMany, less than `n` elements are returned immediately. If +// the queue is closed and there are 0 elements left in the queue, then +// an OutOfRange error is returned just like in QueueDequeueMany. +// Otherwise the behavior is identical to QueueDequeueMany: +// +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size n in the 0th dimension. +// +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. // // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBilinearGrad", + Type: "QueueDequeueUpToV2", Input: []tf.Input{ - grads, original_image, + handle, n, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueUpToV2", err) + return + } + return components } -// Outputs all keys and values in the table. -// -// Arguments: -// table_handle: Handle to the table. -// -// +// Returns the max of x and y (i.e. x > y ? x : y) element-wise. // -// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. -func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { +// *NOTE*: `Maximum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} opspec := tf.OpSpec{ - Type: "LookupTableExportV2", + Type: "Maximum", Input: []tf.Input{ - table_handle, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Replaces the contents of the table with the specified keys and values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. +// Returns element-wise remainder of division. This emulates C semantics in that // -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. +// the result here is consistent with a truncating divide. E.g. +// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. // -// Returns the created operation. -func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// *NOTE*: `Mod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LookupTableImportV2", + Type: "Mod", Input: []tf.Input{ - table_handle, keys, values, + x, y, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MultiDeviceIteratorFromStringHandleAttr is an optional argument to MultiDeviceIteratorFromStringHandle. -type MultiDeviceIteratorFromStringHandleAttr func(optionalAttr) - -// MultiDeviceIteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. +// Returns element-wise remainder of division. This emulates C semantics in that // -// value: The type list for the return values. -// If not specified, defaults to <> +// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +// y + truncate_mod(x, y) = x`. // -// REQUIRES: len(value) >= 0 -func MultiDeviceIteratorFromStringHandleOutputTypes(value []tf.DataType) MultiDeviceIteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_types"] = value +// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TruncateMod", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MultiDeviceIteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. +// Computes offsets of concat inputs within its output. // -// value: The list of shapes being produced. -// If not specified, defaults to <> +// For example: // -// REQUIRES: len(value) >= 0 -func MultiDeviceIteratorFromStringHandleOutputShapes(value []tf.Shape) MultiDeviceIteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_shapes"] = value - } -} - -// Generates a MultiDeviceIterator resource from its provided string handle. +// ``` +// # 'x' is [2, 2, 7] +// # 'y' is [2, 3, 7] +// # 'z' is [2, 5, 7] +// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] +// ``` +// +// This is typically used by gradient computations for a concat operation. // // Arguments: -// string_handle: String representing the resource. +// concat_dim: The dimension along which to concatenate. +// shape: The `N` int32 vectors representing shape of tensors being concatenated. // -// Returns A MultiDeviceIterator resource. -func MultiDeviceIteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...MultiDeviceIteratorFromStringHandleAttr) (multi_device_iterator tf.Output) { +// Returns The `N` int32 vectors representing the starting offset +// of input tensors within the concatenated output. +func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MultiDeviceIteratorFromStringHandle", + Type: "ConcatOffset", Input: []tf.Input{ - string_handle, + concat_dim, tf.OutputList(shape), }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { + scope.UpdateErr("ConcatOffset", err) + return + } + return offset } -// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. -type MutableHashTableV2Attr func(optionalAttr) +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingRMSPropParametersGradAccumDebug. +type LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) -// MutableHashTableV2Container sets the optional container attribute to value. +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableV2Container(value string) MutableHashTableV2Attr { +// REQUIRES: value >= -1 +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["container"] = value + m["table_id"] = value } } -// MutableHashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["use_node_name_sharing"] = value + m["table_name"] = value } } -// Creates an empty hash table. +// Load RMSProp embedding parameters with debug support. // -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// parameters: Value of parameters used in the RMSProp optimization algorithm. +// ms: Value of ms used in the RMSProp optimization algorithm. +// mom: Value of mom used in the RMSProp optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the RMSProp optimization algorithm. // -// Returns Handle to a table. -func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { +// +// +// Returns the created operation. +func LoadTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableHashTableV2", - + Type: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug", + Input: []tf.Input{ + parameters, ms, mom, gradient_accumulators, + }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DequantizeAttr is an optional argument to Dequantize. -type DequantizeAttr func(optionalAttr) - -// DequantizeMode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func DequantizeMode(value string) DequantizeAttr { - return func(m optionalAttr) { - m["mode"] = value - } + return scope.AddOperation(opspec) } -// Dequantize the 'input' tensor into a float Tensor. +// Compute the lower regularized incomplete Gamma function `P(a, x)`. // -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. +// The lower regularized incomplete Gamma function is defined as: // -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: // -// ``` -// if T == qint8: in[i] += (range(T) + 1)/ 2.0 -// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) // -// *MIN_COMBINED Mode Example* +// where // -// If the input comes from a QuantizedRelu6, the output type is -// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is -// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. -// Dequantize on quint8 will take each value, cast to float, and multiply -// by 6 / 255. -// Note that if quantizedtype is qint8, the operation will additionally add -// each value by 128 prior to casting. +// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) // -// If the mode is 'MIN_FIRST', then this approach is used: +// is the lower incomplete Gamma function. // -// ```c++ -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = range / num_discrete_values -// const double offset_input = static_cast(input) - lowest_quantized; -// result = range_min + ((input - numeric_limits::min()) * range_scale) -// ``` +// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +// Gamma function. +func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). // -// *SCALED mode Example* +// The regularized incomplete beta integral is defined as: // -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. // -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. +// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) // -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` +// where // -// Our input tensor range is then `[-m, m]`. // -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` +// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) // -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` // -// From this we compute our scaling factor, s: -// ```c++ -// s = (2 * m) / (max_fixed - min_fixed) -// ``` +// is the incomplete beta function and \\(B(a, b)\\) is the *complete* +// beta function. +func Betainc(scope *Scope, a tf.Output, b tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Betainc", + Input: []tf.Input{ + a, b, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ShapeAttr is an optional argument to Shape. +type ShapeAttr func(optionalAttr) + +// ShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeOutType(value tf.DataType) ShapeAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Returns the shape of a tensor. // -// Now we can dequantize the elements of our tensor: -// ```c++ -// result = input * s -// ``` +// This operation returns a 1-D integer tensor representing the shape of `input`. // -// Arguments: +// For example: // -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. -func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] +// ``` +func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -27917,9 +31770,9 @@ func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf a(attrs) } opspec := tf.OpSpec{ - Type: "Dequantize", + Type: "Shape", Input: []tf.Input{ - input, min_range, max_range, + input, }, Attrs: attrs, } @@ -27927,155 +31780,131 @@ func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf return op.Output(0) } -// Flips all bits elementwise. +// Computes fingerprints of the input strings. // -// The result will have exactly those bits set, that are not set in `x`. The -// computation is performed on the underlying representation of x. -func Invert(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// input: vector of strings to compute fingerprints on. +// +// Returns a (N,2) shaped matrix where N is the number of elements in the input +// vector. Each row contains the low and high parts of the fingerprint. +func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Invert", + Type: "SdcaFprint", Input: []tf.Input{ - x, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deserialize bucket boundaries and ready flag into current QuantileAccumulator. -// -// An op that deserializes bucket boundaries and are boundaries ready flag into current QuantileAccumulator. +// Computes the power of one value to another. // -// Arguments: -// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. -// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. +// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for +// corresponding elements in `x` and `y`. For example: // -// Returns the created operation. -func BoostedTreesQuantileStreamResourceDeserialize(scope *Scope, quantile_stream_resource_handle tf.Output, bucket_boundaries []tf.Output) (o *tf.Operation) { +// ``` +// # tensor 'x' is [[2, 2]], [3, 3]] +// # tensor 'y' is [[8, 16], [2, 3]] +// tf.pow(x, y) ==> [[256, 65536], [9, 27]] +// ``` +func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesQuantileStreamResourceDeserialize", + Type: "Pow", Input: []tf.Input{ - quantile_stream_resource_handle, tf.OutputList(bucket_boundaries), + x, y, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Inverse 3D fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform over the -// inner-most 3 dimensions of `input`. +// QuantizedReluXAttr is an optional argument to QuantizedReluX. +type QuantizedReluXAttr func(optionalAttr) + +// QuantizedReluXOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedReluXOutType(value tf.DataType) QuantizedReluXAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` // // Arguments: -// input: A complex64 tensor. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their inverse 3D Fourier transform. // -// @compatibility(numpy) -// Equivalent to np.fft.ifftn with 3 dimensions. -// @end_compatibility -func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluXAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IFFT3D", + Type: "QuantizedReluX", Input: []tf.Input{ - input, + features, max_value, min_features, max_features, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Deprecated. Disallowed in GraphDef version >= 2. +// Returns the truth value of (x < y) element-wise. // -// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead -func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { +// *NOTE*: `Less` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AdjustContrast", + Type: "Less", Input: []tf.Input{ - images, contrast_factor, min_value, max_value, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Table initializer that takes two tensors for keys and values respectively. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// keys: Keys of type Tkey. -// values: Values of type Tval. -// -// Returns the created operation. -func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InitializeTableV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// PrintAttr is an optional argument to Print. -type PrintAttr func(optionalAttr) - -// PrintMessage sets the optional message attribute to value. -// -// value: A string, prefix of the error message. -// If not specified, defaults to "" -func PrintMessage(value string) PrintAttr { - return func(m optionalAttr) { - m["message"] = value - } -} +// RandomPoissonAttr is an optional argument to RandomPoisson. +type RandomPoissonAttr func(optionalAttr) -// PrintFirstN sets the optional first_n attribute to value. -// -// value: Only log `first_n` number of times. -1 disables logging. -// If not specified, defaults to -1 -func PrintFirstN(value int64) PrintAttr { +// RandomPoissonSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed(value int64) RandomPoissonAttr { return func(m optionalAttr) { - m["first_n"] = value + m["seed"] = value } } -// PrintSummarize sets the optional summarize attribute to value. -// -// value: Only print this many entries of each tensor. -// If not specified, defaults to 3 -func PrintSummarize(value int64) PrintAttr { +// RandomPoissonSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed2(value int64) RandomPoissonAttr { return func(m optionalAttr) { - m["summarize"] = value + m["seed2"] = value } } -// Prints a list of tensors. -// -// Passes `input` through to `output` and prints `data` when evaluating. -// -// Arguments: -// input: The tensor passed to `output` -// data: A list of tensors to print out when op is evaluated. +// Use RandomPoissonV2 instead. // -// Returns = The unmodified `input` tensor -func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAttr) (output tf.Output) { +// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 +func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -28084,9 +31913,9 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt a(attrs) } opspec := tf.OpSpec{ - Type: "Print", + Type: "RandomPoisson", Input: []tf.Input{ - input, tf.OutputList(data), + shape, rate, }, Attrs: attrs, } @@ -28094,93 +31923,68 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt return op.Output(0) } -// Outputs a `Summary` protocol buffer with a tensor and per-plugin data. +// Gets the next output from the given iterator. // -// Arguments: -// tag: A string attached to this summary. Used for organization in TensorBoard. -// tensor: A tensor to serialize. -// serialized_summary_metadata: A serialized SummaryMetadata proto. Contains plugin -// data. -func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_summary_metadata tf.Output) (summary tf.Output) { +// This operation is a synchronous version IteratorGetNext. It should only be used +// in situations where the iterator does not block the calling thread, or where +// the calling thread is not a member of the thread pool used to execute parallel +// operations (e.g. in eager mode). +func IteratorGetNextSync(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorSummaryV2", + Type: "IteratorGetNextSync", Input: []tf.Input{ - tag, tensor, serialized_summary_metadata, + iterator, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNextSync", err) + return + } + return components } -// Creates a dataset that asynchronously prefetches elements from `input_dataset`. -// -// Arguments: -// -// buffer_size: The maximum number of elements to buffer in an iterator over -// this dataset. -// +// Returns the truth value of (x >= y) element-wise. // -func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func GreaterEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "PrefetchDataset", + Type: "GreaterEqual", Input: []tf.Input{ - input_dataset, buffer_size, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorSummaryAttr is an optional argument to TensorSummary. -type TensorSummaryAttr func(optionalAttr) - -// TensorSummaryDescription sets the optional description attribute to value. -// -// value: A json-encoded SummaryDescription proto. -// If not specified, defaults to "" -func TensorSummaryDescription(value string) TensorSummaryAttr { - return func(m optionalAttr) { - m["description"] = value - } -} - -// TensorSummaryLabels sets the optional labels attribute to value. -// -// value: An unused list of strings. -// If not specified, defaults to <> -func TensorSummaryLabels(value []string) TensorSummaryAttr { - return func(m optionalAttr) { - m["labels"] = value - } -} +// ApproximateEqualAttr is an optional argument to ApproximateEqual. +type ApproximateEqualAttr func(optionalAttr) -// TensorSummaryDisplayName sets the optional display_name attribute to value. -// -// value: An unused string. -// If not specified, defaults to "" -func TensorSummaryDisplayName(value string) TensorSummaryAttr { +// ApproximateEqualTolerance sets the optional tolerance attribute to value. +// If not specified, defaults to 1e-05 +func ApproximateEqualTolerance(value float32) ApproximateEqualAttr { return func(m optionalAttr) { - m["display_name"] = value + m["tolerance"] = value } } -// Outputs a `Summary` protocol buffer with a tensor. -// -// This op is being phased out in favor of TensorSummaryV2, which lets callers pass -// a tag as well as a serialized SummaryMetadata proto string that contains -// plugin-specific data. We will keep this op to maintain backwards compatibility. -// -// Arguments: -// tensor: A tensor to serialize. -func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr) (summary tf.Output) { +// Returns the truth value of abs(x-y) < tolerance element-wise. +func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...ApproximateEqualAttr) (z tf.Output) { if scope.Err() != nil { return } @@ -28189,9 +31993,9 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr a(attrs) } opspec := tf.OpSpec{ - Type: "TensorSummary", + Type: "ApproximateEqual", Input: []tf.Input{ - tensor, + x, y, }, Attrs: attrs, } @@ -28199,229 +32003,263 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr return op.Output(0) } -// Read an element from the TensorArray into output `value`. -// -// Arguments: -// handle: The handle to a TensorArray. -// -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// Returns the truth value of x OR y element-wise. // -// Returns The tensor that is read from the TensorArray. -func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { +// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LogicalOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "TensorArrayReadV3", + Type: "LogicalOr", Input: []tf.Input{ - handle, index, flow_in, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] = max(ref[indices, ...], updates[...]) -// -// # Vector indices (for each i) -// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. +// MatMulAttr is an optional argument to MatMul. +type MatMulAttr func(optionalAttr) + +// MatMulTransposeA sets the optional transpose_a attribute to value. // -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// value: If true, "a" is transposed before multiplication. +// If not specified, defaults to false +func MatMulTransposeA(value bool) MatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// MatMulTransposeB sets the optional transpose_b attribute to value. // -//
-// -//
+// value: If true, "b" is transposed before multiplication. +// If not specified, defaults to false +func MatMulTransposeB(value bool) MatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// Multiply the matrix "a" by the matrix "b". // -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// The inputs must be two-dimensional matrices and the inner dimension of +// "a" (after being transposed if transpose_a is true) must match the +// outer dimension of "b" (after being transposed if transposed_b is +// true). // -// Returns the created operation. -func ResourceScatterMax(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// *Note*: The default kernel implementation for MatMul on GPUs uses +// cublas. +func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (product tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceScatterMax", + Type: "MatMul", Input: []tf.Input{ - resource, indices, updates, + a, b, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the gradient for the tanh of `x` wrt its input. +// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. +type InitializeTableFromTextFileV2Attr func(optionalAttr) + +// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. // -// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` -// is the corresponding input gradient. -func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return +// value: Number of elements of the file, use -1 if unknown. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["vocab_size"] = value } - opspec := tf.OpSpec{ - Type: "TanhGrad", - Input: []tf.Input{ - y, dy, - }, +} + +// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. +// +// value: Delimiter to separate fields in a line. +// If not specified, defaults to "\t" +func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["delimiter"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Outputs a `Summary` protocol buffer with scalar values. +// Initializes a table from a text file. // -// The input `tags` and `values` must have the same shape. The generated summary -// has a summary value for each tag-value pair in `tags` and `values`. +// It inserts one key-value pair into the table for each line of the file. +// The key and value is extracted from the whole line content, elements from the +// split line based on `delimiter` or the line number (starting from zero). +// Where to extract the key and value from a line is specified by `key_index` and +// `value_index`. +// +// - A value of -1 means use the line number(starting from zero), expects `int64`. +// - A value of -2 means use the whole line content, expects `string`. +// - A value >= 0 means use the index (starting at zero) of the split line based +// on `delimiter`. // // Arguments: -// tags: Tags for the summary. -// values: Same shape as `tags. Values for the summary. +// table_handle: Handle to a table which will be initialized. +// filename: Filename of a vocabulary text file. +// key_index: Column index in a line to get the table `key` values from. +// value_index: Column index that represents information of a line to get the table +// `value` values from. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { +// Returns the created operation. +func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ScalarSummary", + Type: "InitializeTableFromTextFileV2", Input: []tf.Input{ - tags, values, + table_handle, filename, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Outputs a `Summary` protocol buffer with a histogram. +// MeanAttr is an optional argument to Mean. +type MeanAttr func(optionalAttr) + +// MeanKeepDims sets the optional keep_dims attribute to value. // -// The generated -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// has one summary value containing a histogram for `values`. +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MeanKeepDims(value bool) MeanAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the mean of elements across dimensions of a tensor. // -// This op reports an `InvalidArgument` error if any value is not finite. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// tag: Scalar. Tag to use for the `Summary.Value`. -// values: Any shape. Values to use to build the histogram. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { +// Returns The reduced tensor. +func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "HistogramSummary", + Type: "Mean", Input: []tf.Input{ - tag, values, + input, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the number of elements in the given queue. +// ProdAttr is an optional argument to Prod. +type ProdAttr func(optionalAttr) + +// ProdKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func ProdKeepDims(value bool) ProdAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the product of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// handle: The handle to a queue. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns The number of elements in the given queue. -func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { +// Returns The reduced tensor. +func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QueueSizeV2", + Type: "Prod", Input: []tf.Input{ - handle, + input, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ImageSummaryAttr is an optional argument to ImageSummary. -type ImageSummaryAttr func(optionalAttr) +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) -// ImageSummaryMaxImages sets the optional max_images attribute to value. -// -// value: Max number of batch elements to generate images for. -// If not specified, defaults to 3 +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. // -// REQUIRES: value >= 1 -func ImageSummaryMaxImages(value int64) ImageSummaryAttr { +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { return func(m optionalAttr) { - m["max_images"] = value + m["align_corners"] = value } } -// ImageSummaryBadColor sets the optional bad_color attribute to value. -// -// value: Color to use for pixels with non-finite values. -// If not specified, defaults to > int_val:255 int_val:0 int_val:0 int_val:255 > -func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { +// ResizeBilinearHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBilinearHalfPixelCenters(value bool) ResizeBilinearAttr { return func(m optionalAttr) { - m["bad_color"] = value + m["half_pixel_centers"] = value } } -// Outputs a `Summary` protocol buffer with images. -// -// The summary has up to `max_images` summary values containing images. The -// images are built from `tensor` which must be 4-D with shape `[batch_size, -// height, width, channels]` and where `channels` can be: -// -// * 1: `tensor` is interpreted as Grayscale. -// * 3: `tensor` is interpreted as RGB. -// * 4: `tensor` is interpreted as RGBA. -// -// The images have the same number of channels as the input tensor. For float -// input, the values are normalized one image at a time to fit in the range -// `[0, 255]`. `uint8` values are unchanged. The op uses two different -// normalization algorithms: -// -// * If the input values are all positive, they are rescaled so the largest one -// is 255. -// -// * If any input value is negative, the values are shifted so input value 0.0 -// is at 127. They are then rescaled so that either the smallest value is 0, -// or the largest one is 255. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: -// -// * If `max_images` is 1, the summary value tag is '*tag*/image'. -// * If `max_images` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. +// Resize `images` to `size` using bilinear interpolation. // -// The `bad_color` argument is the color to use in the generated images for -// non-finite input values. It is a `uint8` 1-D tensor of length `channels`. -// Each element must be in the range `[0, 255]` (It represents the value of a -// pixel in the output image). Non-finite values in the input tensor are -// replaced by this tensor in the output image. The default value is the color -// red. +// Input images can be of different types but output images are always float. // // Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 4-D of shape `[batch_size, height, width, channels]` where -// `channels` is 1, 3, or 4. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...ImageSummaryAttr) (summary tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { if scope.Err() != nil { return } @@ -28430,9 +32268,9 @@ func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...Ima a(attrs) } opspec := tf.OpSpec{ - Type: "ImageSummary", + Type: "ResizeBilinear", Input: []tf.Input{ - tag, tensor, + images, size, }, Attrs: attrs, } @@ -28440,42 +32278,33 @@ func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...Ima return op.Output(0) } -// AudioSummaryV2Attr is an optional argument to AudioSummaryV2. -type AudioSummaryV2Attr func(optionalAttr) +// MaxAttr is an optional argument to Max. +type MaxAttr func(optionalAttr) -// AudioSummaryV2MaxOutputs sets the optional max_outputs attribute to value. -// -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 +// MaxKeepDims sets the optional keep_dims attribute to value. // -// REQUIRES: value >= 1 -func AudioSummaryV2MaxOutputs(value int64) AudioSummaryV2Attr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MaxKeepDims(value bool) MaxAttr { return func(m optionalAttr) { - m["max_outputs"] = value + m["keep_dims"] = value } } -// Outputs a `Summary` protocol buffer with audio. -// -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: +// Computes the maximum of elements across dimensions of a tensor. // -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...AudioSummaryV2Attr) (summary tf.Output) { +// Returns The reduced tensor. +func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -28484,9 +32313,9 @@ func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate t a(attrs) } opspec := tf.OpSpec{ - Type: "AudioSummaryV2", + Type: "Max", Input: []tf.Input{ - tag, tensor, sample_rate, + input, axis, }, Attrs: attrs, } @@ -28494,70 +32323,55 @@ func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate t return op.Output(0) } -// Splits a tensor into a list. -// -// list[i] corresponds to lengths[i] tensors from the input tensor. -// The tensor must have rank at least 1 and contain exactly sum(lengths) elements. -// -// tensor: The input tensor. -// element_shape: A shape compatible with that of elements in the tensor. -// lengths: Vector of sizes of the 0th dimension of tensors in the list. -// output_handle: The list. -func TensorListSplit(scope *Scope, tensor tf.Output, element_shape tf.Output, lengths tf.Output) (output_handle tf.Output) { +// Creates a dataset that contains the unique elements of `input_dataset`. +func ExperimentalUniqueDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorListSplit", + Type: "ExperimentalUniqueDataset", Input: []tf.Input{ - tensor, element_shape, lengths, + input_dataset, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AvgPoolAttr is an optional argument to AvgPool. -type AvgPoolAttr func(optionalAttr) +// ArgMinAttr is an optional argument to ArgMin. +type ArgMinAttr func(optionalAttr) -// AvgPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolDataFormat(value string) AvgPoolAttr { +// ArgMinOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMinOutputType(value tf.DataType) ArgMinAttr { return func(m optionalAttr) { - m["data_format"] = value + m["output_type"] = value } } -// Performs average pooling on the input. +// Returns the index with the smallest value across dimensions of a tensor. // -// Each entry in `output` is the mean of the corresponding size `ksize` -// window in `value`. +// Note that in case of ties the identity of the return value is not guaranteed. // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// ksize: The size of the sliding window for each dimension of `value`. -// strides: The stride of the sliding window for each dimension of `value`. -// padding: The type of padding algorithm to use. // -// Returns The average pooled output tensor. -func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool", + Type: "ArgMin", Input: []tf.Input{ - value, + input, dimension, }, Attrs: attrs, } @@ -28565,574 +32379,634 @@ func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padd return op.Output(0) } -// Merges summaries. +// Converts the quantized `input` tensor into a lower-precision `output`. // -// This op creates a -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// protocol buffer that contains the union of all the values in the input -// summaries. +// Converts the quantized `input` tensor into a lower-precision `output`, using the +// output range specified with `requested_output_min` and `requested_output_max`. // -// When the Op is run, it reports an `InvalidArgument` error if multiple values -// in the summaries to merge use the same tag. +// `[input_min, input_max]` are scalar floats that specify the range for the float +// interpretation of the `input` data. For example, if `input_min` is -1.0f and +// `input_max` is 1.0f, and we are dealing with `quint16` quantized data, then a 0 +// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. // // Arguments: -// inputs: Can be of any shape. Each must contain serialized `Summary` protocol -// buffers. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// requested_output_min: The float value that the minimum quantized output value represents. +// requested_output_max: The float value that the maximum quantized output value represents. +// out_type: The type of the output. Should be a lower bit depth than Tinput. +// +// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output. +func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "MergeSummary", + Type: "Requantize", Input: []tf.Input{ - tf.OutputList(inputs), + input, input_min, input_max, requested_output_min, requested_output_max, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// The shape of the elements of the given list, as a tensor. +// Creates a dataset that emits the lines of one or more text files. // -// input_handle: the list -// element_shape: the shape of elements of the list -func TensorListElementShape(scope *Scope, input_handle tf.Output, shape_type tf.DataType) (element_shape tf.Output) { +// Arguments: +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar containing the number of bytes to buffer. +func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape_type": shape_type} opspec := tf.OpSpec{ - Type: "TensorListElementShape", + Type: "TextLineDataset", Input: []tf.Input{ - input_handle, + filenames, compression_type, buffer_size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the item in the list with the given index. +// Computes the sum along segments of a tensor. // -// input_handle: the list -// index: the position in the list from which an element will be retrieved -// item: the element at that position +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output_i = \sum_j data_j\\) where sum is over `j` such +// that `segment_ids[j] == i`. // +// If the sum is empty for a given segment ID `i`, `output[i] = 0`. // -func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_dtype tf.DataType) (item tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListGetItem", - Input: []tf.Input{ - input_handle, index, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Resizes the list. +//
+// +//
// +// For example: // -// input_handle: the input list -// size: size of the output list +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_sum(c, tf.constant([0, 0, 1])) +// # ==> [[5, 5, 5, 5], +// # [5, 6, 7, 8]] +// ``` // -func TensorListResize(scope *Scope, input_handle tf.Output, size tf.Output) (output_handle tf.Output) { +// +// Arguments: +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListResize", + Type: "SegmentSum", Input: []tf.Input{ - input_handle, size, + data, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a diagonal tensor with a given diagonal values. +// Computes the mean along segments of a tensor. // -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // -// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of -// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: +// Computes a tensor such that +// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +// over `j` such that `segment_ids[j] == i` and `N` is the total number of +// values summed. // -// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. +// If the mean is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
// // For example: // // ``` -// # 'diagonal' is [1, 2, 3, 4] -// tf.diag(diagonal) ==> [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] +// c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_mean(c, tf.constant([0, 0, 1])) +// # ==> [[2.5, 2.5, 2.5, 2.5], +// # [5, 6, 7, 8]] // ``` // +// // Arguments: -// diagonal: Rank k tensor where k is at most 1. -func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Diag", + Type: "SegmentMean", Input: []tf.Input{ - diagonal, + data, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. -type ParameterizedTruncatedNormalAttr func(optionalAttr) - -// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. +// Computes the minimum along segments of a tensor. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a normal distribution. The parameters may each be a +// Computes a tensor such that +// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such +// that `segment_ids[j] == i`. // -// scalar which applies to the entire output, or a vector of length shape[0] which -// stores the parameters for each batch. +// If the min is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_min(c, tf.constant([0, 0, 1])) +// # ==> [[1, 2, 2, 1], +// # [5, 6, 7, 8]] +// ``` // // Arguments: -// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. -// means: The mean parameter of each batch. -// stdevs: The standard deviation parameter of each batch. Must be greater than 0. -// minvals: The minimum cutoff. May be -infinity. -// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval -// for each batch. // -// Returns A matrix of shape num_batches x samples_per_batch, filled with random -// truncated normal values using the parameters for each row. -func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ParameterizedTruncatedNormal", + Type: "SegmentMin", Input: []tf.Input{ - shape, means, stdevs, minvals, maxvals, + data, segment_ids, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Sets the index-th position of the list to contain the given tensor. +// Computes the sum along segments of a tensor. // -// input_handle: the list -// index: the position in the list to which the tensor will be assigned -// item: the element to be assigned to that position -// output_handle: the new list, with the element in the proper position +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // -func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) { +// Computes a tensor such that +// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such +// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` +// need not be sorted and need not cover all values in the full +// range of valid values. +// +// If the sum is empty for a given segment ID `i`, `output[i] = 0`. +// If the given segment ID `i` is negative, the value is dropped and will not be +// added to the sum of the segment. +// +// `num_segments` should equal the number of distinct segment IDs. +// +//
+// +//
+// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 5, 5, 5, 5], +// # [5, 6, 7, 8]] +// ``` +// +// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListSetItem", + Type: "UnsortedSegmentSum", Input: []tf.Input{ - input_handle, index, item, + data, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a Tensor by indexing into the TensorList. +// Computes the product along segments of a tensor. // -// Each row in the produced Tensor corresponds to the element in the TensorList -// specified by the given index (see `tf.gather`). +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // -// input_handle: The input tensor list. -// indices: The indices used to index into the list. -// values: The tensor. -func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_dtype tf.DataType) (values tf.Output) { +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the product of all +// entries belonging to a segment such that: +// +// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples +// `j...` such that `segment_ids[j...] == i`. +// +// For example: +// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 4, 6, 6, 4], +// # [5, 6, 7, 8]] +// ``` +// +// If there is no entry for a given segment ID `i`, it outputs 1. +// +// If the given segment ID `i` is negative, then the corresponding value is +// dropped, and will not be included in the result. +// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "TensorListGather", + Type: "UnsortedSegmentProd", Input: []tf.Input{ - input_handle, indices, + data, segment_ids, num_segments, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a TensorList by indexing into a Tensor. -// -// Each member of the TensorList corresponds to one row of the input tensor, -// specified by the given index (see `tf.gather`). -// -// tensor: The input tensor. -// indices: The indices used to index into the list. -// element_shape: The shape of the elements in the list (can be less specified than -// the shape of the tensor). -// output_handle: The TensorList. -func TensorListScatter(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output) (output_handle tf.Output) { +// Computes hyperbolic cosine of x element-wise. +func Cosh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListScatter", + Type: "Cosh", Input: []tf.Input{ - tensor, indices, element_shape, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a `RaggedTensor` containing the specified sequences of numbers. -// -// -// Returns a `RaggedTensor` `result` composed from `rt_dense_values` and -// `rt_nested_splits`, such that -// `result[i] = range(starts[i], limits[i], deltas[i])`. +// Computes the mean along sparse segments of a tensor. // -// ```python -// >>> (rt_nested_splits, rt_dense_values) = gen_ragged_ops.ragged_range( -// ... starts=[2, 5, 8], limits=[3, 5, 12], deltas=1) -// >>> result = ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) -// >>> print result.eval().tolist() -// [[2], # result[0] = range(2, 3) -// [], # result[1] = range(5, 5) -// [8, 9, 10, 11]] # result[2] = range(8, 12) -// ``` +// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. // -// The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. -// The vector inputs must all have the same size. Scalar inputs are broadcast -// to match the size of the vector inputs. +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // // Arguments: -// starts: The starts of each range. -// limits: The limits of each range. -// deltas: The deltas of each range. // -// Returns The `row_splits` for the returned `RaggedTensor`.The `flat_values` for the returned `RaggedTensor`. -func RaggedRange(scope *Scope, starts tf.Output, limits tf.Output, deltas tf.Output) (rt_nested_splits tf.Output, rt_dense_values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RaggedRange", - Input: []tf.Input{ - starts, limits, deltas, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Deprecated, use python implementation tf.linalg.matrix_exponential. +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. // -// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead. -func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Has same shape as data, except for dimension 0 which has size +// `num_segments`. +func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixExponential", + Type: "SparseSegmentMeanWithNumSegments", Input: []tf.Input{ - input, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. -type QueueDequeueUpToV2Attr func(optionalAttr) +// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize. +type CudnnRNNParamsSizeAttr func(optionalAttr) -// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { +// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["rnn_mode"] = value } } -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// This operation is not supported by all queues. If a queue does not support -// DequeueUpTo, then an Unimplemented error is returned. -// -// If the queue is closed and there are more than 0 but less than `n` -// elements remaining, then instead of returning an OutOfRange error like -// QueueDequeueMany, less than `n` elements are returned immediately. If -// the queue is closed and there are 0 elements left in the queue, then -// an OutOfRange error is returned just like in QueueDequeueMany. -// Otherwise the behavior is identical to QueueDequeueMany: -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size n in the 0th dimension. -// -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. +// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNParamsSizeDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["dropout"] = value + } +} + +// CudnnRNNParamsSizeSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Computes size of weights that can be used by a Cudnn RNN model. // -// Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. +// Return the params size that can be used by the Cudnn RNN model. Subsequent +// weight allocation and initialization should use this size. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { +// num_layers: Specifies the number of layers in the RNN model. +// num_units: Specifies the size of the hidden state. +// input_size: Specifies the size of the input state. +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +// params_size: The size of the params buffer that should be allocated and +// initialized for this RNN model. Note that this params buffer may not be +// compatible across GPUs. Please use CudnnRNNParamsWeights and +// CudnnRNNParamsBiases to save and restore them in a way that is compatible +// across different runs. +func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"T": T, "S": S} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueDequeueUpToV2", + Type: "CudnnRNNParamsSize", Input: []tf.Input{ - handle, n, + num_layers, num_units, input_size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueUpToV2", err) - return - } - return components + return op.Output(0) } -// Computes the Cholesky decomposition of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. -// -// The input has to be symmetric and positive definite. Only the lower-triangular -// part of the input will be used for this operation. The upper-triangular part -// will not be read. -// -// The output is a tensor of the same shape as the input -// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// Computes gradients for SparseSegmentMean. // -// **Note**: The gradient computation on GPU is faster for large matrices but -// not for large batch dimensions when the submatrices are small. In this -// case it might be faster to use the CPU. +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. // // Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. -func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { +// grad: gradient propagated to the SparseSegmentMean op. +// indices: indices passed to the corresponding SparseSegmentMean op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. +func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cholesky", + Type: "SparseSegmentMeanGrad", Input: []tf.Input{ - input, + grad, indices, segment_ids, output_dim0, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Writes contents to the file at input filename. Creates file and recursively +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// +// N is the size of the segment being reduced. +// +// See `tf.sparse.segment_sum` for usage examples. // -// creates directory if not existing. // // Arguments: -// filename: scalar. The name of the file to which we write the contents. -// contents: scalar. The content to be written to the output file. // -// Returns the created operation. -func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "WriteFile", + Type: "SparseSegmentSqrtN", Input: []tf.Input{ - filename, contents, + data, indices, segment_ids, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// AllAttr is an optional argument to All. -type AllAttr func(optionalAttr) - -// AllKeepDims sets the optional keep_dims attribute to value. +// Compute the upper regularized incomplete Gamma function `Q(a, x)`. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AllKeepDims(value bool) AllAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the "logical and" of elements across dimensions of a tensor. +// The upper regularized incomplete Gamma function is defined as: // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) // -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// where // -// Returns The reduced tensor. -func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { +// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) +// +// is the upper incomplete Gama function. +// +// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete +// Gamma function. +func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "All", + Type: "Igammac", Input: []tf.Input{ - input, axis, + a, x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. // -// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. +// N is the size of the segment being reduced. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices, with the same constraints as the single matrix -// SelfAdjointEig. +// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. // -// The result is a [..., M+1, M] matrix with [..., 0,:] containing the -// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues -// are sorted in non-decreasing order. +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // // Arguments: -// input: Shape is `[..., M, M]`. // -// Returns Shape is `[..., M+1, M]`. -func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SelfAdjointEig", + Type: "SparseSegmentSqrtNWithNumSegments", Input: []tf.Input{ - input, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softplus gradients for a softplus operation. +// Computes gradients for SparseSegmentSqrtN. // -// Arguments: -// gradients: The backpropagated gradients to the corresponding softplus operation. -// features: The features passed as input to the corresponding softplus operation. +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. // -// Returns The gradients: `gradients / (1 + exp(-features))`. -func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// Arguments: +// grad: gradient propagated to the SparseSegmentSqrtN op. +// indices: indices passed to the corresponding SparseSegmentSqrtN op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. +func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SoftplusGrad", + Type: "SparseSegmentSqrtNGrad", Input: []tf.Input{ - gradients, features, + grad, indices, segment_ids, output_dim0, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. -type SelfAdjointEigV2Attr func(optionalAttr) +// LRNGradAttr is an optional argument to LRNGrad. +type LRNGradAttr func(optionalAttr) -// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. +// LRNGradDepthRadius sets the optional depth_radius attribute to value. // -// value: If `True` then eigenvectors will be computed and returned in `v`. -// Otherwise, only the eigenvalues will be computed. -// If not specified, defaults to true -func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { +// value: A depth radius. +// If not specified, defaults to 5 +func LRNGradDepthRadius(value int64) LRNGradAttr { return func(m optionalAttr) { - m["compute_v"] = value + m["depth_radius"] = value } } -// Computes the eigen decomposition of one or more square self-adjoint matrices. +// LRNGradBias sets the optional bias attribute to value. // -// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in -// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues -// are sorted in non-decreasing order. +// value: An offset (usually > 0 to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNGradBias(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNGradAlpha sets the optional alpha attribute to value. // -// ```python -// # a is a tensor. -// # e is a tensor of eigenvalues. -// # v is a tensor of eigenvectors. -// e, v = self_adjoint_eig(a) -// e = self_adjoint_eig(a, compute_v=False) -// ``` +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNGradAlpha(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNGradBeta sets the optional beta attribute to value. +// +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNGradBeta(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Gradients for Local Response Normalization. // // Arguments: -// input: `Tensor` input of shape `[N, N]`. +// input_grads: 4-D with shape `[batch, height, width, channels]`. +// input_image: 4-D with shape `[batch, height, width, channels]`. +// output_image: 4-D with shape `[batch, height, width, channels]`. // -// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. -func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { +// Returns The gradients for LRN. +func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -29141,87 +33015,85 @@ func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV a(attrs) } opspec := tf.OpSpec{ - Type: "SelfAdjointEigV2", + Type: "LRNGrad", Input: []tf.Input{ - input, + input_grads, input_image, output_image, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Adjust the saturation of one or more images. +// AnyAttr is an optional argument to Any. +type AnyAttr func(optionalAttr) + +// AnyKeepDims sets the optional keep_dims attribute to value. // -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AnyKeepDims(value bool) AnyAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical or" of elements across dimensions of a tensor. // -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A scale is then applied all the saturation -// values, and then remapped back to RGB colorspace. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// images: Images to adjust. At least 3-D. -// scale: A float scale to add to the saturation. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns The hue-adjusted image or images. -func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { +// Returns The reduced tensor. +func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "AdjustSaturation", + Type: "Any", Input: []tf.Input{ - images, scale, + input, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyKerasMomentumAttr is an optional argument to ResourceApplyKerasMomentum. -type ResourceApplyKerasMomentumAttr func(optionalAttr) - -// ResourceApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyKerasMomentumUseLocking(value bool) ResourceApplyKerasMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) -// ResourceApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. // -// value: If `True`, the tensor passed to compute grad will be -// var + momentum * accum, so in the end, the var you get is actually -// var + momentum * accum. -// If not specified, defaults to false -func ResourceApplyKerasMomentumUseNesterov(value bool) ResourceApplyKerasMomentumAttr { +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["ignore_lookup_error"] = value } } -// Update '*var' according to the momentum scheme. Set use_nesterov = True if you -// -// want to use Nesterov momentum. +// Deletes the resource specified by the handle. // -// accum = accum * momentum - lr * grad -// var += accum +// All subsequent operations using the resource will result in a NotFound +// error status. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// momentum: Momentum. Must be a scalar. +// resource: handle to the resource to delete. // // Returns the created operation. -func ResourceApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyKerasMomentumAttr) (o *tf.Operation) { +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -29230,130 +33102,118 @@ func ResourceApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, l a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyKerasMomentum", + Type: "DestroyResourceOp", Input: []tf.Input{ - var_, accum, lr, grad, momentum, + resource, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// MatrixSolveAttr is an optional argument to MatrixSolve. -type MatrixSolveAttr func(optionalAttr) - -// MatrixSolveAdjoint sets the optional adjoint attribute to value. +// Generates values in an interval. // -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. -// If not specified, defaults to false -func MatrixSolveAdjoint(value bool) MatrixSolveAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Solves systems of linear equations. +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. // -// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is -// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix -// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `True` then each output matrix satisfies -// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. +// For example: +// +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +// ``` // // Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. +// start: 0-D tensor. First entry in the range. +// stop: 0-D tensor. Last entry in the range. +// num: 0-D tensor. Number of values to generate. // -// Returns Shape is `[..., M, K]`. -func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MatrixSolve", + Type: "LinSpace", Input: []tf.Input{ - matrix, rhs, + start, stop, num, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a serialized GraphDef representing `input_dataset`. +// ComplexAttr is an optional argument to Complex. +type ComplexAttr func(optionalAttr) + +// ComplexTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_COMPLEX64 +func ComplexTout(value tf.DataType) ComplexAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Converts two real numbers to a complex number. // -// Returns a graph representation for `input_dataset`. +// Given a tensor `real` representing the real part of a complex number, and a +// tensor `imag` representing the imaginary part of a complex number, this +// operation returns complex numbers elementwise of the form \\(a + bj\\), where +// *a* represents the `real` part and *b* represents the `imag` part. // -// Arguments: -// input_dataset: A variant tensor representing the dataset to return the graph representation for. +// The input tensors `real` and `imag` must have the same shape. // -// Returns The graph representation of the dataset (as serialized GraphDef). -func DatasetToGraph(scope *Scope, input_dataset tf.Output) (graph tf.Output) { +// For example: +// +// ``` +// # tensor 'real' is [2.25, 3.25] +// # tensor `imag` is [4.75, 5.75] +// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] +// ``` +func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DatasetToGraph", + Type: "Complex", Input: []tf.Input{ - input_dataset, + real, imag, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// LuAttr is an optional argument to Lu. -type LuAttr func(optionalAttr) +// ImagAttr is an optional argument to Imag. +type ImagAttr func(optionalAttr) -// LuOutputIdxType sets the optional output_idx_type attribute to value. -// If not specified, defaults to DT_INT32 -func LuOutputIdxType(value tf.DataType) LuAttr { +// ImagTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ImagTout(value tf.DataType) ImagAttr { return func(m optionalAttr) { - m["output_idx_type"] = value + m["Tout"] = value } } -// Computes the LU decomposition of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. -// -// The input has to be invertible. -// -// The output consists of two tensors LU and P containing the LU decomposition -// of all input submatrices `[..., :, :]`. LU encodes the lower triangular and -// upper triangular factors. -// -// For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of -// shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower -// triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose -// entries correspond to the upper triangular part, including the diagonal, of LU. +// Returns the imaginary part of a complex number. // -// P represents a permutation matrix encoded as a list of indices each between `0` -// and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to -// P, then the L, U and P satisfies P_mat * input = L * U. +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the imaginary part of each element in `input`. All +// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part returned by this operation. // -// Arguments: -// input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of -// size `[M, M]`. +// For example: // -// Returns A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the -// lower triangular factor `L` with unit diagonal, and whose upper triangular part -// denotes the upper triangular factor `U`.Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is -// `[..., M]`. -// @compatibility(scipy) -// Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are -// packed into a single tensor, the permutation is applied to `input` instead of -// the right hand side and the permutation `P` is returned as a list of indices -// instead of a permutation matrix. -// @end_compatibility -func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.Output) { +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.imag(input) ==> [4.75, 5.75] +// ``` +func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -29362,198 +33222,179 @@ func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.O a(attrs) } opspec := tf.OpSpec{ - Type: "Lu", + Type: "Imag", Input: []tf.Input{ input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Computes the matrix square root of one or more square matrices: +// Computes hyperbolic tangent of `x` element-wise. +func Tanh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tanh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the maximum along segments of a tensor. // -// matmul(sqrtm(A), sqrtm(A)) = A +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // -// The input matrix should be invertible. If the input matrix is real, it should -// have no eigenvalues which are real and negative (pairs of complex conjugate -// eigenvalues are allowed). +// Computes a tensor such that +// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such +// that `segment_ids[j] == i`. // -// The matrix square root is computed by first reducing the matrix to -// quasi-triangular form with the real Schur decomposition. The square root -// of the quasi-triangular matrix is then computed directly. Details of -// the algorithm can be found in: Nicholas J. Higham, "Computing real -// square roots of a real matrix", Linear Algebra Appl., 1987. +// If the max is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_max(c, tf.constant([0, 0, 1])) +// # ==> [[4, 3, 3, 4], +// # [5, 6, 7, 8]] +// ``` // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the matrix square root for all input submatrices `[..., :, :]`. // // Arguments: -// input: Shape is `[..., M, M]`. // -// Returns Shape is `[..., M, M]`. +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. // -// @compatibility(scipy) -// Equivalent to scipy.linalg.sqrtm -// @end_compatibility -func MatrixSquareRoot(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixSquareRoot", + Type: "SegmentMax", Input: []tf.Input{ - input, + data, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SvdAttr is an optional argument to Svd. -type SvdAttr func(optionalAttr) - -// SvdComputeUv sets the optional compute_uv attribute to value. -// -// value: If true, left and right singular vectors will be -// computed and returned in `u` and `v`, respectively. -// If false, `u` and `v` are not set and should never referenced. -// If not specified, defaults to true -func SvdComputeUv(value bool) SvdAttr { - return func(m optionalAttr) { - m["compute_uv"] = value - } -} - -// SvdFullMatrices sets the optional full_matrices attribute to value. -// -// value: If true, compute full-sized `u` and `v`. If false -// (the default), compute only the leading `P` singular vectors. -// Ignored if `compute_uv` is `False`. -// If not specified, defaults to false -func SvdFullMatrices(value bool) SvdAttr { - return func(m optionalAttr) { - m["full_matrices"] = value - } -} - -// Computes the singular value decompositions of one or more matrices. +// Creates a dataset that skips `count` elements from the `input_dataset`. // -// Computes the SVD of each inner matrix in `input` such that -// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` +// Arguments: // -// ```python -// # a is a tensor containing a batch of matrices. -// # s is a tensor of singular values for each matrix. -// # u is the tensor containing of left singular vectors for each matrix. -// # v is the tensor containing of right singular vectors for each matrix. -// s, u, v = svd(a) -// s, _, _ = svd(a, compute_uv=False) -// ``` +// count: A scalar representing the number of elements from the `input_dataset` +// that should be skipped. If count is -1, skips everything. // -// Arguments: -// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. // -// Returns Singular values. Shape is `[..., P]`.Left singular vectors. If `full_matrices` is `False` then shape is -// `[..., M, P]`; if `full_matrices` is `True` then shape is -// `[..., M, M]`. Undefined if `compute_uv` is `False`.Left singular vectors. If `full_matrices` is `False` then shape is -// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`. -// Undefined if `compute_uv` is false. -func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) { +func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Svd", + Type: "SkipDataset", Input: []tf.Input{ - input, + input_dataset, count, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// PrintV2Attr is an optional argument to PrintV2. -type PrintV2Attr func(optionalAttr) +// VarHandleOpAttr is an optional argument to VarHandleOp. +type VarHandleOpAttr func(optionalAttr) -// PrintV2OutputStream sets the optional output_stream attribute to value. +// VarHandleOpContainer sets the optional container attribute to value. // -// value: A string specifying the output stream or logging level to print to. -// If not specified, defaults to "stderr" -func PrintV2OutputStream(value string) PrintV2Attr { +// value: the container this variable is placed in. +// If not specified, defaults to "" +func VarHandleOpContainer(value string) VarHandleOpAttr { return func(m optionalAttr) { - m["output_stream"] = value + m["container"] = value } } -// Prints a string scalar. +// VarHandleOpSharedName sets the optional shared_name attribute to value. // -// Prints a string scalar to the desired output_stream. +// value: the name by which this variable is referred to. +// If not specified, defaults to "" +func VarHandleOpSharedName(value string) VarHandleOpAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a handle to a Variable resource. // // Arguments: -// input: The string scalar to print. -// -// Returns the created operation. -func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) { +// dtype: the type of this variable. Must agree with the dtypes +// of all ops using this variable. +// shape: The (possibly partially specified) shape of this variable. +func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "PrintV2", - Input: []tf.Input{ - input, - }, + Type: "VarHandleOp", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. -type QueueEnqueueManyV2Attr func(optionalAttr) +// AngleAttr is an optional argument to Angle. +type AngleAttr func(optionalAttr) -// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue is too full, this operation will block for up -// to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { +// AngleTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func AngleTout(value tf.DataType) AngleAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["Tout"] = value } } -// Enqueues zero or more tuples of one or more tensors in the given queue. +// Returns the argument of a complex number. // -// This operation slices each component tensor along the 0th dimension to -// make multiple queue elements. All of the tuple components must have the -// same size in the 0th dimension. +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the argument of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part. // -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. +// The argument returned by this operation is of the form \\(atan2(b, a)\\). // -// N.B. If the queue is full, this operation will block until the given -// elements have been enqueued (or 'timeout_ms' elapses, if specified). +// For example: // -// Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should -// be taken. +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.angle(input) ==> [2.0132, 1.056] +// ``` // -// Returns the created operation. -func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { +// @compatibility(numpy) +// Equivalent to np.angle. +// @end_compatibility +func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -29562,124 +33403,139 @@ func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "QueueEnqueueManyV2", + Type: "Angle", Input: []tf.Input{ - handle, tf.OutputList(components), + input, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the product along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \prod_j data_j\\) where the product is over `j` such -// that `segment_ids[j] == i`. -// -// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// Clips tensor values to a specified min and max. // -//
-// -//
+// Given a tensor `t`, this operation returns a tensor of the same type and +// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. +// Any values less than `clip_value_min` are set to `clip_value_min`. Any values +// greater than `clip_value_max` are set to `clip_value_max`. // // Arguments: +// t: A `Tensor`. +// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The minimum value to clip by. +// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The maximum value to clip by. // -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns A clipped `Tensor` with the same shape as input 't'. +func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentProd", + Type: "ClipByValue", Input: []tf.Input{ - data, segment_ids, + t, clip_value_min, clip_value_max, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts one or more images from RGB to HSV. +// Counts the number of occurrences of each value in an integer array. // -// Outputs a tensor of the same shape as the `images` tensor, containing the HSV -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// Outputs a vector with length `size` and the same dtype as `weights`. If +// `weights` are empty, then index `i` stores the number of times the value `i` is +// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of +// the value in `weights` at each index where the corresponding value in `arr` is +// `i`. // -// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and -// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 -// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// Values in `arr` outside of the range [0, size) are ignored. // // Arguments: -// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// arr: int32 `Tensor`. +// size: non-negative int32 scalar `Tensor`. +// weights: is an int32, int64, float32, or float64 `Tensor` with the same +// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights +// equal to 1. // -// Returns `images` converted to HSV. -func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { +// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for +// each value in the range [0, size). +func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RGBToHSV", + Type: "Bincount", Input: []tf.Input{ - images, + arr, size, weights, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Does nothing. Only useful as a placeholder for control edges. +// CumsumAttr is an optional argument to Cumsum. +type CumsumAttr func(optionalAttr) + +// CumsumExclusive sets the optional exclusive attribute to value. // -// Returns the created operation. -func NoOp(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NoOp", +// value: If `True`, perform exclusive cumsum. +// If not specified, defaults to false +func CumsumExclusive(value bool) CumsumAttr { + return func(m optionalAttr) { + m["exclusive"] = value } - return scope.AddOperation(opspec) } -// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. -type MergeV2CheckpointsAttr func(optionalAttr) - -// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. +// CumsumReverse sets the optional reverse attribute to value. // -// value: see above. -// If not specified, defaults to true -func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumsumReverse(value bool) CumsumAttr { return func(m optionalAttr) { - m["delete_old_dirs"] = value + m["reverse"] = value } } -// V2 format specific: merges the metadata files of sharded checkpoints. The +// Compute the cumulative sum of the tensor `x` along `axis`. // -// result is one logical checkpoint, with one physical metadata file and renamed -// data files. +// By default, this op performs an inclusive cumsum, which means that the first +// element of the input is identical to the first element of the output: // -// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. +// ```python +// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] +// ``` // -// If delete_old_dirs is true, attempts to delete recursively the dirname of each -// path in the input checkpoint_prefixes. This is useful when those paths are non -// user-facing temporary locations. +// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +// performed instead: // -// Arguments: -// checkpoint_prefixes: prefixes of V2 checkpoints to merge. -// destination_prefix: scalar. The desired final prefix. Allowed to be the same -// as one of the checkpoint_prefixes. +// ```python +// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +// ``` // -// Returns the created operation. -func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { +// By setting the `reverse` kwarg to `True`, the cumsum is performed in the +// opposite direction: +// +// ```python +// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { if scope.Err() != nil { return } @@ -29688,524 +33544,471 @@ func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination a(attrs) } opspec := tf.OpSpec{ - Type: "MergeV2Checkpoints", + Type: "Cumsum", Input: []tf.Input{ - checkpoint_prefixes, destination_prefix, + x, axis, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Saves input tensors slices to disk. -// -// This is like `Save` except that tensors can be listed in the saved file as being -// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the -// larger tensor and the slice that this tensor covers. `shapes_and_slices` must -// have as many elements as `tensor_names`. -// -// Elements of the `shapes_and_slices` input must either be: -// -// * The empty string, in which case the corresponding tensor is -// saved normally. -// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the -// `dimI` are the dimensions of the larger tensor and `slice-spec` -// specifies what part is covered by the tensor to save. -// -// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` -// where each `sliceI` is either: -// -// * The string `-` meaning that the slice covers all indices of this dimension -// * `start,length` where `start` and `length` are integers. In that -// case the slice covers `length` indices starting at `start`. -// -// See also `Save`. -// -// Arguments: -// filename: Must have a single element. The name of the file to which we write the -// tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when -// saving the tensors. -// data: `N` tensors to save. +// Return the shape of s0 op s1 with broadcast. // -// Returns the created operation. -func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { +// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. +func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SaveSlices", + Type: "BroadcastArgs", Input: []tf.Input{ - filename, tensor_names, shapes_and_slices, tf.OutputList(data), + s0, s1, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. -type DenseToDenseSetOperationAttr func(optionalAttr) +// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. +type DataFormatDimMapAttr func(optionalAttr) -// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { +// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. +// +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["src_format"] = value } } -// Applies set operation along last dimension of 2 `Tensor` inputs. +// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. // -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { + return func(m optionalAttr) { + m["dst_format"] = value + } +} + +// Returns the dimension index in the destination data format given the one in // -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// the source data format. // // Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// +// x: A Tensor with each element as a dimension index in source data format. +// Must be in the range [-4, 4). // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// Returns A Tensor with each element as a dimension index in destination data format. +func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DenseToDenseSetOperation", + Type: "DataFormatDimMap", Input: []tf.Input{ - set1, set2, + x, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Generate a sharded filename. The filename is printf formatted as +// CumprodAttr is an optional argument to Cumprod. +type CumprodAttr func(optionalAttr) + +// CumprodExclusive sets the optional exclusive attribute to value. // -// %s-%05d-of-%05d, basename, shard, num_shards. -func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ShardedFilename", - Input: []tf.Input{ - basename, shard, num_shards, - }, +// value: If `True`, perform exclusive cumprod. +// If not specified, defaults to false +func CumprodExclusive(value bool) CumprodAttr { + return func(m optionalAttr) { + m["exclusive"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// BatchToSpace for N-D tensors of type T. -// -// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape -// `block_shape + [batch]`, interleaves these blocks back into the grid defined by -// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as -// the input. The spatial dimensions of this intermediate result are then -// optionally cropped according to `crops` to produce the output. This is the -// reverse of SpaceToBatch. See below for a precise description. -// -// Arguments: -// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, -// where spatial_shape has M dimensions. -// block_shape: 1-D with shape `[M]`, all values must be >= 1. -// crops: 2-D with shape `[M, 2]`, all values must be >= 0. -// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input -// dimension `i + 1`, which corresponds to spatial dimension `i`. It is -// required that -// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. -// -// This operation is equivalent to the following steps: -// -// 1. Reshape `input` to `reshaped` of shape: -// [block_shape[0], ..., block_shape[M-1], -// batch / prod(block_shape), -// input_shape[1], ..., input_shape[N-1]] -// -// 2. Permute dimensions of `reshaped` to produce `permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1], block_shape[0], -// ..., -// input_shape[M], block_shape[M-1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// 3. Reshape `permuted` to produce `reshaped_permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0], -// ..., -// input_shape[M] * block_shape[M-1], -// -// input_shape[M+1], -// ..., -// input_shape[N-1]] -// -// 4. Crop the start and end of dimensions `[1, ..., M]` of -// `reshaped_permuted` according to `crops` to produce the output of shape: -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], -// ..., -// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: +// CumprodReverse sets the optional reverse attribute to value. // -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumprodReverse(value bool) CumprodAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Compute the cumulative product of the tensor `x` along `axis`. // -// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: +// By default, this op performs an inclusive cumprod, which means that the first +// element of the input is identical to the first element of the output: // -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ```python +// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] // ``` // -// The output tensor has shape `[1, 2, 2, 3]` and value: +// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +// performed instead: // -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] +// ```python +// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] // ``` // -// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: +// By setting the `reverse` kwarg to `True`, the cumprod is performed in the +// opposite direction: // +// ```python +// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] // ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// The output tensor has shape `[1, 4, 4, 1]` and value: // -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` +// This is more efficient than using separate `tf.reverse` ops. // -// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [2, 0]]`: +// The `reverse` and `exclusive` kwargs can also be combined: // +// ```python +// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] // ``` -// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], -// [[[0], [2], [4]]], [[[0], [10], [12]]], -// [[[0], [5], [7]]], [[[0], [13], [15]]], -// [[[0], [6], [8]]], [[[0], [14], [16]]]] -// ``` -// -// The output tensor has shape `[2, 2, 4, 1]` and value: // -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]]], -// [[[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` -func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BatchToSpaceND", + Type: "Cumprod", Input: []tf.Input{ - input, block_shape, crops, + x, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnpackAttr is an optional argument to Unpack. -type UnpackAttr func(optionalAttr) +// RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParameters. +type RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) -// UnpackAxis sets the optional axis attribute to value. +// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: Dimension along which to unpack. Negative values wrap around, so the -// valid range is `[-R, R)`. -// If not specified, defaults to 0 -func UnpackAxis(value int64) UnpackAttr { +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { return func(m optionalAttr) { - m["axis"] = value + m["table_id"] = value } } -// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. +// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve SGD embedding parameters. // -// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. -// For example, given a tensor of shape `(A, B, C, D)`; +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. // -// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` -// and each tensor in `output` will have shape `(B, C, D)`. (Note that the -// dimension unpacked along is gone, unlike `split`). +// Returns Parameter parameters updated by the stochastic gradient descent optimization algorithm. +func RetrieveTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr) (parameters tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingStochasticGradientDescentParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedMulAttr is an optional argument to QuantizedMul. +type QuantizedMulAttr func(optionalAttr) + +// QuantizedMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// Returns x * y element-wise, working on quantized buffers. // -// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` -// and each tensor in `output` will have shape `(A, C, D)`. -// Etc. +// Arguments: // -// This is the opposite of `pack`. // -// Arguments: -// value: 1-D or higher, with `axis` dimension size equal to `num`. +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. // +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. // -// Returns The list of tensors unpacked from `value`. -func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { +// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num": num} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Unpack", + Type: "QuantizedMul", Input: []tf.Input{ - value, + x, y, min_x, max_x, min_y, max_y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Unpack", err) - return + return op.Output(0), op.Output(1), op.Output(2) +} + +// QuantizedAddAttr is an optional argument to QuantizedAdd. +type QuantizedAddAttr func(optionalAttr) + +// QuantizedAddToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { + return func(m optionalAttr) { + m["Toutput"] = value } - return output } -// Increments variable pointed to by 'resource' until it reaches 'limit'. +// Returns x + y element-wise, working on quantized buffers. // // Arguments: -// resource: Should be from a scalar `Variable` node. -// limit: If incrementing ref would bring it above limit, instead generates an -// 'OutOfRange' error. // // -// Returns A copy of the input before increment. If nothing else modifies the -// input, the values produced will all be distinct. -func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// +// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"limit": limit, "T": T} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceCountUpTo", + Type: "QuantizedAdd", Input: []tf.Input{ - resource, + x, y, min_x, max_x, min_y, max_y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Delete the stack from its resource container. +// Scatters tensor at indices in an input list. // -// Arguments: -// handle: The handle to a stack. +// Each member of the TensorList corresponds to one row of the input tensor, +// specified by the given index (see `tf.gather`). // -// Returns the created operation. -func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { +// input_handle: The list to scatter into. +// tensor: The input tensor. +// indices: The indices used to index into the list. +// output_handle: The TensorList. +func TensorListScatterIntoExistingList(scope *Scope, input_handle tf.Output, tensor tf.Output, indices tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "StackCloseV2", + Type: "TensorListScatterIntoExistingList", Input: []tf.Input{ - handle, + input_handle, tensor, indices, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Generate a glob pattern matching all sharded file names. -func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { +// Computes a range that covers the actual values present in a quantized tensor. +// +// Given a quantized tensor described by `(input, input_min, input_max)`, outputs a +// range that covers the actual values present in that tensor. This op is typically +// used to produce the `requested_output_min` and `requested_output_max` for +// `Requantize`. +// +// Arguments: +// +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// +// Returns The computed min output.the computed max output. +func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ShardedFilespec", + Type: "RequantizationRange", Input: []tf.Input{ - basename, num_shards, + input, input_min, input_max, }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. -type TextLineReaderV2Attr func(optionalAttr) - -// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. +// Rolls the elements of a tensor along an axis. // -// value: Number of lines to skip from the beginning of every file. -// If not specified, defaults to 0 -func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["skip_header_lines"] = value - } -} - -// TextLineReaderV2Container sets the optional container attribute to value. +// The elements are shifted positively (towards larger indices) by the offset of +// `shift` along the dimension of `axis`. Negative `shift` values will shift +// elements in the opposite direction. Elements that roll passed the last position +// will wrap around to the first and vice versa. Multiple shifts along multiple +// axes may be specified. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TextLineReaderV2Container(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TextLineReaderV2SharedName sets the optional shared_name attribute to value. +// For example: // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the lines of a file delimited by '\n'. +// ``` +// # 't' is [0, 1, 2, 3, 4] +// roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] // -// Returns The handle to reference the Reader. -func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { +// # shifting along multiple dimensions +// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +// roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] +// +// # shifting along the same axis multiple times +// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +// roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +// ``` +// +// Arguments: +// +// shift: Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which +// elements are shifted positively (towards larger indices) along the dimension +// specified by `axis[i]`. Negative shifts will roll the elements in the opposite +// direction. +// axis: Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift +// `shift[i]` should occur. If the same axis is referenced more than once, the +// total shift for that axis will be the sum of all the shifts that belong to that +// axis. +// +// Returns Has the same shape and size as the input. The elements are shifted +// positively (towards larger indices) by the offsets of `shift` along the +// dimensions of `axis`. +func Roll(scope *Scope, input tf.Output, shift tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TextLineReaderV2", - - Attrs: attrs, + Type: "Roll", + Input: []tf.Input{ + input, shift, axis, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. -type LoadAndRemapMatrixAttr func(optionalAttr) - -// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. +// Updates the table to associates keys with values. // -// value: The maximum number of rows to load from the checkpoint at -// once. If less than or equal to 0, the entire matrix will be loaded into -// memory. Setting this arg trades increased disk reads for lower memory usage. -// If not specified, defaults to -1 -func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { - return func(m optionalAttr) { - m["max_rows_in_memory"] = value +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, } + return scope.AddOperation(opspec) } -// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint -// -// at `ckpt_path` and potentially reorders its rows and columns using the -// specified remappings. -// -// Most users should use one of the wrapper initializers (such as -// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this -// function directly. -// -// The remappings are 1-D tensors with the following properties: -// -// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output -// matrix will be initialized from the row corresponding to index -// `row_remapping[i]` in the old `Tensor` from the checkpoint. -// * `col_remapping` must have either 0 entries (indicating that no column -// reordering is needed) or `num_cols` entries. If specified, column `j` of the -// output matrix will be initialized from the column corresponding to index -// `col_remapping[j]` in the old `Tensor` from the checkpoint. -// * A value of -1 in either of the remappings signifies a "missing" entry. In that -// case, values from the `initializing_values` tensor will be used to fill that -// missing row or column. If `row_remapping` has `r` missing entries and -// `col_remapping` has `c` missing entries, then the following condition must be -// true: +// Creates a `Dataset` that includes only 1/`num_shards` of this dataset. // -// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` +// Arguments: // -// The remapping tensors can be generated using the GenerateVocabRemapping op. +// num_shards: An integer representing the number of shards operating in parallel. +// index: An integer representing the current worker index. // -// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], -// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing -// the value from row i, column j of the old tensor in the checkpoint, the output -// matrix will look like the following: // -// [[w(1, 0), w(1, 2), 0.5], -// [w(0, 0), w(0, 2), -0.5], -// [0.25, -0.25, 42]] +func ShardDataset(scope *Scope, input_dataset tf.Output, num_shards tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ShardDataset", + Input: []tf.Input{ + input_dataset, num_shards, index, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that batches and pads `batch_size` elements from the input. // // Arguments: -// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from -// which the old matrix `Tensor` will be loaded. -// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. -// row_remapping: An int `Tensor` of row remappings (generally created by -// `generate_vocab_remapping`). Even if no row remapping is needed, this must -// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted -// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). -// col_remapping: An int `Tensor` of column remappings (generally created by -// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping -// is to be done (e.g. column ordering is the same). -// initializing_values: A float `Tensor` containing values to fill in for cells -// in the output matrix that are not loaded from the checkpoint. Length must be -// exactly the same as the number of missing / new cells. -// num_rows: Number of rows (length of the 1st dimension) in the output matrix. -// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. // -// Returns Output matrix containing existing values loaded from the -// checkpoint, and with any missing values filled in from initializing_values. -func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// drop_remainder: A scalar representing whether the last batch should be dropped in case its size +// is smaller than desired. +// +func PaddedBatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, drop_remainder tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "LoadAndRemapMatrix", + Type: "PaddedBatchDatasetV2", Input: []tf.Input{ - ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), drop_remainder, }, Attrs: attrs, } @@ -30213,43 +34016,74 @@ func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Ou return op.Output(0) } -// TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. -type TFRecordReaderV2Attr func(optionalAttr) +// Returns element-wise smallest integer not less than x. +func Ceil(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Ceil", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// TFRecordReaderV2Container sets the optional container attribute to value. +// Computes the number of elements in the given table. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TFRecordReaderV2Container(value string) TFRecordReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value +// Arguments: +// table_handle: Handle to the table. +// +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableSizeV2", + Input: []tf.Input{ + table_handle, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// TFRecordReaderV2SharedName sets the optional shared_name attribute to value. +// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. +type ResizeBilinearGradAttr func(optionalAttr) + +// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TFRecordReaderV2SharedName(value string) TFRecordReaderV2Attr { +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["align_corners"] = value } } -// TFRecordReaderV2CompressionType sets the optional compression_type attribute to value. -// If not specified, defaults to "" -func TFRecordReaderV2CompressionType(value string) TFRecordReaderV2Attr { +// ResizeBilinearGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBilinearGradHalfPixelCenters(value bool) ResizeBilinearGradAttr { return func(m optionalAttr) { - m["compression_type"] = value + m["half_pixel_centers"] = value } } -// A Reader that outputs the records from a TensorFlow Records file. +// Computes the gradient of bilinear interpolation. // -// Returns The handle to reference the Reader. -func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_handle tf.Output) { +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -30258,38 +34092,74 @@ func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_ha a(attrs) } opspec := tf.OpSpec{ - Type: "TFRecordReaderV2", - + Type: "ResizeBilinearGrad", + Input: []tf.Input{ + grads, original_image, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. -type QuantizeAndDequantizeV3Attr func(optionalAttr) +// Outputs all keys and values in the table. +// +// Arguments: +// table_handle: Handle to the table. +// +// +// +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} -// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { +// MultiDeviceIteratorFromStringHandleAttr is an optional argument to MultiDeviceIteratorFromStringHandle. +type MultiDeviceIteratorFromStringHandleAttr func(optionalAttr) + +// MultiDeviceIteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. +// +// value: The type list for the return values. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func MultiDeviceIteratorFromStringHandleOutputTypes(value []tf.DataType) MultiDeviceIteratorFromStringHandleAttr { return func(m optionalAttr) { - m["signed_input"] = value + m["output_types"] = value } } -// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { +// MultiDeviceIteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. +// +// value: The list of shapes being produced. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func MultiDeviceIteratorFromStringHandleOutputShapes(value []tf.Shape) MultiDeviceIteratorFromStringHandleAttr { return func(m optionalAttr) { - m["range_given"] = value + m["output_shapes"] = value } } -// Quantizes then dequantizes a tensor. +// Generates a MultiDeviceIterator resource from its provided string handle. // -// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a -// tensor, so its value can change during training. -func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { +// Arguments: +// string_handle: String representing the resource. +// +// Returns A MultiDeviceIterator resource. +func MultiDeviceIteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...MultiDeviceIteratorFromStringHandleAttr) (multi_device_iterator tf.Output) { if scope.Err() != nil { return } @@ -30298,9 +34168,9 @@ func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantizeV3", + Type: "MultiDeviceIteratorFromStringHandle", Input: []tf.Input{ - input, input_min, input_max, num_bits, + string_handle, }, Attrs: attrs, } @@ -30308,47 +34178,63 @@ func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, return op.Output(0) } -// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. -type IdentityReaderV2Attr func(optionalAttr) +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) -// IdentityReaderV2Container sets the optional container attribute to value. +// MutableHashTableV2Container sets the optional container attribute to value. // -// value: If non-empty, this reader is placed in the given container. +// value: If non-empty, this table is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func IdentityReaderV2Container(value string) IdentityReaderV2Attr { +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { return func(m optionalAttr) { m["container"] = value } } -// IdentityReaderV2SharedName sets the optional shared_name attribute to value. +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. +// value: If non-empty, this table is shared under the given name across +// multiple sessions. // If not specified, defaults to "" -func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// A Reader that outputs the queued work as both the key and value. +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. // -// To use, enqueue strings in a Queue. ReaderRead will take the front -// work string and output (work, work). +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates an empty hash table. // -// Returns The handle to reference the Reader. -func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "IdentityReaderV2", + Type: "MutableHashTableV2", Attrs: attrs, } @@ -30356,29 +34242,98 @@ func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_ha return op.Output(0) } -// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. -type ResourceApplyGradientDescentAttr func(optionalAttr) +// DequantizeAttr is an optional argument to Dequantize. +type DequantizeAttr func(optionalAttr) -// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { +// DequantizeMode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func DequantizeMode(value string) DequantizeAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["mode"] = value } } -// Update '*var' by subtracting 'alpha' * 'delta' from it. +// Dequantize the 'input' tensor into a float Tensor. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// if T == qint8: in[i] += (range(T) + 1)/ 2.0 +// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// If the input comes from a QuantizedRelu6, the output type is +// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is +// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. +// Dequantize on quint8 will take each value, cast to float, and multiply +// by 6 / 255. +// Note that if quantizedtype is qint8, the operation will additionally add +// each value by 128 prior to casting. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ```c++ +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = range / num_discrete_values +// const double offset_input = static_cast(input) - lowest_quantized; +// result = range_min + ((input - numeric_limits::min()) * range_scale) +// ``` +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// ```c++ +// s = (2 * m) / (max_fixed - min_fixed) +// ``` +// +// Now we can dequantize the elements of our tensor: +// ```c++ +// result = input * s +// ``` // // Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// delta: The change. // -// Returns the created operation. -func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -30387,274 +34342,285 @@ func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyGradientDescent", + Type: "Dequantize", Input: []tf.Input{ - var_, alpha, delta, + input, min_range, max_range, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the next record (key, value pair) produced by a Reader. -// -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// -// Arguments: -// reader_handle: Handle to a Reader. -// queue_handle: Handle to a Queue, with string work items. +// Flips all bits elementwise. // -// Returns A scalar.A scalar. -func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { +// The result will have exactly those bits set, that are not set in `x`. The +// computation is performed on the underlying representation of x. +func Invert(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderReadV2", + Type: "Invert", Input: []tf.Input{ - reader_handle, queue_handle, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Returns up to `num_records` (key, value) pairs produced by a Reader. +// Deserialize bucket boundaries and ready flag into current QuantileAccumulator. // -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// It may return less than `num_records` even before the last batch. +// An op that deserializes bucket boundaries and are boundaries ready flag into current QuantileAccumulator. // // Arguments: -// reader_handle: Handle to a `Reader`. -// queue_handle: Handle to a `Queue`, with string work items. -// num_records: number of records to read from `Reader`. +// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. +// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. // -// Returns A 1-D tensor.A 1-D tensor. -func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output, num_records tf.Output) (keys tf.Output, values tf.Output) { +// Returns the created operation. +func BoostedTreesQuantileStreamResourceDeserialize(scope *Scope, quantile_stream_resource_handle tf.Output, bucket_boundaries []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderReadUpToV2", + Type: "BoostedTreesQuantileStreamResourceDeserialize", Input: []tf.Input{ - reader_handle, queue_handle, num_records, + quantile_stream_resource_handle, tf.OutputList(bucket_boundaries), }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Adds v into specified rows of x. +// Inverse 3D fast Fourier transform. // -// Computes y = x; y[i, :] += v; return y. +// Computes the inverse 3-dimensional discrete Fourier transform over the +// inner-most 3 dimensions of `input`. // // Arguments: -// x: A `Tensor` of type T. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// input: A complex64 tensor. // -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceAdd(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their inverse 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifftn with 3 dimensions. +// @end_compatibility +func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InplaceAdd", + Type: "IFFT3D", Input: []tf.Input{ - x, i, v, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Restore a Reader to its initial clean state. +// Shuts down a running distributed TPU system. // -// Arguments: -// reader_handle: Handle to a Reader. +// The op returns an error if no system is running. // // Returns the created operation. -func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { +func ShutdownDistributedTPU(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderResetV2", - Input: []tf.Input{ - reader_handle, - }, + Type: "ShutdownDistributedTPU", } return scope.AddOperation(opspec) } -// BatchAttr is an optional argument to Batch. -type BatchAttr func(optionalAttr) - -// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. -// If not specified, defaults to 10 -func BatchMaxEnqueuedBatches(value int64) BatchAttr { - return func(m optionalAttr) { - m["max_enqueued_batches"] = value +// Deprecated. Disallowed in GraphDef version >= 2. +// +// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead +func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. -// If not specified, defaults to <> -func BatchAllowedBatchSizes(value []int64) BatchAttr { - return func(m optionalAttr) { - m["allowed_batch_sizes"] = value + opspec := tf.OpSpec{ + Type: "AdjustContrast", + Input: []tf.Input{ + images, contrast_factor, min_value, max_value, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// BatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func BatchContainer(value string) BatchAttr { - return func(m optionalAttr) { - m["container"] = value +// Table initializer that takes two tensors for keys and values respectively. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// keys: Keys of type Tkey. +// values: Values of type Tval. +// +// Returns the created operation. +func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InitializeTableV2", + Input: []tf.Input{ + table_handle, keys, values, + }, } + return scope.AddOperation(opspec) } -// BatchSharedName sets the optional shared_name attribute to value. +// PrintAttr is an optional argument to Print. +type PrintAttr func(optionalAttr) + +// PrintMessage sets the optional message attribute to value. +// +// value: A string, prefix of the error message. // If not specified, defaults to "" -func BatchSharedName(value string) BatchAttr { +func PrintMessage(value string) PrintAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["message"] = value } } -// BatchBatchingQueue sets the optional batching_queue attribute to value. -// If not specified, defaults to "" -func BatchBatchingQueue(value string) BatchAttr { +// PrintFirstN sets the optional first_n attribute to value. +// +// value: Only log `first_n` number of times. -1 disables logging. +// If not specified, defaults to -1 +func PrintFirstN(value int64) PrintAttr { return func(m optionalAttr) { - m["batching_queue"] = value + m["first_n"] = value } } -// Batches all input tensors nondeterministically. -// -// When many instances of this Op are being run concurrently with the same -// container/shared_name in the same device, some will output zero-shaped Tensors -// and others will output Tensors of size up to max_batch_size. -// -// All Tensors in in_tensors are batched together (so, for example, labels and -// features should be batched with a single instance of this operation. -// -// Each invocation of batch emits an `id` scalar which will be used to identify -// this particular invocation when doing unbatch or its gradient. -// -// Each op which emits a non-empty batch will also emit a non-empty batch_index -// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, -// start, and length of elements of each set of Tensors present in batched_tensors. -// -// Batched tensors are concatenated along the first dimension, and all tensors in -// in_tensors must have the first dimension of the same size. +// PrintSummarize sets the optional summarize attribute to value. // -// in_tensors: The tensors to be batched. -// num_batch_threads: Number of scheduling threads for processing batches of work. -// Determines the number of batches processed in parallel. -// max_batch_size: Batch sizes will never be bigger than this. -// batch_timeout_micros: Maximum number of microseconds to wait before outputting -// an incomplete batch. -// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does -// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad -// batches up to one of those sizes. The entries must increase monotonically, and -// the final entry must equal max_batch_size. -// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. -// batched_tensors: Either empty tensors or a batch of concatenated Tensors. -// batch_index: If out_tensors is non-empty, has information to invert it. -// container: Controls the scope of sharing of this batch. -// id: always contains a scalar with a unique ID for this invocation of Batch. -// shared_name: Concurrently running instances of batch in the same device with the -// same container and shared_name will batch their elements together. If left -// empty, the op name will be used as the shared name. -// T: the types of tensors to be batched. -func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { +// value: Only print this many entries of each tensor. +// If not specified, defaults to 3 +func PrintSummarize(value int64) PrintAttr { + return func(m optionalAttr) { + m["summarize"] = value + } +} + +// Prints a list of tensors. +// +// Passes `input` through to `output` and prints `data` when evaluating. +// +// Arguments: +// input: The tensor passed to `output` +// data: A list of tensors to print out when op is evaluated. +// +// Returns = The unmodified `input` tensor +func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Batch", + Type: "Print", Input: []tf.Input{ - tf.OutputList(in_tensors), + input, tf.OutputList(data), }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs a `Summary` protocol buffer with a tensor and per-plugin data. +// +// Arguments: +// tag: A string attached to this summary. Used for organization in TensorBoard. +// tensor: A tensor to serialize. +// serialized_summary_metadata: A serialized SummaryMetadata proto. Contains plugin +// data. +func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_summary_metadata tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { - scope.UpdateErr("Batch", err) - return + opspec := tf.OpSpec{ + Type: "TensorSummaryV2", + Input: []tf.Input{ + tag, tensor, serialized_summary_metadata, + }, } - batch_index = op.Output(idx) - id = op.Output(idx) - return batched_tensors, batch_index, id + op := scope.AddOperation(opspec) + return op.Output(0) } -// Adjust the hue of one or more images. +// Creates a dataset that asynchronously prefetches elements from `input_dataset`. // -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. +// Arguments: // -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A delta is then applied all the hue values, -// and then remapped back to RGB colorspace. +// buffer_size: The maximum number of elements to buffer in an iterator over +// this dataset. // -// Arguments: -// images: Images to adjust. At least 3-D. -// delta: A float delta to add to the hue. // -// Returns The hue-adjusted image or images. -func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { +func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "AdjustHue", + Type: "PrefetchDataset", Input: []tf.Input{ - images, delta, + input_dataset, buffer_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. -type ResizeBicubicGradAttr func(optionalAttr) +// TensorSummaryAttr is an optional argument to TensorSummary. +type TensorSummaryAttr func(optionalAttr) -// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// TensorSummaryDescription sets the optional description attribute to value. // -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { +// value: A json-encoded SummaryDescription proto. +// If not specified, defaults to "" +func TensorSummaryDescription(value string) TensorSummaryAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["description"] = value } } -// Computes the gradient of bicubic interpolation. +// TensorSummaryLabels sets the optional labels attribute to value. // -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. +// value: An unused list of strings. +// If not specified, defaults to <> +func TensorSummaryLabels(value []string) TensorSummaryAttr { + return func(m optionalAttr) { + m["labels"] = value + } +} + +// TensorSummaryDisplayName sets the optional display_name attribute to value. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { +// value: An unused string. +// If not specified, defaults to "" +func TensorSummaryDisplayName(value string) TensorSummaryAttr { + return func(m optionalAttr) { + m["display_name"] = value + } +} + +// Outputs a `Summary` protocol buffer with a tensor. +// +// This op is being phased out in favor of TensorSummaryV2, which lets callers pass +// a tag as well as a serialized SummaryMetadata proto string that contains +// plugin-specific data. We will keep this op to maintain backwards compatibility. +// +// Arguments: +// tensor: A tensor to serialize. +func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } @@ -30663,9 +34629,9 @@ func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBicubicGrad", + Type: "TensorSummary", Input: []tf.Input{ - grads, original_image, + tensor, }, Attrs: attrs, } @@ -30673,41 +34639,24 @@ func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, return op.Output(0) } -// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. -type ResizeNearestNeighborAttr func(optionalAttr) - -// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize `images` to `size` using nearest neighbor interpolation. +// Read an element from the TensorArray into output `value`. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// handle: The handle to a TensorArray. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns The tensor that is read from the TensorArray. +func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ResizeNearestNeighbor", + Type: "TensorArrayReadV3", Input: []tf.Input{ - images, size, + handle, index, flow_in, }, Attrs: attrs, } @@ -30715,215 +34664,211 @@ func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optio return op.Output(0) } -// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. -type ResizeNearestNeighborGradAttr func(optionalAttr) - -// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. +// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. // -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of nearest neighbor interpolation. +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = max(ref[indices, ...], updates[...]) +// +// # Vector indices (for each i) +// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
// // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The -// original input size. +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients -// with respect to the input image. -func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { +// Returns the created operation. +func ResourceScatterMax(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResizeNearestNeighborGrad", + Type: "ResourceScatterMax", Input: []tf.Input{ - grads, size, + resource, indices, updates, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. -type ExtractJpegShapeAttr func(optionalAttr) - -// ExtractJpegShapeOutputType sets the optional output_type attribute to value. +// Computes the gradient for the tanh of `x` wrt its input. // -// value: (Optional) The output type of the operation (int32 or int64). -// Defaults to int32. -// If not specified, defaults to DT_INT32 -func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { - return func(m optionalAttr) { - m["output_type"] = value +// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` +// is the corresponding input gradient. +func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TanhGrad", + Input: []tf.Input{ + y, dy, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Extract the shape information of a JPEG-encoded image. +// Outputs a `Summary` protocol buffer with scalar values. // -// This op only parses the image header, so it is much faster than DecodeJpeg. +// The input `tags` and `values` must have the same shape. The generated summary +// has a summary value for each tag-value pair in `tags` and `values`. // // Arguments: -// contents: 0-D. The JPEG-encoded image. +// tags: Tags for the summary. +// values: Same shape as `tags. Values for the summary. // -// Returns 1-D. The image shape with format [height, width, channels]. -func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ExtractJpegShape", + Type: "ScalarSummary", Input: []tf.Input{ - contents, + tags, values, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. -type PaddingFIFOQueueV2Attr func(optionalAttr) +// ImageSummaryAttr is an optional argument to ImageSummary. +type ImageSummaryAttr func(optionalAttr) -// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. +// ImageSummaryMaxImages sets the optional max_images attribute to value. // -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. -// Shapes of fixed rank but variable size are allowed by setting -// any shape dimension to -1. In this case, the inputs' shape may vary along -// the given dimension, and DequeueMany will pad the given dimension with -// zeros up to the maximum shape of all elements in the given batch. -// If the length of this attr is 0, different queue elements may have -// different ranks and shapes, but only one element may be dequeued at a time. -// If not specified, defaults to <> +// value: Max number of batch elements to generate images for. +// If not specified, defaults to 3 // -// REQUIRES: len(value) >= 0 -func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { +// REQUIRES: value >= 1 +func ImageSummaryMaxImages(value int64) ImageSummaryAttr { return func(m optionalAttr) { - m["shapes"] = value + m["max_images"] = value } } -// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. +// ImageSummaryBadColor sets the optional bad_color attribute to value. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { +// value: Color to use for pixels with non-finite values. +// If not specified, defaults to > int_val:255 int_val:0 int_val:0 int_val:255 > +func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { - m["capacity"] = value + m["bad_color"] = value } } -// PaddingFIFOQueueV2Container sets the optional container attribute to value. +// Outputs a `Summary` protocol buffer with images. // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. +// The summary has up to `max_images` summary values containing images. The +// images are built from `tensor` which must be 4-D with shape `[batch_size, +// height, width, channels]` and where `channels` can be: // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements in first-in first-out order. +// * 1: `tensor` is interpreted as Grayscale. +// * 3: `tensor` is interpreted as RGB. +// * 4: `tensor` is interpreted as RGBA. // -// Variable-size shapes are allowed by setting the corresponding shape dimensions -// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum -// size of any given element in the minibatch. See below for details. +// The images have the same number of channels as the input tensor. For float +// input, the values are normalized one image at a time to fit in the range +// `[0, 255]`. `uint8` values are unchanged. The op uses two different +// normalization algorithms: +// +// * If the input values are all positive, they are rescaled so the largest one +// is 255. +// +// * If any input value is negative, the values are shifted so input value 0.0 +// is at 127. They are then rescaled so that either the smallest value is 0, +// or the largest one is 255. +// +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: +// +// * If `max_images` is 1, the summary value tag is '*tag*/image'. +// * If `max_images` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. +// +// The `bad_color` argument is the color to use in the generated images for +// non-finite input values. It is a `uint8` 1-D tensor of length `channels`. +// Each element must be in the range `[0, 255]` (It represents the value of a +// pixel in the output image). Non-finite values in the input tensor are +// replaced by this tensor in the output image. The default value is the color +// red. // // Arguments: -// component_types: The type of each component in a value. +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 4-D of shape `[batch_size, height, width, channels]` where +// `channels` is 1, 3, or 4. // -// Returns The handle to the queue. -func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...ImageSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "PaddingFIFOQueueV2", - + Type: "ImageSummary", + Input: []tf.Input{ + tag, tensor, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodePngAttr is an optional argument to DecodePng. -type DecodePngAttr func(optionalAttr) +// AudioSummaryV2Attr is an optional argument to AudioSummaryV2. +type AudioSummaryV2Attr func(optionalAttr) -// DecodePngChannels sets the optional channels attribute to value. +// AudioSummaryV2MaxOutputs sets the optional max_outputs attribute to value. // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodePngChannels(value int64) DecodePngAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodePngDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_UINT8 -func DecodePngDtype(value tf.DataType) DecodePngAttr { +// value: Max number of batch elements to generate audio for. +// If not specified, defaults to 3 +// +// REQUIRES: value >= 1 +func AudioSummaryV2MaxOutputs(value int64) AudioSummaryV2Attr { return func(m optionalAttr) { - m["dtype"] = value + m["max_outputs"] = value } } -// Decode a PNG-encoded image to a uint8 or uint16 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: +// Outputs a `Summary` protocol buffer with audio. // -// * 0: Use the number of channels in the PNG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// * 4: output an RGBA image. +// The summary has up to `max_outputs` summary values containing audio. The +// audio is built from `tensor` which must be 3-D with shape `[batch_size, +// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. // -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: // -// This op also supports decoding JPEGs and non-animated GIFs since the interface -// is the same, though it is cleaner to use `tf.image.decode_image`. +// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +// * If `max_outputs` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. // // Arguments: -// contents: 0-D. The PNG-encoded image. +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 2-D of shape `[batch_size, frames]`. +// sample_rate: The sample rate of the signal in hertz. // -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...AudioSummaryV2Attr) (summary tf.Output) { if scope.Err() != nil { return } @@ -30932,9 +34877,9 @@ func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (ima a(attrs) } opspec := tf.OpSpec{ - Type: "DecodePng", + Type: "AudioSummaryV2", Input: []tf.Input{ - contents, + tag, tensor, sample_rate, }, Attrs: attrs, } @@ -30942,139 +34887,119 @@ func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (ima return op.Output(0) } -// Decode the first frame of a GIF-encoded image to a uint8 tensor. -// -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: -// -// convert $src.gif -coalesce $dst.gif -// -// This op also supports decoding JPEGs and PNGs, though it is cleaner to use -// `tf.image.decode_image`. +// Splits a tensor into a list. // -// Arguments: -// contents: 0-D. The GIF-encoded image. +// list[i] corresponds to lengths[i] tensors from the input tensor. +// The tensor must have rank at least 1 and contain exactly sum(lengths) elements. // -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { +// tensor: The input tensor. +// element_shape: A shape compatible with that of elements in the tensor. +// lengths: Vector of sizes of the 0th dimension of tensors in the list. +// output_handle: The list. +func TensorListSplit(scope *Scope, tensor tf.Output, element_shape tf.Output, lengths tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DecodeGif", + Type: "TensorListSplit", Input: []tf.Input{ - contents, + tensor, element_shape, lengths, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. -type LearnedUnigramCandidateSamplerAttr func(optionalAttr) - -// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// AvgPoolAttr is an optional argument to AvgPool. +type AvgPoolAttr func(optionalAttr) -// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// AvgPoolDataFormat sets the optional data_format attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolDataFormat(value string) AvgPoolAttr { return func(m optionalAttr) { - m["seed2"] = value + m["data_format"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. +// Performs average pooling on the input. // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// Each entry in `output` is the mean of the corresponding size `ksize` +// window in `value`. // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). -// -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// value: 4-D with shape `[batch, height, width, channels]`. +// ksize: The size of the sliding window for each dimension of `value`. +// strides: The stride of the sliding window for each dimension of `value`. +// padding: The type of padding algorithm to use. +// +// Returns The average pooled output tensor. +func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LearnedUnigramCandidateSampler", + Type: "AvgPool", Input: []tf.Input{ - true_classes, + value, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// SerializeSparseAttr is an optional argument to SerializeSparse. -type SerializeSparseAttr func(optionalAttr) - -// SerializeSparseOutType sets the optional out_type attribute to value. +// Merges summaries. // -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { - return func(m optionalAttr) { - m["out_type"] = value +// This op creates a +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// protocol buffer that contains the union of all the values in the input +// summaries. +// +// When the Op is run, it reports an `InvalidArgument` error if multiple values +// in the summaries to merge use the same tag. +// +// Arguments: +// inputs: Can be of any shape. Each must contain serialized `Summary` protocol +// buffers. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MergeSummary", + Input: []tf.Input{ + tf.OutputList(inputs), + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Serialize a `SparseTensor` into a `[3]` `Tensor` object. +// The shape of the elements of the given list, as a tensor. // -// Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. -func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { +// input_handle: the list +// element_shape: the shape of elements of the list +func TensorListElementShape(scope *Scope, input_handle tf.Output, shape_type tf.DataType) (element_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"shape_type": shape_type} opspec := tf.OpSpec{ - Type: "SerializeSparse", + Type: "TensorListElementShape", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + input_handle, }, Attrs: attrs, } @@ -31082,305 +35007,268 @@ func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Ou return op.Output(0) } -// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. -type RandomShuffleQueueV2Attr func(optionalAttr) - -// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// Returns the item in the list with the given index. // -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> +// input_handle: the list +// index: the position in the list from which an element will be retrieved +// item: the element at that position // -// REQUIRES: len(value) >= 0 -func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value +// +func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_shape tf.Output, element_dtype tf.DataType) (item tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListGetItem", + Input: []tf.Input{ + input_handle, index, element_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// Resizes the list. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value +// +// input_handle: the input list +// size: size of the output list +// +func TensorListResize(scope *Scope, input_handle tf.Output, size tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return } + opspec := tf.OpSpec{ + Type: "TensorListResize", + Input: []tf.Input{ + input_handle, size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// Returns a diagonal tensor with a given diagonal values. // -// value: Dequeue will block unless there would be this -// many elements after the dequeue or the queue is closed. This -// ensures a minimum level of mixing of elements. -// If not specified, defaults to 0 -func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["min_after_dequeue"] = value +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: +// +// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: +// +// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. +// +// For example: +// +// ``` +// # 'diagonal' is [1, 2, 3, 4] +// tf.diag(diagonal) ==> [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] +// ``` +// +// Arguments: +// diagonal: Rank k tensor where k is at most 1. +func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } + opspec := tf.OpSpec{ + Type: "Diag", + Input: []tf.Input{ + diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. +type ParameterizedTruncatedNormalAttr func(optionalAttr) + +// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. // -// value: If either seed or seed2 is set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, a random seed is used. +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. // If not specified, defaults to 0 -func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { +func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { m["seed"] = value } } -// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. +// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. // // value: A second seed to avoid seed collision. // If not specified, defaults to 0 -func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { +func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { m["seed2"] = value } } -// RandomShuffleQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// Outputs random values from a normal distribution. The parameters may each be a // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that randomizes the order of elements. +// scalar which applies to the entire output, or a vector of length shape[0] which +// stores the parameters for each batch. // // Arguments: -// component_types: The type of each component in a value. +// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. +// means: The mean parameter of each batch. +// stdevs: The standard deviation parameter of each batch. Must be greater than 0. +// minvals: The minimum cutoff. May be -infinity. +// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval +// for each batch. // -// Returns The handle to the queue. -func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { +// Returns A matrix of shape num_batches x samples_per_batch, filled with random +// truncated normal values using the parameters for each row. +func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomShuffleQueueV2", - + Type: "ParameterizedTruncatedNormal", + Input: []tf.Input{ + shape, means, stdevs, minvals, maxvals, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Draw bounding boxes on a batch of images. -// -// Outputs a copy of `images` but draws on top of the pixels zero or more bounding -// boxes specified by the locations in `boxes`. The coordinates of the each -// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, if an image is 100 x 200 pixels (height x width) and the bounding -// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of -// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates). -// -// Parts of the bounding box may fall outside the image. +// Sets the index-th position of the list to contain the given tensor. // -// Arguments: -// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. -// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding -// boxes. +// input_handle: the list +// index: the position in the list to which the tensor will be assigned +// item: the element to be assigned to that position +// output_handle: the new list, with the element in the proper position // -// Returns 4-D with the same shape as `images`. The batch of input images with -// bounding boxes drawn on the images. -func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { +func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DrawBoundingBoxes", + Type: "TensorListSetItem", Input: []tf.Input{ - images, boxes, + input_handle, index, item, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Gets the next output from the given iterator. +// Creates a TensorList by indexing into a Tensor. // -// This operation is a synchronous version IteratorGetNext. It should only be used -// in situations where the iterator does not block the calling thread, or where -// the calling thread is not a member of the thread pool used to execute parallel -// operations (e.g. in eager mode). -func IteratorGetNextSync(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { +// Each member of the TensorList corresponds to one row of the input tensor, +// specified by the given index (see `tf.gather`). +// +// tensor: The input tensor. +// indices: The indices used to index into the list. +// element_shape: The shape of the elements in the list (can be less specified than +// the shape of the tensor). +// output_handle: The TensorList. +func TensorListScatter(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IteratorGetNextSync", + Type: "TensorListScatter", Input: []tf.Input{ - iterator, + tensor, indices, element_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayScatterV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 +func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNextSync", err) - return + opspec := tf.OpSpec{ + Type: "TensorArrayScatterV2", + Input: []tf.Input{ + handle, indices, value, flow_in, + }, } - return components + op := scope.AddOperation(opspec) + return op.Output(0) } -// SampleDistortedBoundingBoxV2Attr is an optional argument to SampleDistortedBoundingBoxV2. -type SampleDistortedBoundingBoxV2Attr func(optionalAttr) - -// SampleDistortedBoundingBoxV2Seed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxV2Seed(value int64) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// AsStringAttr is an optional argument to AsString. +type AsStringAttr func(optionalAttr) -// SampleDistortedBoundingBoxV2Seed2 sets the optional seed2 attribute to value. +// AsStringPrecision sets the optional precision attribute to value. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2Attr { +// value: The post-decimal precision to use for floating point numbers. +// Only used if precision > -1. +// If not specified, defaults to -1 +func AsStringPrecision(value int64) AsStringAttr { return func(m optionalAttr) { - m["seed2"] = value + m["precision"] = value } } -// SampleDistortedBoundingBoxV2AspectRatioRange sets the optional aspect_ratio_range attribute to value. +// AsStringScientific sets the optional scientific attribute to value. // -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { +// value: Use scientific notation for floating point numbers. +// If not specified, defaults to false +func AsStringScientific(value bool) AsStringAttr { return func(m optionalAttr) { - m["aspect_ratio_range"] = value + m["scientific"] = value } } -// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. +// AsStringShortest sets the optional shortest attribute to value. // -// value: The cropped area of the image must contain a fraction of the -// supplied image within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { +// value: Use shortest representation (either scientific or standard) for +// floating point numbers. +// If not specified, defaults to false +func AsStringShortest(value bool) AsStringAttr { return func(m optionalAttr) { - m["area_range"] = value + m["shortest"] = value } } -// SampleDistortedBoundingBoxV2MaxAttempts sets the optional max_attempts attribute to value. +// AsStringWidth sets the optional width attribute to value. // -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxV2MaxAttempts(value int64) SampleDistortedBoundingBoxV2Attr { +// value: Pad pre-decimal numbers to this width. +// Applies to both floating point and integer numbers. +// Only used if width > -1. +// If not specified, defaults to -1 +func AsStringWidth(value int64) AsStringAttr { return func(m optionalAttr) { - m["max_attempts"] = value + m["width"] = value } } -// SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// AsStringFill sets the optional fill attribute to value. // -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. -// If not specified, defaults to false -func SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxV2Attr { +// value: The value to pad if width > -1. If empty, pads with spaces. +// Another typical value is '0'. String cannot be longer than 1 character. +// If not specified, defaults to "" +func AsStringFill(value string) AsStringAttr { return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value + m["fill"] = value } } -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, -// -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) -// -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) -// -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` -// -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. -// -// Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. -// min_object_covered: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. +// Converts each entry in the given tensor to strings. Supports many numeric // -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, min_object_covered tf.Output, optional ...SampleDistortedBoundingBoxV2Attr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// types and boolean. +func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -31389,258 +35277,168 @@ func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_b a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBoxV2", + Type: "AsString", Input: []tf.Input{ - image_size, bounding_boxes, min_object_covered, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. -type ExtractGlimpseAttr func(optionalAttr) - -// ExtractGlimpseCentered sets the optional centered attribute to value. -// -// value: indicates if the offset coordinates are centered relative to -// the image, in which case the (0, 0) offset is relative to the center -// of the input images. If false, the (0,0) offset corresponds to the -// upper left corner of the input images. -// If not specified, defaults to true -func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["centered"] = value - } -} - -// ExtractGlimpseNormalized sets the optional normalized attribute to value. -// -// value: indicates if the offset coordinates are normalized. -// If not specified, defaults to true -func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["normalized"] = value - } -} - -// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. -// -// value: indicates if the noise should be generated using a -// uniform distribution or a Gaussian distribution. -// If not specified, defaults to true -func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["uniform_noise"] = value - } + return op.Output(0) } -// Extracts a glimpse from the input tensor. +// Returns a `RaggedTensor` containing the specified sequences of numbers. // -// Returns a set of windows called glimpses extracted at location -// `offsets` from the input tensor. If the windows only partially -// overlaps the inputs, the non overlapping areas will be filled with -// random noise. // -// The result is a 4-D tensor of shape `[batch_size, glimpse_height, -// glimpse_width, channels]`. The channels and batch dimensions are the -// same as that of the input tensor. The height and width of the output -// windows are specified in the `size` parameter. +// Returns a `RaggedTensor` `result` composed from `rt_dense_values` and +// `rt_nested_splits`, such that +// `result[i] = range(starts[i], limits[i], deltas[i])`. // -// The argument `normalized` and `centered` controls how the windows are built: +// ```python +// >>> (rt_nested_splits, rt_dense_values) = gen_ragged_ops.ragged_range( +// ... starts=[2, 5, 8], limits=[3, 5, 12], deltas=1) +// >>> result = ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) +// >>> print result.eval().tolist() +// [[2], # result[0] = range(2, 3) +// [], # result[1] = range(5, 5) +// [8, 9, 10, 11]] # result[2] = range(8, 12) +// ``` // -// * If the coordinates are normalized but not centered, 0.0 and 1.0 -// correspond to the minimum and maximum of each height and width -// dimension. -// * If the coordinates are both normalized and centered, they range from -// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper -// left corner, the lower right corner is located at (1.0, 1.0) and the -// center is at (0, 0). -// * If the coordinates are not normalized they are interpreted as -// numbers of pixels. +// The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. +// The vector inputs must all have the same size. Scalar inputs are broadcast +// to match the size of the vector inputs. // // Arguments: -// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. -// size: A 1-D tensor of 2 elements containing the size of the glimpses -// to extract. The glimpse height must be specified first, following -// by the glimpse width. -// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing -// the y, x locations of the center of each window. +// starts: The starts of each range. +// limits: The limits of each range. +// deltas: The deltas of each range. // -// Returns A tensor representing the glimpses `[batch_size, -// glimpse_height, glimpse_width, channels]`. -func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { +// Returns The `row_splits` for the returned `RaggedTensor`.The `flat_values` for the returned `RaggedTensor`. +func RaggedRange(scope *Scope, starts tf.Output, limits tf.Output, deltas tf.Output) (rt_nested_splits tf.Output, rt_dense_values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ExtractGlimpse", + Type: "RaggedRange", Input: []tf.Input{ - input, size, offsets, + starts, limits, deltas, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// A container for an iterator resource. +// Deprecated, use python implementation tf.linalg.matrix_exponential. // -// Returns A handle to the iterator that can be passed to a "MakeIterator" -// or "IteratorGetNext" op. -func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead. +func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Iterator", - - Attrs: attrs, + Type: "MatrixExponential", + Input: []tf.Input{ + input, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorForestTreeResourceHandleOpAttr is an optional argument to TensorForestTreeResourceHandleOp. -type TensorForestTreeResourceHandleOpAttr func(optionalAttr) - -// TensorForestTreeResourceHandleOpContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func TensorForestTreeResourceHandleOpContainer(value string) TensorForestTreeResourceHandleOpAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TensorForestTreeResourceHandleOpSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func TensorForestTreeResourceHandleOpSharedName(value string) TensorForestTreeResourceHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a TensorForestTreeResource -func TensorForestTreeResourceHandleOp(scope *Scope, optional ...TensorForestTreeResourceHandleOpAttr) (resource tf.Output) { +// Computes the Cholesky decomposition of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. +// +// The input has to be symmetric and positive definite. Only the lower-triangular +// part of the input will be used for this operation. The upper-triangular part +// will not be read. +// +// The output is a tensor of the same shape as the input +// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// +// **Note**: The gradient computation on GPU is faster for large matrices but +// not for large batch dimensions when the submatrices are small. In this +// case it might be faster to use the CPU. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TensorForestTreeResourceHandleOp", - - Attrs: attrs, + Type: "Cholesky", + Input: []tf.Input{ + input, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) - -// CropAndResizeGradImageMethod sets the optional method attribute to value. +// Writes contents to the file at input filename. Creates file and recursively // -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// creates directory if not existing. // // Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. -// +// filename: scalar. The name of the file to which we write the contents. +// contents: scalar. The content to be written to the output file. // -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { +// Returns the created operation. +func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", + Type: "WriteFile", Input: []tf.Input{ - grads, boxes, box_ind, image_size, + filename, contents, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ShuffleDatasetAttr is an optional argument to ShuffleDataset. -type ShuffleDatasetAttr func(optionalAttr) +// AllAttr is an optional argument to All. +type AllAttr func(optionalAttr) -// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. +// AllKeepDims sets the optional keep_dims attribute to value. // -// value: If true, each iterator over this dataset will be given -// a different pseudorandomly generated seed, based on a sequence seeded by the -// `seed` and `seed2` inputs. If false, each iterator will be given the same -// seed, and repeated iteration over this dataset will yield the exact same -// sequence of results. -// If not specified, defaults to true -func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AllKeepDims(value bool) AllAttr { return func(m optionalAttr) { - m["reshuffle_each_iteration"] = value + m["keep_dims"] = value } } -// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. -// -// Arguments: +// Computes the "logical and" of elements across dimensions of a tensor. // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { +// Returns The reduced tensor. +func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ShuffleDataset", + Type: "All", Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, + input, axis, }, Attrs: attrs, } @@ -31648,26 +35446,28 @@ func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output return op.Output(0) } -// 3D fast Fourier transform. +// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. // -// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 -// dimensions of `input`. +// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. // -// Arguments: -// input: A complex64 tensor. +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices, with the same constraints as the single matrix +// SelfAdjointEig. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their 3D Fourier transform. +// The result is a [..., M+1, M] matrix with [..., 0,:] containing the +// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues +// are sorted in non-decreasing order. // -// @compatibility(numpy) -// Equivalent to np.fft.fftn with 3 dimensions. -// @end_compatibility -func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M+1, M]`. +func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FFT3D", + Type: "SelfAdjointEig", Input: []tf.Input{ input, }, @@ -31676,155 +35476,89 @@ func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. -type CropAndResizeGradBoxesAttr func(optionalAttr) - -// CropAndResizeGradBoxesMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. +// Computes softplus gradients for a softplus operation. // // Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// gradients: The backpropagated gradients to the corresponding softplus operation. +// features: The features passed as input to the corresponding softplus operation. // -// Returns A 2-D tensor of shape `[num_boxes, 4]`. -func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { +// Returns The gradients: `gradients / (1 + exp(-features))`. +func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "CropAndResizeGradBoxes", + Type: "SoftplusGrad", Input: []tf.Input{ - grads, image, boxes, box_ind, + gradients, features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Greedily selects a subset of bounding boxes in descending order of score, +// Solves tridiagonal systems of equations. // -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes with score less than -// `score_threshold` are removed. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system and more -// generally is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold, score_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) +// `diagonals` is a tensor of shape `[..., 3, M]` whose inner-most 2 dimensions +// represent matrices with three rows being the superdiagonal, diagonals, and +// subdiagonals, in order. The last element of the superdiagonal and the first +// element of the subdiagonal is ignored. +// `rhs` is a tensor of shape `[..., M, K]`, representing K right-hand sides per +// each left-hand side. +// The output is a tensor of shape `[..., M, K]` containing the solutions. // // Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. -// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove -// boxes based on score. +// diagonals: Shape is `[..., 3, M]`. +// rhs: Shape is `[..., M, K]`. // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionV3(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { +// Returns Shape is `[..., M, K]`. +func TridiagonalSolve(scope *Scope, diagonals tf.Output, rhs tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV3", + Type: "TridiagonalSolve", Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, score_threshold, + diagonals, rhs, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// NonMaxSuppressionV4Attr is an optional argument to NonMaxSuppressionV4. -type NonMaxSuppressionV4Attr func(optionalAttr) +// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. +type SelfAdjointEigV2Attr func(optionalAttr) -// NonMaxSuppressionV4PadToMaxOutputSize sets the optional pad_to_max_output_size attribute to value. +// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. // -// value: If true, the output `selected_indices` is padded to be of length -// `max_output_size`. Defaults to false. -// If not specified, defaults to false -func NonMaxSuppressionV4PadToMaxOutputSize(value bool) NonMaxSuppressionV4Attr { +// value: If `True` then eigenvectors will be computed and returned in `v`. +// Otherwise, only the eigenvalues will be computed. +// If not specified, defaults to true +func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { return func(m optionalAttr) { - m["pad_to_max_output_size"] = value + m["compute_v"] = value } } -// Greedily selects a subset of bounding boxes in descending order of score, +// Computes the eigen decomposition of one or more square self-adjoint matrices. // -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes with score less than -// `score_threshold` are removed. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system and more -// generally is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold, score_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) +// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in +// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues +// are sorted in non-decreasing order. +// +// ```python +// # a is a tensor. +// # e is a tensor of eigenvalues. +// # v is a tensor of eigenvectors. +// e, v = self_adjoint_eig(a) +// e = self_adjoint_eig(a, compute_v=False) +// ``` // // Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. -// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove -// boxes based on score. +// input: `Tensor` input of shape `[N, N]`. // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`.A 0-D integer tensor representing the number of valid elements in -// `selected_indices`, with the valid elements appearing first. -func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...NonMaxSuppressionV4Attr) (selected_indices tf.Output, valid_outputs tf.Output) { +// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. +func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { if scope.Err() != nil { return } @@ -31833,9 +35567,9 @@ func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_ou a(attrs) } opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV4", + Type: "SelfAdjointEigV2", Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, score_threshold, + input, }, Attrs: attrs, } @@ -31843,127 +35577,183 @@ func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_ou return op.Output(0), op.Output(1) } -// Computes the matrix logarithm of one or more square matrices: -// -// -// \\(log(exp(A)) = A\\) -// -// This op is only defined for complex matrices. If A is positive-definite and -// real, then casting to a complex matrix, taking the logarithm and casting back -// to a real matrix will give the correct result. +// Adjust the saturation of one or more images. // -// This function computes the matrix logarithm using the Schur-Parlett algorithm. -// Details of the algorithm can be found in Section 11.6.2 of: -// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. -// ISBN 978-0-898716-46-7. +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the exponential for all input submatrices `[..., :, :]`. +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A scale is then applied all the saturation +// values, and then remapped back to RGB colorspace. // // Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. +// images: Images to adjust. At least 3-D. +// scale: A float scale to add to the saturation. // -// @compatibility(scipy) -// Equivalent to scipy.linalg.logm -// @end_compatibility -func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { +// Returns The hue-adjusted image or images. +func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixLogarithm", + Type: "AdjustSaturation", Input: []tf.Input{ - input, + images, scale, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// This op is used as a placeholder in If branch functions. It doesn't provide a -// valid output when run, so must either be removed (e.g. replaced with a -// function input) or guaranteed not to be used (e.g. if mirroring an -// intermediate output needed for the gradient computation of the other branch). +// LuAttr is an optional argument to Lu. +type LuAttr func(optionalAttr) + +// LuOutputIdxType sets the optional output_idx_type attribute to value. +// If not specified, defaults to DT_INT32 +func LuOutputIdxType(value tf.DataType) LuAttr { + return func(m optionalAttr) { + m["output_idx_type"] = value + } +} + +// Computes the LU decomposition of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. +// +// The input has to be invertible. +// +// The output consists of two tensors LU and P containing the LU decomposition +// of all input submatrices `[..., :, :]`. LU encodes the lower triangular and +// upper triangular factors. +// +// For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of +// shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower +// triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose +// entries correspond to the upper triangular part, including the diagonal, of LU. +// +// P represents a permutation matrix encoded as a list of indices each between `0` +// and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to +// P, then the L, U and P satisfies P_mat * input = L * U. // // Arguments: -// dtype: The type of the output. -// shape: The purported shape of the output. This is only used for shape inference; -// the output will not necessarily have this shape. Can be a partial shape. +// input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of +// size `[M, M]`. // -// Returns \"Fake\" output value. This should not be consumed by another op. -func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { +// Returns A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the +// lower triangular factor `L` with unit diagonal, and whose upper triangular part +// denotes the upper triangular factor `U`.Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is +// `[..., M]`. +// @compatibility(scipy) +// Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are +// packed into a single tensor, the permutation is applied to `input` instead of +// the right hand side and the permutation `P` is returned as a list of indices +// instead of a permutation matrix. +// @end_compatibility +func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "FakeParam", - + Type: "Lu", + Input: []tf.Input{ + input, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Returns the next representable value of `x1` in the direction of `x2`, element-wise. -// -// This operation returns the same result as the C++ std::nextafter function. +// Deprecated. Use TensorArrayCloseV3 // -// It can also return a subnormal number. +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 // -// @compatibility(cpp) -// Equivalent to C++ std::nextafter function. -// @end_compatibility -func NextAfter(scope *Scope, x1 tf.Output, x2 tf.Output) (output tf.Output) { +// Returns the created operation. +func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NextAfter", + Type: "TensorArrayCloseV2", Input: []tf.Input{ - x1, x2, + handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Computes the gradient for the inverse of `x` wrt its input. +// EncodeBase64Attr is an optional argument to EncodeBase64. +type EncodeBase64Attr func(optionalAttr) + +// EncodeBase64Pad sets the optional pad attribute to value. // -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// value: Bool whether padding is applied at the ends. +// If not specified, defaults to false +func EncodeBase64Pad(value bool) EncodeBase64Attr { + return func(m optionalAttr) { + m["pad"] = value + } +} + +// Encode strings into web-safe base64 format. +// +// Refer to the following article for more information on base64 format: +// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the +// end so that the encoded has length multiple of 4. See Padding section of the +// link above. +// +// Web-safe means that the encoder uses - and _ instead of + and /. +// +// Arguments: +// input: Strings to be encoded. +// +// Returns Input strings encoded in base64. +func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "InvGrad", + Type: "EncodeBase64", Input: []tf.Input{ - y, dy, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// List of the given size with empty elements. +// A dataset that creates window datasets from the input dataset. // -// element_shape: the shape of the future elements of the list -// num_elements: the number of elements to reserve -// handle: the output list -// element_dtype: the desired type of elements in the list. -func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { +// Arguments: +// +// size: A scalar representing the number of elements to accumulate in a window. +// shift: A scalar representing the steps moving the sliding window forward in one +// iteration. It must be positive. +// stride: A scalar representing the stride of the input elements of the sliding window. +// It must be positive. +// drop_remainder: A scalar representing whether a window should be dropped in case its size is +// smaller than desired. +// +// +func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorListReserve", + Type: "WindowDataset", Input: []tf.Input{ - element_shape, num_elements, + input_dataset, size, shift, stride, drop_remainder, }, Attrs: attrs, } @@ -31971,73 +35761,97 @@ func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Ou return op.Output(0) } -// A substitute for `InterleaveDataset` on a fixed list of `N` datasets. +// Computes the matrix square root of one or more square matrices: +// +// matmul(sqrtm(A), sqrtm(A)) = A +// +// The input matrix should be invertible. If the input matrix is real, it should +// have no eigenvalues which are real and negative (pairs of complex conjugate +// eigenvalues are allowed). +// +// The matrix square root is computed by first reducing the matrix to +// quasi-triangular form with the real Schur decomposition. The square root +// of the quasi-triangular matrix is then computed directly. Details of +// the algorithm can be found in: Nicholas J. Higham, "Computing real +// square roots of a real matrix", Linear Algebra Appl., 1987. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the matrix square root for all input submatrices `[..., :, :]`. // // Arguments: -// selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines which of the -// `N` data inputs should produce the next output element. -// data_input_datasets: `N` datasets with the same type that will be interleaved according to -// the values of `selector_input_dataset`. +// input: Shape is `[..., M, M]`. // +// Returns Shape is `[..., M, M]`. // -func ExperimentalDirectedInterleaveDataset(scope *Scope, selector_input_dataset tf.Output, data_input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// @compatibility(scipy) +// Equivalent to scipy.linalg.sqrtm +// @end_compatibility +func MatrixSquareRoot(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalDirectedInterleaveDataset", + Type: "MatrixSquareRoot", Input: []tf.Input{ - selector_input_dataset, tf.OutputList(data_input_datasets), + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// RandomUniformIntAttr is an optional argument to RandomUniformInt. -type RandomUniformIntAttr func(optionalAttr) +// SvdAttr is an optional argument to Svd. +type SvdAttr func(optionalAttr) -// RandomUniformIntSeed sets the optional seed attribute to value. +// SvdComputeUv sets the optional compute_uv attribute to value. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformIntSeed(value int64) RandomUniformIntAttr { +// value: If true, left and right singular vectors will be +// computed and returned in `u` and `v`, respectively. +// If false, `u` and `v` are not set and should never referenced. +// If not specified, defaults to true +func SvdComputeUv(value bool) SvdAttr { return func(m optionalAttr) { - m["seed"] = value + m["compute_uv"] = value } } -// RandomUniformIntSeed2 sets the optional seed2 attribute to value. +// SvdFullMatrices sets the optional full_matrices attribute to value. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { +// value: If true, compute full-sized `u` and `v`. If false +// (the default), compute only the leading `P` singular vectors. +// Ignored if `compute_uv` is `False`. +// If not specified, defaults to false +func SvdFullMatrices(value bool) SvdAttr { return func(m optionalAttr) { - m["seed2"] = value + m["full_matrices"] = value } } -// Outputs random integers from a uniform distribution. +// Computes the singular value decompositions of one or more matrices. // -// The generated values are uniform integers in the range `[minval, maxval)`. -// The lower bound `minval` is included in the range, while the upper bound -// `maxval` is excluded. +// Computes the SVD of each inner matrix in `input` such that +// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` // -// The random integers are slightly biased unless `maxval - minval` is an exact -// power of two. The bias is small for values of `maxval - minval` significantly -// smaller than the range of the output (either `2^32` or `2^64`). +// ```python +// # a is a tensor containing a batch of matrices. +// # s is a tensor of singular values for each matrix. +// # u is the tensor containing of left singular vectors for each matrix. +// # v is the tensor containing of right singular vectors for each matrix. +// s, u, v = svd(a) +// s, _, _ = svd(a, compute_uv=False) +// ``` // // Arguments: -// shape: The shape of the output tensor. -// minval: 0-D. Inclusive lower bound on the generated integers. -// maxval: 0-D. Exclusive upper bound on the generated integers. +// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. // -// Returns A tensor of the specified shape filled with uniform random integers. -func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { +// Returns Singular values. Shape is `[..., P]`.Left singular vectors. If `full_matrices` is `False` then shape is +// `[..., M, P]`; if `full_matrices` is `True` then shape is +// `[..., M, M]`. Undefined if `compute_uv` is `False`.Left singular vectors. If `full_matrices` is `False` then shape is +// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`. +// Undefined if `compute_uv` is false. +func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) { if scope.Err() != nil { return } @@ -32046,160 +35860,194 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf a(attrs) } opspec := tf.OpSpec{ - Type: "RandomUniformInt", + Type: "Svd", Input: []tf.Input{ - shape, minval, maxval, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Add the quantile summaries to each quantile stream resource. +// Converts one or more images from RGB to HSV. // -// An op that adds a list of quantile summaries to a quantile stream resource. Each -// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank) -// for a single feature. +// Outputs a tensor of the same shape as the `images` tensor, containing the HSV +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. // // Arguments: -// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. -// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature. +// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. // -// Returns the created operation. -func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) { +// Returns `images` converted to HSV. +func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesQuantileStreamResourceAddSummaries", + Type: "RGBToHSV", Input: []tf.Input{ - quantile_stream_resource_handle, tf.OutputList(summaries), + images, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Creates a Dataset that returns pseudorandom numbers. -// -// Arguments: -// seed: A scalar seed for the random number generator. If either seed or -// seed2 is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// +// Does nothing. Only useful as a placeholder for control edges. // -func ExperimentalRandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns the created operation. +func NoOp(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalRandomDataset", - Input: []tf.Input{ - seed, seed2, - }, - Attrs: attrs, + Type: "NoOp", } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// A dataset that splits the elements of its input into multiple elements. -func ExperimentalUnbatchDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. +type MergeV2CheckpointsAttr func(optionalAttr) + +// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. +// +// value: see above. +// If not specified, defaults to true +func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { + return func(m optionalAttr) { + m["delete_old_dirs"] = value + } +} + +// V2 format specific: merges the metadata files of sharded checkpoints. The +// +// result is one logical checkpoint, with one physical metadata file and renamed +// data files. +// +// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. +// +// If delete_old_dirs is true, attempts to delete recursively the dirname of each +// path in the input checkpoint_prefixes. This is useful when those paths are non +// user-facing temporary locations. +// +// Arguments: +// checkpoint_prefixes: prefixes of V2 checkpoints to merge. +// destination_prefix: scalar. The desired final prefix. Allowed to be the same +// as one of the checkpoint_prefixes. +// +// Returns the created operation. +func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ExperimentalUnbatchDataset", + Type: "MergeV2Checkpoints", Input: []tf.Input{ - input_dataset, + checkpoint_prefixes, destination_prefix, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset that overrides the maximum intra-op parallelism. +// Saves input tensors slices to disk. // -// Arguments: +// This is like `Save` except that tensors can be listed in the saved file as being +// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the +// larger tensor and the slice that this tensor covers. `shapes_and_slices` must +// have as many elements as `tensor_names`. // -// max_intra_op_parallelism: Identifies the maximum intra-op parallelism to use. +// Elements of the `shapes_and_slices` input must either be: +// +// * The empty string, in which case the corresponding tensor is +// saved normally. +// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the +// `dimI` are the dimensions of the larger tensor and `slice-spec` +// specifies what part is covered by the tensor to save. // +// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` +// where each `sliceI` is either: // -func ExperimentalMaxIntraOpParallelismDataset(scope *Scope, input_dataset tf.Output, max_intra_op_parallelism tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// * The string `-` meaning that the slice covers all indices of this dimension +// * `start,length` where `start` and `length` are integers. In that +// case the slice covers `length` indices starting at `start`. +// +// See also `Save`. +// +// Arguments: +// filename: Must have a single element. The name of the file to which we write the +// tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when +// saving the tensors. +// data: `N` tensors to save. +// +// Returns the created operation. +func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalMaxIntraOpParallelismDataset", + Type: "SaveSlices", Input: []tf.Input{ - input_dataset, max_intra_op_parallelism, + filename, tensor_names, shapes_and_slices, tf.OutputList(data), }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// StringSplitV2Attr is an optional argument to StringSplitV2. -type StringSplitV2Attr func(optionalAttr) +// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. +type DenseToDenseSetOperationAttr func(optionalAttr) -// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. -// -// value: An `int`. If `maxsplit > 0`, limit of the split of the result. -// If not specified, defaults to -1 -func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { +// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { return func(m optionalAttr) { - m["maxsplit"] = value + m["validate_indices"] = value } } -// Split elements of `source` based on `sep` into a `SparseTensor`. +// Applies set operation along last dimension of 2 `Tensor` inputs. // -// Let N be the size of source (typically N will be the batch size). Split each -// element of `source` based on `sep` and return a `SparseTensor` -// containing the split tokens. Empty tokens are ignored. +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. // -// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', -// then the output will be -// ``` -// st.indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// st.shape = [2, 3] -// st.values = ['hello', 'world', 'a', 'b', 'c'] -// ``` +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. // -// If `sep` is given, consecutive delimiters are not grouped together and are -// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and -// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty -// string, consecutive whitespace are regarded as a single separator, and the -// result will contain no empty strings at the startor end if the string has -// leading or trailing whitespace. +// Arguments: +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. // -// Note that the above mentioned behavior matches python's str.split. // -// Arguments: -// input: `1-D` string `Tensor`, the strings to split. -// sep: `0-D` string `Tensor`, the delimiter character. -func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"set_operation": set_operation} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StringSplitV2", + Type: "DenseToDenseSetOperation", Input: []tf.Input{ - input, sep, + set1, set2, }, Attrs: attrs, } @@ -32207,278 +36055,271 @@ func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...Str return op.Output(0), op.Output(1), op.Output(2) } -// Creates a dataset that uses a custom thread pool to compute `input_dataset`. -// -// Arguments: -// -// thread_pool: A resource produced by the ThreadPoolHandle op. -// +// Generate a sharded filename. The filename is printf formatted as // -func ExperimentalThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread_pool tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// %s-%05d-of-%05d, basename, shard, num_shards. +func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalThreadPoolDataset", + Type: "ShardedFilename", Input: []tf.Input{ - input_dataset, thread_pool, + basename, shard, num_shards, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softsign: `features / (abs(features) + 1)`. -func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { +// Generate a glob pattern matching all sharded file names. +func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Softsign", + Type: "ShardedFilespec", Input: []tf.Input{ - features, + basename, num_shards, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// EncodeProtoAttr is an optional argument to EncodeProto. -type EncodeProtoAttr func(optionalAttr) +// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. +type TextLineReaderV2Attr func(optionalAttr) -// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value. -// If not specified, defaults to "local://" -func EncodeProtoDescriptorSource(value string) EncodeProtoAttr { +// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. +// +// value: Number of lines to skip from the beginning of every file. +// If not specified, defaults to 0 +func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { return func(m optionalAttr) { - m["descriptor_source"] = value + m["skip_header_lines"] = value } } -// The op serializes protobuf messages provided in the input tensors. -// -// The types of the tensors in `values` must match the schema for the -// fields specified in `field_names`. All the tensors in `values` must -// have a common shape prefix, *batch_shape*. -// -// The `sizes` tensor specifies repeat counts for each field. The repeat -// count (last dimension) of a each tensor in `values` must be greater -// than or equal to corresponding repeat count in `sizes`. -// -// A `message_type` name must be provided to give context for the field -// names. The actual message descriptor can be looked up either in the -// linked-in descriptor pool or a filename provided by the caller using -// the `descriptor_source` attribute. -// -// The `descriptor_source` attribute selects a source of protocol -// descriptors to consult when looking up `message_type`. This may be a -// filename containing a serialized `FileDescriptorSet` message, -// or the special value `local://`, in which case only descriptors linked -// into the code will be searched; the filename can be on any filesystem -// accessible to TensorFlow. -// -// You can build a `descriptor_source` file using the `--descriptor_set_out` -// and `--include_imports` options to the protocol compiler `protoc`. -// -// The `local://` database only covers descriptors linked into the -// code via C++ libraries, not Python imports. You can link in a proto descriptor -// by creating a cc_library target with alwayslink=1. -// -// There are a few special cases in the value mapping: -// -// Submessage and group fields must be pre-serialized as TensorFlow strings. -// -// TensorFlow lacks support for unsigned int64s, so they must be -// represented as `tf.int64` with the same twos-complement bit pattern -// (the obvious way). +// TextLineReaderV2Container sets the optional container attribute to value. // -// Unsigned int32 values can be represented exactly with `tf.int64`, or -// with sign wrapping if the input is of type `tf.int32`. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TextLineReaderV2Container(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TextLineReaderV2SharedName sets the optional shared_name attribute to value. // -// Arguments: -// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`. -// values: List of tensors containing values for the corresponding field. -// field_names: List of strings containing proto field names. -// message_type: Name of the proto message type to decode. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the lines of a file delimited by '\n'. // -// Returns Tensor of serialized protos with shape `batch_shape`. -func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) { +// Returns The handle to reference the Reader. +func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "EncodeProto", - Input: []tf.Input{ - sizes, tf.OutputList(values), - }, + Type: "TextLineReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a TensorArray for storing the gradients of values in the given handle. +// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. +type LoadAndRemapMatrixAttr func(optionalAttr) + +// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. // -// If the given TensorArray gradient already exists, returns a reference to it. +// value: The maximum number of rows to load from the checkpoint at +// once. If less than or equal to 0, the entire matrix will be loaded into +// memory. Setting this arg trades increased disk reads for lower memory usage. +// If not specified, defaults to -1 +func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { + return func(m optionalAttr) { + m["max_rows_in_memory"] = value + } +} + +// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint // -// Locks the size of the original TensorArray by disabling its dynamic size flag. +// at `ckpt_path` and potentially reorders its rows and columns using the +// specified remappings. // -// **A note about the input flow_in:** +// Most users should use one of the wrapper initializers (such as +// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this +// function directly. // -// The handle flow_in forces the execution of the gradient lookup to occur -// only after certain other operations have occurred. For example, when -// the forward TensorArray is dynamically sized, writes to this TensorArray -// may resize the object. The gradient TensorArray is statically sized based -// on the size of the forward TensorArray when this operation executes. -// Furthermore, the size of the forward TensorArray is frozen by this call. -// As a result, the flow is used to ensure that the call to generate the gradient -// TensorArray only happens after all writes are executed. +// The remappings are 1-D tensors with the following properties: // -// In the case of dynamically sized TensorArrays, gradient computation should -// only be performed on read operations that have themselves been chained via -// flow to occur only after all writes have executed. That way the final size -// of the forward TensorArray is known when this operation is called. +// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output +// matrix will be initialized from the row corresponding to index +// `row_remapping[i]` in the old `Tensor` from the checkpoint. +// * `col_remapping` must have either 0 entries (indicating that no column +// reordering is needed) or `num_cols` entries. If specified, column `j` of the +// output matrix will be initialized from the column corresponding to index +// `col_remapping[j]` in the old `Tensor` from the checkpoint. +// * A value of -1 in either of the remappings signifies a "missing" entry. In that +// case, values from the `initializing_values` tensor will be used to fill that +// missing row or column. If `row_remapping` has `r` missing entries and +// `col_remapping` has `c` missing entries, then the following condition must be +// true: // -// **A note about the source attribute:** +// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` // -// TensorArray gradient calls use an accumulator TensorArray object. If -// multiple gradients are calculated and run in the same session, the multiple -// gradient nodes may accidentally flow through the same accumulator TensorArray. -// This double counts and generally breaks the TensorArray gradient flow. +// The remapping tensors can be generated using the GenerateVocabRemapping op. // -// The solution is to identify which gradient call this particular -// TensorArray gradient is being called in. This is performed by identifying -// a unique string (e.g. "gradients", "gradients_1", ...) from the input -// gradient Tensor's name. This string is used as a suffix when creating -// the TensorArray gradient object here (the attribute `source`). +// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], +// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing +// the value from row i, column j of the old tensor in the checkpoint, the output +// matrix will look like the following: // -// The attribute `source` is added as a suffix to the forward TensorArray's -// name when performing the creation / lookup, so that each separate gradient -// calculation gets its own TensorArray accumulator. +// [[w(1, 0), w(1, 2), 0.5], +// [w(0, 0), w(0, 2), -0.5], +// [0.25, -0.25, 42]] // // Arguments: -// handle: The handle to the forward TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// source: The gradient source string, used to decide which gradient TensorArray -// to return. -func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { +// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from +// which the old matrix `Tensor` will be loaded. +// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. +// row_remapping: An int `Tensor` of row remappings (generally created by +// `generate_vocab_remapping`). Even if no row remapping is needed, this must +// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted +// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). +// col_remapping: An int `Tensor` of column remappings (generally created by +// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping +// is to be done (e.g. column ordering is the same). +// initializing_values: A float `Tensor` containing values to fill in for cells +// in the output matrix that are not loaded from the checkpoint. Length must be +// exactly the same as the number of missing / new cells. +// num_rows: Number of rows (length of the 1st dimension) in the output matrix. +// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +// +// Returns Output matrix containing existing values loaded from the +// checkpoint, and with any missing values filled in from initializing_values. +func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} + attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayGradV3", + Type: "LoadAndRemapMatrix", Input: []tf.Input{ - handle, flow_in, + ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Creates a dataset that splits a SparseTensor into elements row-wise. -func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseTensorSliceDataset", - Input: []tf.Input{ - indices, values, dense_shape, - }, +// TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. +type TFRecordReaderV2Attr func(optionalAttr) + +// TFRecordReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TFRecordReaderV2Container(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns x / y element-wise for real types. -// -// If `x` and `y` are reals, this will return the floating-point division. +// TFRecordReaderV2SharedName sets the optional shared_name attribute to value. // -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TFRecordReaderV2SharedName(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value } - opspec := tf.OpSpec{ - Type: "RealDiv", - Input: []tf.Input{ - x, y, - }, +} + +// TFRecordReaderV2CompressionType sets the optional compression_type attribute to value. +// If not specified, defaults to "" +func TFRecordReaderV2CompressionType(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["compression_type"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Creates a dataset that concatenates `input_dataset` with `another_dataset`. -func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// A Reader that outputs the records from a TensorFlow Records file. +// +// Returns The handle to reference the Reader. +func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ConcatenateDataset", - Input: []tf.Input{ - input_dataset, another_dataset, - }, + Type: "TFRecordReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. -// -// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the -// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each -// input channel is processed independently of the others with its own structuring -// function. The `output` tensor has shape -// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output -// tensor depend on the `padding` algorithm. We currently only support the default -// "NHWC" `data_format`. -// -// In detail, the grayscale morphological 2-D dilation is the max-sum correlation -// (for consistency with `conv2d`, we use unmirrored filters): -// -// output[b, y, x, c] = -// max_{dy, dx} input[b, -// strides[1] * y + rates[1] * dy, -// strides[2] * x + rates[2] * dx, -// c] + -// filter[dy, dx, c] -// -// Max-pooling is a special case when the filter has size equal to the pooling -// kernel size and contains all zeros. -// -// Note on duality: The dilation of `input` by the `filter` is equal to the -// negation of the erosion of `-input` by the reflected `filter`. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// strides: The stride of the sliding window for each dimension of the input -// tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: The input stride for atrous morphological dilation. Must be: -// `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. +type QuantizeAndDequantizeV3Attr func(optionalAttr) + +// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// Quantizes then dequantizes a tensor. // -// Returns 4-D with shape `[batch, out_height, out_width, depth]`. -func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { +// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a +// tensor, so its value can change during training. +func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Dilation2D", + Type: "QuantizeAndDequantizeV3", Input: []tf.Input{ - input, filter, + input, input_min, input_max, num_bits, }, Attrs: attrs, } @@ -32486,87 +36327,38 @@ func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64 return op.Output(0) } -// Converts the given variant tensor to an iterator and stores it in the given resource. -// -// Arguments: -// resource_handle: A handle to an iterator resource. -// serialized: A variant tensor storing the state of the iterator contained in the -// resource. +// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. +type IdentityReaderV2Attr func(optionalAttr) + +// IdentityReaderV2Container sets the optional container attribute to value. // -// Returns the created operation. -func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DeserializeIterator", - Input: []tf.Input{ - resource_handle, serialized, - }, +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func IdentityReaderV2Container(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value } - return scope.AddOperation(opspec) } -// ResourceScatterNdSubAttr is an optional argument to ResourceScatterNdSub. -type ResourceScatterNdSubAttr func(optionalAttr) - -// ResourceScatterNdSubUseLocking sets the optional use_locking attribute to value. +// IdentityReaderV2SharedName sets the optional shared_name attribute to value. // -// value: An optional bool. Defaults to True. If True, the assignment will -// be protected by a lock; otherwise the behavior is undefined, -// but may exhibit less contention. -// If not specified, defaults to true -func ResourceScatterNdSubUseLocking(value bool) ResourceScatterNdSubAttr { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { return func(m optionalAttr) { - m["use_locking"] = value + m["shared_name"] = value } } -// Applies sparse subtraction to individual values or slices in a Variable. -// -// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `ref`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -// dimension of `ref`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] -// ``` -// -// For example, say we want to subtract 4 scattered elements from a rank-1 tensor -// with 8 elements. In Python, that subtraction would look like this: -// -// ```python -// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// sub = tf.scatter_nd_sub(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(sub) -// ``` -// -// The resulting update to ref would look like this: -// -// [1, -9, 3, -6, -4, 6, 7, -4] -// -// See `tf.scatter_nd` for more details about how to make updates to -// slices. +// A Reader that outputs the queued work as both the key and value. // -// Arguments: -// ref: A resource handle. Must be from a VarHandleOp. -// indices: A Tensor. Must be one of the following types: int32, int64. -// A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of -// values to add to ref. +// To use, enqueue strings in a Queue. ReaderRead will take the front +// work string and output (work, work). // -// Returns the created operation. -func ResourceScatterNdSub(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdSubAttr) (o *tf.Operation) { +// Returns The handle to reference the Reader. +func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } @@ -32575,245 +36367,321 @@ func ResourceScatterNdSub(scope *Scope, ref tf.Output, indices tf.Output, update a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceScatterNdSub", - Input: []tf.Input{ - ref, indices, updates, - }, + Type: "IdentityReaderV2", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2. -type TensorArrayConcatV2Attr func(optionalAttr) +// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. +type ResourceApplyGradientDescentAttr func(optionalAttr) -// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. -// If not specified, defaults to -func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr { +// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { return func(m optionalAttr) { - m["element_shape_except0"] = value + m["use_locking"] = value } } -// Deprecated. Use TensorArrayConcatV3 -func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) { +// Update '*var' by subtracting 'alpha' * 'delta' from it. +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// delta: The change. +// +// Returns the created operation. +func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayConcatV2", + Type: "ResourceApplyGradientDescent", Input: []tf.Input{ - handle, flow_in, + var_, alpha, delta, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Creates a dataset that batches and pads `batch_size` elements from the input. +// Returns the next record (key, value pair) produced by a Reader. // -// Arguments: +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). // -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// padded_shapes: A list of int64 tensors representing the desired padded shapes -// of the corresponding output components. These shapes may be partially -// specified, using `-1` to indicate that a particular dimension should be -// padded to the maximum size of all batch elements. -// padding_values: A list of scalars containing the padding value to use for -// each of the outputs. +// Arguments: +// reader_handle: Handle to a Reader. +// queue_handle: Handle to a Queue, with string work items. // -func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { +// Returns A scalar.A scalar. +func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "PaddedBatchDataset", + Type: "ReaderReadV2", Input: []tf.Input{ - input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), + reader_handle, queue_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Creates a dataset that shuffles and repeats elements from `input_dataset` +// Returns up to `num_records` (key, value) pairs produced by a Reader. // -// pseudorandomly. +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// It may return less than `num_records` even before the last batch. // // Arguments: +// reader_handle: Handle to a `Reader`. +// queue_handle: Handle to a `Queue`, with string work items. +// num_records: number of records to read from `Reader`. // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// count: A scalar representing the number of times the underlying dataset -// should be repeated. The default is `-1`, which results in infinite repetition. -// -// -func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns A 1-D tensor.A 1-D tensor. +func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output, num_records tf.Output) (keys tf.Output, values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ShuffleAndRepeatDataset", + Type: "ReaderReadUpToV2", Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, count, + reader_handle, queue_handle, num_records, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Creates a dataset that caches elements from `input_dataset`. +// Adds v into specified rows of x. // -// A CacheDataset will iterate over the input_dataset, and store tensors. If the -// cache already exists, the cache will be used. If the cache is inappropriate -// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error -// will the returned when used. +// Computes y = x; y[i, :] += v; return y. // // Arguments: +// x: A `Tensor` of type T. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. // -// filename: A path on the filesystem where we should cache the dataset. Note: this -// will be a directory. -// -// -func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceAdd(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "CacheDataset", + Type: "InplaceAdd", Input: []tf.Input{ - input_dataset, filename, + x, i, v, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that emits the records from one or more binary files. +// Restore a Reader to its initial clean state. // // Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// header_bytes: A scalar representing the number of bytes to skip at the -// beginning of a file. -// record_bytes: A scalar representing the number of bytes in each record. -// footer_bytes: A scalar representing the number of bytes to skip at the end -// of a file. -// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. -func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { +// reader_handle: Handle to a Reader. +// +// Returns the created operation. +func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FixedLengthRecordDataset", + Type: "ReaderResetV2", Input: []tf.Input{ - filenames, header_bytes, record_bytes, footer_bytes, buffer_size, + reader_handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Gradients for batch normalization. +// BatchAttr is an optional argument to Batch. +type BatchAttr func(optionalAttr) + +// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. +// If not specified, defaults to 10 +func BatchMaxEnqueuedBatches(value int64) BatchAttr { + return func(m optionalAttr) { + m["max_enqueued_batches"] = value + } +} + +// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. +// If not specified, defaults to <> +func BatchAllowedBatchSizes(value []int64) BatchAttr { + return func(m optionalAttr) { + m["allowed_batch_sizes"] = value + } +} + +// BatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func BatchContainer(value string) BatchAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// BatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func BatchSharedName(value string) BatchAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// BatchBatchingQueue sets the optional batching_queue attribute to value. +// If not specified, defaults to "" +func BatchBatchingQueue(value string) BatchAttr { + return func(m optionalAttr) { + m["batching_queue"] = value + } +} + +// Batches all input tensors nondeterministically. // -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// When many instances of this Op are being run concurrently with the same +// container/shared_name in the same device, some will output zero-shaped Tensors +// and others will output Tensors of size up to max_batch_size. // -// This op is deprecated. See `tf.nn.batch_normalization`. +// All Tensors in in_tensors are batched together (so, for example, labels and +// features should be batched with a single instance of this operation. // -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. +// Each invocation of batch emits an `id` scalar which will be used to identify +// this particular invocation when doing unbatch or its gradient. // -// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { +// Each op which emits a non-empty batch will also emit a non-empty batch_index +// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, +// start, and length of elements of each set of Tensors present in batched_tensors. +// +// Batched tensors are concatenated along the first dimension, and all tensors in +// in_tensors must have the first dimension of the same size. +// +// in_tensors: The tensors to be batched. +// num_batch_threads: Number of scheduling threads for processing batches of work. +// Determines the number of batches processed in parallel. +// max_batch_size: Batch sizes will never be bigger than this. +// batch_timeout_micros: Maximum number of microseconds to wait before outputting +// an incomplete batch. +// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does +// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad +// batches up to one of those sizes. The entries must increase monotonically, and +// the final entry must equal max_batch_size. +// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. +// batched_tensors: Either empty tensors or a batch of concatenated Tensors. +// batch_index: If out_tensors is non-empty, has information to invert it. +// container: Controls the scope of sharing of this batch. +// id: always contains a scalar with a unique ID for this invocation of Batch. +// shared_name: Concurrently running instances of batch in the same device with the +// same container and shared_name will batch their elements together. If left +// empty, the op name will be used as the shared name. +// T: the types of tensors to be batched. +func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", + Type: "Batch", Input: []tf.Input{ - t, m, v, gamma, backprop, + tf.OutputList(in_tensors), }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + if scope.Err() != nil { + return + } + var idx int + var err error + if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { + scope.UpdateErr("Batch", err) + return + } + batch_index = op.Output(idx) + id = op.Output(idx) + return batched_tensors, batch_index, id } -// Creates a dataset that emits the records from one or more TFRecord files. +// Adjust the hue of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A delta is then applied all the hue values, +// and then remapped back to RGB colorspace. // // Arguments: -// filenames: A scalar or vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar representing the number of bytes to buffer. A value of -// 0 means no buffering will be performed. -func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { +// images: Images to adjust. At least 3-D. +// delta: A float delta to add to the hue. +// +// Returns The hue-adjusted image or images. +func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TFRecordDataset", + Type: "AdjustHue", Input: []tf.Input{ - filenames, compression_type, buffer_size, + images, delta, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ExperimentalStatsAggregatorHandleAttr is an optional argument to ExperimentalStatsAggregatorHandle. -type ExperimentalStatsAggregatorHandleAttr func(optionalAttr) +// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. +type ResizeBicubicGradAttr func(optionalAttr) -// ExperimentalStatsAggregatorHandleContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func ExperimentalStatsAggregatorHandleContainer(value string) ExperimentalStatsAggregatorHandleAttr { +// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { return func(m optionalAttr) { - m["container"] = value + m["align_corners"] = value } } -// ExperimentalStatsAggregatorHandleSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func ExperimentalStatsAggregatorHandleSharedName(value string) ExperimentalStatsAggregatorHandleAttr { +// ResizeBicubicGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBicubicGradHalfPixelCenters(value bool) ResizeBicubicGradAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["half_pixel_centers"] = value } } -// Creates a statistics manager resource. -func ExperimentalStatsAggregatorHandle(scope *Scope, optional ...ExperimentalStatsAggregatorHandleAttr) (handle tf.Output) { +// Computes the gradient of bicubic interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -32822,132 +36690,59 @@ func ExperimentalStatsAggregatorHandle(scope *Scope, optional ...ExperimentalSta a(attrs) } opspec := tf.OpSpec{ - Type: "ExperimentalStatsAggregatorHandle", - + Type: "ResizeBicubicGrad", + Input: []tf.Input{ + grads, original_image, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// A container for an iterator resource. +// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. +type ResizeNearestNeighborAttr func(optionalAttr) + +// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. // -// Returns A handle to the iterator that can be passed to a "MakeIterator" or -// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents -// resource sharing by name, and does not keep a reference to the resource -// container. -func AnonymousIterator(scope *Scope, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["align_corners"] = value } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "AnonymousIterator", +} - Attrs: attrs, +// ResizeNearestNeighborHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeNearestNeighborHalfPixelCenters(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// BatchToSpace for 4-D tensors of type T. -// -// This is a legacy version of the more general BatchToSpaceND. -// -// Rearranges (permutes) data from batch into blocks of spatial data, followed by -// cropping. This is the reverse transformation of SpaceToBatch. More specifically, -// this op outputs a copy of the input tensor where values from the `batch` -// dimension are moved in spatial blocks to the `height` and `width` dimensions, -// followed by cropping along the `height` and `width` dimensions. +// Resize `images` to `size` using nearest neighbor interpolation. // // Arguments: -// input: 4-D tensor with shape -// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, -// depth]`. Note that the batch size of the input tensor must be divisible by -// `block_size * block_size`. -// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies -// how many elements to crop from the intermediate result across the spatial -// dimensions as follows: -// -// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] -// -// -// Returns 4-D with shape `[batch, height, width, depth]`, where: -// -// height = height_pad - crop_top - crop_bottom -// width = width_pad - crop_left - crop_right -// -// The attr `block_size` must be greater than one. It indicates the block size. -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 3]` and value: -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` -// -// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// The output tensor has shape `[1, 4, 4, 1]` and value: -// -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` -// -// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], -// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] -// ``` -// -// The output tensor has shape `[2, 2, 4, 1]` and value: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// ``` -// x = [[[[1], [3]], [[5], [7]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BatchToSpace", + Type: "ResizeNearestNeighbor", Input: []tf.Input{ - input, crops, + images, size, }, Attrs: attrs, } @@ -32955,186 +36750,141 @@ func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int return op.Output(0) } -// Produces a summary of any statistics recorded by the given statistics manager. -func ExperimentalStatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExperimentalStatsAggregatorSummary", - Input: []tf.Input{ - iterator, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. +type ResizeNearestNeighborGradAttr func(optionalAttr) -// Makes a new iterator from the given `dataset` and stores it in `iterator`. -// -// This operation may be executed multiple times. Each execution will reset the -// iterator in `iterator` to the first element of `dataset`. +// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. // -// Returns the created operation. -func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value } - opspec := tf.OpSpec{ - Type: "MakeIterator", - Input: []tf.Input{ - dataset, iterator, - }, +} + +// ResizeNearestNeighborGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeNearestNeighborGradHalfPixelCenters(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value } - return scope.AddOperation(opspec) } -// Adjust the contrast of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are -// interpreted as `[height, width, channels]`. The other dimensions only -// represent a collection of images, such as `[batch, height, width, channels].` -// -// Contrast is adjusted independently for each channel of each image. -// -// For each channel, the Op first computes the mean of the image pixels in the -// channel and then adjusts each component of each pixel to -// `(x - mean) * contrast_factor + mean`. +// Computes the gradient of nearest neighbor interpolation. // // Arguments: -// images: Images to adjust. At least 3-D. -// contrast_factor: A float multiplier for adjusting contrast. +// grads: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The +// original input size. // -// Returns The contrast-adjusted image or images. -func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients +// with respect to the input image. +func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "AdjustContrastv2", - Input: []tf.Input{ - images, contrast_factor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gets the next output from the given iterator . -func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IteratorGetNext", + Type: "ResizeNearestNeighborGrad", Input: []tf.Input{ - iterator, + grads, size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNext", err) - return - } - return components + return op.Output(0) } -// Outputs the single element from the given dataset. -// -// Arguments: -// dataset: A handle to a dataset that contains a single element. -// -// +// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. +type ExtractJpegShapeAttr func(optionalAttr) + +// ExtractJpegShapeOutputType sets the optional output_type attribute to value. // -// Returns The components of the single element of `input`. -func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "DatasetToSingleElement", - Input: []tf.Input{ - dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("DatasetToSingleElement", err) - return +// value: (Optional) The output type of the operation (int32 or int64). +// Defaults to int32. +// If not specified, defaults to DT_INT32 +func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { + return func(m optionalAttr) { + m["output_type"] = value } - return components } -// Converts the given `resource_handle` representing an iterator to a string. +// Extract the shape information of a JPEG-encoded image. +// +// This op only parses the image header, so it is much faster than DecodeJpeg. // // Arguments: -// resource_handle: A handle to an iterator resource. +// contents: 0-D. The JPEG-encoded image. // -// Returns A string representation of the given handle. -func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { +// Returns 1-D. The image shape with format [height, width, channels]. +func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IteratorToStringHandle", + Type: "ExtractJpegShape", Input: []tf.Input{ - resource_handle, + contents, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. -type IteratorFromStringHandleAttr func(optionalAttr) +// DecodePngAttr is an optional argument to DecodePng. +type DecodePngAttr func(optionalAttr) -// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. -// -// value: If specified, defines the type of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> +// DecodePngChannels sets the optional channels attribute to value. // -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodePngChannels(value int64) DecodePngAttr { return func(m optionalAttr) { - m["output_types"] = value + m["channels"] = value } } -// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. -// -// value: If specified, defines the shape of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { +// DecodePngDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_UINT8 +func DecodePngDtype(value tf.DataType) DecodePngAttr { return func(m optionalAttr) { - m["output_shapes"] = value + m["dtype"] = value } } -// Converts the given string representing a handle to an iterator to a resource. +// Decode a PNG-encoded image to a uint8 or uint16 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the PNG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. +// +// This op also supports decoding JPEGs and non-animated GIFs since the interface +// is the same, though it is cleaner to use `tf.image.decode_image`. // // Arguments: -// string_handle: A string representation of the given handle. +// contents: 0-D. The PNG-encoded image. // -// Returns A handle to an iterator resource. -func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { if scope.Err() != nil { return } @@ -33143,9 +36893,9 @@ func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional .. a(attrs) } opspec := tf.OpSpec{ - Type: "IteratorFromStringHandle", + Type: "DecodePng", Input: []tf.Input{ - string_handle, + contents, }, Attrs: attrs, } @@ -33153,84 +36903,111 @@ func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional .. return op.Output(0) } -// Gather slices from `params` axis `axis` according to `indices`. -// -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `params.shape[:axis] + indices.shape + -// params.shape[axis + 1:]` where: -// -// ```python -// # Scalar indices (output is rank(params) - 1). -// output[a_0, ..., a_n, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices, b_0, ..., b_n] -// -// # Vector indices (output is rank(params)). -// output[a_0, ..., a_n, i, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] -// -// # Higher rank indices (output is rank(params) + rank(indices) - 1). -// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = -// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] -// ``` +// Decode the first frame of a GIF-encoded image to a uint8 tensor. // -//
-// -//
+// GIF with frame or transparency compression are not supported +// convert animated GIF from compressed to uncompressed by: // -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, a 0 is stored in the -// corresponding output value. +// convert $src.gif -coalesce $dst.gif // -// See also `tf.batch_gather` and `tf.gather_nd`. +// This op also supports decoding JPEGs and PNGs, though it is cleaner to use +// `tf.image.decode_image`. // // Arguments: -// params: The tensor from which to gather values. Must be at least rank -// `axis + 1`. -// indices: Index tensor. Must be in range `[0, params.shape[axis])`. -// axis: The axis in `params` to gather `indices` from. Defaults to the first -// dimension. Supports negative indexes. +// contents: 0-D. The GIF-encoded image. // -// Returns Values from `params` gathered from indices given by `indices`, with -// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. -func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GatherV2", + Type: "DecodeGif", Input: []tf.Input{ - params, indices, axis, + contents, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts the given `resource_handle` representing an iterator to a variant tensor. +// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. +type LearnedUnigramCandidateSamplerAttr func(optionalAttr) + +// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. +// +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// resource_handle: A handle to an iterator resource. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). // -// Returns A variant tensor storing the state of the iterator contained in the -// resource. -func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SerializeIterator", + Type: "LearnedUnigramCandidateSampler", Input: []tf.Input{ - resource_handle, + true_classes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. -type FIFOQueueV2Attr func(optionalAttr) +// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. +type RandomShuffleQueueV2Attr func(optionalAttr) -// FIFOQueueV2Shapes sets the optional shapes attribute to value. +// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. // // value: The shape of each component in a value. The length of this attr must // be either 0 or the same as the length of component_types. If the length of @@ -33239,52 +37016,85 @@ type FIFOQueueV2Attr func(optionalAttr) // If not specified, defaults to <> // // REQUIRES: len(value) >= 0 -func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { +func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["shapes"] = value } } -// FIFOQueueV2Capacity sets the optional capacity attribute to value. +// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. // // value: The upper bound on the number of elements in this queue. // Negative numbers mean no limit. // If not specified, defaults to -1 -func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { +func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["capacity"] = value } } -// FIFOQueueV2Container sets the optional container attribute to value. +// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// +// value: Dequeue will block unless there would be this +// many elements after the dequeue or the queue is closed. This +// ensures a minimum level of mixing of elements. +// If not specified, defaults to 0 +func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["min_after_dequeue"] = value + } +} + +// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// RandomShuffleQueueV2Container sets the optional container attribute to value. // // value: If non-empty, this queue is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func FIFOQueueV2Container(value string) FIFOQueueV2Attr { +func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// FIFOQueueV2SharedName sets the optional shared_name attribute to value. +// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. // // value: If non-empty, this queue will be shared under the given name // across multiple sessions. // If not specified, defaults to "" -func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { +func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// A queue that produces elements in first-in first-out order. +// A queue that randomizes the order of elements. // // Arguments: // component_types: The type of each component in a value. // // Returns The handle to the queue. -func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { +func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } @@ -33293,7 +37103,7 @@ func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQu a(attrs) } opspec := tf.OpSpec{ - Type: "FIFOQueueV2", + Type: "RandomShuffleQueueV2", Attrs: attrs, } @@ -33301,230 +37111,38 @@ func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQu return op.Output(0) } -// Deserializes a proto into the tree handle -// -// Arguments: -// tree_handle: Handle to the tree resource to be restored. -// tree_config: Serialied proto string of the boosted_trees.Tree proto. -// -// Returns the created operation. -func TensorForestTreeDeserialize(scope *Scope, tree_handle tf.Output, tree_config tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorForestTreeDeserialize", - Input: []tf.Input{ - tree_handle, tree_config, - }, - } - return scope.AddOperation(opspec) -} - -// Constructs an Optional variant from a tuple of tensors. -func OptionalFromValue(scope *Scope, components []tf.Output) (optional tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OptionalFromValue", - Input: []tf.Input{ - tf.OutputList(components), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodeProtoV2Attr is an optional argument to DecodeProtoV2. -type DecodeProtoV2Attr func(optionalAttr) - -// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value. -// -// value: Either the special value `local://` or a path to a file containing -// a serialized `FileDescriptorSet`. -// If not specified, defaults to "local://" -func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["descriptor_source"] = value - } -} - -// DecodeProtoV2MessageFormat sets the optional message_format attribute to value. -// -// value: Either `binary` or `text`. -// If not specified, defaults to "binary" -func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["message_format"] = value - } -} +// SerializeSparseAttr is an optional argument to SerializeSparse. +type SerializeSparseAttr func(optionalAttr) -// DecodeProtoV2Sanitize sets the optional sanitize attribute to value. +// SerializeSparseOutType sets the optional out_type attribute to value. // -// value: Whether to sanitize the result or not. -// If not specified, defaults to false -func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr { +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { return func(m optionalAttr) { - m["sanitize"] = value + m["out_type"] = value } } -// The op extracts fields from a serialized protocol buffers message into tensors. -// -// The `decode_proto` op extracts fields from a serialized protocol buffers -// message into tensors. The fields in `field_names` are decoded and converted -// to the corresponding `output_types` if possible. -// -// A `message_type` name must be provided to give context for the field -// names. The actual message descriptor can be looked up either in the -// linked-in descriptor pool or a filename provided by the caller using -// the `descriptor_source` attribute. -// -// Each output tensor is a dense tensor. This means that it is padded to -// hold the largest number of repeated elements seen in the input -// minibatch. (The shape is also padded by one to prevent zero-sized -// dimensions). The actual repeat counts for each example in the -// minibatch can be found in the `sizes` output. In many cases the output -// of `decode_proto` is fed immediately into tf.squeeze if missing values -// are not a concern. When using tf.squeeze, always pass the squeeze -// dimension explicitly to avoid surprises. -// -// For the most part, the mapping between Proto field types and -// TensorFlow dtypes is straightforward. However, there are a few -// special cases: -// -// - A proto field that contains a submessage or group can only be converted -// to `DT_STRING` (the serialized submessage). This is to reduce the -// complexity of the API. The resulting string can be used as input -// to another instance of the decode_proto op. -// -// - TensorFlow lacks support for unsigned integers. The ops represent uint64 -// types as a `DT_INT64` with the same twos-complement bit pattern -// (the obvious way). Unsigned int32 values can be represented exactly by -// specifying type `DT_INT64`, or using twos-complement if the caller -// specifies `DT_INT32` in the `output_types` attribute. -// -// The `descriptor_source` attribute selects a source of protocol -// descriptors to consult when looking up `message_type`. This may be a -// filename containing a serialized `FileDescriptorSet` message, -// or the special value `local://`, in which case only descriptors linked -// into the code will be searched; the filename can be on any filesystem -// accessible to TensorFlow. -// -// You can build a `descriptor_source` file using the `--descriptor_set_out` -// and `--include_imports` options to the protocol compiler `protoc`. -// -// The `local://` database only covers descriptors linked into the -// code via C++ libraries, not Python imports. You can link in a proto descriptor -// by creating a cc_library target with alwayslink=1. -// -// Both binary and text proto serializations are supported, and can be -// chosen using the `format` attribute. +// Serialize a `SparseTensor` into a `[3]` `Tensor` object. // // Arguments: -// bytes: Tensor of serialized protos with shape `batch_shape`. -// message_type: Name of the proto message type to decode. -// field_names: List of strings containing proto field names. -// output_types: List of TF types to use for the respective field in field_names. -// -// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`. -// Each entry is the number of values found for the corresponding field. -// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field. -// `values[i]` has datatype `output_types[i]` -// and shape `[batch_shape, max(sizes[...,i])]`. -func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) { +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeProtoV2", - Input: []tf.Input{ - bytes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - sizes = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("DecodeProtoV2", err) - return - } - return sizes, values -} - -// Creates an Optional variant with no value. -func OptionalNone(scope *Scope) (optional tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OptionalNone", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns true if and only if the given Optional variant has a value. -func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OptionalHasValue", - Input: []tf.Input{ - optional, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the value stored in an Optional variant or raises an error if none exists. -func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "OptionalGetValue", - Input: []tf.Input{ - optional, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("OptionalGetValue", err) - return - } - return components -} - -// Gets the next output from the given iterator as an Optional variant. -func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (optional tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "IteratorGetNextAsOptional", + Type: "SerializeSparse", Input: []tf.Input{ - iterator, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } @@ -33532,482 +37150,361 @@ func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types [] return op.Output(0) } -// Fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform over the inner-most -// dimension of `input`. -// -// Arguments: -// input: A complex tensor. +// Draw bounding boxes on a batch of images. // -// Returns A complex tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its 1D Fourier transform. +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// @compatibility(numpy) -// Equivalent to np.fft.fft -// @end_compatibility -func FFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Identity transformation that models performance. +// For example, if an image is 100 x 200 pixels (height x width) and the bounding +// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of +// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates). // -// Identity transformation that models performance. +// Parts of the bounding box may fall outside the image. // // Arguments: -// input_dataset: A variant tensor representing the input dataset. -// +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. // -func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ModelDataset", + Type: "DrawBoundingBoxes", Input: []tf.Input{ - input_dataset, + images, boxes, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of (x > y) element-wise. -// -// *NOTE*: `Greater` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Greater", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// SampleDistortedBoundingBoxV2Attr is an optional argument to SampleDistortedBoundingBoxV2. +type SampleDistortedBoundingBoxV2Attr func(optionalAttr) -// Performs a padding as a preprocess during a convolution. -// -// Similar to FusedResizeAndPadConv2d, this op allows for an optimized -// implementation where the spatial padding transformation stage is fused with the -// im2col lookup, but in this case without the bilinear filtering required for -// resizing. Fusing the padding prevents the need to write out the intermediate -// results as whole tensors, reducing memory pressure, and we can get some latency -// gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' -// order is used instead. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. +// SampleDistortedBoundingBoxV2Seed sets the optional seed attribute to value. // -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "FusedPadConv2D", - Input: []tf.Input{ - input, paddings, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. -type Conv2DBackpropInputAttr func(optionalAttr) - -// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxV2Seed(value int64) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value + m["seed"] = value } } -// Conv2DBackpropInputExplicitPaddings sets the optional explicit_paddings attribute to value. +// SampleDistortedBoundingBoxV2Seed2 sets the optional seed2 attribute to value. // -// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith -// dimension, the amount of padding inserted before and after the dimension is -// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If -// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. -// If not specified, defaults to <> -func Conv2DBackpropInputExplicitPaddings(value []int64) Conv2DBackpropInputAttr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { - m["explicit_paddings"] = value + m["seed2"] = value } } -// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. +// SampleDistortedBoundingBoxV2AspectRatioRange sets the optional aspect_ratio_range attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["aspect_ratio_range"] = value } } -// Conv2DBackpropInputDilations sets the optional dilations attribute to value. +// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { +// value: The cropped area of the image must contain a fraction of the +// supplied image within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["area_range"] = value } } -// Computes the gradients of convolution with respect to the input. -// -// Arguments: -// input_sizes: An integer vector representing the shape of `input`, -// where `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. -// padding: The type of padding algorithm to use. +// SampleDistortedBoundingBoxV2MaxAttempts sets the optional max_attempts attribute to value. // -// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient -// w.r.t. the input of the convolution. -func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv2DBackpropInput", - Input: []tf.Input{ - input_sizes, filter, out_backprop, - }, - Attrs: attrs, +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxV2MaxAttempts(value int64) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["max_attempts"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` +// SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. // -// For example, if each `indices[m]` is scalar or vector, we have +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value + } +} + +// Generate a single randomly distorted bounding box for an image. // -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. // -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. // -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// merged.shape = [max(indices)] + constant +// For example, // -// Values are merged in order, so if an index appears in both `indices[m][i]` and -// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the -// merged result. If you do not need this guarantee, ParallelDynamicStitch might -// perform better on some devices. +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) // -// For example: +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) // -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) // ``` // -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. // -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` +// Arguments: +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. +// min_object_covered: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. // -//
-// -//
-func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, min_object_covered tf.Output, optional ...SampleDistortedBoundingBoxV2Attr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DynamicStitch", + Type: "SampleDistortedBoundingBoxV2", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + image_size, bounding_boxes, min_object_covered, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Returns the truth value of (x == y) element-wise. +// Computes requantization range per channel. // -// *NOTE*: `Equal` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// input: The original input tensor. +// input_min: The minimum value of the input tensor +// input_max: The maximum value of the input tensor. +// clip_value_max: The maximum value of the output that needs to be clipped. +// Example: set this to 6 for Relu6. +// +// Returns The minimum value of the final output tensorThe maximum value of the final output tensor. +func RequantizationRangePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, clip_value_max float32) (output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"clip_value_max": clip_value_max} opspec := tf.OpSpec{ - Type: "Equal", + Type: "RequantizationRangePerChannel", Input: []tf.Input{ - x, y, + input, input_min, input_max, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. -type TensorArrayGatherV2Attr func(optionalAttr) +// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. +type ExtractGlimpseAttr func(optionalAttr) -// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { +// ExtractGlimpseCentered sets the optional centered attribute to value. +// +// value: indicates if the offset coordinates are centered relative to +// the image, in which case the (0, 0) offset is relative to the center +// of the input images. If false, the (0,0) offset corresponds to the +// upper left corner of the input images. +// If not specified, defaults to true +func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { return func(m optionalAttr) { - m["element_shape"] = value + m["centered"] = value } } -// Deprecated. Use TensorArrayGatherV3 +// ExtractGlimpseNormalized sets the optional normalized attribute to value. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 -func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayGatherV2", - Input: []tf.Input{ - handle, indices, flow_in, - }, - Attrs: attrs, +// value: indicates if the offset coordinates are normalized. +// If not specified, defaults to true +func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["normalized"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` -// -// For example, if each `indices[m]` is scalar or vector, we have -// -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] -// -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` +// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. // -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is +// value: indicates if the noise should be generated using a +// uniform distribution or a Gaussian distribution. +// If not specified, defaults to true +func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["uniform_noise"] = value + } +} + +// ExtractGlimpseNoise sets the optional noise attribute to value. // -// merged.shape = [max(indices)] + constant +// value: indicates if the noise should `uniform`, `gaussian`, or +// `zero`. The default is `uniform` which means the the noise type +// will be decided by `uniform_noise`. +// If not specified, defaults to "uniform" +func ExtractGlimpseNoise(value string) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["noise"] = value + } +} + +// Extracts a glimpse from the input tensor. // -// Values may be merged in parallel, so if an index appears in both `indices[m][i]` -// and `indices[n][j]`, the result may be invalid. This differs from the normal -// DynamicStitch operator that defines the behavior in that case. +// Returns a set of windows called glimpses extracted at location +// `offsets` from the input tensor. If the windows only partially +// overlaps the inputs, the non overlapping areas will be filled with +// random noise. // -// For example: +// The result is a 4-D tensor of shape `[batch_size, glimpse_height, +// glimpse_width, channels]`. The channels and batch dimensions are the +// same as that of the input tensor. The height and width of the output +// windows are specified in the `size` parameter. // -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` +// The argument `normalized` and `centered` controls how the windows are built: // -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: +// * If the coordinates are normalized but not centered, 0.0 and 1.0 +// correspond to the minimum and maximum of each height and width +// dimension. +// * If the coordinates are both normalized and centered, they range from +// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper +// left corner, the lower right corner is located at (1.0, 1.0) and the +// center is at (0, 0). +// * If the coordinates are not normalized they are interpreted as +// numbers of pixels. // -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` +// Arguments: +// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. +// size: A 1-D tensor of 2 elements containing the size of the glimpses +// to extract. The glimpse height must be specified first, following +// by the glimpse width. +// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing +// the y, x locations of the center of each window. // -//
-// -//
-func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Returns A tensor representing the glimpses `[batch_size, +// glimpse_height, glimpse_width, channels]`. +func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ParallelDynamicStitch", + Type: "ExtractGlimpse", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + input, size, offsets, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. -type PriorityQueueV2Attr func(optionalAttr) - -// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. -// -// value: The type of each component in a value. -// If not specified, defaults to <> +// A container for an iterator resource. // -// REQUIRES: len(value) >= 0 -func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["component_types"] = value +// Returns A handle to the iterator that can be passed to a "MakeIterator" +// or "IteratorGetNext" op. +func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return } -} + attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "Iterator", -// PriorityQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// PriorityQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. +// TensorForestTreeResourceHandleOpAttr is an optional argument to TensorForestTreeResourceHandleOp. +type TensorForestTreeResourceHandleOpAttr func(optionalAttr) + +// TensorForestTreeResourceHandleOpContainer sets the optional container attribute to value. // If not specified, defaults to "" -func PriorityQueueV2Container(value string) PriorityQueueV2Attr { +func TensorForestTreeResourceHandleOpContainer(value string) TensorForestTreeResourceHandleOpAttr { return func(m optionalAttr) { m["container"] = value } } -// PriorityQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. +// TensorForestTreeResourceHandleOpSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { +func TensorForestTreeResourceHandleOpSharedName(value string) TensorForestTreeResourceHandleOpAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// A queue that produces elements sorted by the first component value. -// -// Note that the PriorityQueue requires the first component of any element -// to be a scalar int64, in addition to the other elements declared by -// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue -// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra -// entry in their input (resp. output) lists. -// -// Arguments: -// shapes: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// -// Returns The handle to the queue. -func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { +// Creates a handle to a TensorForestTreeResource +func TensorForestTreeResourceHandleOp(scope *Scope, optional ...TensorForestTreeResourceHandleOpAttr) (resource tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shapes": shapes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "PriorityQueueV2", + Type: "TensorForestTreeResourceHandleOp", Attrs: attrs, } @@ -34015,164 +37512,173 @@ func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV return op.Output(0) } -// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. -type QueueEnqueueV2Attr func(optionalAttr) +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) -// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. +// CropAndResizeGradImageMethod sets the optional method attribute to value. // -// value: If the queue is full, this operation will block for up to -// timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["method"] = value } } -// Enqueues a tuple of one or more tensors in the given queue. -// -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. -// -// N.B. If the queue is full, this operation will block until the given -// element has been enqueued (or 'timeout_ms' elapses, if specified). +// Computes the gradient of the crop_and_resize op wrt the input image tensor. // // Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should be taken. +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. // -// Returns the created operation. -func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"T": T} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueEnqueueV2", + Type: "CropAndResizeGradImage", Input: []tf.Input{ - handle, tf.OutputList(components), + grads, boxes, box_ind, image_size, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Computes the Bessel i0e function of `x` element-wise. -// -// Exponentially scaled modified Bessel function of order 0 defined as -// `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. -// -// This function is faster and numerically stabler than `bessel_i0(x)`. -func BesselI0e(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BesselI0e", - Input: []tf.Input{ - x, - }, - } op := scope.AddOperation(opspec) return op.Output(0) } -// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. -type QueueDequeueManyV2Attr func(optionalAttr) +// ShuffleDatasetAttr is an optional argument to ShuffleDataset. +type ShuffleDatasetAttr func(optionalAttr) -// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. // -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { +// value: If true, each iterator over this dataset will be given +// a different pseudorandomly generated seed, based on a sequence seeded by the +// `seed` and `seed2` inputs. If false, each iterator will be given the same +// seed, and repeated iteration over this dataset will yield the exact same +// sequence of results. +// If not specified, defaults to true +func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["reshuffle_each_iteration"] = value } } -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// If the queue is closed and there are fewer than `n` elements, then an -// OutOfRange error is returned. -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size `n` in the 0th dimension. +// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. // -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. +// Arguments: // -// N.B. If the queue is empty, this operation will block until `n` elements -// have been dequeued (or 'timeout_ms' elapses, if specified). +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. // -// Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { +func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueDequeueManyV2", + Type: "ShuffleDataset", Input: []tf.Input{ - handle, n, + input_dataset, buffer_size, seed, seed2, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 3D fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 +// dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fftn with 3 dimensions. +// @end_compatibility +func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueManyV2", err) - return + opspec := tf.OpSpec{ + Type: "FFT3D", + Input: []tf.Input{ + input, + }, } - return components + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeBase64Attr is an optional argument to EncodeBase64. -type EncodeBase64Attr func(optionalAttr) +// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. +type CropAndResizeGradBoxesAttr func(optionalAttr) -// EncodeBase64Pad sets the optional pad attribute to value. +// CropAndResizeGradBoxesMethod sets the optional method attribute to value. // -// value: Bool whether padding is applied at the ends. -// If not specified, defaults to false -func EncodeBase64Pad(value bool) EncodeBase64Attr { +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { return func(m optionalAttr) { - m["pad"] = value + m["method"] = value } } -// Encode strings into web-safe base64 format. -// -// Refer to the following article for more information on base64 format: -// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the -// end so that the encoded has length multiple of 4. See Padding section of the -// link above. -// -// Web-safe means that the encoder uses - and _ instead of + and /. +// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. // // Arguments: -// input: Strings to be encoded. +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. // -// Returns Input strings encoded in base64. -func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { +// Returns A 2-D tensor of shape `[num_boxes, 4]`. +func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -34181,9 +37687,9 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) ( a(attrs) } opspec := tf.OpSpec{ - Type: "EncodeBase64", + Type: "CropAndResizeGradBoxes", Input: []tf.Input{ - input, + grads, image, boxes, box_ind, }, Attrs: attrs, } @@ -34191,128 +37697,198 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) ( return op.Output(0) } -// A dataset that creates window datasets from the input dataset. -// -// Arguments: +// Greedily selects a subset of bounding boxes in descending order of score, // -// size: A scalar representing the number of elements to accumulate in a window. -// shift: A scalar representing the steps moving the sliding window forward in one -// iteration. It must be positive. -// stride: A scalar representing the stride of the input elements of the sliding window. -// It must be positive. -// drop_remainder: A scalar representing whether a window should be dropped in case its size is -// smaller than desired. +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system and more +// generally is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. // -func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV3(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "WindowDataset", + Type: "NonMaxSuppressionV3", Input: []tf.Input{ - input_dataset, size, shift, stride, drop_remainder, + boxes, scores, max_output_size, iou_threshold, score_threshold, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Use TensorArrayCloseV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// NonMaxSuppressionV4Attr is an optional argument to NonMaxSuppressionV4. +type NonMaxSuppressionV4Attr func(optionalAttr) + +// NonMaxSuppressionV4PadToMaxOutputSize sets the optional pad_to_max_output_size attribute to value. // -// Returns the created operation. -func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayCloseV2", - Input: []tf.Input{ - handle, - }, +// value: If true, the output `selected_indices` is padded to be of length +// `max_output_size`. Defaults to false. +// If not specified, defaults to false +func NonMaxSuppressionV4PadToMaxOutputSize(value bool) NonMaxSuppressionV4Attr { + return func(m optionalAttr) { + m["pad_to_max_output_size"] = value } - return scope.AddOperation(opspec) } -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. +// Greedily selects a subset of bounding boxes in descending order of score, // -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system and more +// generally is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: -// inputs: The input tensors, exactly one of which will become available. +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. // -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`.A 0-D integer tensor representing the number of valid elements in +// `selected_indices`, with the valid elements appearing first. +func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...NonMaxSuppressionV4Attr) (selected_indices tf.Output, valid_outputs tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Merge", + Type: "NonMaxSuppressionV4", Input: []tf.Input{ - tf.OutputList(inputs), + boxes, scores, max_output_size, iou_threshold, score_threshold, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0), op.Output(1) } -// Writes the given dataset to the given file using the TFRecord format. +// Removes keys and its associated values from a table. +// +// The tensor `keys` must of the same type as the keys of the table. Keys not +// already in the table are silently ignored. // // Arguments: -// input_dataset: A variant tensor representing the dataset to write. -// filename: A scalar string tensor representing the filename to use. -// compression_type: A scalar string tensor containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". +// table_handle: Handle to the table. +// keys: Any shape. Keys of the elements to remove. // // Returns the created operation. -func ExperimentalDatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) { +func LookupTableRemoveV2(scope *Scope, table_handle tf.Output, keys tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ExperimentalDatasetToTFRecord", + Type: "LookupTableRemoveV2", Input: []tf.Input{ - input_dataset, filename, compression_type, + table_handle, keys, }, } return scope.AddOperation(opspec) } -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) +// CombinedNonMaxSuppressionAttr is an optional argument to CombinedNonMaxSuppression. +type CombinedNonMaxSuppressionAttr func(optionalAttr) -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. +// CombinedNonMaxSuppressionPadPerClass sets the optional pad_per_class attribute to value. // -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. +// value: If false, the output nmsed boxes, scores and classes +// are padded/clipped to `max_total_size`. If true, the +// output nmsed boxes, scores and classes are padded to be of length +// `max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in +// which case it is clipped to `max_total_size`. Defaults to false. // If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { +func CombinedNonMaxSuppressionPadPerClass(value bool) CombinedNonMaxSuppressionAttr { return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value + m["pad_per_class"] = value } } -// Closes the given queue. -// -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. +// Greedily selects a subset of bounding boxes in descending order of score, // -// Arguments: -// handle: The handle to a queue. +// This operation performs non_max_suppression on the inputs per batch, across +// all classes. +// Prunes away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Also note that +// this algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is the final boxes, scores and classes tensor +// returned after performing non_max_suppression. +// +// Arguments: +// boxes: A 4-D float tensor of shape `[batch_size, num_boxes, q, 4]`. If `q` is 1 then +// same boxes are used for all classes otherwise, if `q` is equal to number of +// classes, class-specific boxes are used. +// scores: A 3-D float tensor of shape `[batch_size, num_boxes, num_classes]` +// representing a single score corresponding to each box (each row of boxes). +// max_output_size_per_class: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression per class +// max_total_size: A scalar representing maximum number of boxes retained over all classes. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. // -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { +// Returns A [batch_size, max_detections, 4] float32 tensor +// containing the non-max suppressed boxes.A [batch_size, max_detections] float32 tensor +// containing the scores for the boxes.A [batch_size, max_detections] float32 tensor +// containing the classes for the boxes.A [batch_size] int32 tensor indicating the number of +// valid detections per batch item. Only the top num_detections[i] entries in +// nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the +// entries are zero paddings. +func CombinedNonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size_per_class tf.Output, max_total_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...CombinedNonMaxSuppressionAttr) (nmsed_boxes tf.Output, nmsed_scores tf.Output, nmsed_classes tf.Output, valid_detections tf.Output) { if scope.Err() != nil { return } @@ -34321,106 +37897,99 @@ func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) a(attrs) } opspec := tf.OpSpec{ - Type: "QueueCloseV2", + Type: "CombinedNonMaxSuppression", Input: []tf.Input{ - handle, + boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Computes inverse hyperbolic tangent of x element-wise. -func Atanh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atanh", - Input: []tf.Input{ - x, - }, - } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Returns true if queue is closed. +// Computes the matrix logarithm of one or more square matrices: // -// This operation returns true if the queue is closed and false if the queue -// is open. +// +// \\(log(exp(A)) = A\\) +// +// This op is only defined for complex matrices. If A is positive-definite and +// real, then casting to a complex matrix, taking the logarithm and casting back +// to a real matrix will give the correct result. +// +// This function computes the matrix logarithm using the Schur-Parlett algorithm. +// Details of the algorithm can be found in Section 11.6.2 of: +// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. +// ISBN 978-0-898716-46-7. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the exponential for all input submatrices `[..., :, :]`. // // Arguments: -// handle: The handle to a queue. -func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +// +// @compatibility(scipy) +// Equivalent to scipy.linalg.logm +// @end_compatibility +func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QueueIsClosedV2", + Type: "MatrixLogarithm", Input: []tf.Input{ - handle, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the absolute value of a tensor. +// This op is used as a placeholder in If branch functions. It doesn't provide a +// valid output when run, so must either be removed (e.g. replaced with a +// function input) or guaranteed not to be used (e.g. if mirroring an +// intermediate output needed for the gradient computation of the other branch). // -// Given a tensor `x`, this operation returns a tensor containing the absolute -// value of each element in `x`. For example, if x is an input element and y is -// an output element, this operation computes \\(y = |x|\\). -func Abs(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// dtype: The type of the output. +// shape: The purported shape of the output. This is only used for shape inference; +// the output will not necessarily have this shape. Can be a partial shape. +// +// Returns \"Fake\" output value. This should not be consumed by another op. +func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} opspec := tf.OpSpec{ - Type: "Abs", - Input: []tf.Input{ - x, - }, + Type: "FakeParam", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StackV2Attr is an optional argument to StackV2. -type StackV2Attr func(optionalAttr) - -// StackV2StackName sets the optional stack_name attribute to value. +// Returns the next representable value of `x1` in the direction of `x2`, element-wise. // -// value: Overrides the name used for the temporary stack resource. Default -// value is the name of the 'Stack' op (which is guaranteed unique). -// If not specified, defaults to "" -func StackV2StackName(value string) StackV2Attr { - return func(m optionalAttr) { - m["stack_name"] = value - } -} - -// A stack that produces elements in first-in last-out order. +// This operation returns the same result as the C++ std::nextafter function. // -// Arguments: -// max_size: The maximum size of the stack if non-negative. If negative, the stack -// size is unlimited. -// elem_type: The type of the elements on the stack. +// It can also return a subnormal number. // -// Returns The handle to the stack. -func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { +// @compatibility(cpp) +// Equivalent to C++ std::nextafter function. +// @end_compatibility +func NextAfter(scope *Scope, x1 tf.Output, x2 tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StackV2", + Type: "NextAfter", Input: []tf.Input{ - max_size, + x1, x2, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -34503,6 +38072,45 @@ func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf return scope.AddOperation(opspec) } +// StackPushV2Attr is an optional argument to StackPushV2. +type StackPushV2Attr func(optionalAttr) + +// StackPushV2SwapMemory sets the optional swap_memory attribute to value. +// +// value: Swap `elem` to CPU. Default to false. +// If not specified, defaults to false +func StackPushV2SwapMemory(value bool) StackPushV2Attr { + return func(m optionalAttr) { + m["swap_memory"] = value + } +} + +// Push an element onto the stack. +// +// Arguments: +// handle: The handle to a stack. +// elem: The tensor to be pushed onto the stack. +// +// Returns The same tensor as the input 'elem'. +func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StackPushV2", + Input: []tf.Input{ + handle, elem, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // RpcAttr is an optional argument to Rpc. type RpcAttr func(optionalAttr) @@ -34639,38 +38247,24 @@ func ExperimentalBytesProducedStatsDataset(scope *Scope, input_dataset tf.Output return op.Output(0) } -// StackPushV2Attr is an optional argument to StackPushV2. -type StackPushV2Attr func(optionalAttr) - -// StackPushV2SwapMemory sets the optional swap_memory attribute to value. -// -// value: Swap `elem` to CPU. Default to false. -// If not specified, defaults to false -func StackPushV2SwapMemory(value bool) StackPushV2Attr { - return func(m optionalAttr) { - m["swap_memory"] = value - } -} - -// Push an element onto the stack. +// A substitute for `InterleaveDataset` on a fixed list of `N` datasets. // // Arguments: -// handle: The handle to a stack. -// elem: The tensor to be pushed onto the stack. +// selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines which of the +// `N` data inputs should produce the next output element. +// data_input_datasets: `N` datasets with the same type that will be interleaved according to +// the values of `selector_input_dataset`. // -// Returns The same tensor as the input 'elem'. -func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { +// +func ExperimentalDirectedInterleaveDataset(scope *Scope, selector_input_dataset tf.Output, data_input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "StackPushV2", + Type: "ExperimentalDirectedInterleaveDataset", Input: []tf.Input{ - handle, elem, + selector_input_dataset, tf.OutputList(data_input_datasets), }, Attrs: attrs, } @@ -34678,239 +38272,308 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta return op.Output(0) } -// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. -type FusedBatchNormGradV2Attr func(optionalAttr) +// RandomUniformIntAttr is an optional argument to RandomUniformInt. +type RandomUniformIntAttr func(optionalAttr) -// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. +// RandomUniformIntSeed sets the optional seed attribute to value. // -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformIntSeed(value int64) RandomUniformIntAttr { return func(m optionalAttr) { - m["epsilon"] = value + m["seed"] = value } } -// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. +// RandomUniformIntSeed2 sets the optional seed2 attribute to value. // -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { return func(m optionalAttr) { - m["data_format"] = value + m["seed2"] = value } } -// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. +// Outputs random integers from a uniform distribution. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["is_training"] = value +// The generated values are uniform integers in the range `[minval, maxval)`. +// The lower bound `minval` is included in the range, while the upper bound +// `maxval` is excluded. +// +// The random integers are slightly biased unless `maxval - minval` is an exact +// power of two. The bias is small for values of `maxval - minval` significantly +// smaller than the range of the output (either `2^32` or `2^64`). +// +// Arguments: +// shape: The shape of the output tensor. +// minval: 0-D. Inclusive lower bound on the generated integers. +// maxval: 0-D. Exclusive upper bound on the generated integers. +// +// Returns A tensor of the specified shape filled with uniform random integers. +func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniformInt", + Input: []tf.Input{ + shape, minval, maxval, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Gradient for batch normalization. +// Add the quantile summaries to each quantile stream resource. // -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// An op that adds a list of quantile summaries to a quantile stream resource. Each +// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank) +// for a single feature. // // Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. +// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. +// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature. +// +// Returns the created operation. +func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesQuantileStreamResourceAddSummaries", + Input: []tf.Input{ + quantile_stream_resource_handle, tf.OutputList(summaries), + }, + } + return scope.AddOperation(opspec) +} + +// Creates a Dataset that returns pseudorandom numbers. +// +// Arguments: +// seed: A scalar seed for the random number generator. If either seed or +// seed2 is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// // -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { +func ExperimentalRandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "FusedBatchNormGradV2", + Type: "ExperimentalRandomDataset", Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, + seed, seed2, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } -// Creates a TensorArray for storing multiple gradients of values in the given handle. -// -// Similar to TensorArrayGradV3. However it creates an accumulator with an -// expanded shape compared to the input TensorArray whose gradient is being -// computed. This enables multiple gradients for the same TensorArray to be -// calculated using the same accumulator. +// Creates a dataset that overrides the maximum intra-op parallelism. // // Arguments: -// handle: The handle to the forward TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// shape_to_prepend: An int32 vector representing a shape. Elements in the gradient accumulator will -// have shape which is this shape_to_prepend value concatenated with shape of the -// elements in the TensorArray corresponding to the input handle. -// source: The gradient source string, used to decide which gradient TensorArray -// to return. -func TensorArrayGradWithShape(scope *Scope, handle tf.Output, flow_in tf.Output, shape_to_prepend tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { +// +// max_intra_op_parallelism: Identifies the maximum intra-op parallelism to use. +// +// +func ExperimentalMaxIntraOpParallelismDataset(scope *Scope, input_dataset tf.Output, max_intra_op_parallelism tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayGradWithShape", + Type: "ExperimentalMaxIntraOpParallelismDataset", Input: []tf.Input{ - handle, flow_in, shape_to_prepend, + input_dataset, max_intra_op_parallelism, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. +// StringSplitV2Attr is an optional argument to StringSplitV2. +type StringSplitV2Attr func(optionalAttr) + +// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. // -// Each comparison returns a boolean `true` (if `input_value > threshold`) -// or and `false` otherwise. +// value: An `int`. If `maxsplit > 0`, limit of the split of the result. +// If not specified, defaults to -1 +func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { + return func(m optionalAttr) { + m["maxsplit"] = value + } +} + +// Split elements of `source` based on `sep` into a `SparseTensor`. // -// This operation is useful for Locality-Sensitive-Hashing (LSH) and other -// algorithms that use hashing approximations of cosine and `L2` distances; -// codes can be generated from an input via: +// Let N be the size of source (typically N will be the batch size). Split each +// element of `source` based on `sep` and return a `SparseTensor` +// containing the split tokens. Empty tokens are ignored. // -// ```python -// codebook_size = 50 -// codebook_bits = codebook_size * 32 -// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], -// dtype=x.dtype, -// initializer=tf.orthogonal_initializer()) -// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) -// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 -// # now codes has shape x.shape[:-1] + [codebook_size] +// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', +// then the output will be +// ``` +// st.indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// st.shape = [2, 3] +// st.values = ['hello', 'world', 'a', 'b', 'c'] // ``` // -// **NOTE**: Currently, the innermost dimension of the tensor must be divisible -// by 8. +// If `sep` is given, consecutive delimiters are not grouped together and are +// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and +// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty +// string, consecutive whitespace are regarded as a single separator, and the +// result will contain no empty strings at the startor end if the string has +// leading or trailing whitespace. // -// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is -// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +// Note that the above mentioned behavior matches python's str.split. // // Arguments: -// input: Values to compare against `threshold` and bitpack. -// threshold: Threshold to compare against. -// -// Returns The bitpacked comparisons. -func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { +// input: `1-D` string `Tensor`, the strings to split. +// sep: `0-D` string `Tensor`, the delimiter character. +func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "CompareAndBitpack", + Type: "StringSplitV2", Input: []tf.Input{ - input, threshold, + input, sep, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Push an element onto the tensor_array. +// Creates a dataset that uses a custom thread pool to compute `input_dataset`. // // Arguments: -// handle: The handle to a TensorArray. -// index: The position to write to inside the TensorArray. -// value: The tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// thread_pool: A resource produced by the ThreadPoolHandle op. +// +// +func ExperimentalThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread_pool tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayWriteV3", + Type: "ExperimentalThreadPoolDataset", Input: []tf.Input{ - handle, index, value, flow_in, + input_dataset, thread_pool, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Scatter the data from the input value into specific TensorArray elements. -// -// `indices` must be a vector, its length must match the first dim of `value`. -// -// Arguments: -// handle: The handle to a TensorArray. -// indices: The locations at which to write the tensor elements. -// value: The concatenated tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Computes softsign: `features / (abs(features) + 1)`. +func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayScatterV3", + Type: "Softsign", Input: []tf.Input{ - handle, indices, value, flow_in, + features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// EmptyAttr is an optional argument to Empty. -type EmptyAttr func(optionalAttr) +// EncodeProtoAttr is an optional argument to EncodeProto. +type EncodeProtoAttr func(optionalAttr) -// EmptyInit sets the optional init attribute to value. -// -// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. -// If not specified, defaults to false -func EmptyInit(value bool) EmptyAttr { +// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value. +// If not specified, defaults to "local://" +func EncodeProtoDescriptorSource(value string) EncodeProtoAttr { return func(m optionalAttr) { - m["init"] = value + m["descriptor_source"] = value } } -// Creates a tensor with the given shape. +// The op serializes protobuf messages provided in the input tensors. // -// This operation creates a tensor of `shape` and `dtype`. +// The types of the tensors in `values` must match the schema for the +// fields specified in `field_names`. All the tensors in `values` must +// have a common shape prefix, *batch_shape*. // -// Arguments: -// shape: 1-D. Represents the shape of the output tensor. +// The `sizes` tensor specifies repeat counts for each field. The repeat +// count (last dimension) of a each tensor in `values` must be greater +// than or equal to corresponding repeat count in `sizes`. +// +// A `message_type` name must be provided to give context for the field +// names. The actual message descriptor can be looked up either in the +// linked-in descriptor pool or a filename provided by the caller using +// the `descriptor_source` attribute. +// +// The `descriptor_source` attribute selects a source of protocol +// descriptors to consult when looking up `message_type`. This may be a +// filename containing a serialized `FileDescriptorSet` message, +// or the special value `local://`, in which case only descriptors linked +// into the code will be searched; the filename can be on any filesystem +// accessible to TensorFlow. // +// You can build a `descriptor_source` file using the `--descriptor_set_out` +// and `--include_imports` options to the protocol compiler `protoc`. // -// Returns A `Tensor` of type `T`. -func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { +// The `local://` database only covers descriptors linked into the +// code via C++ libraries, not Python imports. You can link in a proto descriptor +// by creating a cc_library target with alwayslink=1. +// +// There are a few special cases in the value mapping: +// +// Submessage and group fields must be pre-serialized as TensorFlow strings. +// +// TensorFlow lacks support for unsigned int64s, so they must be +// represented as `tf.int64` with the same twos-complement bit pattern +// (the obvious way). +// +// Unsigned int32 values can be represented exactly with `tf.int64`, or +// with sign wrapping if the input is of type `tf.int32`. +// +// Arguments: +// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`. +// values: List of tensors containing values for the corresponding field. +// field_names: List of strings containing proto field names. +// message_type: Name of the proto message type to decode. +// +// Returns Tensor of serialized protos with shape `batch_shape`. +func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Empty", + Type: "EncodeProto", Input: []tf.Input{ - shape, + sizes, tf.OutputList(values), }, Attrs: attrs, } @@ -34918,285 +38581,419 @@ func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAt return op.Output(0) } -// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. -type TensorArrayConcatV3Attr func(optionalAttr) +// Creates an Optional variant with no value. +func OptionalNone(scope *Scope) (optional tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OptionalNone", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. +// DecodeProtoV2Attr is an optional argument to DecodeProtoV2. +type DecodeProtoV2Attr func(optionalAttr) + +// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value. // -// value: The expected shape of an element, if known, -// excluding the first dimension. Used to validate the shapes of -// TensorArray elements. If this shape is not fully specified, concatenating -// zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { +// value: Either the special value `local://` or a path to a file containing +// a serialized `FileDescriptorSet`. +// If not specified, defaults to "local://" +func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr { return func(m optionalAttr) { - m["element_shape_except0"] = value + m["descriptor_source"] = value } } -// Concat the elements from the TensorArray into value `value`. +// DecodeProtoV2MessageFormat sets the optional message_format attribute to value. // -// Takes `T` elements of shapes +// value: Either `binary` or `text`. +// If not specified, defaults to "binary" +func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr { + return func(m optionalAttr) { + m["message_format"] = value + } +} + +// DecodeProtoV2Sanitize sets the optional sanitize attribute to value. // -// ``` -// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) -// ``` +// value: Whether to sanitize the result or not. +// If not specified, defaults to false +func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr { + return func(m optionalAttr) { + m["sanitize"] = value + } +} + +// The op extracts fields from a serialized protocol buffers message into tensors. // -// and concatenates them into a Tensor of shape: +// The `decode_proto` op extracts fields from a serialized protocol buffers +// message into tensors. The fields in `field_names` are decoded and converted +// to the corresponding `output_types` if possible. // -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` +// A `message_type` name must be provided to give context for the field +// names. The actual message descriptor can be looked up either in the +// linked-in descriptor pool or a filename provided by the caller using +// the `descriptor_source` attribute. // -// All elements must have the same shape (excepting the first dimension). +// Each output tensor is a dense tensor. This means that it is padded to +// hold the largest number of repeated elements seen in the input +// minibatch. (The shape is also padded by one to prevent zero-sized +// dimensions). The actual repeat counts for each example in the +// minibatch can be found in the `sizes` output. In many cases the output +// of `decode_proto` is fed immediately into tf.squeeze if missing values +// are not a concern. When using tf.squeeze, always pass the squeeze +// dimension explicitly to avoid surprises. +// +// For the most part, the mapping between Proto field types and +// TensorFlow dtypes is straightforward. However, there are a few +// special cases: +// +// - A proto field that contains a submessage or group can only be converted +// to `DT_STRING` (the serialized submessage). This is to reduce the +// complexity of the API. The resulting string can be used as input +// to another instance of the decode_proto op. +// +// - TensorFlow lacks support for unsigned integers. The ops represent uint64 +// types as a `DT_INT64` with the same twos-complement bit pattern +// (the obvious way). Unsigned int32 values can be represented exactly by +// specifying type `DT_INT64`, or using twos-complement if the caller +// specifies `DT_INT32` in the `output_types` attribute. +// +// The `descriptor_source` attribute selects a source of protocol +// descriptors to consult when looking up `message_type`. This may be a +// filename containing a serialized `FileDescriptorSet` message, +// or the special value `local://`, in which case only descriptors linked +// into the code will be searched; the filename can be on any filesystem +// accessible to TensorFlow. +// +// You can build a `descriptor_source` file using the `--descriptor_set_out` +// and `--include_imports` options to the protocol compiler `protoc`. +// +// The `local://` database only covers descriptors linked into the +// code via C++ libraries, not Python imports. You can link in a proto descriptor +// by creating a cc_library target with alwayslink=1. +// +// Both binary and text proto serializations are supported, and can be +// chosen using the `format` attribute. // // Arguments: -// handle: The handle to a TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// bytes: Tensor of serialized protos with shape `batch_shape`. +// message_type: Name of the proto message type to decode. +// field_names: List of strings containing proto field names. An extension field can be decoded +// by using its full name, e.g. EXT_PACKAGE.EXT_FIELD_NAME. +// output_types: List of TF types to use for the respective field in field_names. // -// Returns All of the elements in the TensorArray, concatenated along the first -// axis.A vector of the row sizes of the original T elements in the -// value output. In the example above, this would be the values: -// `(n1, n2, ..., n(T-1))`. -func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { +// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`. +// Each entry is the number of values found for the corresponding field. +// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field. +// `values[i]` has datatype `output_types[i]` +// and shape `[batch_shape, max(sizes[...,i])]`. +func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayConcatV3", + Type: "DecodeProtoV2", Input: []tf.Input{ - handle, flow_in, + bytes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + if scope.Err() != nil { + return + } + var idx int + var err error + sizes = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("DecodeProtoV2", err) + return + } + return sizes, values } -// Split the data from the input value into TensorArray elements. -// -// Assuming that `lengths` takes on values -// -// ```(n0, n1, ..., n(T-1))``` +// Creates a dataset that splits a SparseTensor into elements row-wise. +func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseTensorSliceDataset", + Input: []tf.Input{ + indices, values, dense_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x / y element-wise for real types. // -// and that `value` has shape +// If `x` and `y` are reals, this will return the floating-point division. // -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RealDiv", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that concatenates `input_dataset` with `another_dataset`. +func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ConcatenateDataset", + Input: []tf.Input{ + input_dataset, another_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. // -// this splits values into a TensorArray with T tensors. +// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the +// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each +// input channel is processed independently of the others with its own structuring +// function. The `output` tensor has shape +// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output +// tensor depend on the `padding` algorithm. We currently only support the default +// "NHWC" `data_format`. // -// TensorArray index t will be the subtensor of values with starting position +// In detail, the grayscale morphological 2-D dilation is the max-sum correlation +// (for consistency with `conv2d`, we use unmirrored filters): // -// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` +// output[b, y, x, c] = +// max_{dy, dx} input[b, +// strides[1] * y + rates[1] * dy, +// strides[2] * x + rates[2] * dx, +// c] + +// filter[dy, dx, c] // -// and having size +// Max-pooling is a special case when the filter has size equal to the pooling +// kernel size and contains all zeros. // -// ```nt x d0 x d1 x ...``` +// Note on duality: The dilation of `input` by the `filter` is equal to the +// negation of the erosion of `-input` by the reflected `filter`. // // Arguments: -// handle: The handle to a TensorArray. -// value: The concatenated tensor to write to the TensorArray. -// lengths: The vector of lengths, how to split the rows of value into the -// TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// strides: The stride of the sliding window for each dimension of the input +// tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: The input stride for atrous morphological dilation. Must be: +// `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns 4-D with shape `[batch, out_height, out_width, depth]`. +func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "TensorArraySplitV3", + Type: "Dilation2D", Input: []tf.Input{ - handle, value, lengths, flow_in, + input, filter, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for the scaled exponential linear (Selu) operation. +// Creates a dataset that shuffles and repeats elements from `input_dataset` +// +// pseudorandomly. // // Arguments: -// gradients: The backpropagated gradients to the corresponding Selu operation. -// outputs: The outputs of the corresponding Selu operation. // -// Returns The gradients: `gradients * (outputs + scale * alpha)` -// if outputs < 0, `scale * gradients` otherwise. -func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// count: A scalar representing the number of times the underlying dataset +// should be repeated. The default is `-1`, which results in infinite repetition. +// +// +func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SeluGrad", + Type: "ShuffleAndRepeatDataset", Input: []tf.Input{ - gradients, outputs, + input_dataset, buffer_size, seed, seed2, count, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Get the current size of the TensorArray. +// Creates a dataset that caches elements from `input_dataset`. +// +// A CacheDataset will iterate over the input_dataset, and store tensors. If the +// cache already exists, the cache will be used. If the cache is inappropriate +// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error +// will the returned when used. // // Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). -// flow_in: A float scalar that enforces proper chaining of operations. // -// Returns The current size of the TensorArray. -func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// filename: A path on the filesystem where we should cache the dataset. Note: this +// will be a directory. +// +// +func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArraySizeV3", + Type: "CacheDataset", Input: []tf.Input{ - handle, flow_in, + input_dataset, filename, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Use TensorArrayGradV3 +// Creates a dataset that emits the records from one or more binary files. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 -func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { +// Arguments: +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// header_bytes: A scalar representing the number of bytes to skip at the +// beginning of a file. +// record_bytes: A scalar representing the number of bytes in each record. +// footer_bytes: A scalar representing the number of bytes to skip at the end +// of a file. +// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. +func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} opspec := tf.OpSpec{ - Type: "TensorArrayGradV2", + Type: "FixedLengthRecordDataset", Input: []tf.Input{ - handle, flow_in, + filenames, header_bytes, record_bytes, footer_bytes, buffer_size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseReduceMaxAttr is an optional argument to SparseReduceMax. -type SparseReduceMaxAttr func(optionalAttr) - -// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the max of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. +// Gradients for batch normalization. // -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// This op is deprecated. See `tf.nn.batch_normalization`. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { +// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "SparseReduceMax", + Type: "BatchNormWithGlobalNormalizationGrad", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + t, m, v, gamma, backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// AsStringAttr is an optional argument to AsString. -type AsStringAttr func(optionalAttr) - -// AsStringPrecision sets the optional precision attribute to value. +// Creates a dataset that emits the records from one or more TFRecord files. // -// value: The post-decimal precision to use for floating point numbers. -// Only used if precision > -1. -// If not specified, defaults to -1 -func AsStringPrecision(value int64) AsStringAttr { - return func(m optionalAttr) { - m["precision"] = value +// Arguments: +// filenames: A scalar or vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar representing the number of bytes to buffer. A value of +// 0 means no buffering will be performed. +func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return } -} - -// AsStringScientific sets the optional scientific attribute to value. -// -// value: Use scientific notation for floating point numbers. -// If not specified, defaults to false -func AsStringScientific(value bool) AsStringAttr { - return func(m optionalAttr) { - m["scientific"] = value + opspec := tf.OpSpec{ + Type: "TFRecordDataset", + Input: []tf.Input{ + filenames, compression_type, buffer_size, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// AsStringShortest sets the optional shortest attribute to value. -// -// value: Use shortest representation (either scientific or standard) for -// floating point numbers. -// If not specified, defaults to false -func AsStringShortest(value bool) AsStringAttr { - return func(m optionalAttr) { - m["shortest"] = value - } -} +// ExperimentalStatsAggregatorHandleAttr is an optional argument to ExperimentalStatsAggregatorHandle. +type ExperimentalStatsAggregatorHandleAttr func(optionalAttr) -// AsStringWidth sets the optional width attribute to value. -// -// value: Pad pre-decimal numbers to this width. -// Applies to both floating point and integer numbers. -// Only used if width > -1. -// If not specified, defaults to -1 -func AsStringWidth(value int64) AsStringAttr { +// ExperimentalStatsAggregatorHandleContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func ExperimentalStatsAggregatorHandleContainer(value string) ExperimentalStatsAggregatorHandleAttr { return func(m optionalAttr) { - m["width"] = value + m["container"] = value } } -// AsStringFill sets the optional fill attribute to value. -// -// value: The value to pad if width > -1. If empty, pads with spaces. -// Another typical value is '0'. String cannot be longer than 1 character. +// ExperimentalStatsAggregatorHandleSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func AsStringFill(value string) AsStringAttr { +func ExperimentalStatsAggregatorHandleSharedName(value string) ExperimentalStatsAggregatorHandleAttr { return func(m optionalAttr) { - m["fill"] = value + m["shared_name"] = value } } -// Converts each entry in the given tensor to strings. Supports many numeric -// -// types and boolean. -func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { +// Creates a statistics manager resource. +func ExperimentalStatsAggregatorHandle(scope *Scope, optional ...ExperimentalStatsAggregatorHandleAttr) (handle tf.Output) { if scope.Err() != nil { return } @@ -35205,363 +39002,191 @@ func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output t a(attrs) } opspec := tf.OpSpec{ - Type: "AsString", - Input: []tf.Input{ - input, - }, + Type: "ExperimentalStatsAggregatorHandle", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Use TensorArrayScatterV3 +// A container for an iterator resource. // -// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 -func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns A handle to the iterator that can be passed to a "MakeIterator" or +// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents +// resource sharing by name, and does not keep a reference to the resource +// container. +func AnonymousIterator(scope *Scope, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayScatterV2", - Input: []tf.Input{ - handle, indices, value, flow_in, - }, + Type: "AnonymousIterator", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Applies sparse addition to `input` using individual values or slices -// -// from `updates` according to indices `indices`. The updates are non-aliasing: -// `input` is only modified in-place if no other operations will use it. -// Otherwise, a copy of `input` is made. This operation has a gradient with -// respect to both `input` and `updates`. -// -// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `input`. -// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or `(P-K)`-dimensional slices -// (if `K < P`) along the `K`th dimension of `input`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ -// -// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 -// elements. In Python, that addition would look like this: -// -// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) -// with tf.Session() as sess: -// print(sess.run(output)) +// Adjust the contrast of one or more images. // -// The resulting value `output` would look like this: +// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are +// interpreted as `[height, width, channels]`. The other dimensions only +// represent a collection of images, such as `[batch, height, width, channels].` // -// [1, 13, 3, 14, 14, 6, 7, 20] +// Contrast is adjusted independently for each channel of each image. // -// See `tf.scatter_nd` for more details about how to make updates to slices. +// For each channel, the Op first computes the mean of the image pixels in the +// channel and then adjusts each component of each pixel to +// `(x - mean) * contrast_factor + mean`. // // Arguments: -// input: A Tensor. -// indices: A Tensor. Must be one of the following types: `int32`, `int64`. -// A tensor of indices into `input`. -// updates: A Tensor. Must have the same type as ref. A tensor of updated values -// to add to `input`. +// images: Images to adjust. At least 3-D. +// contrast_factor: A float multiplier for adjusting contrast. // -// Returns A `Tensor` with the same shape as `input`, containing values of `input` -// updated with `updates`. -func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { +// Returns The contrast-adjusted image or images. +func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ScatterNdNonAliasingAdd", + Type: "AdjustContrastv2", Input: []tf.Input{ - input, indices, updates, + images, contrast_factor, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. -type FractionalMaxPoolAttr func(optionalAttr) - -// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. -// -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. -// If not specified, defaults to false -func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["pseudo_random"] = value - } -} - -// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` -// -// `value 20 5 16 3 7` -// -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. -// If not specified, defaults to false -func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["overlapping"] = value +// Gets the next output from the given iterator . +func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return } -} - -// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. -// -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalMaxPool node in the computation graph. Mainly used -// in unit test to make FractionalMaxPool deterministic. -// If not specified, defaults to false -func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["deterministic"] = value + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "IteratorGetNext", + Input: []tf.Input{ + iterator, + }, + Attrs: attrs, } -} - -// FractionalMaxPoolSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed"] = value + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return } -} - -// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed2"] = value + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNext", err) + return } + return components } -// Performs fractional max pooling on the input. -// -// Fractional max pooling is slightly different than regular max pooling. In -// regular max pooling, you downsize an input set by taking the maximum value of -// smaller N x N subsections of the set (often 2x2), and try to reduce the set by -// a factor of N, where N is an integer. Fractional max pooling, as you might -// expect from the word "fractional", means that the overall reduction ratio N -// does not have to be an integer. -// -// The sizes of the pooling regions are generated randomly but are fairly uniform. -// For example, let's look at the height dimension, and the constraints on the -// list of rows that will be pool boundaries. -// -// First we define the following: -// -// 1. input_row_length : the number of rows from the input set -// 2. output_row_length : which will be smaller than the input -// 3. alpha = input_row_length / output_row_length : our reduction ratio -// 4. K = floor(alpha) -// 5. row_pooling_sequence : this is the result list of pool boundary rows -// -// Then, row_pooling_sequence should satisfy: +// Outputs the single element from the given dataset. // -// 1. a[0] = 0 : the first value of the sequence is 0 -// 2. a[end] = input_row_length : the last value of the sequence is the size -// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size -// 4. length(row_pooling_sequence) = output_row_length+1 +// Arguments: +// dataset: A handle to a dataset that contains a single element. // -// For more details on fractional max pooling, see this paper: -// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) // -// Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. // -// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { +// Returns The components of the single element of `input`. +func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "FractionalMaxPool", + Type: "DatasetToSingleElement", Input: []tf.Input{ - value, + dataset, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Creates a MultiDeviceIterator resource. -// -// Arguments: -// devices: A list of devices the iterator works across. -// shared_name: If non-empty, this resource will be shared under the given name -// across multiple sessions. -// container: If non-empty, this resource is placed in the given container. -// Otherwise, a default container is used. -// output_types: The type list for the return values. -// output_shapes: The list of shapes being produced. -// -// Returns Handle to the resource created. -func MultiDeviceIterator(scope *Scope, devices []string, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"devices": devices, "shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "MultiDeviceIterator", - - Attrs: attrs, + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("DatasetToSingleElement", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return components } -// Deprecated. Use TensorArraySizeV3 +// Converts the given `resource_handle` representing an iterator to a string. // -// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 -func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// Arguments: +// resource_handle: A handle to an iterator resource. +// +// Returns A string representation of the given handle. +func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySizeV2", + Type: "IteratorToStringHandle", Input: []tf.Input{ - handle, flow_in, + resource_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv2DAttr is an optional argument to Conv2D. -type Conv2DAttr func(optionalAttr) - -// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} +// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. +type IteratorFromStringHandleAttr func(optionalAttr) -// Conv2DExplicitPaddings sets the optional explicit_paddings attribute to value. +// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. // -// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith -// dimension, the amount of padding inserted before and after the dimension is -// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If -// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// value: If specified, defines the type of each tuple component in an +// element produced by the resulting iterator. // If not specified, defaults to <> -func Conv2DExplicitPaddings(value []int64) Conv2DAttr { - return func(m optionalAttr) { - m["explicit_paddings"] = value - } -} - -// Conv2DDataFormat sets the optional data_format attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func Conv2DDataFormat(value string) Conv2DAttr { +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { return func(m optionalAttr) { - m["data_format"] = value + m["output_types"] = value } } -// Conv2DDilations sets the optional dilations attribute to value. +// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv2DDilations(value []int64) Conv2DAttr { +// value: If specified, defines the shape of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { return func(m optionalAttr) { - m["dilations"] = value + m["output_shapes"] = value } } -// Computes a 2-D convolution given 4-D `input` and `filter` tensors. -// -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]`, this op -// performs the following: -// -// 1. Flattens the filter to a 2-D matrix with shape -// `[filter_height * filter_width * in_channels, output_channels]`. -// 2. Extracts image patches from the input tensor to form a *virtual* -// tensor of shape `[batch, out_height, out_width, -// filter_height * filter_width * in_channels]`. -// 3. For each patch, right-multiplies the filter matrix and the image patch -// vector. -// -// In detail, with the default NHWC format, -// -// output[b, i, j, k] = -// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * -// filter[di, dj, q, k] -// -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// Converts the given string representing a handle to an iterator to a resource. // // Arguments: -// input: A 4-D tensor. The dimension order is interpreted according to the value -// of `data_format`, see below for details. -// filter: A 4-D tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]` -// strides: 1-D tensor of length 4. The stride of the sliding window for each -// dimension of `input`. The dimension order is determined by the value of -// `data_format`, see below for details. -// padding: The type of padding algorithm to use. +// string_handle: A string representation of the given handle. // -// Returns A 4-D tensor. The dimension order is determined by the value of -// `data_format`, see below for details. -func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { +// Returns A handle to an iterator resource. +func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2D", + Type: "IteratorFromStringHandle", Input: []tf.Input{ - input, filter, + string_handle, }, Attrs: attrs, } diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index af5503f2ad308fffb03d2ebd5964eec273896c72..d70e0d6c0ab2cfefb55ef04ad67b2a74f83ac48d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -362,7 +362,8 @@ tf_cc_test( filegroup( name = "libtensorflow_jni", srcs = select({ - "//tensorflow:darwin": [":libtensorflow_jni.dylib"], + "//tensorflow:windows": [":tensorflow_jni.dll"], + "//tensorflow:macos": [":libtensorflow_jni.dylib"], "//conditions:default": [":libtensorflow_jni.so"], }), visibility = ["//visibility:public"], @@ -373,26 +374,25 @@ LINKER_VERSION_SCRIPT = ":config/version_script.lds" LINKER_EXPORTED_SYMBOLS = ":config/exported_symbols.lds" tf_cc_binary( - name = "libtensorflow_jni.so", + name = "tensorflow_jni", # Set linker options to strip out anything except the JNI # symbols from the library. This reduces the size of the library # considerably (~50% as of January 2017). linkopts = select({ "//tensorflow:debug": [], # Disable all custom linker options in debug mode - "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by LINKER_EXPORTED_SYMBOLS - "$(location {})".format(LINKER_EXPORTED_SYMBOLS), + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location {})".format(LINKER_EXPORTED_SYMBOLS), ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", "-s", - "-Wl,--version-script", # This line must be directly followed by LINKER_VERSION_SCRIPT - "$(location {})".format(LINKER_VERSION_SCRIPT), + "-Wl,--version-script,$(location {})".format(LINKER_VERSION_SCRIPT), ], }), linkshared = 1, linkstatic = 1, + per_os_targets = True, deps = [ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/java/src/main/native", @@ -414,14 +414,3 @@ tf_cc_binary( srcs = ["generate_pom.cc"], deps = ["//tensorflow/c:c_api"], ) - -# System.loadLibrary() on OS X looks for ".dylib" or ".jnilib" -# and no ".so". If and when https://github.com/bazelbuild/bazel/issues/914 -# is resolved, perhaps this workaround rule can be removed. -genrule( - name = "darwin-compat", - srcs = [":libtensorflow_jni.so"], - outs = ["libtensorflow_jni.dylib"], - cmd = "cp $< $@", - output_to_bindir = 1, -) diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index 951e8bdd0dd8aae46a361a8ffcff276579433641..4206f6f9fc8ed029d1a7d9b044dd079ec523de31 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -20,13 +20,13 @@ Releases built from release branches are available on Maven Central. Additionally, every day binaries are built from the `master` branch on GitHub: -- [JAR](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow.jar) -- [Sourc JAR](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-src.jar) -- JNI: - - [Linux CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow_jni-cpu-linux-x86_64.tar.gz) - - [Linux GPU](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow_jni-gpu-linux-x86_64.tar.gz) - - [MacOS](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow_jni-cpu-darwin-x86_64.tar.gz) - - Windows: (No nightly builds available yet) +- [JAR](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow.jar) +- [Source JAR](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-src.jar) +- JNI: + - [Linux CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow_jni-cpu-linux-x86_64.tar.gz) + - [Linux GPU](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow_jni-gpu-linux-x86_64.tar.gz) + - [MacOS](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow_jni-cpu-darwin-x86_64.tar.gz) + - Windows: (No nightly builds available yet) ## Building from source diff --git a/tensorflow/java/build_defs.bzl b/tensorflow/java/build_defs.bzl index e1916ca4d9d6aa179e1a69451a5e981783560026..f423cc4d8277509d45aa8344e322f71b7f1306a8 100644 --- a/tensorflow/java/build_defs.bzl +++ b/tensorflow/java/build_defs.bzl @@ -18,7 +18,7 @@ XLINT_OPTS = [ "-Xlint:-processing", "-Xlint:-serial", "-Xlint:-try", - "-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile + "-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile ] # The bazel errorprone plugin currently only enables default errorChecks diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index db3a3609f1ac4fda18ff5a1248e61c675a8bf9f9..7e04af42be0086bbc510bf07096f65d0c3a9ddda 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.12.0 + 1.13.0-rc2 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 53f7a2d63ef5bc8cfe4fbe372cf2fd3f58a0fe33..8a4d471fa401cf031c2fe5b494aa55744fc882a4 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.12.0 + 1.13.0-rc2 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index a17724c805e38239c61dd27a5cc9ec918bbb2e0f..2e2625641dd8f7eb61d17f52396d7d957ecce2d6 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.12.0 + 1.13.0-rc2 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 30831f90b9f7b4beb5ae3f2ceebadcb6e1f8771e..2cd600541d09994437438a86877f1ec98922fcb2 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.12.0 + 1.13.0-rc2 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index dd6b52be62487ba6cb989b4917a15df7f473a848..b79ee472c378e36ed795adb86cf24e19635f1c58 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.12.0 + 1.13.0-rc2 ../ proto diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml index f47c11809d58464953028c388d491b91f67c3510..70091ad1571631ffd200550e35dba39025dbbca2 100644 --- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml +++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml @@ -6,7 +6,7 @@ org.tensorflow spark-tensorflow-connector_2.11 jar - 1.12.0 + 1.13.0-rc2 spark-tensorflow-connector https://www.tensorflow.org TensorFlow TFRecord connector for Apache Spark DataFrames diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml index 11aaba983f6ded9a6e757703fd9a2411db82ceb6..94d6801cda89670d784f6e2c4c6779d9c3eb39c3 100644 --- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml +++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml @@ -5,7 +5,7 @@ org.tensorflow tensorflow-hadoop jar - 1.12.0 + 1.13.0-rc2 tensorflow-hadoop https://www.tensorflow.org TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 07fcfa5144600f7d9cbf6edbfbecbecc7c115631..5f78f89b93b984e561dfc98b6cb4e8a8a3314b72 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.12.0 + 1.13.0-rc2 ../ tensorflow diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 5d6387e88e96802e9226774abd391ac2dd673143..db6116bd5c843c2846d6b9f67e253e87db6daffc 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -516,7 +516,7 @@ bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) { return false; } for (const auto& attr : op_def.attr()) { - if (attr.type() == "func") { + if (attr.type() == "func" || attr.type() == "list(func)") { return false; // TODO(karllessard) add support for function attributes } } diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index f4ff34ea0361fba5528126b93f3f6e45289d8df2..b46721a93dcbd105dea7c52e8ea615cbd00af1c8 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -17,46 +17,48 @@ load( # and then archive those source files into # ops/gen_sources.srcjar # -def tf_java_op_gen_srcjar(name, - gen_tool, - base_package, - api_def_srcs=[], - out_dir="ops/", - out_src_dir="src/main/java/", - visibility=["//tensorflow/java:__pkg__"]): +def tf_java_op_gen_srcjar( + name, + gen_tool, + base_package, + api_def_srcs = [], + out_dir = "ops/", + out_src_dir = "src/main/java/", + visibility = ["//tensorflow/java:__pkg__"]): + gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + srcs = api_def_srcs[:] - gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files - srcs = api_def_srcs[:] + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + "$$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))", + ) + api_def_args_str = ",".join(api_def_args) - if not api_def_srcs: - api_def_args_str = "," - else: - api_def_args = [] - for api_def_src in api_def_srcs: - # Add directory of the first ApiDef source to args. - # We are assuming all ApiDefs in a single api_def_src are in the - # same directory. - api_def_args.append( - "$$(dirname $$(echo $(locations " + api_def_src + - ") | cut -d\" \" -f1))") - api_def_args_str = ",".join(api_def_args) + gen_cmds += ["$(location " + gen_tool + ")" + + " --output_dir=$(@D)/" + out_src_dir + + " --base_package=" + base_package + + " --api_dirs=" + api_def_args_str] - gen_cmds += ["$(location " + gen_tool + ")" + - " --output_dir=$(@D)/" + out_src_dir + - " --base_package=" + base_package + - " --api_dirs=" + api_def_args_str] + # Generate a source archive containing generated code for these ops. + gen_srcjar = out_dir + name + ".srcjar" + gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] - # Generate a source archive containing generated code for these ops. - gen_srcjar = out_dir + name + ".srcjar" - gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] - - native.genrule( - name=name, - srcs=srcs, - outs=[gen_srcjar], - tools=[ - "@local_jdk//:jar", - "@local_jdk//:jdk", - gen_tool - ] + tf_binary_additional_srcs(), - cmd=" && ".join(gen_cmds)) + native.genrule( + name = name, + srcs = srcs, + outs = [gen_srcjar], + tools = [ + "@local_jdk//:jar", + "@local_jdk//:jdk", + gen_tool, + ] + tf_binary_additional_srcs(), + cmd = " && ".join(gen_cmds), + ) diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index df1426ad75143d720f1d5bd3cf4ce44d30cb226e..c668d4dda8cf729b2fd70644eb2a2cd38ae0923b 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -357,10 +357,10 @@ public final class OperatorProcessor extends AbstractProcessor { + " // Optional attributes\n" + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + " // Naming operators\n" - + " ops.withName(“foo”).constant(5); // name “foo”\n" + + " ops.withName(\"foo\").constant(5); // name \"foo\"\n" + " // Names can exist in a hierarchy\n" - + " Ops sub = ops.withSubScope(“sub”);\n" - + " sub.withName(“bar”).constant(4); // “sub/bar”\n" + + " Ops sub = ops.withSubScope(\"sub\");\n" + + " sub.withName(\"bar\").constant(4); // \"sub/bar\"\n" + "}\n" + "}\n", T_GRAPH, diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 752b49af040268d7e3355b12e4ae6aae310789bd..d5dae187197347a5a77b3c9d819321f7f58841e7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -235,7 +235,116 @@ public final class Graph implements AutoCloseable { public Output[] addGradients(Output y, Output[] x) { return addGradients(null, new Output[] {y}, x, null); } - + + /** + * Used to instantiate an abstract class which overrides the buildSubgraph method to build a + * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to + * create a lambda for the same purpose. + * + *

To be used when calling {@link #whileLoop(Output[], + * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)} + * + *

Example usage (prior to Java 8): + * + *

{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public + * void buildSubgraph(Graph bodyGraph, Output[] bodyInputs, Output[] bodyOutputs) { // build + * body subgraph } }; } + * + *

Example usage (after Java 8): + * + *

{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { // + * build body subgraph };} + */ + public interface WhileSubgraphBuilder { + /** + * To be overridden by user with code to build conditional or body subgraph for a while loop + * + * @param g the subgraph + * @param inputs subgraph inputs + * @param outputs subgraph outputs + */ + public void buildSubgraph(Graph g, Output[] inputs, Output[] outputs); + } + + // called by while loop code in graph_jni.cc to construct conditional/body subgraphs + private static long[] buildSubgraph( + WhileSubgraphBuilder subgraphBuilder, + long subgraphHandle, + long[] inputHandles, + int[] inputIndices, + long[] outputHandles, + int[] outputIndices) { + Graph subgraph = new Graph(subgraphHandle); + + int ninputs = inputHandles.length; + int noutputs = outputHandles.length; + Output[] inputs = new Output[ninputs]; + Output[] outputs = new Output[noutputs]; + long[] outputHandlesAndIndices = new long[noutputs * 2]; + + synchronized (subgraph.nativeHandleLock) { + try (Reference ref = subgraph.ref()) { + + for (int i = 0; i < ninputs; i++) { + Operation op = new Operation(subgraph, inputHandles[i]); + inputs[i] = new Output<>(op, inputIndices[i]); + } + + for (int i = 0; i < noutputs; i++) { + Operation op = new Operation(subgraph, outputHandles[i]); + outputs[i] = new Output<>(op, outputIndices[i]); + } + + subgraphBuilder.buildSubgraph(subgraph, inputs, outputs); + + for (int i = 0, j = noutputs; i < noutputs; i++, j++) { + outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle(); + outputHandlesAndIndices[j] = (long) outputs[i].index(); + } + } + return outputHandlesAndIndices; + } + } + + /** + * Builds a while loop. + * + * @param inputs the loop inputs + * @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph + * @param bgBuilder WhileSubgraphBuilder to build the body subgraph + * @param name name for the loop + * @return list of loop outputs, of the same length as {@code inputs} + */ + public Output[] whileLoop( + Output[] inputs, + WhileSubgraphBuilder cgBuilder, + WhileSubgraphBuilder bgBuilder, + String name) { + int ninputs = inputs.length; + long[] inputHandles = new long[ninputs]; + int[] inputIndices = new int[ninputs]; + Output[] outputs = new Output[ninputs]; + + synchronized (nativeHandleLock) { + try (Reference ref = ref()) { + + for (int i = 0; i < ninputs; i++) { + inputHandles[i] = inputs[i].op().getUnsafeNativeHandle(); + inputIndices[i] = inputs[i].index(); + } + + long[] outputHandlesAndIndices = + whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder); + + for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { + Operation op = new Operation(this, outputHandlesAndIndices[i]); + outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]); + } + } + return outputs; + } + } + private final Object nativeHandleLock = new Object(); private long nativeHandle; private int refcount = 0; @@ -357,6 +466,14 @@ public final class Graph implements AutoCloseable { long[] gradInputHandles, int[] gradInputIndices); + private static native long[] whileLoop( + long handle, + long[] inputHandles, + int[] inputIndices, + String name, + WhileSubgraphBuilder condGraphBuilder, + WhileSubgraphBuilder bodyGraphBuilder); + static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD index 49348daa94ed04990a657922a0fbb515b7721d82..ff3b4102013aff3c63cc0b6ac9d17411194a3783 100644 --- a/tensorflow/java/src/main/native/BUILD +++ b/tensorflow/java/src/main/native/BUILD @@ -33,13 +33,12 @@ tf_cuda_library( "//tensorflow:android": [], "//conditions:default": ["."], }), - deps = [ - "//tensorflow/c:c_api", - ] + select({ + deps = select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib", ], "//conditions:default": [ + "//tensorflow/c:c_api", "//tensorflow/core:all_kernels", "//tensorflow/core:direct_session", "//tensorflow/core:ops", @@ -68,7 +67,7 @@ genrule( name = "copy_jni_md_h", srcs = select({ "//tensorflow:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"], - "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//tensorflow:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], }), outs = ["jni_md.h"], diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index f1744d87693ae8f43c032b24622aaecb41a30cb2..570ba8ac1074ace63f722a1af385a72e2d320b8d 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -18,19 +18,28 @@ limitations under the License. #include #include #include "tensorflow/c/c_api.h" -#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" +#include "tensorflow/java/src/main/native/utils_jni.h" namespace { -TF_Graph* requireHandle(JNIEnv* env, jlong handle) { - static_assert(sizeof(jlong) >= sizeof(TF_Graph*), +template +T* requireHandleImpl(JNIEnv* env, jlong handle) { + static_assert(sizeof(jlong) >= sizeof(T*), "Cannot package C object pointers as a Java long"); if (handle == 0) { throwException(env, kIllegalStateException, "close() has been called on the Graph"); return nullptr; } - return reinterpret_cast(handle); + return reinterpret_cast(handle); +} + +TF_Graph* requireHandle(JNIEnv* env, jlong handle) { + return requireHandleImpl(env, handle); +} + +TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) { + return requireHandleImpl(env, handle); } } // namespace @@ -56,10 +65,8 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env, return reinterpret_cast(op); } -JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv* env, - jclass clazz, - jlong handle, - jint position) { +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation( + JNIEnv* env, jclass clazz, jlong handle, jint position) { TF_Graph* g = requireHandle(env, handle); if (g == nullptr) return nullptr; @@ -189,3 +196,140 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( return dy_handles_and_indices; } + +// helper function for while loop -- constructs conditional or body subgraph +jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder, + TF_Graph* const subgraph, + const TF_Output* const inputs, + const TF_Output* const outputs, const int ninputs, + const int noutputs) { + jmethodID build_subgraph_method_id = env->GetStaticMethodID( + clazz, "buildSubgraph", + "(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J"); + if (build_subgraph_method_id == 0) return nullptr; + + jlong subgraph_handle = reinterpret_cast(subgraph); + + jlongArray input_handles = env->NewLongArray(ninputs); + jintArray input_indices = env->NewIntArray(ninputs); + jlongArray output_handles = env->NewLongArray(noutputs); + jintArray output_indices = env->NewIntArray(noutputs); + + jlong* input_handles_elems = + env->GetLongArrayElements(input_handles, nullptr); + jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr); + jlong* output_handles_elems = + env->GetLongArrayElements(output_handles, nullptr); + jint* output_indices_elems = + env->GetIntArrayElements(output_indices, nullptr); + + for (int i = 0; i < ninputs; ++i) { + input_handles_elems[i] = reinterpret_cast((inputs[i]).oper); + input_indices_elems[i] = static_cast((inputs[i]).index); + } + + for (int i = 0; i < noutputs; ++i) { + output_handles_elems[i] = reinterpret_cast((outputs[i]).oper); + output_indices_elems[i] = static_cast((outputs[i]).index); + } + + env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0); + env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0); + env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0); + env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0); + + // call Java code to construct the subgraph + jlongArray output_handles_and_indices = + (jlongArray)env->CallStaticObjectMethod( + clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle, + input_handles, input_indices, output_handles, output_indices); + + if (env->ExceptionOccurred()) { + env->ExceptionDescribe(); + return nullptr; + } + + // returned array contains both op handles and output indices, in pair + return output_handles_and_indices; +} + +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop( + JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles, + jintArray input_indices, jstring name, jobject cond_graph_builder, + jobject body_graph_builder) { + TF_Graph* g = requireHandle(env, handle); + TF_Status* status = TF_NewStatus(); + if (g == nullptr) return nullptr; + + int ninputs = env->GetArrayLength(input_handles); + + std::unique_ptr inputs(new TF_Output[ninputs]); + resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(), + ninputs); + if (env->ExceptionCheck()) return nullptr; + + // initialize while params + TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status); + throwExceptionIfNotOK(env, status); + + // build conditional subgraph + jlongArray cond_output_handles_and_indices = + buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph, + params.cond_inputs, ¶ms.cond_output, params.ninputs, 1); + + // build body subgraph + jlongArray body_output_handles_and_indices = buildSubgraph( + env, clazz, body_graph_builder, params.body_graph, params.body_inputs, + params.body_outputs, params.ninputs, params.ninputs); + + if (cond_output_handles_and_indices == nullptr || + body_output_handles_and_indices == nullptr) + return nullptr; + + // set cond_output param to output of the conditional subgraph + jlong* cond_output_elems = + env->GetLongArrayElements(cond_output_handles_and_indices, nullptr); + TF_Operation* cond_output_op = + requireOperationHandle(env, cond_output_elems[0]); + params.cond_output = {cond_output_op, + static_cast(cond_output_elems[1])}; + env->ReleaseLongArrayElements(cond_output_handles_and_indices, + cond_output_elems, 0); + + // set body_outputs param to outputs of the body subgraph + jlong* body_output_elems = + env->GetLongArrayElements(body_output_handles_and_indices, nullptr); + for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { + TF_Operation* body_output_op = + requireOperationHandle(env, body_output_elems[i]); + params.body_outputs[i] = {body_output_op, + static_cast(body_output_elems[j])}; + } + env->ReleaseLongArrayElements(body_output_handles_and_indices, + body_output_elems, 0); + + // set loop name param + params.name = env->GetStringUTFChars(name, 0); + + // build the while loop, storing loop outputs in `outputs` + std::unique_ptr outputs(new TF_Output[ninputs]); + TF_FinishWhile(¶ms, status, outputs.get()); + + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + + env->ReleaseStringUTFChars(name, params.name); + + // returned array contains both op handles and output indices, in pair + jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2); + jlong* output_elems = + env->GetLongArrayElements(output_handles_and_indices, nullptr); + for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { + TF_Output output = outputs.get()[i]; + output_elems[i] = reinterpret_cast(output.oper); + output_elems[j] = static_cast(output.index); + } + env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0); + + return output_handles_and_indices; +} diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index efed23f83b6265e4df37cd8b35ce45576c415c43..4281297dca250604e3baf04c35af254fb126b888 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -51,8 +51,9 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass, * Signature: (JI)[J */ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *, - jclass, jlong, - jint); + jclass, + jlong, + jint); /* * Class: org_tensorflow_Graph @@ -82,6 +83,15 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray, jintArray, jlongArray, jintArray); +/* + * Class: org_tensorflow_Graph + * Method: whileLoop + * Signature: + * (J[J[IILjava/lang/String;Lorg/tensorflow/Graph/WhileSubgraphBuilder;Lorg/tensorflow/Graph/WhileSubgraphBuilder;)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop( + JNIEnv *, jclass, jlong, jlongArray, jintArray, jstring, jobject, jobject); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index 7c05c1deafeea5d0b482a70f528d997a3394b365..f0428a1ae6c23a132cd2c349ac8bc7a354f3ecf6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -254,7 +254,115 @@ public class GraphTest { } } } - + + @Test + public void buildWhileLoopSingleInput() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output input = TestUtil.placeholder(g, "input1", Integer.class); + + // could write this using lambda after Java 8 + Graph.WhileSubgraphBuilder condGraphBuilder = + new Graph.WhileSubgraphBuilder() { + @Override + public void buildSubgraph( + Graph condGraph, Output[] condInputs, Output[] condOutputs) { + Output sixteen = TestUtil.constant(condGraph, "sixteen", 16); + // condInputs[0] < 16 + Output condOutput = + condGraph + .opBuilder("Less", "cond") + .addInput(condInputs[0]) + .addInput(sixteen) + .build() + .output(0); + + condOutputs[0] = condOutput; + } + }; + + // could write this using lambda after Java 8 + Graph.WhileSubgraphBuilder bodyGraphBuilder = + new Graph.WhileSubgraphBuilder() { + @Override + public void buildSubgraph( + Graph bodyGraph, Output[] bodyInputs, Output[] bodyOutputs) { + bodyOutputs[0] = TestUtil.square(bodyGraph, "square", bodyInputs[0]); + } + }; + + Output[] loopOutputs = + g.whileLoop(toArray(input), condGraphBuilder, bodyGraphBuilder, "test_loop"); + + try (Tensor c = Tensors.create(2); + Tensor output = s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) { + + assertEquals(16, output.intValue()); // ((2^2)^2) + } + } + } + + @Test + public void buildWhileLoopMultipleInputs() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output input1 = TestUtil.placeholder(g, "input1", Integer.class); + Output input2 = TestUtil.placeholder(g, "input2", Integer.class); + Output[] inputs = toArray(input1, input2); + + // could write this using lambda after Java 8 + Graph.WhileSubgraphBuilder condGraphBuilder = + new Graph.WhileSubgraphBuilder() { + @Override + public void buildSubgraph( + Graph condGraph, Output[] condInputs, Output[] condOutputs) { + Output sixteen = TestUtil.constant(condGraph, "sixteen", 16); + Output condOutput = + condGraph + .opBuilder("Less", "cond") + .addInput(condInputs[0]) + .addInput(sixteen) + .build() + .output(0); // condInputs[0] < 16 + + condOutputs[0] = condOutput; + } + }; + + // could write this using lambda after Java 8 + Graph.WhileSubgraphBuilder bodyGraphBuilder = + new Graph.WhileSubgraphBuilder() { + @Override + public void buildSubgraph( + Graph bodyGraph, Output[] bodyInputs, Output[] bodyOutputs) { + bodyOutputs[0] = TestUtil.square(bodyGraph, "square1", bodyInputs[0]); + bodyOutputs[1] = TestUtil.square(bodyGraph, "square2", bodyInputs[1]); + } + }; + + Output[] loopOutputs = + g.whileLoop(inputs, condGraphBuilder, bodyGraphBuilder, "test_loop"); + + try (Tensor c1 = Tensors.create(2); + Tensor c2 = Tensors.create(5); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList<>( + s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run())) { + + assertEquals(2, outputs.size()); + assertEquals(16, outputs.get(0).intValue()); // ((2^2)^2) + assertEquals(625, outputs.get(1).intValue()); // ((5^2)^2) + } + } + } + private static Output[] toArray(Output... outputs) { return outputs; } diff --git a/tensorflow/lite/Android.bp b/tensorflow/lite/Android.bp index a059c43d1a95183ebb7f20b6b5ec87fe726eb73a..4e54edc49f037da90e305f8a845c7918143ab64b 100644 --- a/tensorflow/lite/Android.bp +++ b/tensorflow/lite/Android.bp @@ -52,10 +52,13 @@ cc_library_static { "core/subgraph.cc", "graph_info.cc", "interpreter.cc", + "minimal_logging.cc", + "minimal_logging_android.cc", "mmap_allocation.cc", "model.cc", "mutable_op_resolver.cc", "nnapi_delegate.cc", + "nnapi/nnapi_implementation.cc", "optional_debug_tools.cc", "simple_memory_arena.cc", "stderr_reporter.cc", @@ -72,6 +75,7 @@ cc_library_static { cflags: [ "-Wno-deprecated-declarations", "-Wno-extern-c-compat", + "-Wno-ignored-attributes", "-Wno-invalid-partial-specialization", "-Wno-mismatched-tags", "-Wno-sign-compare", diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 18fd1c10126820527ad844e94a490d9af514d6d4..e9eb7af8ea4145f2fc6783821a7c75b8ef399cf5 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -4,13 +4,15 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/lite:build_def.bzl", "tflite_copts") -load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test") +load("//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object", "tflite_copts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") exports_files(glob([ "testdata/*.bin", "testdata/*.pb", + "testdata/*.tflite", + "testdata/*.csv", "models/testdata/*", ])) @@ -39,6 +41,7 @@ config_setting( TFLITE_DEFAULT_COPTS = if_not_windows([ "-Wall", "-Wno-comment", + "-Wno-extern-c-compat", ]) cc_library( @@ -173,27 +176,18 @@ cc_library( "stderr_reporter.h", ], copts = tflite_copts() + TFLITE_DEFAULT_COPTS, - linkopts = [ - ] + select({ - "//tensorflow:android": [ - "-llog", - ], - "//conditions:default": [ - ], - }), deps = [ ":arena_planner", ":graph_info", ":memory_planner", + ":minimal_logging", ":schema_fbs_version", ":simple_memory_arena", ":string", ":util", "//tensorflow/lite/c:c_api_internal", - "//tensorflow/lite/core/api:api", - "//tensorflow/lite/kernels:eigen_support", - "//tensorflow/lite/kernels:gemm_support", - "//tensorflow/lite/nnapi:nnapi_lib", + "//tensorflow/lite/core/api", + "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/schema:schema_fbs", ] + select({ @@ -219,6 +213,10 @@ cc_test( name = "string_util_test", size = "small", srcs = ["string_util_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], deps = [ ":framework", ":string_util", @@ -233,10 +231,13 @@ cc_test( name = "interpreter_test", size = "small", srcs = ["interpreter_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], deps = [ ":framework", ":string_util", - "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:kernel_util", @@ -252,6 +253,10 @@ cc_test( name = "graph_info_test", size = "small", srcs = ["graph_info_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], deps = [ ":framework", "//tensorflow/lite/testing:util", @@ -264,6 +269,10 @@ cc_test( name = "simple_memory_arena_test", size = "small", srcs = ["simple_memory_arena_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], deps = [ ":simple_memory_arena", "//tensorflow/lite/testing:util", @@ -284,9 +293,11 @@ cc_test( "testdata/test_model.bin", "testdata/test_model_broken.bin", ], + tags = [ + "tflite_not_portable", + ], deps = [ ":framework", - "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/testing:util", @@ -323,6 +334,10 @@ cc_test( name = "mutable_op_resolver_test", size = "small", srcs = ["mutable_op_resolver_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], deps = [ ":framework", "//tensorflow/lite/testing:util", @@ -344,9 +359,76 @@ cc_test( name = "util_test", size = "small", srcs = ["util_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], deps = [ ":util", "//tensorflow/lite/c:c_api_internal", "@com_google_googletest//:gtest", ], ) + +cc_library( + name = "minimal_logging", + srcs = [ + "minimal_logging.cc", + ] + select({ + "//tensorflow:android": [ + "minimal_logging_android.cc", + ], + "//tensorflow:ios": [ + "minimal_logging_ios.cc", + ], + "//conditions:default": [ + "minimal_logging_default.cc", + ], + }), + hdrs = ["minimal_logging.h"], + copts = TFLITE_DEFAULT_COPTS + tflite_copts(), + linkopts = select({ + "//tensorflow:android": ["-llog"], + "//conditions:default": [], + }), + visibility = ["//visibility:private"], +) + +cc_test( + name = "minimal_logging_test", + size = "small", + srcs = ["minimal_logging_test.cc"], + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], + deps = [ + ":minimal_logging", + "@com_google_googletest//:gtest", + ], +) + +# Shared lib target for convenience, pulls in the core runtime and builtin ops. +# Note: This target is not yet finalized, and the exact set of exported (C/C++) +# APIs is subject to change. +tflite_cc_shared_object( + name = "libtensorflowlite.so", + linkopts = select({ + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location //tensorflow/lite:tflite_exported_symbols.lds)", + "-Wl,-install_name,@rpath/libtensorflowlite.so", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-Wl,--version-script,$(location //tensorflow/lite:tflite_version_script.lds)", + ], + }), + deps = [ + ":framework", + ":tflite_exported_symbols.lds", + ":tflite_version_script.lds", + "//tensorflow/lite/kernels:builtin_ops", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 5e6b13a46ae2ac8d4a1cfc2f802648ef27742e22..2b30309d83a544889fe80c9fd3eac293ef557d6b 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -2,6 +2,7 @@ load( "//tensorflow:tensorflow.bzl", + "tf_binary_additional_srcs", "tf_cc_shared_object", "tf_cc_test", ) @@ -85,24 +86,27 @@ def tflite_jni_linkopts_unstripped(): "//conditions:default": [], }) -def tflite_linkopts(): - """Defines linker flags to reduce size of TFLite binary.""" - return tflite_linkopts_unstripped() + select({ +def tflite_symbol_opts(): + """Defines linker flags whether to include symbols or not.""" + return select({ "//tensorflow:android": [ - "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. ], "//conditions:default": [], + }) + select({ + "//tensorflow:debug": [], + "//conditions:default": [ + "-s", # Omit symbol table, for all non debug builds + ], }) +def tflite_linkopts(): + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + tflite_symbol_opts() + def tflite_jni_linkopts(): """Defines linker flags to reduce size of TFLite binary with JNI.""" - return tflite_jni_linkopts_unstripped() + select({ - "//tensorflow:android": [ - "-s", # Omit symbol table. - "-latomic", # Required for some uses of ISO C++11 in x86. - ], - "//conditions:default": [], - }) + return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() def tflite_jni_binary( name, @@ -157,7 +161,7 @@ def tf_to_tflite(name, src, options, out): """ toco_cmdline = " ".join([ - "//tensorflow/lite/toco:toco", + "$(location //tensorflow/lite/toco:toco)", "--input_format=TENSORFLOW_GRAPHDEF", "--output_format=TFLITE", ("--input_file=$(location %s)" % src), @@ -168,7 +172,7 @@ def tf_to_tflite(name, src, options, out): srcs = [src], outs = [out], cmd = toco_cmdline, - tools = ["//tensorflow/lite/toco:toco"], + tools = ["//tensorflow/lite/toco:toco"] + tf_binary_additional_srcs(), ) def tflite_to_json(name, src, out): @@ -225,9 +229,11 @@ def generated_test_models(): return [ "abs", "add", + "add_n", "arg_min_max", "avg_pool", "batch_to_space_nd", + "ceil", "concat", "constant", "control_dep", @@ -235,8 +241,10 @@ def generated_test_models(): "conv2d_transpose", "conv_with_shared_weights", "conv_to_depthwiseconv_with_shared_weights", + "cos", "depthwiseconv", "div", + "elu", "equal", "exp", "expand_dims", @@ -247,6 +255,7 @@ def generated_test_models(): "fully_connected", "fused_batch_norm", "gather", + "gather_nd", "gather_with_constant", "global_batch_norm", "greater", @@ -264,7 +273,7 @@ def generated_test_models(): "logical_and", "logical_or", "logical_xor", - #"lstm", TODO(b/122889684): Resolve toco structured line parsing in oss. + "lstm", "max_pool", "maximum", "mean", @@ -281,6 +290,7 @@ def generated_test_models(): "prelu", "pow", "range", + "rank", "reduce_any", "reduce_max", "reduce_min", @@ -290,6 +300,9 @@ def generated_test_models(): "relu6", "reshape", "resize_bilinear", + "resolve_constant_strided_slice", + "reverse_sequence", + "reverse_v2", "rsqrt", "shape", "sigmoid", @@ -307,12 +320,13 @@ def generated_test_models(): "squeeze", "strided_slice", "strided_slice_1d_exhaustive", - "strided_slice_buggy", "sub", "tile", "topk", "transpose", "transpose_conv", + "unidirectional_sequence_lstm", + "unidirectional_sequence_rnn", "unique", "unpack", "unroll_batch_matmul", @@ -328,6 +342,8 @@ def generated_test_models_failing(conversion_mode): return [ "lstm", # TODO(b/117510976): Restore when lstm flex conversion works. "unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests. + "unidirectional_sequence_lstm", + "unidirectional_sequence_rnn", ] return [] @@ -442,10 +458,11 @@ def flex_dep(target_op_sets): else: return [] -def gen_model_coverage_test(model_name, data, failure_type, tags): +def gen_model_coverage_test(src, model_name, data, failure_type, tags): """Generates Python test targets for testing TFLite models. Args: + src: Main source file. model_name: Name of the model to test (must be also listed in the 'data' dependencies) data: List of BUILD targets linking the data. @@ -462,9 +479,9 @@ def gen_model_coverage_test(model_name, data, failure_type, tags): i = i + 1 native.py_test( name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")), - srcs = ["model_coverage_test.py"], + srcs = [src], + main = src, size = "large", - main = "model_coverage_test.py", args = [ "--model_name=%s" % model_name, "--target_ops=%s" % target_op_sets, diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index ce73aa0f9cddbe53021dc2e6fd515fc1606b9469..7b4efdf4a36c8bed835725d0277ffacb001c17de 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -129,6 +129,15 @@ typedef enum { kTfLiteBuiltinAbs = 101, kTfLiteBuiltinSplitV = 102, kTfLiteBuiltinUnique = 103, + kTfLiteBuiltinCeil = 104, + kTfLiteBuiltinReverseV2 = 105, + kTfLiteBuiltinAddN = 106, + kTfLiteBuiltinGatherNd = 107, + kTfLiteBuiltinCos = 108, + kTfLiteBuiltinWhere = 109, + kTfLiteBuiltinRank = 110, + kTfLiteBuiltinElu = 111, + kTfLiteBuiltinReverseSequence = 112, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 91c04a5f1fb5bb1a15bd1da074a1276a3d8e7793..661b648550c9a3fc64b8bc2fb97a2f3b16e7aac1 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -12,7 +12,7 @@ cc_library( "c_api_internal.h", ], visibility = [ - "//tensorflow/contrib/lite:__subpackages__", + "//learning/brain/mobile/kernel_test:__subpackages__", "//tensorflow/lite:__subpackages__", ], ) diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index 332c2db14511af18a8e3d99fc93891ce92d1792a..5d1c92d36f5d73ba4de79be217daade4f1866b84 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -333,6 +333,9 @@ typedef struct { TfLiteType out_type; } TfLiteShapeParams; +typedef struct { +} TfLiteRankParams; + typedef struct { // Parameters supported by version 1: float min; @@ -365,6 +368,11 @@ typedef struct { TfLiteType index_out_type; } TfLiteUniqueParams; +typedef struct { + int seq_dim; + int batch_dim; +} TfLiteReverseSequenceParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data_test.cc b/tensorflow/lite/c/builtin_op_data_test.cc index 4ce7c481e1c26e6fcfdaa680e9ca666b82968d53..4967183dd56df64b75c719869d16d052ae976081 100644 --- a/tensorflow/lite/c/builtin_op_data_test.cc +++ b/tensorflow/lite/c/builtin_op_data_test.cc @@ -71,6 +71,7 @@ TEST(IntArray, CanCompileStructs) { TfLiteTransposeConvParams transpose_conv_params; TfLiteSparseToDenseParams sparse_to_dense_params; TfLiteShapeParams shape_params; + TfLiteRankParams rank_params; TfLiteFakeQuantParams fake_quant_params; TfLitePackParams pack_params; TfLiteOneHotParams one_hot_params; diff --git a/tensorflow/lite/c/c_api_internal.c b/tensorflow/lite/c/c_api_internal.c index 29dba15c63cb9f2b87484cc0d777db471084d8ce..f20ee23bd81eb87c25a1a7f61cce59df7ae6678e 100644 --- a/tensorflow/lite/c/c_api_internal.c +++ b/tensorflow/lite/c/c_api_internal.c @@ -91,10 +91,10 @@ void TfLiteTensorDataFree(TfLiteTensor* t) { t->data.raw = NULL; } -void TfLiteQuantizationFree(TfLiteTensor* t) { - if (t->quantization.type == kTfLiteAffineQuantization) { +void TfLiteQuantizationFree(TfLiteQuantization* quantization) { + if (quantization->type == kTfLiteAffineQuantization) { TfLiteAffineQuantization* q_params = - (TfLiteAffineQuantization*)(t->quantization.params); + (TfLiteAffineQuantization*)(quantization->params); if (q_params->scale) { TfLiteFloatArrayFree(q_params->scale); q_params->scale = NULL; @@ -105,8 +105,8 @@ void TfLiteQuantizationFree(TfLiteTensor* t) { } free(q_params); } - t->quantization.params = NULL; - t->quantization.type = kTfLiteNoQuantization; + quantization->params = NULL; + quantization->type = kTfLiteNoQuantization; } void TfLiteTensorFree(TfLiteTensor* t) { @@ -114,7 +114,7 @@ void TfLiteTensorFree(TfLiteTensor* t) { if (t->dims) TfLiteIntArrayFree(t->dims); t->dims = NULL; - TfLiteQuantizationFree(t); + TfLiteQuantizationFree(&t->quantization); } void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index 31f483370cf7797e5d071d637adc377eefd25352..83e2be690762be3e2cacf02ea8311b76dc1731c4 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -340,6 +340,9 @@ typedef struct { // Free data memory of tensor `t`. void TfLiteTensorDataFree(TfLiteTensor* t); +// Free quantization data. +void TfLiteQuantizationFree(TfLiteQuantization* quantization); + // Free memory of tensor `t`. void TfLiteTensorFree(TfLiteTensor* t); diff --git a/tensorflow/lite/context_util.h b/tensorflow/lite/context_util.h index 68b91ea0b93e602c20d1db3284a523e9f55dfd5b..2f846cc259e34b1f750ba0787dffa93db597cbe0 100644 --- a/tensorflow/lite/context_util.h +++ b/tensorflow/lite/context_util.h @@ -38,6 +38,7 @@ class TfLiteIntArrayView { const_iterator begin() const { return int_array_->data; } const_iterator end() const { return &int_array_->data[int_array_->size]; } size_t size() const { return end() - begin(); } + int operator[](size_t pos) const { return int_array_->data[pos]; } private: const TfLiteIntArray* int_array_; diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 6a43b0322d17041a5ae4a0527376d1465a539b1d..db6b4a2d18ecd894fa3b8a0bf646ca9f8c6b6511 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -51,6 +51,7 @@ cc_test( srcs = ["flatbuffer_conversions_test.cc"], deps = [ ":api", + "//tensorflow/lite:string", "//tensorflow/lite/c:c_api_internal", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 970e45bbdcd5c5d582c0cab29ea89c657987c70d..2ba64f51d9aa13f6ff863f1a305a7ad36ae4c67a 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -28,22 +28,27 @@ namespace { // Copies the contents from the flatbuffer int vector `flatbuffer` into the // int array `buffer`. `flat_vector` and `buffer` represent the same // configuration operation for a given operation. -void FlatBufferIntVectorToArray(int max_size_of_buffer, - const flatbuffers::Vector* flat_vector, - int* buffer, ErrorReporter* error_reporter) { +TfLiteStatus FlatBufferIntVectorToArray( + int max_size_of_buffer, const flatbuffers::Vector* flat_vector, + int* buffer, ErrorReporter* error_reporter, const char* op_name) { if (!flat_vector) { - error_reporter->Report("Input array not provided for operation.\n"); + error_reporter->Report("Input array not provided for operation '%s'.\n", + op_name); + return kTfLiteError; } else { int num_dimensions = flat_vector->Length(); if (num_dimensions > max_size_of_buffer / sizeof(int)) { error_reporter->Report( - "Found too many dimensions in the operation's input array.\n"); + "Found too many dimensions in the input array of operation '%s'.\n", + op_name); + return kTfLiteError; } else { for (int i = 0; i < num_dimensions; ++i) { buffer[i] = flat_vector->Get(i); } } } + return kTfLiteOk; } } // namespace @@ -452,8 +457,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { auto* new_shape = schema_params->new_shape(); - FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, - params->shape, error_reporter); + TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( + sizeof(params->shape), new_shape, params->shape, error_reporter, + "reshape")); params->num_dimensions = new_shape->Length(); } *builtin_data = reinterpret_cast(params); @@ -521,8 +527,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { const auto& squeeze_dims = schema_params->squeeze_dims(); - FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, - params->squeeze_dims, error_reporter); + TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( + sizeof(params->squeeze_dims), squeeze_dims, params->squeeze_dims, + error_reporter, "squeeze")); params->num_squeeze_dims = squeeze_dims->Length(); } *builtin_data = reinterpret_cast(params); @@ -665,6 +672,17 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_REVERSE_SEQUENCE: { + TfLiteReverseSequenceParams* params = + allocator->AllocatePOD(); + if (auto* reverse_seq_params = + op->builtin_options_as_ReverseSequenceOptions()) { + params->seq_dim = reverse_seq_params->seq_dim(); + params->batch_dim = reverse_seq_params->batch_dim(); + } + *builtin_data = reinterpret_cast(params); + break; + } // Below are the ops with no builtin_data strcture. case BuiltinOperator_ABS: @@ -673,12 +691,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, // ok for now, since there is no call implementation either. case BuiltinOperator_CALL: case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_COS: case BuiltinOperator_CUSTOM: case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_ELU: case BuiltinOperator_EMBEDDING_LOOKUP: case BuiltinOperator_EQUAL: case BuiltinOperator_EXP: case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_CEIL: case BuiltinOperator_FLOOR: case BuiltinOperator_GREATER: case BuiltinOperator_GREATER_EQUAL: @@ -718,6 +739,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_FLOOR_MOD: case BuiltinOperator_RANGE: case BuiltinOperator_SQUARED_DIFFERENCE: + case BuiltinOperator_REVERSE_V2: + case BuiltinOperator_ADD_N: + case BuiltinOperator_GATHER_ND: + case BuiltinOperator_WHERE: + case BuiltinOperator_RANK: break; } return kTfLiteOk; diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 4d1d1b21fda106b3196ff43421996f45ab83af4f..4a5de48302c1e840c524335ee549c74a162e107e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -17,8 +17,10 @@ limitations under the License. #include +#include #include #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace { @@ -33,6 +35,8 @@ class MockErrorReporter : public ErrorReporter { char* GetBuffer() { return buffer_; } int GetBufferSize() { return buffer_size_; } + string GetAsString() const { return string(buffer_, buffer_size_); } + private: static constexpr int kBufferSize = 256; char buffer_[kBufferSize]; @@ -60,25 +64,56 @@ class MockDataAllocator : public BuiltinDataAllocator { } // namespace -TEST(FlatbufferConversions, TestParseOpDataConv) { - MockErrorReporter mock_reporter; - ErrorReporter* reporter = &mock_reporter; - MockDataAllocator mock_allocator; - - flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset conv_options = - CreateConv2DOptions(builder, Padding_SAME, 1, 2, - ActivationFunctionType_RELU, 3, 4) - .Union(); - flatbuffers::Offset conv_offset = CreateOperatorDirect( - builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options, - nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr); - builder.Finish(conv_offset); - void* conv_pointer = builder.GetBufferPointer(); - const Operator* conv_op = flatbuffers::GetRoot(conv_pointer); +class FlatbufferConversionsTest : public ::testing::Test { + public: + const Operator* BuildTestOperator(BuiltinOptions op_type, + flatbuffers::Offset options) { + flatbuffers::Offset offset = + CreateOperatorDirect(builder_, 0, nullptr, nullptr, op_type, options, + nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr); + builder_.Finish(offset); + void* pointer = builder_.GetBufferPointer(); + return flatbuffers::GetRoot(pointer); + } + + protected: + MockErrorReporter mock_reporter_; + MockDataAllocator mock_allocator_; + flatbuffers::FlatBufferBuilder builder_; +}; + +TEST_F(FlatbufferConversionsTest, ParseBadSqueeze) { + const Operator* op = BuildTestOperator( + BuiltinOptions_SqueezeOptions, CreateSqueezeOptions(builder_).Union()); + void* output_data = nullptr; + EXPECT_NE(kTfLiteOk, ParseOpData(op, BuiltinOperator_SQUEEZE, &mock_reporter_, + &mock_allocator_, &output_data)); + EXPECT_THAT(mock_reporter_.GetAsString(), + ::testing::ContainsRegex( + "Input array not provided for operation 'squeeze'")); +} + +TEST_F(FlatbufferConversionsTest, ParseBadReshape) { + const Operator* op = BuildTestOperator( + BuiltinOptions_ReshapeOptions, CreateSqueezeOptions(builder_).Union()); + void* output_data = nullptr; + EXPECT_NE(kTfLiteOk, ParseOpData(op, BuiltinOperator_RESHAPE, &mock_reporter_, + &mock_allocator_, &output_data)); + EXPECT_THAT(mock_reporter_.GetAsString(), + ::testing::ContainsRegex( + "Input array not provided for operation 'reshape'")); +} + +TEST_F(FlatbufferConversionsTest, TestParseOpDataConv) { + const Operator* conv_op = + BuildTestOperator(BuiltinOptions_Conv2DOptions, + CreateConv2DOptions(builder_, Padding_SAME, 1, 2, + ActivationFunctionType_RELU, 3, 4) + .Union()); void* output_data = nullptr; - EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter, - &mock_allocator, &output_data)); + EXPECT_EQ(kTfLiteOk, + ParseOpData(conv_op, BuiltinOperator_CONV_2D, &mock_reporter_, + &mock_allocator_, &output_data)); EXPECT_NE(nullptr, output_data); TfLiteConvParams* params = reinterpret_cast(output_data); EXPECT_EQ(kTfLitePaddingSame, params->padding); @@ -89,30 +124,20 @@ TEST(FlatbufferConversions, TestParseOpDataConv) { EXPECT_EQ(4, params->dilation_height_factor); } -TEST(FlatbufferConversions, TestParseOpDataCustom) { - MockErrorReporter mock_reporter; - ErrorReporter* reporter = &mock_reporter; - MockDataAllocator mock_allocator; - - flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset null_options; - flatbuffers::Offset custom_offset = CreateOperatorDirect( - builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr, - CustomOptionsFormat_FLEXBUFFERS, nullptr); - builder.Finish(custom_offset); - void* custom_pointer = builder.GetBufferPointer(); - const Operator* custom_op = flatbuffers::GetRoot(custom_pointer); +TEST_F(FlatbufferConversionsTest, TestParseOpDataCustom) { + const Operator* custom_op = + BuildTestOperator(BuiltinOptions_NONE, flatbuffers::Offset()); void* output_data = nullptr; - EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter, - &mock_allocator, &output_data)); + EXPECT_EQ(kTfLiteOk, + ParseOpData(custom_op, BuiltinOperator_CUSTOM, &mock_reporter_, + &mock_allocator_, &output_data)); EXPECT_EQ(nullptr, output_data); } -TEST(FlatbufferConversions, TestConvertTensorType) { - MockErrorReporter mock_reporter; - ErrorReporter* reporter = &mock_reporter; +TEST_F(FlatbufferConversionsTest, TestConvertTensorType) { TfLiteType type; - EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter)); + EXPECT_EQ(kTfLiteOk, + ConvertTensorType(TensorType_FLOAT32, &type, &mock_reporter_)); EXPECT_EQ(kTfLiteFloat32, type); } diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 4be80d143e7d34b45be45e06e42519afe4d32827..ec6762b16c95e86fb65ec187d5e92f91eff1cbc5 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -72,6 +72,34 @@ bool HasDynamicTensor(const TfLiteContext& context, return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array}); } +// Gets the legacy TfLiteQuantizationParams from the current TfLiteQuantization. +TfLiteQuantizationParams GetLegacyQuantization( + const TfLiteQuantization& quantization) { + TfLiteQuantizationParams legacy_quantization; + legacy_quantization.scale = 0; + legacy_quantization.zero_point = 0; + + // If the quantization type isn't affine, return the empty + // legacy_quantization. + if (quantization.type != kTfLiteAffineQuantization) { + return legacy_quantization; + } + + auto* affine_quantization = + reinterpret_cast(quantization.params); + if (!affine_quantization || !affine_quantization->scale || + !affine_quantization->zero_point || + affine_quantization->scale->size != 1 || + affine_quantization->zero_point->size != 1) { + return legacy_quantization; + } + + // We know its per-layer quantization now. + legacy_quantization.scale = affine_quantization->scale->data[0]; + legacy_quantization.zero_point = affine_quantization->zero_point->data[0]; + return legacy_quantization; +} + } // namespace // A trivial implementation of GraphInfo around the Interpreter. @@ -369,6 +397,10 @@ void Subgraph::SetCancellationFunction(void* data, check_cancelled_func_ = check_cancelled_func; } +void Subgraph::ReserveNodes(int count) { + nodes_and_registration_.reserve(count); +} + TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices, int length) { // Making sure kOptionalTensor is not re-defined to something other than -1. @@ -382,7 +414,9 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices, continue; } if (index < 0 || static_cast(index) >= context_->tensors_size) { - ReportError("Invalid tensor index %d in %s\n", index, label); + ReportError( + "Invalid tensor index %d in %s. The subgraph has %d tensors\n", index, + label, context_->tensors_size); consistent_ = false; return kTfLiteError; } @@ -558,7 +592,12 @@ TfLiteStatus Subgraph::ResizeInputTensor(int tensor_index, // Short-circuit the state change if the dimensions don't change, avoiding // unnecessary (re)allocations. - if (EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) { + // + // Note that it's required to check `tensor->data.raw != nullptr`. Otherwise + // the subgraph won't allocate memory for a dynamic tensor when its size + // is equal to the original tensor size. + if (tensor->data.raw != nullptr && + EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) { return kTfLiteOk; } @@ -779,7 +818,7 @@ TfLiteStatus Subgraph::GetNodeAndRegistration( TfLiteStatus Subgraph::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization, const char* buffer, + const int* dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation) { if (state_ == kStateInvokableAndImmutable) { ReportError( @@ -804,16 +843,22 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) { // Fast path which does not invalidate the invokable property. TfLiteTensorDataFree(&tensor); + TfLiteQuantizationFree(&tensor.quantization); tensor.data.raw = const_cast(buffer); if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims); - tensor.params = quantization; + tensor.params = GetLegacyQuantization(quantization); + tensor.quantization = quantization; tensor.allocation_type = kTfLiteMmapRo; tensor.allocation = allocation; } else { state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), - quantization, const_cast(buffer), bytes, - kTfLiteMmapRo, allocation, false, &tensor); + GetLegacyQuantization(quantization), + const_cast(buffer), bytes, kTfLiteMmapRo, + allocation, false, &tensor); + // TODO(suharshs): Update TfLiteTensorReset to include the new quantization + // if there are other required callers. + tensor.quantization = quantization; } return kTfLiteOk; } @@ -824,7 +869,7 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Subgraph::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { + const int* dims, TfLiteQuantization quantization, bool is_variable) { if (state_ == kStateInvokableAndImmutable) { ReportError( "SetTensorParametersReadWrite is disallowed when graph is immutable."); @@ -854,10 +899,14 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite( allocation_type = kTfLiteArenaRwPersistent; } + TfLiteTensor& tensor = context_->tensors[tensor_index]; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), - quantization, + GetLegacyQuantization(quantization), /*buffer=*/nullptr, required_bytes, allocation_type, - nullptr, is_variable, &context_->tensors[tensor_index]); + nullptr, is_variable, &tensor); + // TODO(suharshs): Update TfLiteTensorReset to include the new quantization + // if there are other required callers. + tensor.quantization = quantization; return kTfLiteOk; } diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 0bec218d898873459ce80eb0d191041851cda90d..5db15a177ef9fe8fcb54e0bf92f0193238440941 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -59,6 +59,11 @@ class Subgraph { // interpreter. TfLiteStatus SetVariables(std::vector variables); + // Ensure the internal node storage memory allocates at least `count` + // spots for node. NOTE, this doesn't actually add operators. This is an + // efficiency optimization that is subject to change. + void ReserveNodes(int count); + // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' @@ -68,29 +73,48 @@ class Subgraph { const char* init_data, size_t init_data_size, void* builtin_data, const TfLiteRegistration* registration, - int* node_index); + int* node_index = nullptr); // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. // The value pointed to by `first_new_tensor_index` will be set to the // index of the first new tensor if `first_new_tensor_index` is non-null. - TfLiteStatus AddTensors(int tensors_to_add, int* first_new_tensor_index); + TfLiteStatus AddTensors(int tensors_to_add, + int* first_new_tensor_index = nullptr); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. + inline TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantization quantization, + const char* buffer, size_t bytes, + const Allocation* allocation = nullptr) { + return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), + dims.data(), quantization, buffer, bytes, + allocation); + } TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization, - const char* buffer, size_t bytes, const Allocation* allocation); + const int* dims, TfLiteQuantization quantization, const char* buffer, + size_t bytes, const Allocation* allocation = nullptr); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. - TfLiteStatus SetTensorParametersReadWrite( - int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization, bool is_variable); + inline TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantization quantization, + bool is_variable = false) { + return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), + dims.data(), quantization, is_variable); + } + TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, + const char* name, const size_t rank, + const int* dims, + TfLiteQuantization quantization, + bool is_variable = false); // WARNING: Experimental interface, subject to change // Overrides execution plan. This bounds checks indices sent in. diff --git a/tensorflow/lite/delegates/flex/buffer_map.h b/tensorflow/lite/delegates/flex/buffer_map.h index 45bc571f9284abe95cd3550e64dd098157da14a8..b73ed88d3789d5df8dadaee19d468596ccd4c782 100644 --- a/tensorflow/lite/delegates/flex/buffer_map.h +++ b/tensorflow/lite/delegates/flex/buffer_map.h @@ -56,25 +56,7 @@ class BufferMap { // given TfLiteTensor's data. void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor); - // Sets a bit indicating that the tensor associated with 'tensor_index' can - // be use by TF's forwarding optimizations. - void SetForwardable(int tensor_index) { forwardable_.insert(tensor_index); } - - // Removes all information about which tensors are forwardable. - void ClearForwardable() { forwardable_.clear(); } - - // Returns true if this tensor has been explicitly marks as forwardable by - // a call to SetForwardable(). - bool IsForwardable(int tensor_index) const { - return forwardable_.count(tensor_index) > 0; - } - private: - // List of tensors that can be used by TF in its forwarding optimization. - // Doing so allows an input tensor to be modified and used as the output - // tensor. The delegate takes care of not holding any references to tensors - // in this list while Eager is executing the corresponding op. - std::set forwardable_; // Mapping from TL Lite tensor ID to TensorFlow's Tensor. All tensors that // are inputs or outputs of a subgraph will be added here, irrespective of // whether their data are managed by TF Lite or TensorFlow. diff --git a/tensorflow/lite/delegates/flex/buffer_map_test.cc b/tensorflow/lite/delegates/flex/buffer_map_test.cc index 8edeb280ebdc7953a1031e30b77bb58a760e20e7..accaf3045246b35705085bd5324e5b33ec8ea12a 100644 --- a/tensorflow/lite/delegates/flex/buffer_map_test.cc +++ b/tensorflow/lite/delegates/flex/buffer_map_test.cc @@ -44,6 +44,7 @@ UniqueTfLiteTensor MakeLiteTensor(const std::vector& shape, tensor->dims = ConvertVectorToTfLiteIntArray(shape); tensor->data.raw = nullptr; tensor->is_variable = false; + memset(&tensor->quantization, 0, sizeof(TfLiteQuantization)); TfLiteTensorRealloc(data.size() * sizeof(T), tensor.get()); memcpy(tensor->data.raw, data.data(), data.size() * sizeof(T)); return tensor; @@ -62,6 +63,7 @@ UniqueTfLiteTensor MakeLiteTensor(const std::vector& shape, tensor->dims = ConvertVectorToTfLiteIntArray(shape); tensor->data.raw = nullptr; tensor->is_variable = false; + memset(&tensor->quantization, 0, sizeof(TfLiteQuantization)); TfLiteTensorRealloc(data.size() * sizeof(string), tensor.get()); DynamicBuffer b; @@ -222,15 +224,6 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) { ElementsAre(0, 0, 0, 0.123f, 0, 0)); } -TEST(BufferMapTest, Forwardable) { - BufferMap buffer_map; - EXPECT_FALSE(buffer_map.IsForwardable(0)); - buffer_map.SetForwardable(0); - EXPECT_TRUE(buffer_map.IsForwardable(0)); - buffer_map.ClearForwardable(); - EXPECT_FALSE(buffer_map.IsForwardable(0)); -} - } // namespace } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 2e0fc22ad6872884d04da4c2d2f8a4dce0246de9..ceb9918f6fa7ccfbb4d27a0bf921987faecc1c12 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -67,6 +67,7 @@ class OpInputs { for (int index : TfLiteIntArrayView(indexes)) { inputs_.push_back(index); } + forwardable_.resize(inputs_.size()); } ~OpInputs() {} @@ -89,11 +90,21 @@ class OpInputs { } } + void SetForwardable(int i, bool v) { forwardable_[i] = v; } + + bool IsForwardable(int i) const { return forwardable_[i]; } + TensorSource GetTensorSource(int i) const { return sources_[i]; } private: std::vector inputs_; std::vector sources_; + + // List of tensors that can be used by TF in its forwarding optimization. + // Doing so allows an input tensor to be modified and used as the output + // tensor. The delegate takes care of not holding any references to tensors + // in this list while Eager is executing the corresponding op. + std::vector forwardable_; }; // A list of outputs of a given node of the TensorFlow/Eager graph, along with @@ -279,7 +290,7 @@ class OpNode { } else { // If this is a forwardable tensor, we will remove it from the previous // op's list, giving TF the opportunity to reuse its buffer. - bool unref_handle = buffer_map->IsForwardable(input_index); + bool unref_handle = inputs_.IsForwardable(i); auto* handle = s.node->outputs_.GetHandle(s.node_output_index, unref_handle); op_->MutableInputs()->push_back(handle); @@ -485,13 +496,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } } - buffer_map->ClearForwardable(); - for (const auto& x : tensor_ref_count) { - if (x.second == 1) { - // This tensor is referenced once by a single op. We can allow the TF - // kernel to "forward" it to the output, meaning its buffer will be - // reused and overwritten. - buffer_map->SetForwardable(x.first); + // All tensors that are referenced exactly once are marked as "forwardable", + // meaning that we will allow TensorFlow to reuse its buffer as the output of + // an op. + for (auto& node_data : op_data->nodes) { + for (int i = 0; i < node_data->inputs().Size(); ++i) { + bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1); + node_data->mutable_inputs()->SetForwardable(i, f); } } diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD index 63f6da1fa6d38f39abed2bbe2b382841c7d5bce4..ec0d78e59d5e726f146f1d143b5fdc63f3e0d8aa 100644 --- a/tensorflow/lite/delegates/nnapi/BUILD +++ b/tensorflow/lite/delegates/nnapi/BUILD @@ -2,7 +2,6 @@ package(default_visibility = [ "//visibility:public", ]) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") licenses(["notice"]) # Apache 2.0 @@ -16,11 +15,11 @@ cc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:kernel_util", - "//tensorflow/lite/nnapi:nnapi_lib", + "//tensorflow/lite/nnapi:nnapi_implementation", ], ) -tf_cc_test( +cc_test( name = "nnapi_delegate_test", size = "small", srcs = ["nnapi_delegate_test.cc"], diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 94fcc8740b42c8f0beaf8801b5a061bfbf3c6433..efbb7d7d6704cebe48e23a4f35a484956227366a 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -25,11 +25,13 @@ limitations under the License. #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" #ifdef __ANDROID__ -#include #include +#endif +#if defined __ANDROID__ || defined __unix__ +#include #include #endif @@ -49,65 +51,93 @@ namespace { } while (0) namespace { -int32_t GetAndroidSdkVersion() { -#ifdef __ANDROID__ - const char* sdkProp = "ro.build.version.sdk"; - char sdkVersion[PROP_VALUE_MAX]; - int length = __system_property_get(sdkProp, sdkVersion); - if (length != 0) { - for (int i = 0; i < length; ++i) { - int digit = sdkVersion[i] - '0'; - if (digit < 0 || digit > 9) { - // Non-numeric SDK version, assume it's higher then expected; - return std::numeric_limits::max(); - } + +bool IsFloat(TfLiteType type) { + switch (type) { + case kTfLiteFloat32: + return true; + default: + return false; + } +} + +bool IsQuantized(TfLiteType type) { + switch (type) { + case kTfLiteUInt8: + case kTfLiteInt8: + case kTfLiteInt16: + return true; + default: + return false; + } +} + +bool IsHybridOperator(const TfLiteContext* context, int builtin_code, + const TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinFullyConnected: { + const int input_id = node->inputs->data[0]; + const int filter_id = node->inputs->data[1]; + const TfLiteType input_type = context->tensors[input_id].type; + const TfLiteType filter_type = context->tensors[filter_id].type; + return IsFloat(input_type) && IsQuantized(filter_type); } - return atoi(sdkVersion); + default: + return false; } -#endif // __ANDROID__ - return 0; } constexpr int32_t kMinSdkVersionForNNAPI = 27; constexpr int32_t kMinSdkVersionForNNAPI11 = 28; -static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); +constexpr int32_t kMinSdkVersionForNNAPI12 = 29; } // namespace // RAII NN API Model Destructor for use with std::unique_ptr struct NNFreeModel { void operator()(ANeuralNetworksModel* model) { - ANeuralNetworksModel_free(model); + NnApiImplementation()->ANeuralNetworksModel_free(model); } }; // RAII NN API Compilation Destructor for use with std::unique_ptr struct NNFreeCompilation { void operator()(ANeuralNetworksCompilation* model) { - ANeuralNetworksCompilation_free(model); + NnApiImplementation()->ANeuralNetworksCompilation_free(model); + } +}; + +// RAII NN API Execution Destructor for use with std::unique_ptr +struct NNFreeExecution { + void operator()(ANeuralNetworksExecution* execution) { + NnApiImplementation()->ANeuralNetworksExecution_free(execution); } }; // Manage NNAPI shared memory handle class NNMemory { public: - NNMemory(const char* name, size_t size) { -#ifdef __ANDROID__ +#if defined __ANDROID__ || defined __unix__ + NNMemory(const NnApi* nnapi, const char* name, size_t size) { + nnapi_ = nnapi; byte_size_ = size; - fd_ = ASharedMemory_create(name, size); + fd_ = nnapi_->ASharedMemory_create(name, size); data_ptr_ = reinterpret_cast( mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0)); - ANeuralNetworksMemory_createFromFd(size, PROT_READ | PROT_WRITE, fd_, 0, - &nn_memory_handle_); -#endif + nnapi_->ANeuralNetworksMemory_createFromFd(size, PROT_READ | PROT_WRITE, + fd_, 0, &nn_memory_handle_); } +#else + NNMemory(const NnApi* /*nnapi*/, const char* /*name*/, size_t /*size*/) {} +#endif ~NNMemory() { -#ifdef __ANDROID__ +#if defined __ANDROID__ || defined __unix__ if (data_ptr_) { munmap(data_ptr_, byte_size_); } if (nn_memory_handle_) { - ANeuralNetworksMemory_free(nn_memory_handle_); + nnapi_->ANeuralNetworksMemory_free(nn_memory_handle_); } if (fd_ > 0) close(fd_); #endif @@ -117,7 +147,8 @@ class NNMemory { uint8_t* get_data_ptr() { return data_ptr_; } private: -#ifdef __ANDROID__ +#if defined __ANDROID__ || defined __unix__ + const NnApi* nnapi_; int fd_ = 0; size_t byte_size_ = 0; #endif @@ -162,15 +193,42 @@ class OperandMapping { std::vector lite_tensor_to_ann_tensor_; }; +class DequantizeMapping { + public: + int DequantizedAnnIndex(int ann_index, TfLiteType type) const { + for (const auto& element : mapping_) { + if (ann_index == std::get<0>(element) && type == std::get<1>(element)) { + return std::get<2>(element); + } + } + return -1; + } + + void Add(int ann_index, TfLiteType type, int dequantized_ann_index) { + // This assumes it is not already mapped. + mapping_.emplace_back(ann_index, type, dequantized_ann_index); + } + + private: + // Each tuple specifies the ANN (quantized) tensor index, the desired + // floating-point type and the matching ANN (dequantized) tensor index. This + // could use a map but instead std::vector is used to keep code size lower. + std::vector> mapping_; +}; + // Abstract builder for building an op in the NN API graph. This handles // the disparity between TFLite and NN API operand types. NN API has singular // operands for both tensors and parameters, and TFLite separates the two. class NNAPIOpBuilder { public: - NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + NNAPIOpBuilder(const NnApi* nnapi, TfLiteContext* context, + OperandMapping* tensor_mapping, + DequantizeMapping* dequantize_mapping, ANeuralNetworksModel* nn_model) - : context_(context), + : nnapi_(nnapi), + context_(context), operand_mapping_(tensor_mapping), + dequantize_mapping_(dequantize_mapping), nn_model_(nn_model) {} TfLiteStatus AddScalarInt32Operand(int32_t value) { @@ -204,48 +262,129 @@ class NNAPIOpBuilder { return kTfLiteOk; } - TfLiteStatus AddTensorInput(int tensor_index) { - int ann_index; - TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); - augmented_inputs_.push_back(ann_index); - return kTfLiteOk; + TfLiteStatus AddTensorInput(int tensor_index, bool hybrid_op) { + return AddTensor(tensor_index, hybrid_op, &augmented_inputs_); } TfLiteStatus AddTensorOutput(int tensor_index) { - int ann_index; - TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); - augmented_outputs_.push_back(ann_index); - return kTfLiteOk; + return AddTensor(tensor_index, /*hybrid_op=*/false, &augmented_outputs_); } TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) { std::vector dims(dimension_count, 0); - ANeuralNetworksOperandType operand_type{ - .type = ANEURALNETWORKS_TENSOR_FLOAT32, - .dimensionCount = dimension_count, - .dimensions = dims.data()}; - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - augmented_outputs_.push_back(ann_operand); - return kTfLiteOk; + return AddFloat32OutputTensor(dimension_count, dims.data(), nullptr); } TfLiteStatus AddStateFloat32Tensor(int tensor_index, int* ann_tensor_index_out) { TfLiteTensor* tensor = &context_->tensors[tensor_index]; - int ann_index = operand_mapping_->add_new_non_tensor_operand(); + return AddFloat32OutputTensor( + tensor->dims->size, reinterpret_cast(tensor->dims->data), + ann_tensor_index_out); + } + // Adds a Dequantize operator and replaces the input tensor index with the + // dequantized version. If the dequantized version of the operator already + // exists then it is not added again. + TfLiteStatus AddDequantize(int nn_input_index, int lite_index, + TfLiteType dequantized_type) { + const int ann_index = operand_mapping_->lite_index_to_ann(lite_index); + int dequantized_ann_index = + dequantize_mapping_->DequantizedAnnIndex(ann_index, dequantized_type); + + if (dequantized_ann_index == -1) { + // The dequantized version does not exist yet, it has to be added: a new + // Dequantize operation is added, yielding a new tensor. + const TfLiteTensor& tensor = context_->tensors[lite_index]; + ANeuralNetworksOperandType operand_type{ + dequantized_type, static_cast(tensor.dims->size), + reinterpret_cast(tensor.dims->data), 0.f, 0}; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, + nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + dequantized_ann_index = operand_mapping_->add_new_non_tensor_operand(); + + // Add Dequantize operation. + const uint32_t dequantize_input[1] = {static_cast(ann_index)}; + const uint32_t dequantize_output[1] = { + static_cast(dequantized_ann_index)}; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, nnapi_->ANeuralNetworksModel_addOperation( + nn_model_, ANEURALNETWORKS_DEQUANTIZE, 1, + dequantize_input, 1, dequantize_output)); + dequantize_mapping_->Add(ann_index, dequantized_type, + dequantized_ann_index); + } + + // The input for the original operation is modified so that the operation + // now uses the dequantized tensor as input. + augmented_inputs_[nn_input_index] = dequantized_ann_index; + + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, + nnapi_->ANeuralNetworksModel_addOperation( + nn_model_, type, static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_inputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + template + TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { + ANeuralNetworksOperandType operand_type{.type = nn_type}; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, + nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + const int ann_index = operand_mapping_->add_new_non_tensor_operand(); + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, nnapi_->ANeuralNetworksModel_setOperandValue( + nn_model_, ann_index, &value, sizeof(T))); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + template + TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values, + int32_t nn_type) { ANeuralNetworksOperandType operand_type{ - ANEURALNETWORKS_TENSOR_FLOAT32, - static_cast(tensor->dims->size), - reinterpret_cast(tensor->dims->data), tensor->params.scale, - tensor->params.zero_point}; + .type = nn_type, .dimensionCount = 1, .dimensions = &num_values}; + RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - augmented_outputs_.push_back(ann_index); + context_, + nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + const int ann_index = operand_mapping_->add_new_non_tensor_operand(); + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, nnapi_->ANeuralNetworksModel_setOperandValue( + nn_model_, ann_index, values, sizeof(T) * num_values)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } - *ann_tensor_index_out = ann_index; + TfLiteStatus AddFloat32OutputTensor(uint32_t dimension_count, + const uint32_t* dimension_data, + int* ann_index_out) { + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_FLOAT32, + .dimensionCount = dimension_count, + .dimensions = dimension_data, + }; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context_, + nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + const int ann_index = operand_mapping_->add_new_non_tensor_operand(); + augmented_outputs_.push_back(ann_index); + if (ann_index_out) *ann_index_out = ann_index; return kTfLiteOk; } @@ -253,10 +392,11 @@ class NNAPIOpBuilder { // This returns the NN API tensor index corresponding to the created tensor. // If another caller previously created a NN API tensor for `tensor_index` // then the existing one is returned. - TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + TfLiteStatus AddTensor(int tensor_index, bool hybrid_op, + std::vector* indices) { int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); if (ann_tensor_index != -1) { - *ann_tensor_index_out = ann_tensor_index; + indices->push_back(ann_tensor_index); return kTfLiteOk; } // Allocate a new tensor index @@ -267,11 +407,17 @@ class NNAPIOpBuilder { float scale = 0.0f; int32_t zeroPoint = 0; TfLiteTensor* tensor = &context_->tensors[tensor_index]; - switch (tensor->type) { + TfLiteType tensor_type = tensor->type; + if (hybrid_op && (tensor_type == kTfLiteUInt8)) { + // For legacy reason, UINT8 weights in hybrid operators are actually INT8 + // values and should be interpreted as such. + tensor_type = kTfLiteInt8; + } + switch (tensor_type) { case kTfLiteNoType: // Tensors added during initialization of Ops don't have a type yet and // should not be registered with the NNAPI. - *ann_tensor_index_out = -1; + indices->push_back(-1); return kTfLiteOk; case kTfLiteFloat32: nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; @@ -285,6 +431,10 @@ class NNAPIOpBuilder { scale = 1; } break; + case kTfLiteInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM; + scale = tensor->params.scale; + break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; scale = tensor->params.scale; @@ -299,69 +449,35 @@ class NNAPIOpBuilder { nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + context_, + nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); if (tensor->allocation_type == kTfLiteMmapRo) { // TODO(b/80630405): Use NNAPIAllocation. RETURN_TFLITE_ERROR_IF_NN_ERROR( context_, - ANeuralNetworksModel_setOperandValue( + nnapi_->ANeuralNetworksModel_setOperandValue( nn_model_, ann_tensor_index, tensor->data.raw, tensor->bytes)); } - *ann_tensor_index_out = ann_tensor_index; - return kTfLiteOk; - } - - // Finish emitting the op (of type `type`) into the NN API. - TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { - // Actually add a NN API operation - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, - ANeuralNetworksModel_addOperation( - nn_model_, type, static_cast(augmented_inputs_.size()), - augmented_inputs_.data(), - static_cast(augmented_outputs_.size()), - augmented_outputs_.data())); - augmented_inputs_.clear(); - augmented_outputs_.clear(); - return kTfLiteOk; - } - - private: - template - TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { - ANeuralNetworksOperandType operand_type{.type = nn_type}; - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_setOperandValue(nn_model_, ann_operand, - &value, sizeof(T))); - augmented_inputs_.push_back(ann_operand); + indices->push_back(ann_tensor_index); return kTfLiteOk; } - template - TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values, - int32_t nn_type) { - ANeuralNetworksOperandType operand_type{ - .type = nn_type, .dimensionCount = 1, .dimensions = &num_values}; - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context_, ANeuralNetworksModel_setOperandValue( - nn_model_, ann_operand, values, sizeof(T) * num_values)); - augmented_inputs_.push_back(ann_operand); - return kTfLiteOk; - } + // Access to NNAPI. + const NnApi* const nnapi_; // TfLiteContext for error handling. TfLiteContext* const context_; // Tracks relationship between indices. - OperandMapping* operand_mapping_; + OperandMapping* const operand_mapping_; + + // Keeps mapping of ANN quantized tensor and float data type to equivalent + // dequantized ANN tensor. For example, tensor #4 (UINT8) + FLOAT32 could map + // to tensor #10 (FLOAT32) because a DEQUANTIZE operator was added to convert + // tensor #4 to a FLOAT32 tensor. + DequantizeMapping* const dequantize_mapping_; // The NNAPI model. ANeuralNetworksModel* const nn_model_; @@ -392,7 +508,7 @@ ANeuralNetworksOperationType BasicMappingFn( // The kernel that represents the node sub set of TF Lite being run on NN API. class NNAPIDelegateKernel { public: - NNAPIDelegateKernel() = default; + NNAPIDelegateKernel() { nnapi_ = NnApiImplementation(); } typedef ANeuralNetworksOperationType (*MappingFn)( const NNAPIOpMappingArgs& mapping_args); @@ -400,8 +516,9 @@ class NNAPIDelegateKernel { // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported // (i.e. that MappingFn is not nullptr). - static MappingFn Map(TfLiteContext* context, int builtin_code, int version, - TfLiteNode* node) { + static MappingFn Map(const TfLiteContext* context, int builtin_code, + int version, int android_sdk_version, + const TfLiteNode* node) { switch (builtin_code) { case kTfLiteBuiltinAdd: if (version == 1) { @@ -457,6 +574,11 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinConv2d: if (version == 1) { + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + IsHybridOperator(context, builtin_code, node)) { + // Hybrid operators not supported before NNAPI 1.2. + return nullptr; + } auto builtin = reinterpret_cast(node->builtin_data); if (builtin->dilation_width_factor != 1 || @@ -494,6 +616,11 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinFullyConnected: if (version == 1) { + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + IsHybridOperator(context, builtin_code, node)) { + // Hybrid operators not supported before NNAPI 1.2. + return nullptr; + } return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( @@ -520,7 +647,7 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinSqueeze: - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( @@ -528,7 +655,7 @@ class NNAPIDelegateKernel { // Note that we add the squeeze dimensions even if the dimensions // were unspecified (empty), as NNAPI requires the operand. mapping_args.builder->AddVectorInt32Operand( - builtin->squeeze_dims, + builtin->num_squeeze_dims ? builtin->squeeze_dims : nullptr, static_cast(builtin->num_squeeze_dims)); return ANEURALNETWORKS_SQUEEZE; }; @@ -576,9 +703,10 @@ class NNAPIDelegateKernel { if (version == 1 && reinterpret_cast(node->builtin_data) ->activation == kTfLiteActNone) { - if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8) { - // NNAPI only support concatenating quantized tensor of the same - // scale and offset. + if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8 && + android_sdk_version < kMinSdkVersionForNNAPI12) { + // NNAPI 1.0-1 only supported concatenating quantized tensor of the + // same scale and offset. auto first_param = context->tensors[node->inputs->data[0]].params; for (int i = 1; i < node->inputs->size; i++) { auto curr_param = context->tensors[node->inputs->data[i]].params; @@ -598,7 +726,16 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinDequantize: - if (version == 1) { + if (version == 1 || version == 2) { + const auto& input = context->tensors[node->inputs->data[0]]; + const auto zero_point = input.params.zero_point; + // NN API supports int8 type since version 1.2 but only for symmetric + // quantization. + if (input.type == kTfLiteInt8 && + (zero_point != 0 || + android_sdk_version < kMinSdkVersionForNNAPI12)) { + return nullptr; + } return BasicMappingFn; } break; @@ -636,7 +773,7 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinSub: - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 && context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { // NNAPI only support float sub. return [](const NNAPIOpMappingArgs& mapping_args) @@ -649,7 +786,7 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinDiv: - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 && context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { // NNAPI only support float div. return [](const NNAPIOpMappingArgs& mapping_args) @@ -662,22 +799,24 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinPad: - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && - node->inputs->size == 2 && - context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + if (version == 1 && node->inputs->size == 2 && + (android_sdk_version >= kMinSdkVersionForNNAPI11) && + (context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 || + android_sdk_version >= kMinSdkVersionForNNAPI12)) { // NNAPI does not support specifying the padding value. - // NNAPI pads physical zero for quantized tensors, so only delegate - // float pad to NNAPI. + // Before 1.2, NNAPI pads physical zero for quantized tensors, so only + // delegate float pad to NNAPI. NNAPI 1.2 onwards pads with + // zero-point, so delegate quantized pad as well. return BasicMappingFn; } break; case kTfLiteBuiltinSpaceToBatchNd: - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { return BasicMappingFn; } break; case kTfLiteBuiltinStridedSlice: - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( @@ -695,7 +834,7 @@ class NNAPIDelegateKernel { // dimensions. // TODO(b/110888333): Support dynamically-sized tensors in delegates. if ((version == 1) && - (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) && + (android_sdk_version >= kMinSdkVersionForNNAPI11) && (node->inputs->size > 1) && (context->tensors[node->inputs->data[1]].allocation_type == kTfLiteMmapRo)) { @@ -726,7 +865,9 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinSvdf: // NNAPI only support float32 weights. + // Only delegate to NNAPI 1.1, as SVDF does not support rank > 1 on 1.0. if (version == 1 && node->inputs->size == 5 && + android_sdk_version >= kMinSdkVersionForNNAPI11 && context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]] .type == kTfLiteFloat32) { return [](const NNAPIOpMappingArgs& mapping_args) @@ -752,8 +893,11 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinLstm: // NNAPI only support float32 weights. + // Only delegate to NNAPI 1.1, as 1.0 has a bug for optional tensors + // which would affect LSTM. // TODO(miaowang): add loggings to indicate why the op is rejected. if (version == 1 && node->inputs->size == 20 && + android_sdk_version >= kMinSdkVersionForNNAPI11 && context->tensors[node->inputs ->data[/*kInputToOutputWeightsTensor*/ 4]] .type == kTfLiteFloat32) { @@ -793,7 +937,7 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinMean: // NNAPI does not support generating a scalar as output for MEAN. - if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 && context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 && context->tensors[node->outputs->data[0]].dims->size > 0) { return [](const NNAPIOpMappingArgs& mapping_args) @@ -836,9 +980,9 @@ class NNAPIDelegateKernel { } if (!nn_model_) { - ANeuralNetworksModel* model; - RETURN_TFLITE_ERROR_IF_NN_ERROR(context, - ANeuralNetworksModel_create(&model)); + ANeuralNetworksModel* model = nullptr; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi_->ANeuralNetworksModel_create(&model)); nn_model_.reset(model); TF_LITE_ENSURE_STATUS( @@ -846,12 +990,17 @@ class NNAPIDelegateKernel { } if (!nn_compilation_) { - ANeuralNetworksCompilation* compilation; + ANeuralNetworksCompilation* compilation = nullptr; RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, - ANeuralNetworksCompilation_create(nn_model_.get(), &compilation)); - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, ANeuralNetworksCompilation_finish(compilation)); + context, nnapi_->ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + const int finish_result = + nnapi_->ANeuralNetworksCompilation_finish(compilation); + if (finish_result != ANEURALNETWORKS_NO_ERROR) { + nnapi_->ANeuralNetworksCompilation_free(compilation); + compilation = nullptr; + } + RETURN_TFLITE_ERROR_IF_NN_ERROR(context, finish_result); nn_compilation_.reset(compilation); } return kTfLiteOk; @@ -860,8 +1009,10 @@ class NNAPIDelegateKernel { TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { ANeuralNetworksExecution* execution = nullptr; RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, - ANeuralNetworksExecution_create(nn_compilation_.get(), &execution)); + context, nnapi_->ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + std::unique_ptr + execution_unique_ptr(execution); // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. @@ -881,7 +1032,7 @@ class NNAPIDelegateKernel { tensor->data.raw, tensor->bytes); RETURN_TFLITE_ERROR_IF_NN_ERROR( context, - ANeuralNetworksExecution_setInputFromMemory( + nnapi_->ANeuralNetworksExecution_setInputFromMemory( execution, relative_input_index, nullptr, nn_input_memory_->get_handle(), input_offset, tensor->bytes)); input_offset += tensor->bytes; @@ -896,7 +1047,7 @@ class NNAPIDelegateKernel { TfLiteTensor* tensor = &context->tensors[output_index]; RETURN_TFLITE_ERROR_IF_NN_ERROR( context, - ANeuralNetworksExecution_setOutputFromMemory( + nnapi_->ANeuralNetworksExecution_setOutputFromMemory( execution, relative_output_index, nullptr, nn_output_memory_->get_handle(), output_offset, tensor->bytes)); output_offset += tensor->bytes; @@ -912,18 +1063,25 @@ class NNAPIDelegateKernel { // reading and writing into the same buffer during a invocation. // TODO(110369471): using double shared buffer to minimize the copies. RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, ANeuralNetworksExecution_setOutput( + context, nnapi_->ANeuralNetworksExecution_setOutput( execution, relative_output_index, nullptr, tensor->data.raw, tensor->bytes)); relative_output_index++; } // Invoke ANN in blocking fashion. - ANeuralNetworksEvent* event = nullptr; - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, ANeuralNetworksExecution_startCompute(execution, &event)); - RETURN_TFLITE_ERROR_IF_NN_ERROR(context, ANeuralNetworksEvent_wait(event)); - ANeuralNetworksEvent_free(event); - ANeuralNetworksExecution_free(execution); + if (nnapi_->android_sdk_version < kMinSdkVersionForNNAPI12) { + ANeuralNetworksEvent* event = nullptr; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, + nnapi_->ANeuralNetworksExecution_startCompute(execution, &event)); + const int wait_result = nnapi_->ANeuralNetworksEvent_wait(event); + nnapi_->ANeuralNetworksEvent_free(event); + RETURN_TFLITE_ERROR_IF_NN_ERROR(context, wait_result); + } else { + // Use synchronous execution for NNAPI 1.2+. + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi_->ANeuralNetworksExecution_compute(execution)); + } // copy results from shared memory to the destination. output_offset = 0; @@ -938,6 +1096,8 @@ class NNAPIDelegateKernel { } private: + // Access to NNApi. + const NnApi* nnapi_; // ANN API state. std::unique_ptr nn_model_; std::unique_ptr @@ -954,18 +1114,69 @@ class NNAPIDelegateKernel { std::unique_ptr nn_input_memory_; std::unique_ptr nn_output_memory_; + void AddDequantizeOperatorsWhereNeeded(const TfLiteContext* context, + int builtin_code, + const TfLiteNode* node, + NNAPIOpBuilder* builder) { + // Depending on the operator and the input data format, Dequantize + // operators may need to be added. For example when the input is + // floating-point but weights are quantized then the weights will first be + // dequantized to the same format as the input before being passed to the + // operator. + + // The tensor determining whether the inputs should be floating-point. + int input_tensor_index = -1; + std::vector inputs_to_potentially_dequantize; + + switch (builtin_code) { + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinFullyConnected: { + input_tensor_index = 0; + // Weights and bias are inputs #1 and #2 respectively and may require + // dequantization. + inputs_to_potentially_dequantize = {1, 2}; + break; + } + default: + return; + } + + int tensor_id = node->inputs->data[input_tensor_index]; + if (tensor_id < 0) return; + + // Nothing to do if the input is not floating-point. + if (!IsFloat(context->tensors[tensor_id].type)) return; + + for (int i : inputs_to_potentially_dequantize) { + tensor_id = node->inputs->data[i]; + if (tensor_id < 0) continue; // Ignore optional input. + + const TfLiteType type = context->tensors[tensor_id].type; + // Nothing to do for this tensor if it's not quantized. + if (type != kTfLiteUInt8) continue; + + // Insert Dequantize operator if it hasn't been done already and change + // the node's input accordingly. + builder->AddDequantize(i, node->inputs->data[i], type); + } + } + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { - // The operand builder allows creating a single op. We create it at this - // reduced power position rather than in the for loop to avoid reallocating - // the vectors. - NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); - // Add Tensors - // allocate outside to avoid realloc + DequantizeMapping dequantize_mapping; + // The operand builder allows creating a single op. It is created outside + // the for loop to avoid reallocating the vectors. + NNAPIOpBuilder builder(nnapi_, context, &operand_mapping_, + &dequantize_mapping, nn_model_.get()); + // Add Tensors. for (auto node_index : nodes_) { // Obtain the op and registration. TfLiteNode* node; TfLiteRegistration* reg; - context->GetNodeAndRegistration(context, node_index, &node, ®); + TF_LITE_ENSURE_STATUS( + context->GetNodeAndRegistration(context, node_index, &node, ®)); + + const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node); + // Map inputs to NN API tensor indices. for (auto input_index : TfLiteIntArrayView(node->inputs)) { if (input_index == kOptionalTensor && @@ -977,18 +1188,24 @@ class NNAPIDelegateKernel { // tensor when supported by NNAPI. TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0)); } else { - TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op)); } } // Get op type and operands - int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( - {context, &builder, node, &model_state_outputs_, - &model_state_tfl_inputs_}); + int nn_op_type = Map( + context, reg->builtin_code, reg->version, nnapi_->android_sdk_version, + node)({context, &builder, node, &model_state_outputs_, + &model_state_tfl_inputs_}); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); } + // Dequantize operators may have to be added in case inputs are to be + // floating-point. + AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node, + &builder); + builder.FinalizeAddOperation(nn_op_type); } return kTfLiteOk; @@ -1022,32 +1239,34 @@ class NNAPIDelegateKernel { total_output_byte_size += context->tensors[i].bytes; } - // Add state output tensors as model inputs + // Add state output tensors as model outputs. for (int i : model_state_outputs_) { outputs.push_back(i); } // Tell ANN to declare inputs/outputs RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, ANeuralNetworksModel_identifyInputsAndOutputs( + context, nnapi_->ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_.get(), inputs.size(), inputs.data(), outputs.size(), outputs.data())); // Set relaxed computation mode for fp32 if possible. - if (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI11) { RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, ANeuralNetworksModel_relaxComputationFloat32toFloat16( - nn_model_.get(), context->allow_fp32_relax_to_fp16)); + context, + nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16( + nn_model_.get(), context->allow_fp32_relax_to_fp16)); } // Finalize the model RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, ANeuralNetworksModel_finish(nn_model_.get())); + context, nnapi_->ANeuralNetworksModel_finish(nn_model_.get())); // Create shared memory pool for inputs and outputs. - nn_input_memory_.reset(new NNMemory("input_pool", total_input_byte_size)); + nn_input_memory_.reset( + new NNMemory(nnapi_, "input_pool", total_input_byte_size)); nn_output_memory_.reset( - new NNMemory("output_pool", total_output_byte_size)); + new NNMemory(nnapi_, "output_pool", total_output_byte_size)); return kTfLiteOk; } @@ -1059,14 +1278,26 @@ class NNAPIDelegateKernel { TfLiteDelegate* NnApiDelegate() { static TfLiteDelegate delegate = { .data_ = nullptr, - .flags = kTfLiteDelegateFlagsNone, .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // Do not check nodes_ if NN API is unavailable. - if (kAndroidSdkVersion < kMinSdkVersionForNNAPI || !NNAPIExists()) { + const NnApi* nnapi = NnApiImplementation(); + if (nnapi->android_sdk_version < kMinSdkVersionForNNAPI || + !nnapi->nnapi_exists) { return kTfLiteOk; } - + // For NNAPI 1.2+, check if there is any accelerator available. + // If not, don't delegate to NNAPI's CPU reference implementation. + if (nnapi->android_sdk_version >= kMinSdkVersionForNNAPI12) { + uint32_t device_count = 0; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi->ANeuralNetworks_getDeviceCount(&device_count)); + // Any available accelerator will make the device_count larger than 1. + // More sophisticated check and whitelisting can be added later. + if (device_count <= 1) { + return kTfLiteOk; + } + } // Allocate one element in vector already since TensorFlow Lite uses // the first value as the number of nodes. The actual value will be set // later, after the vector has been filled. @@ -1076,6 +1307,7 @@ TfLiteDelegate* NnApiDelegate() { TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int android_sdk_version = NnApiImplementation()->android_sdk_version; // Check for every node if it is supported // TODO(b/80625235): Fix this to do more careful checking of versioning. for (int node_index : TfLiteIntArrayView(plan)) { @@ -1084,7 +1316,8 @@ TfLiteDelegate* NnApiDelegate() { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); if (NNAPIDelegateKernel::Map(context, registration->builtin_code, - registration->version, node)) { + registration->version, + android_sdk_version, node)) { supported_nodes.push_back(node_index); } } @@ -1121,6 +1354,7 @@ TfLiteDelegate* NnApiDelegate() { return state->Invoke(context, node); }, + .profiling_string = nullptr, .builtin_code = kTfLiteBuiltinDelegate, }; @@ -1130,7 +1364,13 @@ TfLiteDelegate* NnApiDelegate() { context, nnapi_delegate_kernel, reinterpret_cast(supported_nodes.data()), delegate); - }}; + }, + + .CopyFromBufferHandle = nullptr, + .CopyToBufferHandle = nullptr, + .FreeBufferHandle = nullptr, + .flags = kTfLiteDelegateFlagsNone, + }; return &delegate; } diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index 5da052eb42275d684bfbf83e7b52227ccbb97a06..69284578625dc6086a7549707186a824aff29137 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -49,6 +49,27 @@ class SingleOpModelWithNNAPI : public SingleOpModel { const std::vector& dims) { return interpreter_->ResizeInputTensor(tensor_index, dims); } + + protected: + void SetData(int index, TensorType type, const std::vector& data) { + switch (type) { + case TensorType_FLOAT32: + PopulateTensor(index, data); + break; + case TensorType_INT32: + QuantizeAndPopulate(index, data); + break; + case TensorType_UINT8: + QuantizeAndPopulate(index, data); + break; + case TensorType_INT8: + QuantizeAndPopulate(index, data); + break; + default: + FAIL() << "Type not supported: " << type; + break; + } + } }; class FloatAddOpModel : public SingleOpModelWithNNAPI { @@ -225,14 +246,15 @@ TEST(NNAPIDelegate, L2PoolWithNoActivation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); } -class BaseConvolutionOpModel : public SingleOpModelWithNNAPI { +class ConvolutionOpModel : public SingleOpModelWithNNAPI { public: - BaseConvolutionOpModel( + ConvolutionOpModel( const TensorData& input, const TensorData& filter, const TensorData& output, int stride_width = 2, int stride_height = 2, enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE, - int dilation_width_factor = 1, int dilation_height_factor = 1) { + int dilation_width_factor = 1, int dilation_height_factor = 1) + : input_type_(input.type), filter_type_(filter.type) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -249,7 +271,8 @@ class BaseConvolutionOpModel : public SingleOpModelWithNNAPI { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { + + if (input_type_ != TensorType_FLOAT32) { // The following is required by quantized inference. It is the unittest's // responsibility to make sure the output scale falls into the correct // range. @@ -265,56 +288,53 @@ class BaseConvolutionOpModel : public SingleOpModelWithNNAPI { BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); } - protected: - int input_; - int filter_; - int bias_; - int output_; -}; - -class ConvolutionOpModel : public BaseConvolutionOpModel { - public: - using BaseConvolutionOpModel::BaseConvolutionOpModel; - - void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } - - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } - - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); - } - std::vector GetOutput() { return ExtractVector(output_); } -}; - -class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { - public: - using BaseConvolutionOpModel::BaseConvolutionOpModel; - void SetInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + SetData(input_, input_type_, data); } void SetFilter(std::initializer_list data) { - QuantizeAndPopulate(filter_, data); + SetData(filter_, filter_type_, data); } void SetBias(std::initializer_list data) { - QuantizeAndPopulate(bias_, data); + const auto bias_type = + (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32; + SetData(bias_, bias_type, data); } - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + std::vector GetOutput() { + if (input_type_ == TensorType_FLOAT32) { + return ExtractVector(output_); + } else { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + } + + std::vector GetQuantizedOutput() { + if (input_type_ == TensorType_FLOAT32) { + return {}; // Not supported. + } else { + return ExtractVector(output_); + } } + + protected: + int input_; + int filter_; + int bias_; + int output_; + + const TensorType input_type_; + const TensorType filter_type_; }; // In this tests we set the input and output scales so that the results // match exactly the 'non-quantized' version. -TEST(NNAPIDelegate, SimpleTestQuantized) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, - {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, - {TensorType_UINT8, {}, -127, 128}); +TEST(ConvolutionOpTest, SimpleTestQuantized) { + ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); m.SetInput({ // First batch 1, 1, 1, 1, // row = 1 @@ -332,25 +352,55 @@ TEST(NNAPIDelegate, SimpleTestQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear( - { - 18, 2, 5, // first batch, left - 18, 2, 5, // first batch, right - 17, 4, 3, // second batch, left - 37, 4, 3, // second batch, right - }, - 1e-5))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); // For good measure, let's also verify the quantized values: - EXPECT_THAT(m.GetOutput(), ElementsAreArray({ - 145, 129, 132, // - 145, 129, 132, // - 144, 131, 130, // - 164, 131, 130, // - })); + EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); } -TEST(NNAPIDelegate, Conv2DWithNoActivation) { +TEST(ConvolutionOpTest, FloatInputQuantizedWeights) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_UINT8, {3, 2, 2, 1}, 0, 64}, + {TensorType_FLOAT32, {}}); + m.SetInput({ + // First batch + 1, 1, 1, 2, // row = 1 + 2, 2, 2, 1, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + 0, 1, 0, 1, // second 2x2 filter + 0, 0, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 18, 5, 7, // first batch, left + 16, 5, 6, // first batch, right + 17, 6, 6, // second batch, left + 37, 10, 10, // second batch, right + }, + 0.2))); +} + +TEST(ConvolutionOpTest, NoActivation) { ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); @@ -458,56 +508,48 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { })); } -class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI { +class FullyConnectedOpModel : public SingleOpModelWithNNAPI { public: - FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, - const TensorData& output = {TensorType_FLOAT32}) - : batches_(batches), units_(units) { - int total_input_size = 1; - for (int i = 0; i < input.shape.size(); ++i) { - total_input_size *= input.shape[i]; - } - input_size_ = total_input_size / batches_; - + FullyConnectedOpModel( + const TensorData& input, const TensorData& weights, + const TensorData& output, + enum ActivationFunctionType activation = ActivationFunctionType_NONE) + : input_type_(input.type), weights_type_(weights.type) { input_ = AddInput(input); - weights_ = - AddInput({input.type, {units_, input_size_}, input.min, input.max}); + weights_ = AddInput(weights); + const int units = weights.shape[0]; if (input.type == TensorType_FLOAT32) { - bias_ = AddInput({TensorType_FLOAT32, {units_}}); + bias_ = AddInput({TensorType_FLOAT32, {units}}); } else { // This is a quantized version. The scale of 'bias' depends on the scales // of input and filter. Supposedly this is correctly set during quantized // training. auto bias_scale = GetScale(input_) * GetScale(weights_); - TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale}; bias_ = AddInput(bias); } output_ = AddOutput(output); - SetBuiltinOp( - BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) - .Union()); + SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED, + BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, activation).Union()); BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); } - int input_size() { return input_size_; } - int num_units() { return units_; } - int num_batches() { return batches_; } - - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } - - void SetWeights(std::initializer_list f) { - PopulateTensor(weights_, f); + void SetInput(std::initializer_list data) { + SetData(input_, input_type_, data); } - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + void SetWeights(std::initializer_list data) { + SetData(weights_, weights_type_, data); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + + void SetBias(std::initializer_list data) { + const auto bias_type = + (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32; + SetData(bias_, bias_type, data); } std::vector GetOutput() { return ExtractVector(output_); } @@ -518,14 +560,14 @@ class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI { int bias_; int output_; - int batches_; - int units_; - int input_size_; + const TensorType input_type_; + const TensorType weights_type_; }; -TEST(NNAPIDelegate, FullyConnectedSimpleTest) { - FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2, - /*input=*/{TensorType_FLOAT32, {2, 10}}); +TEST(FullyConnectedOpTest, SimpleTest) { + FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/{TensorType_FLOAT32, {3, 10}}, + /*output=*/{TensorType_FLOAT32}); m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 @@ -543,6 +585,28 @@ TEST(NNAPIDelegate, FullyConnectedSimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } +TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) { + FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64}, + /*output=*/{TensorType_FLOAT32}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3))); +} + class SoftmaxOpModel : public SingleOpModelWithNNAPI { public: SoftmaxOpModel(int batches, int size, float beta) @@ -955,8 +1019,9 @@ TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) { class DequantizeOpModel : public SingleOpModelWithNNAPI { public: - DequantizeOpModel(std::initializer_list shape, float min, float max) { - input_ = AddInput({TensorType_UINT8, shape, min, max}); + DequantizeOpModel(TensorType inputType, std::initializer_list shape, + float min, float max) { + input_ = AddInput({inputType, shape, min, max}); output_ = AddOutput({TensorType_FLOAT32, shape}); SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, CreateDequantizeOptions(builder_).Union()); @@ -964,7 +1029,8 @@ class DequantizeOpModel : public SingleOpModelWithNNAPI { BuildInterpreter({GetShape(input_)}); } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } @@ -975,16 +1041,27 @@ class DequantizeOpModel : public SingleOpModelWithNNAPI { int output_; }; -TEST(NNAPIDelegate, DequantizeFourDimensional) { - DequantizeOpModel m({2, 5}, -63.5, 64); +TEST(NNAPIDelegate, DequantizeFourDimensionalUint8) { + DequantizeOpModel m(TensorType_UINT8, {2, 5}, -63.5, 64); - m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); + m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}))); } +TEST(NNAPIDelegate, DequantizeFourDimensionalInt8Symm) { + // [-64, 63.5] -> scale=0.5, zero_point=0 for INT8 + DequantizeOpModel m(TensorType_INT8, {2, 5}, -64, 63.5); + + m.SetInput({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-64, -63.5, -63, -62.5, -62, 61.5, 62, 62.5, 63, 63.5}))); +} + class FloorOpModel : public SingleOpModelWithNNAPI { public: FloorOpModel(std::initializer_list input_shape, TensorType input_type) { @@ -1810,8 +1887,8 @@ static std::initializer_list rnn_bias = { class RNNOpModel : public SingleOpModelWithNNAPI { public: RNNOpModel(int batches, int units, int size, - const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType weights = TensorType_FLOAT32, + const TensorType recurrent_weights = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); weights_ = AddInput(weights); @@ -2169,11 +2246,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType& weight_type = TensorType_FLOAT32) + const TensorType weight_type) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), - n_output_(n_output) { + n_output_(n_output), + weight_type_(weight_type) { input_ = AddInput(TensorType_FLOAT32); if (use_cifg) { @@ -2246,69 +2324,71 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { BuildInterpreter(input_shapes); } - void SetInputToInputWeights(std::vector f) { - PopulateTensor(input_to_input_weights_, f); + void SetInputToInputWeights(const std::vector& f) { + SetData(input_to_input_weights_, weight_type_, f); } - void SetInputToForgetWeights(std::vector f) { - PopulateTensor(input_to_forget_weights_, f); + void SetInputToForgetWeights(const std::vector& f) { + SetData(input_to_forget_weights_, weight_type_, f); } - void SetInputToCellWeights(std::vector f) { - PopulateTensor(input_to_cell_weights_, f); + void SetInputToCellWeights(const std::vector& f) { + SetData(input_to_cell_weights_, weight_type_, f); } - void SetInputToOutputWeights(std::vector f) { - PopulateTensor(input_to_output_weights_, f); + void SetInputToOutputWeights(const std::vector& f) { + SetData(input_to_output_weights_, weight_type_, f); } - void SetRecurrentToInputWeights(std::vector f) { - PopulateTensor(recurrent_to_input_weights_, f); + void SetRecurrentToInputWeights(const std::vector& f) { + SetData(recurrent_to_input_weights_, weight_type_, f); } - void SetRecurrentToForgetWeights(std::vector f) { - PopulateTensor(recurrent_to_forget_weights_, f); + void SetRecurrentToForgetWeights(const std::vector& f) { + SetData(recurrent_to_forget_weights_, weight_type_, f); } - void SetRecurrentToCellWeights(std::vector f) { - PopulateTensor(recurrent_to_cell_weights_, f); + void SetRecurrentToCellWeights(const std::vector& f) { + SetData(recurrent_to_cell_weights_, weight_type_, f); } - void SetRecurrentToOutputWeights(std::vector f) { - PopulateTensor(recurrent_to_output_weights_, f); + void SetRecurrentToOutputWeights(const std::vector& f) { + SetData(recurrent_to_output_weights_, weight_type_, f); } - void SetCellToInputWeights(std::vector f) { - PopulateTensor(cell_to_input_weights_, f); + void SetCellToInputWeights(const std::vector& f) { + SetData(cell_to_input_weights_, weight_type_, f); } - void SetCellToForgetWeights(std::vector f) { - PopulateTensor(cell_to_forget_weights_, f); + void SetCellToForgetWeights(const std::vector& f) { + SetData(cell_to_forget_weights_, weight_type_, f); } - void SetCellToOutputWeights(std::vector f) { - PopulateTensor(cell_to_output_weights_, f); + void SetCellToOutputWeights(const std::vector& f) { + SetData(cell_to_output_weights_, weight_type_, f); } - void SetInputGateBias(std::vector f) { + void SetInputGateBias(const std::vector& f) { PopulateTensor(input_gate_bias_, f); } - void SetForgetGateBias(std::vector f) { + void SetForgetGateBias(const std::vector& f) { PopulateTensor(forget_gate_bias_, f); } - void SetCellBias(std::vector f) { PopulateTensor(cell_bias_, f); } + void SetCellBias(const std::vector& f) { + PopulateTensor(cell_bias_, f); + } - void SetOutputGateBias(std::vector f) { + void SetOutputGateBias(const std::vector& f) { PopulateTensor(output_gate_bias_, f); } - void SetProjectionWeights(std::vector f) { - PopulateTensor(projection_weights_, f); + void SetProjectionWeights(const std::vector& f) { + SetData(projection_weights_, weight_type_, f); } - void SetProjectionBias(std::vector f) { + void SetProjectionBias(const std::vector& f) { PopulateTensor(projection_bias_, f); } @@ -2358,6 +2438,9 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { int n_input_; int n_cell_; int n_output_; + + private: + const TensorType weight_type_; }; class BaseLstmTest : public ::testing::Test { @@ -2503,7 +2586,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor - }); + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -2606,7 +2690,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor - }); + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -3260,7 +3345,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { {n_output, n_cell}, // projection_weight tensor {0}, // projection_bias tensor - }); + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -3291,7 +3377,7 @@ class BaseReduceOpModel : public SingleOpModelWithNNAPI { void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } template - void SetInput(std::vector data) { + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } diff --git a/tensorflow/lite/examples/android/app/build.gradle b/tensorflow/lite/examples/android/app/build.gradle index b372afae190ded84947b45655018a78633715c16..d2bc9846af571af71d8d7cbdf1c985e3a24474f7 100644 --- a/tensorflow/lite/examples/android/app/build.gradle +++ b/tensorflow/lite/examples/android/app/build.gradle @@ -1,8 +1,16 @@ apply plugin: 'com.android.application' +// import DownloadModels task +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +apply from: "download-models.gradle" + android { compileSdkVersion 26 - buildToolsVersion '27.0.3' + buildToolsVersion '28.0.3' defaultConfig { applicationId "org.tensorflow.lite.demo" minSdkVersion 15 @@ -36,14 +44,6 @@ repositories { } } -// import DownloadModels task -project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' -project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' - -// Download default models; if you wish to use your own models then -// place them in the "assets" directory and comment out this line. -apply from: "download-models.gradle" - dependencies { implementation fileTree(dir: 'libs', include: ['*.jar']) implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' diff --git a/tensorflow/lite/examples/android/app/download-models.gradle b/tensorflow/lite/examples/android/app/download-models.gradle index 36bd177a1fd6bb21a27edd6d2b6e82fa7aa5d57b..514eeb013501a45ac78778e95df88c946bcc737b 100644 --- a/tensorflow/lite/examples/android/app/download-models.gradle +++ b/tensorflow/lite/examples/android/app/download-models.gradle @@ -67,9 +67,6 @@ task extractModels(type: Copy) { } } - - - tasks.whenTaskAdded { task -> if (task.name == 'assembleDebug') { task.dependsOn 'extractModels' diff --git a/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml b/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml index bc9574d646b7661de8ac9b745bd53cbba1eb9f31..d4c98c61cca7ee1cb9f803bffb1966991b0368c2 100644 --- a/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml +++ b/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml @@ -24,12 +24,7 @@ - - diff --git a/tensorflow/lite/examples/android/build.gradle b/tensorflow/lite/examples/android/build.gradle index 7c038ddd46418b6498251068a284e8ffcdeda96f..90b8f0f3badf1be86588b2236bb37e83676602b8 100644 --- a/tensorflow/lite/examples/android/build.gradle +++ b/tensorflow/lite/examples/android/build.gradle @@ -6,10 +6,7 @@ buildscript { jcenter() } dependencies { - classpath 'com.android.tools.build:gradle:3.1.4' - - // NOTE: Do not place your application dependencies here; they belong - // in the individual module build.gradle files + classpath 'com.android.tools.build:gradle:3.2.1' } } diff --git a/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm index 4f6fcaa96c4b917b79dacc5180594c1458ef18ff..e5c230e65e9e53a8496c4a4f0de49a7677acff81 100644 --- a/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm @@ -23,12 +23,6 @@ #include #include -#if TFLITE_USE_CONTRIB_LITE -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" -#else #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" @@ -36,7 +30,6 @@ #if TFLITE_USE_GPU_DELEGATE #include "tensorflow/lite/delegates/gpu/metal_delegate.h" #endif -#endif #define LOG(x) std::cerr diff --git a/tensorflow/lite/examples/ios/camera/Podfile b/tensorflow/lite/examples/ios/camera/Podfile index 2e15cc63decb30eb2b8c9bffab3b5d1bff10e9b3..0bff676159c22d8cf9419d3d79437b63103b7d06 100644 --- a/tensorflow/lite/examples/ios/camera/Podfile +++ b/tensorflow/lite/examples/ios/camera/Podfile @@ -9,5 +9,5 @@ target 'tflite_camera_example' # Note: TFLite GPU Delegate binary isn't releast yet, and we're working # on it. - pod 'TensorFlowLite', '1.12.0' + pod 'TensorFlowLite', '1.13.1' # pod 'TensorFlowLiteGpuExperimental', '0.0.1' diff --git a/tensorflow/lite/examples/ios/simple/Podfile b/tensorflow/lite/examples/ios/simple/Podfile index 931b72c1f5e946e8be61ac6dec3c6106a75b9685..d9e6a6586bf80f332192a25a74a91abdf79046a5 100644 --- a/tensorflow/lite/examples/ios/simple/Podfile +++ b/tensorflow/lite/examples/ios/simple/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_simple_example' - pod 'TensorFlowLite', '1.12.0' + pod 'TensorFlowLite', '1.13.1' diff --git a/tensorflow/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/lite/examples/ios/simple/RunModelViewController.mm index 32da7f7e4fce5cafc3c4746e5847315172542fc9..b90f2d3cfd05a57460149a1f7868c01f32b28cd1 100644 --- a/tensorflow/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/lite/examples/ios/simple/RunModelViewController.mm @@ -22,10 +22,10 @@ #include #include -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/op_resolver.h" +#include "tensorflow/lite/string_util.h" #include "ios_image_load.h" diff --git a/tensorflow/lite/examples/label_image/BUILD b/tensorflow/lite/examples/label_image/BUILD index 4fc8648d46c4bdefe3865381a23f4d73c87c284b..75e1add60198425694356f326862f132e8e6e70e 100644 --- a/tensorflow/lite/examples/label_image/BUILD +++ b/tensorflow/lite/examples/label_image/BUILD @@ -5,14 +5,13 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/lite:build_def.bzl", "tflite_linkopts") exports_files(glob([ "testdata/*.bmp", ])) -tf_cc_binary( +cc_binary( name = "label_image", srcs = [ "get_top_n.h", diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index b8dc2840dfb49f8c067fbd2bf09432f7b06d6265..340fbab5c6fcc960685bc4581b285b0610fa22df 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -113,6 +113,7 @@ void RunInference(Settings* s) { } interpreter->UseNNAPI(s->accel); + interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16); if (s->verbose) { LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; @@ -251,19 +252,21 @@ void RunInference(Settings* s) { } void display_usage() { - LOG(INFO) << "label_image\n" - << "--accelerated, -a: [0|1], use Android NNAPI or not\n" - << "--count, -c: loop interpreter->Invoke() for certain times\n" - << "--input_mean, -b: input mean\n" - << "--input_std, -s: input standard deviation\n" - << "--image, -i: image_name.bmp\n" - << "--labels, -l: labels for the model\n" - << "--tflite_model, -m: model_name.tflite\n" - << "--profiling, -p: [0|1], profiling or not\n" - << "--num_results, -r: number of results to show\n" - << "--threads, -t: number of threads\n" - << "--verbose, -v: [0|1] print more information\n" - << "\n"; + LOG(INFO) + << "label_image\n" + << "--accelerated, -a: [0|1], use Android NNAPI or not\n" + << "--allow_fp16, -f: [0|1], allow running fp32 models with fp16 not\n" + << "--count, -c: loop interpreter->Invoke() for certain times\n" + << "--input_mean, -b: input mean\n" + << "--input_std, -s: input standard deviation\n" + << "--image, -i: image_name.bmp\n" + << "--labels, -l: labels for the model\n" + << "--tflite_model, -m: model_name.tflite\n" + << "--profiling, -p: [0|1], profiling or not\n" + << "--num_results, -r: number of results to show\n" + << "--threads, -t: number of threads\n" + << "--verbose, -v: [0|1] print more information\n" + << "\n"; } int Main(int argc, char** argv) { @@ -273,6 +276,7 @@ int Main(int argc, char** argv) { while (1) { static struct option long_options[] = { {"accelerated", required_argument, nullptr, 'a'}, + {"allow_fp16", required_argument, nullptr, 'f'}, {"count", required_argument, nullptr, 'c'}, {"verbose", required_argument, nullptr, 'v'}, {"image", required_argument, nullptr, 'i'}, @@ -305,6 +309,10 @@ int Main(int argc, char** argv) { s.loop_count = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; + case 'f': + s.allow_fp16 = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) + break; case 'i': s.input_bmp_name = optarg; break; diff --git a/tensorflow/lite/examples/label_image/label_image.h b/tensorflow/lite/examples/label_image/label_image.h index 88b047fecc4b3efd10ef025193a79516516c03f1..cc46e56b64a9dc848bba1dafed373375a97dcfe7 100644 --- a/tensorflow/lite/examples/label_image/label_image.h +++ b/tensorflow/lite/examples/label_image/label_image.h @@ -26,6 +26,7 @@ struct Settings { bool accel = false; bool input_floating = false; bool profiling = false; + bool allow_fp16 = false; int loop_count = 1; float input_mean = 127.5f; float input_std = 127.5f; diff --git a/tensorflow/lite/examples/minimal/BUILD b/tensorflow/lite/examples/minimal/BUILD index cdd67af1e93661c1f65cc46d9b687acc1fa56fe8..498dbb9693e74545546f849a246630dc16ab7322 100644 --- a/tensorflow/lite/examples/minimal/BUILD +++ b/tensorflow/lite/examples/minimal/BUILD @@ -5,10 +5,9 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/lite:build_def.bzl", "tflite_linkopts") -tf_cc_binary( +cc_binary( name = "minimal", srcs = [ "minimal.cc", diff --git a/tensorflow/lite/experimental/c/BUILD b/tensorflow/lite/experimental/c/BUILD index cde53e283830aca9c7990e3d8c4901f997621bc2..aef39db02140389c0d0bcbde5ae4d2d89059c2a6 100644 --- a/tensorflow/lite/experimental/c/BUILD +++ b/tensorflow/lite/experimental/c/BUILD @@ -19,16 +19,14 @@ load( tflite_cc_shared_object( name = "libtensorflowlite_c.so", linkopts = select({ - "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow/lite/experimental/c:exported_symbols.lds)", + "//tensorflow:macos": [ + "-Wl,-exported_symbols_list,$(location //tensorflow/lite/experimental/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflowlite_c.so", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/lite/experimental/c:version_script.lds)", + "-Wl,--version-script,$(location //tensorflow/lite/experimental/c:version_script.lds)", ], }), deps = [ @@ -66,7 +64,6 @@ cc_library( ":c_api_internal", "//tensorflow/lite:context", "//tensorflow/lite:framework", - "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/kernels:builtin_ops", ], ) @@ -94,7 +91,6 @@ cc_test( deps = [ ":c_api", "//tensorflow/lite:context", - "//tensorflow/lite:kernel_api", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/lite/experimental/examples/lstm/BUILD b/tensorflow/lite/experimental/examples/lstm/BUILD index f39673c028d5924359c5058cbc4a72e2f28394b0..27efa442348c997197becb6a8db603718a158f66 100644 --- a/tensorflow/lite/experimental/examples/lstm/BUILD +++ b/tensorflow/lite/experimental/examples/lstm/BUILD @@ -5,27 +5,29 @@ package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "py_test") py_library( - name = "tflite_lstm", - srcs = ["tflite_lstm.py"], + name = "rnn", + srcs = ["rnn.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/lite/python:lite", + "//tensorflow/lite/python:op_hint", "//tensorflow/python:framework", + "//tensorflow/python:layers_base", + "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", "@six_archive//:six", ], ) py_library( - name = "tflite_rnn", - srcs = ["tflite_rnn.py"], + name = "rnn_cell", + srcs = ["rnn_cell.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/lite/python:lite", + "//tensorflow/lite/python:op_hint", "//tensorflow/python:framework", + "//tensorflow/python:rnn_cell", "@six_archive//:six", ], ) @@ -40,13 +42,14 @@ py_test( "no_pip", ], deps = [ - ":tflite_lstm", + ":rnn", + ":rnn_cell", "//tensorflow:tensorflow_py", "//tensorflow/examples/tutorials/mnist:input_data", "//tensorflow/lite/python:lite", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:framework", "//tensorflow/python:platform", - "//tensorflow/python/tools:optimize_for_inference", + "//tensorflow/python/tools:optimize_for_inference_lib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -62,14 +65,71 @@ py_test( "no_pip", ], deps = [ - ":tflite_rnn", + ":rnn", + ":rnn_cell", "//tensorflow:tensorflow_py", "//tensorflow/examples/tutorials/mnist:input_data", "//tensorflow/lite/python:lite", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", - "//tensorflow/python/tools:optimize_for_inference", + "//tensorflow/python/tools:optimize_for_inference_main_lib", "//third_party/py/numpy", "@six_archive//:six", ], ) + +py_test( + name = "bidirectional_sequence_lstm_test", + size = "large", + srcs = ["bidirectional_sequence_lstm_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + ], + deps = [ + ":rnn", + ":rnn_cell", + "//tensorflow:tensorflow_py", + "//tensorflow/examples/tutorials/mnist:input_data", + "//tensorflow/lite/python:lite", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python/tools:optimize_for_inference_main_lib", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "bidirectional_sequence_rnn_test", + size = "large", + srcs = ["bidirectional_sequence_rnn_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + ], + deps = [ + ":rnn", + ":rnn_cell", + "//tensorflow:tensorflow_py", + "//tensorflow/examples/tutorials/mnist:input_data", + "//tensorflow/lite/python:lite", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python/tools:optimize_for_inference_main_lib", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "tflite_lstm_ops", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":rnn", + ":rnn_cell", + ], +) diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b71d8c22b23952021e7328c275a16e48987147ad --- /dev/null +++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py @@ -0,0 +1,231 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +import numpy as np +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.lite.experimental.examples.lstm.rnn import bidirectional_dynamic_rnn +from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs +from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.tools import optimize_for_inference_lib + +# Number of steps to train model. +TRAIN_STEPS = 1 + +CONFIG = tf.ConfigProto(device_count={"GPU": 0}) + + +class BidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): + + def setUp(self): + tf.reset_default_graph() + # Import MNIST dataset + self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + + # Define constants + # Unrolled through 28 time steps + self.time_steps = 28 + # Rows of 28 pixels + self.n_input = 28 + # Learning rate for Adam optimizer + self.learning_rate = 0.001 + # MNIST is meant to be classified in 10 classes(0-9). + self.n_classes = 10 + # Batch size + self.batch_size = 16 + # Lstm Units. + self.num_units = 16 + + def buildLstmLayer(self): + return tf.keras.layers.StackedRNNCells([ + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units, num_proj=8, forget_bias=0, name="rnn2"), + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units // 2, + use_peepholes=True, + num_proj=8, + forget_bias=0, + name="rnn3"), + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units, forget_bias=0, name="rnn4") + ]) + + def buildModel(self, fw_lstm_layer, bw_lstm_layer, is_dynamic_rnn): + # Weights and biases for output softmax layer. + out_weights = tf.Variable( + tf.random_normal([self.num_units * 2, self.n_classes])) + out_bias = tf.Variable(tf.random_normal([self.n_classes])) + + # input image placeholder + x = tf.placeholder( + "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE") + + if is_dynamic_rnn: + lstm_inputs = tf.transpose(x, [1, 0, 2]) + outputs, _ = bidirectional_dynamic_rnn( + fw_lstm_layer, + bw_lstm_layer, + lstm_inputs, + dtype="float32", + time_major=True) + fw_outputs, bw_outputs = outputs + output = tf.concat([fw_outputs, bw_outputs], 2) + output = tf.unstack(output, axis=0) + output = output[-1] + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _, _ = tf.nn.static_bidirectional_rnn( + fw_lstm_layer, bw_lstm_layer, lstm_input, dtype="float32") + output = outputs[-1] + + # Compute logits by multiplying output of shape [batch_size,num_units*2] + # by the softmax layer's out_weight of shape [num_units*2,n_classes] + # plus out_bias + prediction = tf.matmul(output, out_weights) + out_bias + output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS") + + return x, prediction, output_class + + def trainModel(self, x, prediction, output_class, sess): + # input label placeholder + y = tf.placeholder("float", [None, self.n_classes]) + # Loss function + loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y)) + # Optimization + opt = tf.train.AdamOptimizer( + learning_rate=self.learning_rate).minimize(loss) + + # Initialize variables + init = tf.global_variables_initializer() + sess.run(init) + for _ in range(TRAIN_STEPS): + batch_x, batch_y = self.mnist.train.next_batch( + batch_size=self.batch_size, shuffle=False) + + batch_x = batch_x.reshape((self.batch_size, self.time_steps, + self.n_input)) + sess.run(opt, feed_dict={x: batch_x, y: batch_y}) + + def saveAndRestoreModel(self, fw_lstm_layer, bw_lstm_layer, sess, saver, + is_dynamic_rnn): + model_dir = tempfile.mkdtemp() + saver.save(sess, model_dir) + + # Reset the graph. + tf.reset_default_graph() + x, prediction, output_class = self.buildModel(fw_lstm_layer, bw_lstm_layer, + is_dynamic_rnn) + + new_sess = tf.Session(config=CONFIG) + saver = tf.train.Saver() + saver.restore(new_sess, model_dir) + return x, prediction, output_class, new_sess + + def getInferenceResult(self, x, output_class, sess): + b1, _ = self.mnist.train.next_batch(batch_size=1) + sample_input = np.reshape(b1, (1, self.time_steps, self.n_input)) + + expected_output = sess.run(output_class, feed_dict={x: sample_input}) + # It is important to keep all the ophint output nodes. + hinted_outputs_nodes = find_all_hinted_output_nodes(sess) + hinted_outputs_nodes.append(output_class.op.name) + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, hinted_outputs_nodes) + return sample_input, expected_output, frozen_graph + + def tfliteInvoke(self, graph, test_inputs, outputs): + tf.reset_default_graph() + # Turn the input into placeholder of shape 1 + tflite_input = tf.placeholder( + "float", [1, self.time_steps, self.n_input], name="INPUT_IMAGE_LITE") + tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input}) + with tf.Session() as sess: + curr = sess.graph_def + curr = convert_op_hints_to_stubs(graph_def=curr) + + curr = optimize_for_inference_lib.optimize_for_inference( + curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], + [tf.float32.as_datatype_enum]) + + converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs]) + tflite = converter.convert() + + interpreter = tf.lite.Interpreter(model_content=tflite) + + try: + interpreter.allocate_tensors() + except ValueError: + assert False + + input_index = (interpreter.get_input_details()[0]["index"]) + interpreter.set_tensor(input_index, test_inputs) + interpreter.invoke() + output_index = (interpreter.get_output_details()[0]["index"]) + result = interpreter.get_tensor(output_index) + # Reset all variables so it will not pollute other inferences. + interpreter.reset_all_variables() + return result + + def testStaticRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel(self.buildLstmLayer(), + self.buildLstmLayer(), False) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), self.buildLstmLayer(), sess, saver, False) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + @test_util.enable_control_flow_v2 + def testDynamicRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel(self.buildLstmLayer(), + self.buildLstmLayer(), True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), + self.buildLstmLayer(), + sess, + saver, + is_dynamic_rnn=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0a86b731ad29e890638e9668d4654571aa3487 --- /dev/null +++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py @@ -0,0 +1,305 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +import numpy as np +import tensorflow as tf + +from tensorflow import flags + +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.lite.experimental.examples.lstm.rnn import bidirectional_dynamic_rnn +from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs +from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.tools import optimize_for_inference_lib + +FLAGS = flags.FLAGS + +# Number of steps to train model. +TRAIN_STEPS = 1 + +CONFIG = tf.ConfigProto(device_count={"GPU": 0}) + + +class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): + + def __init__(self, *args, **kwargs): + super(BidirectionalSequenceRnnTest, self).__init__(*args, **kwargs) + # Define constants + # Unrolled through 28 time steps + self.time_steps = 28 + # Rows of 28 pixels + self.n_input = 28 + # Learning rate for Adam optimizer + self.learning_rate = 0.001 + # MNIST is meant to be classified in 10 classes(0-9). + self.n_classes = 10 + # Batch size + self.batch_size = 16 + # Rnn Units. + self.num_units = 16 + + def setUp(self): + super(BidirectionalSequenceRnnTest, self).setUp() + # Import MNIST dataset + data_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) + self.mnist = input_data.read_data_sets(data_dir, one_hot=True) + + def buildRnnLayer(self): + return tf.keras.layers.StackedRNNCells([ + tf.lite.experimental.nn.TfLiteRNNCell(self.num_units, name="rnn1"), + tf.lite.experimental.nn.TfLiteRNNCell(self.num_units, name="rnn2") + ]) + + def buildModel(self, + fw_rnn_layer, + bw_rnn_layer, + is_dynamic_rnn, + is_inference, + use_sequence_length=False): + # Weights and biases for output softmax layer. + out_weights = tf.Variable( + tf.random_normal([self.num_units * 2, self.n_classes])) + out_bias = tf.Variable(tf.random_normal([self.n_classes])) + + batch_size = self.batch_size + if is_inference: + batch_size = 1 + # input image placeholder + x = tf.placeholder( + "float", [batch_size, self.time_steps, self.n_input], + name="INPUT_IMAGE") + + sequence_length = None + if use_sequence_length: + sequence_length = [self.time_steps] * batch_size + if is_dynamic_rnn: + rnn_inputs = tf.transpose(x, [1, 0, 2]) + outputs, _ = bidirectional_dynamic_rnn( + fw_rnn_layer, + bw_rnn_layer, + rnn_inputs, + sequence_length, + dtype="float32", + time_major=True) + fw_outputs, bw_outputs = outputs + output = tf.concat([fw_outputs, bw_outputs], 2) + output = tf.unstack(output, axis=0) + output = output[-1] + else: + rnn_inputs = tf.unstack(x, self.time_steps, 1) + # Sequence length is not supported for static since we don't have a + # wrapper for it. At training phase, we can still have sequence_length, + # but inference phase, we change it to None. + if is_inference: + sequence_length = None + outputs, _, _ = tf.nn.static_bidirectional_rnn( + fw_rnn_layer, + bw_rnn_layer, + rnn_inputs, + dtype="float32", + sequence_length=sequence_length) + output = outputs[-1] + + # Compute logits by multiplying output of shape [batch_size,num_units*2] + # by the softmax layer's out_weight of shape [num_units*2,n_classes] + # plus out_bias + prediction = tf.matmul(output, out_weights) + out_bias + output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS") + + return x, prediction, output_class + + def trainModel(self, x, prediction, output_class, sess): + # input label placeholder + y = tf.placeholder("float", [None, self.n_classes]) + # Loss function + loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y)) + # Optimization + opt = tf.train.AdamOptimizer( + learning_rate=self.learning_rate).minimize(loss) + + # Initialize variables + init = tf.global_variables_initializer() + sess.run(init) + for _ in range(TRAIN_STEPS): + batch_x, batch_y = self.mnist.train.next_batch( + batch_size=self.batch_size, shuffle=False) + + batch_x = batch_x.reshape((self.batch_size, self.time_steps, + self.n_input)) + sess.run(opt, feed_dict={x: batch_x, y: batch_y}) + + def saveAndRestoreModel(self, + fw_rnn_layer, + bw_rnn_layer, + sess, + saver, + is_dynamic_rnn, + use_sequence_length=False): + model_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) + saver.save(sess, model_dir) + + # Reset the graph. + tf.reset_default_graph() + x, prediction, output_class = self.buildModel( + fw_rnn_layer, bw_rnn_layer, is_dynamic_rnn, True, use_sequence_length) + + new_sess = tf.Session(config=CONFIG) + saver = tf.train.Saver() + saver.restore(new_sess, model_dir) + return x, prediction, output_class, new_sess + + def getInferenceResult(self, x, output_class, sess): + b1, _ = self.mnist.train.next_batch(batch_size=1) + sample_input = np.reshape(b1, (1, self.time_steps, self.n_input)) + + expected_output = sess.run(output_class, feed_dict={x: sample_input}) + # It is important to keep all the ophint output nodes. + hinted_outputs_nodes = find_all_hinted_output_nodes(sess) + hinted_outputs_nodes.append(output_class.op.name) + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, hinted_outputs_nodes) + return sample_input, expected_output, frozen_graph + + def tfliteInvoke(self, graph, test_inputs, outputs): + tf.reset_default_graph() + # Turn the input into placeholder of shape 1 + tflite_input = tf.placeholder( + "float", [1, self.time_steps, self.n_input], name="INPUT_IMAGE_LITE") + tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input}) + with tf.Session() as sess: + curr = sess.graph_def + curr = convert_op_hints_to_stubs(graph_def=curr) + + curr = optimize_for_inference_lib.optimize_for_inference( + curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], + [tf.float32.as_datatype_enum]) + + converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs]) + tflite = converter.convert() + + interpreter = tf.lite.Interpreter(model_content=tflite) + + interpreter.allocate_tensors() + + input_index = interpreter.get_input_details()[0]["index"] + interpreter.set_tensor(input_index, test_inputs) + interpreter.invoke() + output_index = interpreter.get_output_details()[0]["index"] + result = interpreter.get_tensor(output_index) + # Reset all variables so it will not pollute other inferences. + interpreter.reset_all_variables() + return result + + def testStaticRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildRnnLayer(), self.buildRnnLayer(), False, is_inference=False) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildRnnLayer(), self.buildRnnLayer(), sess, saver, False) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + def testStaticRnnMultiRnnCellWithSequenceLength(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildRnnLayer(), + self.buildRnnLayer(), + False, + is_inference=False, + use_sequence_length=True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildRnnLayer(), + self.buildRnnLayer(), + sess, + saver, + False, + use_sequence_length=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + @test_util.enable_control_flow_v2 + def testDynamicRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildRnnLayer(), self.buildRnnLayer(), True, is_inference=False) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildRnnLayer(), + self.buildRnnLayer(), + sess, + saver, + is_dynamic_rnn=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + @test_util.enable_control_flow_v2 + def testDynamicRnnMultiRnnCellWithSequenceLength(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildRnnLayer(), + self.buildRnnLayer(), + True, + is_inference=False, + use_sequence_length=True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildRnnLayer(), + self.buildRnnLayer(), + sess, + saver, + is_dynamic_rnn=True, + use_sequence_length=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/lite/experimental/examples/lstm/g3doc/README.md b/tensorflow/lite/experimental/examples/lstm/g3doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dfe2d0d153d6042680857c1ad7efe4892a2a5a73 --- /dev/null +++ b/tensorflow/lite/experimental/examples/lstm/g3doc/README.md @@ -0,0 +1,394 @@ +# TensorFlow Lite LSTM ops API + +TensorFlow Lite LSTM ops help developers deploy LSTM models to TensorFlow Lite. +This is currently an experimental API, it's likely to change in future. + +## Introduction + +LSTM ops in TensorFlow Lite realm are expressed as "fused ops" (e.g., +UnidirectionalSequenceRNN, BidirectionalSequenceLSTM, etc.). However, in +TensorFlow, LSTM ops are expressed as a "cell" (e.g., `tf.nn.rnn_cell.LSTMCell`, +`tf.nn.rnn_cell.BasicRNNCell`, etc., and they all contain multiple TensorFlow +ops) and a "rnn" ( e.g., `tf.nn.static_rnn`, +`tf.nn.bidirectional_dynamic_rnn`). + +The ops breakdown in TensorFlow gives us flexibility while the "fused op" in +TensorFlow Lite gives us performance boost. + +See the difference between TensorFlow LSTM and TensorFlow Lite LSTM. + +##### TensorFlow LSTM op ("cell") + +![TensorFlow LSTM op](./images/tf_lstm.png) + +##### TensorFlow Lite LSTM op ("fused ops") + +![TensorFlow Lite LSTM op](./images/tflite_lstm.png) + +The TensorFlow LSTM figure is credited to this +[blog](https://colah.github.io/posts/2015-08-Understanding-LSTMs/). + +## How to use + +To use TensorFlow Lite LSTM ops is actually pretty simple. + +### 1) Training & Evaluation. + +First step is replacing `tf.nn.rnn_cell.LSTMCell` with +`tf.lite.experimental.nn.TFLiteLSTMCell` in training phase, and replacing +`tf.nn.rnn.dynamic_rnn` with `tf.lite.experimental.nn.dynamic_rnn`, if you are +using dynamic_rnn. Note you don't need to change if you're using static_rnn. + +Both `tf.lite.experimental.nn.TFLiteLSTMCell` & +`tf.lite.experimental.nn.dynamic_rnn` are just normal `tf.nn.rnn_cell.LSTMCell` +and `tf.nn.rnn.dynamic_rnn` with OpHinted nodes in it to help the graph +transformation. + +Then you can train and export the model as usual. + +### 2) Export for TensorFlow Lite inference. + +When you want to convert to TensorFlow Lite model, here's one simple step you +need to do for your frozen graph: + +```python +with tf.Session() as sess: + ophinted_graph = tf.lite.experimental.convert_op_hints_to_stubs(session=sess) +``` + +Then you can convert the model to TensorFlow Lite model as usual. + +```python +converter = tf.lite.TFLiteConverter(ophinted_graph, [INPUTS], [OUTPUTS]) +converter.post_training_quantize = True # If post training quantize is desired. +tflite_model = converter.convert() # You got a tflite model! +``` + +#### Simple example diff for using original TF code VS. TensorFlow Lite code: + +```python +@@ -56,7 +56,7 @@ class MnistLstmModel(object): + for _ in range(self.num_lstm_layer): + lstm_layers.append( + # Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell`. +- tf.nn.rnn_cell.LSTMCell( ++ tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_lstm_units, forget_bias=0)) + # Weights and biases for output softmax layer. + out_weights = tf.Variable(tf.random_normal([self.units, self.num_class])) +@@ -67,7 +67,7 @@ class MnistLstmModel(object): + lstm_cells = tf.nn.rnn_cell.MultiRNNCell(lstm_layers) + # Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major` + # is set to True. +- outputs, _ = tf.nn.dynamic_rnn( ++ outputs, _ = tf.lite.experimental.nn.dynamic_rnn( + lstm_cells, lstm_inputs, dtype='float32', time_major=True) + + # Transpose the outputs back to [batch, time, output] +@@ -154,7 +154,9 @@ def export(model, model_dir, tflite_model_file, + sess, sess.graph_def, [output_class.op.name]) + + # Convert ophinted lstm ops to tflite UnidirectionalSequenceLstm ops. +- converted_graph = tf.graph_util.remove_training_nodes(frozen_graph) ++ converted_graph = tf.lite.experimental.convert_op_hints_to_stubs( ++ graph_def=frozen_graph) ++ converted_graph = tf.graph_util.remove_training_nodes(converted_graph) + converter = tf.lite.TFLiteConverter(converted_graph, [x], [output_class]) + converter.post_training_quantize = use_post_training_quantize + tflite = converter.convert() +``` + +## Why introduce another set of LSTM APIs? + +Bridging TensorFlow LSTM and TensorFlow Lite is not easy, and the use of +`dynamic_rnn` adds additional complexity (as the while loop is introduced). +With the help of +[OpHint](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/op_hint.py) +(also see the next section), we create special wrappers around `rnn_cell` and +`rnn` to help us identify the inputs and outputs of the LSTM ops, and these +ops are converted to a single fused LSTM op when converting TensorFlow models +to TensorFlow Lite format. + +### What's OpHint + +`OpHint` is essentially `Identity` op that is inserted after input tensors and +output tensors to "hint" the customized op boundary, see the following figure. + +##### Ophinted Customized Graph + +Let's say we have a "customized conv" op which is a normal conv2d op with a bias +add op followed by an activation op (graph on the letf), we use `OpHint` to +track down all the inputs and output, during the graph transformation phase +(done by `tf.lite.experimental.convert_op_hints_to_stubs`), the conv2d op, bias +add op and the activation op will become a "my customized conv" op (see the +graph on the right), and all the "OpHinted" tensors will become the +inputs/outputs of the "my customized conv" op. + +![Ophinted Customized Graph](./images/op_hint.png) + + +## Simple Tutorial + +The following tutorial uses MNIST dataset to build a simple two-layer LSTM model +and convert to quantized TensorFlow Lite model. + +Note since we will be using dynamic_rnn, we need to turn on `control_flow_v2`. + +### 0. Turn on `control_flow_v2`. + +```python +# Note this needs to happen before import tensorflow. +import os +os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1' +``` + +### 1. Build the model. + +```python +class MnistLstmModel(object): + """Build a simple LSTM based MNIST model. + + Attributes: + time_steps: The maximum length of the time_steps, but since we're just using + the 'width' dimension as time_steps, it's actually a fixed number. + input_size: The LSTM layer input size. + num_lstm_layer: Number of LSTM layers for the stacked LSTM cell case. + num_lstm_units: Number of units in the LSTM cell. + units: The units for the last layer. + num_class: Number of classes to predict. + """ + + def __init__(self, time_steps, input_size, num_lstm_layer, num_lstm_units, + units, num_class): + self.time_steps = time_steps + self.input_size = input_size + self.num_lstm_layer = num_lstm_layer + self.num_lstm_units = num_lstm_units + self.units = units + self.num_class = num_class + + def build_model(self): + """Build the model using the given configs. + + Returns: + x: The input placehoder tensor. + logits: The logits of the output. + output_class: The prediction. + """ + x = tf.placeholder( + 'float32', [None, self.time_steps, self.input_size], name='INPUT') + lstm_layers = [] + for _ in range(self.num_lstm_layer): + lstm_layers.append( + # Important: + # + # Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell` + # (OpHinted LSTMCell). + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_lstm_units, forget_bias=0)) + # Weights and biases for output softmax layer. + out_weights = tf.Variable(tf.random_normal([self.units, self.num_class])) + out_bias = tf.Variable(tf.zeros([self.num_class])) + + # Transpose input x to make it time major. + lstm_inputs = tf.transpose(x, perm=[1, 0, 2]) + lstm_cells = tf.keras.layers.StackedRNNCells(lstm_layers) + # Important: + # + # Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major` + # is set to True. + outputs, _ = tf.lite.experimental.nn.dynamic_rnn( + lstm_cells, lstm_inputs, dtype='float32', time_major=True) + + # Transpose the outputs back to [batch, time, output] + outputs = tf.transpose(outputs, perm=[1, 0, 2]) + outputs = tf.unstack(outputs, axis=1) + logits = tf.matmul(outputs[-1], out_weights) + out_bias + output_class = tf.nn.softmax(logits, name='OUTPUT_CLASS') + + return x, logits, output_class +``` + +### 2. Let's define the train & eval function. + +```python +def train(model, + model_dir, + batch_size=20, + learning_rate=0.001, + train_steps=2000, + eval_steps=500, + save_every_n_steps=1000): + """Train & save the MNIST recognition model.""" + # Train & test dataset. + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + train_iterator = train_dataset.shuffle( + buffer_size=1000).batch(batch_size).repeat().make_one_shot_iterator() + x, logits, output_class = model.build_model() + test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + test_iterator = test_dataset.batch( + batch_size).repeat().make_one_shot_iterator() + # input label placeholder + y = tf.placeholder(tf.int32, [ + None, + ]) + one_hot_labels = tf.one_hot(y, depth=model.num_class) + # Loss function + loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=one_hot_labels)) + correct = tf.nn.in_top_k(output_class, y, 1) + accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) + # Optimization + opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) + + # Initialize variables + init = tf.global_variables_initializer() + saver = tf.train.Saver() + batch_x, batch_y = train_iterator.get_next() + batch_test_x, batch_test_y = test_iterator.get_next() + with tf.Session() as sess: + sess.run([init]) + for i in range(train_steps): + batch_x_value, batch_y_value = sess.run([batch_x, batch_y]) + _, loss_value = sess.run([opt, loss], + feed_dict={ + x: batch_x_value, + y: batch_y_value + }) + if i % 100 == 0: + tf.logging.info('Training step %d, loss is %f' % (i, loss_value)) + if i > 0 and i % save_every_n_steps == 0: + accuracy_sum = 0.0 + for _ in range(eval_steps): + test_x_value, test_y_value = sess.run([batch_test_x, batch_test_y]) + accuracy_value = sess.run( + accuracy, feed_dict={ + x: test_x_value, + y: test_y_value + }) + accuracy_sum += accuracy_value + tf.logging.info('Training step %d, accuracy is %f' % + (i, accuracy_sum / (eval_steps * 1.0))) + saver.save(sess, model_dir) +``` + +### 3. Let's define the export to TensorFlow Lite model function. + +```python +def export(model, model_dir, tflite_model_file, + use_post_training_quantize=True): + """Export trained model to tflite model.""" + tf.reset_default_graph() + x, _, output_class = model.build_model() + saver = tf.train.Saver() + sess = tf.Session() + saver.restore(sess, model_dir) + # Freeze the graph. + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, [output_class.op.name]) + + # Important: + # + # Convert ophinted lstm ops to tflite UnidirectionalSequenceLstm ops. + converted_graph = + tf.lite.experimental.convert_op_hints_to_stubs(graph_def=frozen_graph) + converted_graph = tf.graph_util.remove_training_nodes(converted_graph) + converter = tf.lite.TFLiteConverter(converted_graph, [x], [output_class]) + converter.post_training_quantize = use_post_training_quantize + tflite = converter.convert() + with open(tflite_model_file, 'w') as f: + f.write(tflite) +``` + +### 4. Hook everything together. + +```python +def train_and_export(parsed_flags): + """Train the MNIST LSTM model and export to TfLite.""" + model = MnistLstmModel( + time_steps=28, + input_size=28, + num_lstm_layer=2, + num_lstm_units=64, + units=64, + num_class=10) + tf.logging.info('Starts training...') + train(model, parsed_flags.model_dir) + tf.logging.info('Finished training, starts exporting to tflite to %s ...' % + parsed_flags.tflite_model_file) + export(model, parsed_flags.model_dir, parsed_flags.tflite_model_file, + parsed_flags.use_post_training_quantize) + tf.logging.info( + 'Finished exporting, model is %s' % parsed_flags.tflite_model_file) + + +def run_main(_): + """Main in the TfLite LSTM tutorial.""" + parser = argparse.ArgumentParser( + description=('Train a MNIST recognition model then export to TfLite.')) + parser.add_argument( + '--model_dir', + type=str, + help='Directory where the models will store.', + required=True) + parser.add_argument( + '--tflite_model_file', + type=str, + help='Full filepath to the exported tflite model file.', + required=True) + parser.add_argument( + '--use_post_training_quantize', + action='store_true', + default=True, + help='Whether or not to use post_training_quatize.') + parsed_flags, _ = parser.parse_known_args() + train_and_export(parsed_flags) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == '__main__': + main() + +``` + +### 5. Visualize the exported TensorFlow Lite model. + +Let's go to where the TensorFlow Lite model is exported and use +[Netron](https://github.com/lutzroeder/netron) to visualize the graph. + +See below. + +##### Exported TensorFlow Lite Model. + +![Exported TensorFlow Lite Model](./images/exported_tflite_model.png) + +## Caveat + +* Currently, `tf.lite.experimental.nn.dynamic_rnn` & + `tf.lite.experimental.nn.bidirectional_dynamic_rnn` only supports + `control_flow_v2`, you can this on by setting the environment variable + `TF_ENABLE_CONTROL_FLOW_V2=1`, see in the tutorial. +* Currently, `sequence_length` is not supported, prefer to set it to None. +* `num_unit_shards` & `num_proj_shards` in LSTMCell are not supported as + well. +* Currently, `tf.lite.experimental.nn.dynamic_rnn` & + `tf.lite.experimental.nn.bidirectional_dynamic_rnn` only takes + `time_major=True`. +* The behavior of `tf.lite.experimental.nn.bidirectional_dynamic_rnn` is a + wrapper around `tf.nn.bidirectional_dynamic_rnn`, not + `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`. +* For bidirectional_rnn cases, make sure you include all the op_hinted nodes + before freeze the graph. See below: + +```python +all_output_nodes = [OUTPUT_NODES] +with tf.Session() as sess + all_output_nodes += tf.lite.find_all_hinted_output_nodes(sess) + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, all_output_nodes) +``` diff --git a/tensorflow/lite/experimental/examples/lstm/g3doc/images/exported_tflite_model.png b/tensorflow/lite/experimental/examples/lstm/g3doc/images/exported_tflite_model.png new file mode 100644 index 0000000000000000000000000000000000000000..0d489d7602e6c579555734e252269c16b4a4e4da Binary files /dev/null and b/tensorflow/lite/experimental/examples/lstm/g3doc/images/exported_tflite_model.png differ diff --git a/tensorflow/lite/experimental/examples/lstm/g3doc/images/op_hint.png b/tensorflow/lite/experimental/examples/lstm/g3doc/images/op_hint.png new file mode 100644 index 0000000000000000000000000000000000000000..583d4869b6d63e576085b601d071f7686a868d54 Binary files /dev/null and b/tensorflow/lite/experimental/examples/lstm/g3doc/images/op_hint.png differ diff --git a/tensorflow/lite/experimental/examples/lstm/g3doc/images/tf_lstm.png b/tensorflow/lite/experimental/examples/lstm/g3doc/images/tf_lstm.png new file mode 100644 index 0000000000000000000000000000000000000000..e962a3c720781e37949a0d654e11dffff1b6803a Binary files /dev/null and b/tensorflow/lite/experimental/examples/lstm/g3doc/images/tf_lstm.png differ diff --git a/tensorflow/lite/experimental/examples/lstm/g3doc/images/tflite_lstm.png b/tensorflow/lite/experimental/examples/lstm/g3doc/images/tflite_lstm.png new file mode 100644 index 0000000000000000000000000000000000000000..8f6befbb581cc85f99ca676ab65e3c0294e2e8dd Binary files /dev/null and b/tensorflow/lite/experimental/examples/lstm/g3doc/images/tflite_lstm.png differ diff --git a/tensorflow/lite/experimental/examples/lstm/rnn.py b/tensorflow/lite/experimental/examples/lstm/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e005d81e4051537e8bed2589005f1637251b301d --- /dev/null +++ b/tensorflow/lite/experimental/examples/lstm/rnn.py @@ -0,0 +1,429 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TfLite LSTMCell wrapper. + +TODO(renjieliu): Find a better home for this one. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.lite.python.op_hint as op_hint +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops.rnn import _best_effort_input_batch_size +from tensorflow.python.ops.rnn import _dynamic_rnn_loop +from tensorflow.python.ops.rnn import _should_cache +from tensorflow.python.ops.rnn import _transpose_batch_time +from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("lite.experimental.nn.dynamic_rnn") +def dynamic_rnn(cell, + inputs, + sequence_length=None, + initial_state=None, + dtype=None, + parallel_iterations=None, + swap_memory=False, + time_major=True, + scope=None): + """Creates a recurrent neural network specified by RNNCell `cell`. + + Performs fully dynamic unrolling of `inputs`. + + Example: + + ```python + # create a BasicRNNCell + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + + # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] + + # defining initial state + initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) + + # 'state' is a tensor of shape [batch_size, cell_state_size] + outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, + initial_state=initial_state, + dtype=tf.float32) + ``` + + ```python + # create 2 LSTMCells + rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] + + # create a RNN cell composed sequentially of a number of RNNCells + multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) + + # 'outputs' is a tensor of shape [batch_size, max_time, 256] + # 'state' is a N-tuple where N is the number of LSTMCells containing a + # tf.contrib.rnn.LSTMStateTuple for each cell + outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, + inputs=data, + dtype=tf.float32) + ``` + + + Args: + cell: An instance of RNNCell. + inputs: The RNN inputs. + If `time_major == False` (default), this must be a `Tensor` of shape: + `[batch_size, max_time, ...]`, or a nested tuple of such elements. + If `time_major == True`, this must be a `Tensor` of shape: `[max_time, + batch_size, ...]`, or a nested tuple of such elements. This may also be + a (possibly nested) tuple of Tensors satisfying this property. The + first two dimensions must match across all the inputs, but otherwise the + ranks and other shape components may differ. In this case, input to + `cell` at each time-step will replicate the structure of these tuples, + except for the time dimension (from which the time is taken). The input + to `cell` at each time step will be a `Tensor` or (possibly nested) + tuple of Tensors each with dimensions `[batch_size, ...]`. + sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used + to copy-through state and zero-out outputs when past a batch element's + sequence length. So it's more for performance than correctness. + initial_state: (optional) An initial state for the RNN. If `cell.state_size` + is an integer, this must be a `Tensor` of appropriate type and shape + `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this + should be a tuple of tensors having shapes `[batch_size, s] for s in + cell.state_size`. + dtype: (optional) The data type for the initial state and expected output. + Required if initial_state is not provided or RNN state has a heterogeneous + dtype. + parallel_iterations: (Default: 32). The number of iterations to run in + parallel. Those operations which do not have any temporal dependency and + can be run in parallel, will be. This parameter trades off time for + space. Values >> 1 use more memory but take less time, while smaller + values use less memory but computations take longer. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs which + would typically not fit on a single GPU, with very minimal (or no) + performance penalty. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, + these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using + `time_major = True` is a bit more efficient because it avoids transposes + at the beginning and end of the RNN calculation. However, most TensorFlow + data is batch-major, so by default this function accepts input and emits + output in batch-major form. + scope: VariableScope for the created subgraph; defaults to "rnn". + + Returns: + A pair (outputs, state) where: + + outputs: The RNN output `Tensor`. + + If time_major == False (default), this will be a `Tensor` shaped: + `[batch_size, max_time, cell.output_size]`. + + If time_major == True, this will be a `Tensor` shaped: + `[max_time, batch_size, cell.output_size]`. + + Note, if `cell.output_size` is a (possibly nested) tuple of integers + or `TensorShape` objects, then `outputs` will be a tuple having the + same structure as `cell.output_size`, containing Tensors having shapes + corresponding to the shape data in `cell.output_size`. + + state: The final state. If `cell.state_size` is an int, this + will be shaped `[batch_size, cell.state_size]`. If it is a + `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. + If it is a (possibly nested) tuple of ints or `TensorShape`, this will + be a tuple having the corresponding shapes. If cells are `LSTMCells` + `state` will be a tuple containing a `LSTMStateTuple` for each cell. + + Raises: + TypeError: If `cell` is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + RuntimeError: If not using control flow v2. + """ + + # Currently only support time_major == True case. + assert time_major + + # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or + # TfLiteRNNCells. + rnn_cell_impl.assert_like_rnncell("cell", cell) + + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: + raise RuntimeError("OpHint dynamic rnn only supports control flow v2.") + + parent_first_child_input = [{ + "parent_ophint_input_index": 0, + "first_child_ophint_input_index": 0 + }] + parent_last_child_output = [{ + "parent_output_index": 0, + # For LstmCell, the index is 2. + # For RnnCell, the index is 1. + # So we use -1 meaning it's the last one. + "child_output_index": -1 + }] + internal_children_input_output = [{ + "child_input_index": 0, + # For LstmCell, the index is 2. + # For RnnCell, the index is 1. + # So we use -1 meaning it's the last one. + "child_output_index": -1 + }] + inputs_outputs_mappings = { + "parent_first_child_input": parent_first_child_input, + "parent_last_child_output": parent_last_child_output, + "internal_children_input_output": internal_children_input_output + } + tflite_wrapper = op_hint.OpHint( + "TfLiteDynamicRnn", + level=2, + children_inputs_mappings=inputs_outputs_mappings) + with vs.variable_scope(scope or "rnn") as varscope: + # Create a new scope in which the caching device is either + # determined by the parent scope, or is set to place the cached + # Variable using the same placement as for the rest of the RNN. + if _should_cache(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) + + inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0) + + # By default, time_major==False and inputs are batch-major: shaped + # [batch, time, depth] + # For internal calculations, we transpose to [time, batch, depth] + flat_input = nest.flatten(inputs) + + if not time_major: + # (batch, time, depth) => (time, batch, depth) + flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] + flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) + + parallel_iterations = parallel_iterations or 32 + if sequence_length is not None: + sequence_length = math_ops.cast(sequence_length, dtypes.int32) + if sequence_length.shape.rank not in (None, 1): + raise ValueError( + "sequence_length must be a vector of length batch_size, " + "but saw shape: %s" % sequence_length.shape) + sequence_length = array_ops.identity( # Just to find it in the graph. + sequence_length, + name="sequence_length") + + batch_size = _best_effort_input_batch_size(flat_input) + + if initial_state is not None: + state = initial_state + else: + if not dtype: + raise ValueError("If there is no initial_state, you must give a dtype.") + if getattr(cell, "get_initial_state", None) is not None: + state = cell.get_initial_state( + inputs=None, batch_size=batch_size, dtype=dtype) + else: + state = cell.zero_state(batch_size, dtype) + + def _assert_has_shape(x, shape): + x_shape = array_ops.shape(x) + packed_shape = array_ops.stack(shape) + return control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ + "Expected shape for Tensor %s is " % x.name, packed_shape, + " but saw shape: ", x_shape + ]) + + if not context.executing_eagerly() and sequence_length is not None: + # Perform some shape validation + with ops.control_dependencies( + [_assert_has_shape(sequence_length, [batch_size])]): + sequence_length = array_ops.identity( + sequence_length, name="CheckSeqLen") + + inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) + + outputs, final_state = _dynamic_rnn_loop( + cell, + inputs, + state, + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, + sequence_length=sequence_length, + dtype=dtype) + + # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. + # If we are performing batch-major calculations, transpose output back + # to shape [batch, time, depth] + if not time_major: + # (time, batch, depth) => (batch, time, depth) + outputs = nest.map_structure(_transpose_batch_time, outputs) + outputs = tflite_wrapper.add_output(outputs, name="outputs") + + return outputs, final_state + + +def bidirectional_dynamic_rnn(cell_fw, + cell_bw, + inputs, + sequence_length=None, + initial_state_fw=None, + initial_state_bw=None, + dtype=None, + parallel_iterations=None, + swap_memory=False, + time_major=False, + scope=None): + """Creates a dynamic version of bidirectional recurrent neural network. + + Takes input and builds independent forward and backward RNNs. The input_size + of forward and backward cell must match. The initial state for both directions + is zero by default (but can be set optionally) and no intermediate states are + ever returned -- the network is fully unrolled for the given (passed in) + length(s) of the sequence(s) or completely unrolled if length(s) is not + given. + + Args: + cell_fw: An instance of RNNCell, to be used for forward direction. + cell_bw: An instance of RNNCell, to be used for backward direction. + inputs: The RNN inputs. + If time_major == False (default), this must be a tensor of shape: + `[batch_size, max_time, ...]`, or a nested tuple of such elements. + If time_major == True, this must be a tensor of shape: `[max_time, + batch_size, ...]`, or a nested tuple of such elements. + sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, + containing the actual lengths for each of the sequences in the batch. If + not provided, all batch entries are assumed to be full sequences; and time + reversal is applied from time `0` to `max_time` for each sequence. + initial_state_fw: (optional) An initial state for the forward RNN. This must + be a tensor of appropriate type and shape `[batch_size, + cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a + tuple of tensors having shapes `[batch_size, s] for s in + cell_fw.state_size`. + initial_state_bw: (optional) Same as for `initial_state_fw`, but using the + corresponding properties of `cell_bw`. + dtype: (optional) The data type for the initial states and expected output. + Required if initial_states are not provided or RNN states have a + heterogeneous dtype. + parallel_iterations: (Default: 32). The number of iterations to run in + parallel. Those operations which do not have any temporal dependency and + can be run in parallel, will be. This parameter trades off time for + space. Values >> 1 use more memory but take less time, while smaller + values use less memory but computations take longer. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs which + would typically not fit on a single GPU, with very minimal (or no) + performance penalty. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, + these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using + `time_major = True` is a bit more efficient because it avoids transposes + at the beginning and end of the RNN calculation. However, most TensorFlow + data is batch-major, so by default this function accepts input and emits + output in batch-major form. + scope: VariableScope for the created subgraph; defaults to + "bidirectional_rnn" + + Returns: + A tuple (outputs, output_states) where: + outputs: A tuple (output_fw, output_bw) containing the forward and + the backward rnn output `Tensor`. + If time_major == False (default), + output_fw will be a `Tensor` shaped: + `[batch_size, max_time, cell_fw.output_size]` + and output_bw will be a `Tensor` shaped: + `[batch_size, max_time, cell_bw.output_size]`. + If time_major == True, + output_fw will be a `Tensor` shaped: + `[max_time, batch_size, cell_fw.output_size]` + and output_bw will be a `Tensor` shaped: + `[max_time, batch_size, cell_bw.output_size]`. + It returns a tuple instead of a single concatenated `Tensor`, unlike + in the `bidirectional_rnn`. If the concatenated one is preferred, + the forward and backward outputs can be concatenated as + `tf.concat(outputs, 2)`. + output_states: A tuple (output_state_fw, output_state_bw) containing + the forward and the backward final states of bidirectional rnn. + + Raises: + TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. + """ + rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) + rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) + + with vs.variable_scope(scope or "bidirectional_rnn"): + # Forward direction + with vs.variable_scope("fw") as fw_scope: + output_fw, output_state_fw = dynamic_rnn( + cell=cell_fw, + inputs=inputs, + sequence_length=sequence_length, + initial_state=initial_state_fw, + dtype=dtype, + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, + time_major=time_major, + scope=fw_scope) + + # Backward direction + if not time_major: + time_axis = 1 + batch_axis = 0 + else: + time_axis = 0 + batch_axis = 1 + + def _reverse(input_, seq_lengths, seq_axis, batch_axis): + if seq_lengths is not None: + return array_ops.reverse_sequence( + input=input_, + seq_lengths=seq_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis) + else: + return array_ops.reverse(input_, axis=[seq_axis]) + + with vs.variable_scope("bw") as bw_scope: + + def _map_reverse(inp): + return _reverse( + inp, + seq_lengths=sequence_length, + seq_axis=time_axis, + batch_axis=batch_axis) + + inputs_reverse = nest.map_structure(_map_reverse, inputs) + tmp, output_state_bw = dynamic_rnn( + cell=cell_bw, + inputs=inputs_reverse, + sequence_length=sequence_length, + initial_state=initial_state_bw, + dtype=dtype, + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, + time_major=time_major, + scope=bw_scope) + + output_bw = _reverse( + tmp, + seq_lengths=sequence_length, + seq_axis=time_axis, + batch_axis=batch_axis) + + outputs = (output_fw, output_bw) + output_states = (output_state_fw, output_state_bw) + + return (outputs, output_states) diff --git a/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py similarity index 71% rename from tensorflow/lite/experimental/examples/lstm/tflite_lstm.py rename to tensorflow/lite/experimental/examples/lstm/rnn_cell.py index 461345060badbad0fc65b37466436b1a1eb424a4..ba2f7875563663bbc78bcf1716280e9cf05c7277 100644 --- a/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py +++ b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TfLite LSTMCell wrapper. +"""TfLite BasicRnnCell wrapper. TODO(renjieliu): Find a better home for this one. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow as tf +import itertools -from tensorflow.lite.python import lite +import tensorflow.lite.python.op_hint as op_hint from tensorflow.python.keras import activations from tensorflow.python.keras import initializers from tensorflow.python.layers import base as base_layer @@ -33,8 +33,130 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export +@tf_export("lite.experimental.nn.TfLiteRNNCell") +class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): + """The most basic RNN cell. + + This is used only for TfLite, it provides hints and it also makes the + variables in the desired for the tflite ops. + """ + + def __init__(self, + num_units, + activation=None, + reuse=None, + name=None, + dtype=None, + **kwargs): + """Initializes the parameters for an RNN cell. + + Args: + num_units: int, The number of units in the RNN cell. + activation: Nonlinearity to use. Default: `tanh`. It could also be string + that is within Keras activation function names. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. Raises an error if not `True` and the existing scope + already has the given variables. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). + + Raises: + ValueError: If the existing scope already has the given variables. + """ + super(TfLiteRNNCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) + + # Inputs must be Rank-2. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._tflite_wrapper = op_hint.OpHint("UnidirectionalSequenceRnn") + self._num_units = num_units + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + """Builds the RNN cell. + + Args: + inputs_shape: Rnn input tensor shape. + + Raises: + ValueError: If last dimension of the input shape is not known. + """ + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + (inputs_shape,)) + + input_depth = inputs_shape[-1] + + def add_variable_wrapped(name, shape, initializer, index): + var = self.add_weight(name, shape=shape, initializer=initializer) + return self._tflite_wrapper.add_input( + var, name=name, index_override=index) + + self._input_weights = add_variable_wrapped( + "input_weights", [self._num_units, input_depth], None, 1) + self._recurrent_weights = add_variable_wrapped( + "recurrent_weights", [self._num_units, self._num_units], None, 2) + self._bias = add_variable_wrapped( + "bias", + shape=[self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype), + index=3) + + self.built = True + + def call(self, inputs, state): + """Most basic RNN: output = new_state = act(W * input + U * state + B).""" + inputs = self._tflite_wrapper.add_input( + inputs, tag="input", name="input", aggregate="stack", index_override=0) + state = self._tflite_wrapper.add_input( + state, + tag="hidden_state", + name="hidden_state", + aggregate="first", + index_override=4) + weights = array_ops.transpose( + array_ops.concat([self._input_weights, self._recurrent_weights], 1)) + gate_inputs = math_ops.matmul(array_ops.concat([inputs, state], 1), weights) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + output = self._activation(gate_inputs) + output = self._tflite_wrapper.add_output( + output, + tag="output", + name="output", + index_override=1, + aggregate="stack") + return output, output + + def get_config(self): + config = { + "num_units": self._num_units, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(TfLiteRNNCell, self).get_config() + return dict(itertools.chain(base_config.items(), config.items())) + + +@tf_export("lite.experimental.nn.TFLiteLSTMCell") class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -132,7 +254,7 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): # TODO(raziel): layers stuff -- chop if un-layerizing Op. self.input_spec = base_layer.InputSpec(ndim=2) - self._tflite_wrapper = lite.OpHint("UnidirectionalSequenceLstm") + self._tflite_wrapper = op_hint.OpHint("UnidirectionalSequenceLstm") self._num_units = num_units self._use_peepholes = use_peepholes @@ -148,7 +270,7 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): self._output_size = num_proj if num_proj else num_units self._state_size = ( - tf.nn.rnn_cell.LSTMStateTuple(num_units, self._output_size) + rnn_cell_impl.LSTMStateTuple(num_units, self._output_size) if state_is_tuple else num_units + self._output_size) @property @@ -169,10 +291,15 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): Raises: ValueError: if the inputs_shape is invalid. """ - if len(inputs_shape) != 2 or inputs_shape[1].value is None: + if len(inputs_shape) != 2: + raise ValueError( + "inputs_shape must be 2-dimensional, saw shape: %s" % inputs_shape) + input_depth = ( + inputs_shape[1] + if isinstance(inputs_shape[1], int) else inputs_shape[1].value) + if input_depth is None: raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value maybe_partitioner = ( partitioned_variables.fixed_size_partitioner(self._num_unit_shards) if self._num_unit_shards is not None else None) @@ -181,7 +308,7 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): bias_shape = [self._num_units] def add_variable_wrapped(name, shape, initializer, index, partitioner): - var = self.add_variable( + var = self.add_weight( name, shape=shape, initializer=initializer, partitioner=partitioner) return self._tflite_wrapper.add_input( var, name=name, index_override=index) @@ -192,6 +319,8 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): else: bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) + forget_bias_initializer = init_ops.constant_initializer(self._forget_bias) + self.input_to_input_w = add_variable_wrapped( "input_to_input_w", input_weight_shape, weight_initializer, 1, maybe_partitioner) @@ -219,8 +348,9 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): self.input_bias = add_variable_wrapped( "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) - self.forget_bias = add_variable_wrapped( - "forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner) + self.forget_bias = add_variable_wrapped("forget_bias", bias_shape, + forget_bias_initializer, 13, + maybe_partitioner) self.cell_bias = add_variable_wrapped( "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) self.output_bias = add_variable_wrapped( @@ -230,10 +360,10 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): # f stands for forget, i stands for input and o stands for output. if self._use_peepholes: self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units], - self._initializer, 9, + self._initializer, 10, maybe_partitioner) self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units], - self._initializer, 10, + self._initializer, 9, maybe_partitioner) self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units], self._initializer, 11, @@ -306,9 +436,9 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): aggregate="first", index_override=18) - input_size = inputs.get_shape().with_rank(2)[1] + input_size = inputs.shape.with_rank(2)[1] if input_size.value is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + raise ValueError("Could not infer input size from inputs.shape[-1]") inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1) @@ -319,35 +449,37 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): # c is the final state. # m is the output. i = nn_ops.bias_add( - tf.matmul( + math_ops.matmul( inputs_and_m_prev, - tf.concat([self.input_to_input_w, self.cell_to_input_w], axis=1), + array_ops.concat([self.input_to_input_w, self.cell_to_input_w], + axis=1), transpose_b=True), self.input_bias) f = nn_ops.bias_add( - tf.matmul( + math_ops.matmul( inputs_and_m_prev, - tf.concat([self.input_to_forget_w, self.cell_to_forget_w], axis=1), + array_ops.concat([self.input_to_forget_w, self.cell_to_forget_w], + axis=1), transpose_b=True), self.forget_bias) o = nn_ops.bias_add( - tf.matmul( + math_ops.matmul( inputs_and_m_prev, - tf.concat([self.input_to_output_w, self.cell_to_output_w], axis=1), + array_ops.concat([self.input_to_output_w, self.cell_to_output_w], + axis=1), transpose_b=True), self.output_bias) j = nn_ops.bias_add( - tf.matmul( + math_ops.matmul( inputs_and_m_prev, - tf.concat([self.input_to_cell_w, self.cell_to_cell_w], axis=1), + array_ops.concat([self.input_to_cell_w, self.cell_to_cell_w], + axis=1), transpose_b=True), self.cell_bias) # Diagonal connections if self._use_peepholes: c = ( - sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(f + self._w_f_diag * c_prev) * c_prev + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) else: - c = ( - sigmoid(f + self._forget_bias) * c_prev + - sigmoid(i) * self._activation(j)) + c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type @@ -359,7 +491,7 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): m = sigmoid(o) * self._activation(c) if self._num_proj is not None: - transposed_proj_kernel = tf.transpose(self._proj_kernel) + transposed_proj_kernel = array_ops.transpose(self._proj_kernel) m = math_ops.matmul(m, transposed_proj_kernel) if self._proj_clip is not None: @@ -373,7 +505,7 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): m, tag="m", name="m", index_override=2, aggregate="stack") new_state = ( - tf.nn.rnn_cell.LSTMStateTuple(c, m) + rnn_cell_impl.LSTMStateTuple(c, m) if self._state_is_tuple else array_ops.concat([c, m], 1)) return m, new_state diff --git a/tensorflow/lite/experimental/examples/lstm/tflite_rnn.py b/tensorflow/lite/experimental/examples/lstm/tflite_rnn.py deleted file mode 100644 index e4aad18367e6c8bf9669e928dff8d7c9376043b7..0000000000000000000000000000000000000000 --- a/tensorflow/lite/experimental/examples/lstm/tflite_rnn.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TfLite BasicRnnCell wrapper. - -TODO(renjieliu): Find a better home for this one. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import itertools - -from tensorflow.lite.python import lite -from tensorflow.python.keras import activations -from tensorflow.python.layers import base as base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import rnn_cell_impl - - -class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): - """The most basic RNN cell. - - This is used only for TfLite, it provides hints and it also makes the - variables in the desired for the tflite ops. - """ - - def __init__(self, - num_units, - activation=None, - reuse=None, - name=None, - dtype=None, - **kwargs): - """Initializes the parameters for an RNN cell. - - Args: - num_units: int, The number of units in the RNN cell. - activation: Nonlinearity to use. Default: `tanh`. It could also be string - that is within Keras activation function names. - reuse: (optional) Python boolean describing whether to reuse variables in - an existing scope. Raises an error if not `True` and the existing scope - already has the given variables. - name: String, the name of the layer. Layers with the same name will share - weights, but to avoid mistakes we require reuse=True in such cases. - dtype: Default dtype of the layer (default of `None` means use the type of - the first input). Required when `build` is called before `call`. - **kwargs: Dict, keyword named properties for common layer attributes, like - `trainable` etc when constructing the cell from configs of get_config(). - - Raises: - ValueError: If the existing scope already has the given variables. - """ - super(TfLiteRNNCell, self).__init__( - _reuse=reuse, name=name, dtype=dtype, **kwargs) - - # Inputs must be Rank-2. - self.input_spec = base_layer.InputSpec(ndim=2) - - self._tflite_wrapper = lite.OpHint("UnidirectionalSequenceRnn") - self._num_units = num_units - if activation: - self._activation = activations.get(activation) - else: - self._activation = math_ops.tanh - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - def build(self, inputs_shape): - """Builds the RNN cell. - - Args: - inputs_shape: Rnn input tensor shape. - - Raises: - ValueError: If last dimension of the input shape is not known. - """ - if inputs_shape[-1] is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % - (inputs_shape,)) - - input_depth = inputs_shape[-1] - - def add_variable_wrapped(name, shape, initializer, index): - var = self.add_variable(name, shape=shape, initializer=initializer) - return self._tflite_wrapper.add_input( - var, name=name, index_override=index) - - self._input_weights = add_variable_wrapped( - "input_weights", [self._num_units, input_depth], None, 1) - self._recurrent_weights = add_variable_wrapped( - "recurrent_weights", [self._num_units, self._num_units], None, 2) - self._bias = add_variable_wrapped( - "bias", - shape=[self._num_units], - initializer=init_ops.zeros_initializer(dtype=self.dtype), - index=3) - - self.built = True - - def call(self, inputs, state): - """Most basic RNN: output = new_state = act(W * input + U * state + B).""" - inputs = self._tflite_wrapper.add_input( - inputs, tag="input", name="input", aggregate="stack", index_override=0) - state = self._tflite_wrapper.add_input( - state, - tag="hidden_state", - name="hidden_state", - aggregate="first", - index_override=4) - weights = array_ops.transpose( - array_ops.concat([self._input_weights, self._recurrent_weights], 1)) - gate_inputs = math_ops.matmul(array_ops.concat([inputs, state], 1), weights) - gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) - output = self._activation(gate_inputs) - output = self._tflite_wrapper.add_output( - output, - tag="output", - name="output", - index_override=1, - aggregate="stack") - return output, output - - def get_config(self): - config = { - "num_units": self._num_units, - "activation": activations.serialize(self._activation), - "reuse": self._reuse, - } - base_config = super(TfLiteRNNCell, self).get_config() - return dict(itertools.chain(base_config.items(), config.items())) diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py index eeb48d123113c5924a74286ad1e0851eb484cdb8..e29c75100344c95d9e76e1d6e27a8b70fe2add4d 100644 --- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py +++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py @@ -20,12 +20,12 @@ import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data -from tensorflow.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.tools import optimize_for_inference_lib + # Number of steps to train model. TRAIN_STEPS = 1 @@ -54,20 +54,22 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): self.num_units = 16 def buildLstmLayer(self): - return tf.nn.rnn_cell.MultiRNNCell([ - TFLiteLSTMCell( - self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), - TFLiteLSTMCell(self.num_units, num_proj=8, forget_bias=0, name="rnn2"), - TFLiteLSTMCell( + return tf.keras.layers.StackedRNNCells([ + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units, use_peepholes=True, forget_bias=1.0, name="rnn1"), + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units, num_proj=8, forget_bias=1.0, name="rnn2"), + tf.lite.experimental.nn.TFLiteLSTMCell( self.num_units // 2, use_peepholes=True, num_proj=8, forget_bias=0, name="rnn3"), - TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4") + tf.lite.experimental.nn.TFLiteLSTMCell( + self.num_units, forget_bias=1.0, name="rnn4") ]) - def buildModel(self, lstm_layer, is_dynamic_rnn, is_train): + def buildModel(self, lstm_layer, is_dynamic_rnn): # Weights and biases for output softmax layer. out_weights = tf.Variable( tf.random_normal([self.num_units, self.n_classes])) @@ -77,16 +79,12 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): x = tf.placeholder( "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE") - # For dynamic_rnn, train with dynamic_rnn and inference with static_rnn. # x is shaped [batch_size,time_steps,num_inputs] if is_dynamic_rnn: - if is_train: - lstm_input = x - outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32") - outputs = tf.unstack(outputs, axis=1) - else: - lstm_input = tf.unstack(x, self.time_steps, 1) - outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + lstm_input = tf.transpose(x, perm=[1, 0, 2]) + outputs, _ = tf.lite.experimental.nn.dynamic_rnn( + lstm_layer, lstm_input, dtype="float32") + outputs = tf.unstack(outputs, axis=0) else: lstm_input = tf.unstack(x, self.time_steps, 1) outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") @@ -126,8 +124,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): # Reset the graph. tf.reset_default_graph() - x, prediction, output_class = self.buildModel( - lstm_layer, is_dynamic_rnn, is_train=False) + x, prediction, output_class = self.buildModel(lstm_layer, is_dynamic_rnn) new_sess = tf.Session(config=CONFIG) saver = tf.train.Saver() @@ -157,8 +154,8 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], [tf.float32.as_datatype_enum]) - tflite = tf.lite.toco_convert( - curr, [tflite_input], [outputs], allow_custom_ops=False) + converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs]) + tflite = converter.convert() interpreter = tf.lite.Interpreter(model_content=tflite) try: @@ -179,7 +176,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): sess = tf.Session(config=CONFIG) x, prediction, output_class = self.buildModel( - self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True) + self.buildLstmLayer(), is_dynamic_rnn=False) self.trainModel(x, prediction, output_class, sess) saver = tf.train.Saver() @@ -192,26 +189,15 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + @test_util.enable_control_flow_v2 def testDynamicRnnMultiRnnCell(self): sess = tf.Session(config=CONFIG) x, prediction, output_class = self.buildModel( - self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True) + self.buildLstmLayer(), is_dynamic_rnn=True) self.trainModel(x, prediction, output_class, sess) - # Since we don't yet support OpHints for dynamic, we will load the model - # back in as a static model. This requires the variables to have the same - # names as if they were trained as a static. Thus, we get rid of while/rnn - # names. - variables_to_save = {} - for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): - op_name = i.name - if op_name.startswith("while/rnn/"): - op_name = op_name.split("while/rnn/")[1] - if op_name.endswith(":0"): - op_name = op_name.split(":0")[0] - variables_to_save[op_name] = i - saver = tf.train.Saver(variables_to_save) + saver = tf.train.Saver() x, prediction, output_class, new_sess = self.saveAndRestoreModel( self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True) diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py index 6f9e2dd9498f03665b52e423db43ce38d5401eb1..aa7c16536bbd1faf1122e50d67ef3df7216150ca 100644 --- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py +++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py @@ -22,7 +22,6 @@ import tensorflow as tf from tensorflow import flags from tensorflow.examples.tutorials.mnist import input_data -from tensorflow.lite.experimental.examples.lstm.tflite_rnn import TfLiteRNNCell from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -61,12 +60,12 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): self.mnist = input_data.read_data_sets(data_dir, one_hot=True) def buildRnnLayer(self): - return tf.nn.rnn_cell.MultiRNNCell([ - TfLiteRNNCell(self.num_units, name="rnn1"), - TfLiteRNNCell(self.num_units, name="rnn2") + return tf.keras.layers.StackedRNNCells([ + tf.lite.experimental.nn.TfLiteRNNCell(self.num_units, name="rnn1"), + tf.lite.experimental.nn.TfLiteRNNCell(self.num_units, name="rnn2") ]) - def buildModel(self, rnn_layer): + def buildModel(self, rnn_layer, is_dynamic_rnn): # Weights and biases for output softmax layer. out_weights = tf.Variable( tf.random_normal([self.num_units, self.n_classes])) @@ -77,8 +76,14 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE") # x is shaped [batch_size,time_steps,num_inputs] - rnn_input = tf.unstack(x, self.time_steps, 1) - outputs, _ = tf.nn.static_rnn(rnn_layer, rnn_input, dtype="float32") + if is_dynamic_rnn: + rnn_input = tf.transpose(x, perm=[1, 0, 2]) + outputs, _ = tf.lite.experimental.nn.dynamic_rnn( + rnn_layer, rnn_input, dtype="float32") + outputs = tf.unstack(outputs, axis=0) + else: + rnn_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(rnn_layer, rnn_input, dtype="float32") # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units] # by the softmax layer's out_weight of shape [num_units,n_classes] @@ -108,13 +113,14 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): self.n_input)) sess.run(opt, feed_dict={x: batch_x, y: batch_y}) - def saveAndRestoreModel(self, rnn_layer, sess, saver): + def saveAndRestoreModel(self, rnn_layer, sess, saver, is_dynamic_rnn): """Saves and restores the model to mimic the most common use case. Args: rnn_layer: The rnn layer either a single rnn cell or a multi rnn cell. sess: Old session. saver: saver created by tf.train.Saver() + is_dynamic_rnn: use dynamic_rnn or not. Returns: A tuple containing: @@ -130,7 +136,7 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): # Reset the graph. tf.reset_default_graph() - x, prediction, output_class = self.buildModel(rnn_layer) + x, prediction, output_class = self.buildModel(rnn_layer, is_dynamic_rnn) new_sess = tf.Session(config=CONFIG) saver = tf.train.Saver() @@ -160,8 +166,8 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], [tf.float32.as_datatype_enum]) - tflite = tf.lite.toco_convert( - curr, [tflite_input], [outputs], allow_custom_ops=False) + converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs]) + tflite = converter.convert() interpreter = tf.lite.Interpreter(model_content=tflite) interpreter.allocate_tensors() @@ -177,12 +183,32 @@ class UnidirectionalSequenceRnnTest(test_util.TensorFlowTestCase): def testStaticRnnMultiRnnCell(self): sess = tf.Session(config=CONFIG) - x, prediction, output_class = self.buildModel(self.buildRnnLayer()) + x, prediction, output_class = self.buildModel( + self.buildRnnLayer(), is_dynamic_rnn=False) self.trainModel(x, prediction, output_class, sess) saver = tf.train.Saver() x, prediction, output_class, new_sess = self.saveAndRestoreModel( - self.buildRnnLayer(), sess, saver) + self.buildRnnLayer(), sess, saver, is_dynamic_rnn=False) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) + + @test_util.enable_control_flow_v2 + def testDynamicRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildRnnLayer(), is_dynamic_rnn=True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildRnnLayer(), sess, saver, is_dynamic_rnn=True) test_inputs, expected_output, frozen_graph = self.getInferenceResult( x, output_class, new_sess) diff --git a/tensorflow/lite/experimental/kernels/BUILD b/tensorflow/lite/experimental/kernels/BUILD index dd314545cb6488ea2a76494df39b4b69e92eca33..78af889cf1ef4c90e3e096e3cc447ecc557f8b51 100644 --- a/tensorflow/lite/experimental/kernels/BUILD +++ b/tensorflow/lite/experimental/kernels/BUILD @@ -5,7 +5,6 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow/lite:build_def.bzl", "tflite_copts") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") # ctc support classes imported directly from TensorFlow. cc_library( @@ -50,26 +49,18 @@ cc_library( }), deps = [ ":ctc_utils", - "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:framework", - "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", - "//tensorflow/lite/kernels:builtin_ops", - "//tensorflow/lite/kernels:gemm_support", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:optimized", "//tensorflow/lite/kernels/internal:optimized_base", - "//tensorflow/lite/kernels/internal:quantization_util", - "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/kernels/internal:tensor_utils", "@flatbuffers", ], ) -tf_cc_test( +cc_test( name = "ctc_beam_search_decoder_test", size = "small", srcs = ["ctc_beam_search_decoder_test.cc"], diff --git a/tensorflow/lite/experimental/micro/BUILD b/tensorflow/lite/experimental/micro/BUILD index 2d00ef76f4a3e5360e45f31ee486e0b8a7c74cc3..b16b8b49f8a46a43475cc08807570e84a160aed4 100644 --- a/tensorflow/lite/experimental/micro/BUILD +++ b/tensorflow/lite/experimental/micro/BUILD @@ -29,7 +29,6 @@ cc_library( "simple_tensor_allocator.h", ], deps = [ - "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/experimental/micro/README.md b/tensorflow/lite/experimental/micro/README.md index b47e0feb79de2129a375476565f9b61128bbf974..9eeae442ee9f5db3e1e608f57a73e2faab4d2807 100644 --- a/tensorflow/lite/experimental/micro/README.md +++ b/tensorflow/lite/experimental/micro/README.md @@ -7,118 +7,73 @@ so it's designed to be portable even to 'bare metal' systems. The core runtime fits in 16KB on a Cortex M3, and with enough operators to run a speech keyword detection model, takes up a total of 22KB. -The design goals are for the framework to be: - -- **Readable**: We want embedded software engineers to be able to understand - what's required to run ML inference without having to study research papers. - We've tried to keep the code base small, modular, and have reference - implementations of all operations to help with this. - -- **Easy to modify**: We know that there are a lot of different platforms and - requirements in the embedded world, and we don't expect to cover all of them - in one framework. Instead, we're hoping that it can be a good starting point - for developers to build on top of to meet their own needs. For example, we - tried to make it easy to replace the implementations of key computational - operators that are often crucial for performance, without having to touch - the data flow and other runtime code. We want it to make more sense to use - our workflow to handle things like model import and less-important - operations, and customize the parts that matter, rather than having to - reimplement everything in your own engine. - -- **Well-tested**: If you're modifying code, you need to know if your changes - are correct. Having an easy way to test lets you develop much faster. To - help there, we've written tests for all the components, and we've made sure - that the tests can be run on almost any platform, with no dependencies apart - from the ability to log text to a debug console somewhere. We also provide - an easy way to run all the tests on-device as part of an automated test - framework, and we use qemu/Renode emulation so that tests can be run even - without physical devices present. - -- **Easy to integrate**: We want to be as open a system as possible, and use - the best code available for each platform. To do that, we're going to rely - on projects like - [CMSIS-NN](https://www.keil.com/pack/doc/CMSIS/NN/html/index.html), - [uTensor](https://github.com/uTensor/uTensor), and other vendor libraries to - handle as much performance-critical code as possible. We know that there are - an increasing number of options to accelerate neural networks on - microcontrollers, so we're aiming to be a good host for deploying those - hardware technologies too. - -- **Compatible**: We're using the same file schema, interpreter API, and - kernel interface as regular TensorFlow Lite, so we leverage the large - existing set of tools, documentation, and examples for the project. The - biggest barrier to deploying ML models is getting them from a training - environment into a form that's easy to run inference on, so we see reusing - this rich ecosystem as being crucial to being easily usable. We also hope to - integrate this experimental work back into the main codebase in the future. - -To meet those goals, we've made some tradeoffs: - -- **Simple C++**: To help with readability, our code is written in a modern - version of C++, but we generally treat it as a "better C", rather relying on - more complex features such as template meta-programming. As mentioned - earlier, we avoid any use of dynamic memory allocation (new/delete) or the - standard C/C++ libraries, so we believe this should still be fairly - portable. It does mean that some older devices with C-only toolchains won't - be supported, but we're hoping that the reference operator implementations - (which are simple C-like functions) can still be useful in those cases. The - interfaces are also designed to be C-only, so it should be possible to - integrate the resulting library with pure C projects. - -- **Interpreted**: Code generation is a popular pattern for embedded code, - because it gives standalone code that's easy to modify and step through, but - we've chosen to go with an interpreted approach. In our internal - microcontroller work we've found that using an extremely stripped-down - interpreter with almost no dependencies gives us a lot of the same - advantages, but is easier to maintain. For example, when new updates come - out for the underlying library, you can just merge your local modifications - in a single step, rather than having to regenerate new code and then patch - in any changes you subsequently made. The coarse granularity of the - interpreted primitives means that each operation call typically takes - hundreds of thousands of instruction cycles at least, so we don't see - noticeable performance gains from avoiding what's essentially a single - switch statement at the interpreter level to call each operation. We're - still working on improving the packaging though, for example we're - considering having the ability to snapshot all the source files and headers - used for a particular model, being able to compile the code and data - together as a library, and then access it through a minimal set of C - interface calls which hide the underlying complexity. - -- **Flatbuffers**: We represent our models using - [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs), - with the difference that we always keep it in read-only program memory - (typically flash) rather than relying on having a file system to read it - from. This is a good fit because flatbuffer's serialized format is designed - to be mapped into memory without requiring any extra memory allocations or - modifications to access it. All of the functions to read model values work - directly on the serialized bytes, and large sections of data like weights - are directly accessible as sequential C-style arrays of their data type, - with no strides or unpacking needed. We do get a lot of value from using - flatbuffers, but there is a cost in complexity. The flat buffer library code - is all inline - [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema_generated.h), - but it isn't straightforward to inspect their implementations, and the model - data structures aren't easy to comprehend from the debugger. The header for - the schema itself also has to be periodically updated when new information - is added to the file format, though we try to handle that transparently for - most developers by checking in a pre-generated version. - -- **Code Duplication**: Some of the code in this prototype largely duplicates - the logic in other parts of the TensorFlow Lite code base, for example the - operator wrappers. We've tried to keep share as much as we can between the - two interpreters, but there are some assumptions built into the original - runtime that make this difficult. We'll be working on modularizing the main - interpreter so that we can move to an entirely shared system. - -This initial preview release is designed to get early feedback, and is not -intended to be a final product. It only includes enough operations to run a -simple keyword recognition model, and the implementations are not optimized. -We're hoping this will be a good way to get feedback and collaborate to improve -the framework. - -## Getting Started with Make - -Building requires a Linux or OS X machine. +## Table of Contents + +- [Getting Started](#getting-started) + + * [Getting Started with Portable Reference Code](#getting-started-with-portable-reference-code) + * [Building Portable Reference Code using Make](#building-portable-reference-code-using-make) + * [Building for the "Blue Pill" STM32F103 using Make](#building-for-the-blue-pill-stm32f103-using-make) + * [Building for "Hifive1" SiFive FE310 development board using Make](#building-for-hifive1-sifive-fe310-development-board-using-make) + * [Building for Ambiq Micro Apollo3Blue EVB using Make](#building-for-ambiq-micro-apollo3blue-evb-using-make) + * [Additional Apollo3 Instructions](#additional-apollo3-instructions) + * [Building for the Eta Compute ECM3531 EVB using Make](#Building-for-the-Eta-Compute-ECM3531-EVB-using-Make) + +- [Goals](#goals) + +- [Generating Project Files](#generating-project-#files) + +- [How to Port TensorFlow Lite Micro to a New Platform](#how-to-port-tensorflow-lite-micro-to-a-new-platform) + + * [Requirements](#requirements) + * [Getting Started](getting-started) + * [Troubleshooting](#troubleshooting) + * [Optimizing for your Platform](#optimizing-for-your-platform) + * [Code Module Organization](#code-module-organization) + * [Working with Generated Projects](#working-with-generated-projects) + * [Supporting a Platform with Makefiles](#supporting-a-platform-with-makefiles) + * [Supporting a Platform with Emulation Testing](#supporting-a-platform-with-emulation-testing) + * [Implementing More Optimizations](#implementing-more-optimizations) + +# Getting Started + +One of the challenges of embedded software development is that there are a lot +of different architectures, devices, operating systems, and build systems. We +aim to support as many of the popular combinations as we can, and make it as +easy as possible to add support for others. + +If you're a product developer, we have build instructions or pre-generated +project files that you can download for the following platforms: + +Device | Mbed | Keil | Make/GCC +---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ | ------------------------------------------------------------------------------ | -------- +[STM32F746G Discovery Board](https://www.st.com/en/evaluation-tools/32f746gdiscovery.html) | [Download](https://drive.google.com/open?id=1OtgVkytQBrEYIpJPsE8F6GUKHPBS3Xeb) | - | [Download](https://drive.google.com/open?id=1u46mTtAMZ7Y1aD-He1u3R8AE4ZyEpnOl) +["Blue Pill" STM32F103-compatible development board](https://github.com/google/stm32_bare_lib) | - | - | [Instructions](#building-for-the-blue-pill-stm32f103-using-make) +[Ambiq Micro Apollo3Blue EVB using Make](https://ambiqmicro.com/apollo-ultra-low-power-mcus/) | - | - | [Instructions](#building-for-ambiq-micro-apollo3blue-evb-using-make) +[Generic Keil uVision Projects](http://www2.keil.com/mdk5/uvision/) | - | [Download](https://drive.google.com/open?id=1Lw9rsdquNKObozClLPoE5CTJLuhfh5mV) | - +[Eta Compute ECM3531 EVB](https://etacompute.com/) | - | - | [Instructions](#Building-for-the-Eta-Compute-ECM3531-EVB-using-Make) + +If your device is not yet supported, it may not be too hard to add support. You +can learn about that process +[here](#how-to-port-tensorflow-lite-micro-to-a-new-platform). We're looking +forward to getting your help expanding this table! + +## Getting Started with Portable Reference Code + +If you don't have a particular microcontroller platform in mind yet, or just +want to try out the code before beginning porting, the easiest way to begin is +by +[downloading the platform-agnostic reference code](https://drive.google.com/open?id=1cawEQAkqquK_SO4crReDYqf_v7yAwOY8). +You'll see a series of folders inside the archive, with each one containing just +the source files you need to build one binary. There is a simple Makefile for +each folder, but you should be able to load the files into almost any IDE and +build them. There's also a [Visual Studio Code](https://code.visualstudio.com/) project file already set up, so +you can easily explore the code in a cross-platform IDE. + +## Building Portable Reference Code using Make + +It's easy to build portable reference code directly from GitHub using make if +you're on a Linux or OS X machine. - Open a terminal - Download the TensorFlow source with `git clone @@ -177,7 +132,7 @@ building binaries that run locally on the Mac OS or Linux machine you're building on, but this approach becomes important when we're targeting simple micro controller devices. -## Building for the "Blue Pill" STM32F103 +## Building for the "Blue Pill" STM32F103 using Make The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the @@ -283,7 +238,7 @@ Successfully tagged riscv_build:latest Building micro_speech_test binary -- Lauch the Docker that we just created using: `docker run -it-v +- Launch the Docker that we just created using: `docker run -it-v /tmp/copybara_out:/workspace riscv_build:latest bash` - Enter the source root directory by running `cd /workspace` - Download the dependencies by running @@ -294,7 +249,7 @@ Building micro_speech_test binary - Build the binary: `make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=riscv32_mcu` -Lauching Renode to test the binary, currently this set up is not automated. +Launching Renode to test the binary, currently this set up is not automated. - Execute the binary on Renode: `renode -P 5000 --disable-xwt -e 's @/workspace/tensorflow/lite/experimental/micro/testing/sifive_fe310.resc'` @@ -313,16 +268,15 @@ You should see the following log with the magic string `~~~ALL TEST PASSED~~~`: 02:25:22.4253 [DEBUG] uart0: [+0.16ms host +0s virt 0.28s virt from start] Progam has exited with code:0x00000000 ``` -## Building for Ambiq Micro Apollo3Blue EVB +## Building for Ambiq Micro Apollo3Blue EVB using Make Follow these steps to get the pushbutton yes/no example working on Apollo 3: -1. Make sure to run the "Getting Started" section before performing the - following steps -2. Download Apollo3-SDK-2018.08.13 and place in - `tensorflow/lite/experimental/micro/tools/make/downloads`. This is not yet - publicly released, but you can contact ashah@ambiqmicro.com to request a - copy. +1. Make sure to run the "Building Portable Reference Code using Make" section + before performing the following steps +2. The Ambiq Micro SDK is downloaded into + `tensorflow/lite/experimental/micro/tools/make/downloads` by + 'download_dependencies.sh'. 3. Compile the project with the following command: make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=apollo3evb pushbutton_cmsis_speech_test_bin @@ -349,7 +303,10 @@ Follow these steps to get the pushbutton yes/no example working on Apollo 3: 4. Press BTN2. An LED will flash for 1 second. Speak your utterance during this one second 5. The debugger will print out four numbers. They are the probabilites for - 1) no speech, 2) unknown speech, 3) yes, 4) no + 1. no speech + 2. unknown speech + 3. yes + 4. no 6. The EVB LEDs will indicate detection. 1. LED0 (rightmost LED) - ON when capturing 1sec of audio 2. LED1 - ON when detecting silence @@ -359,12 +316,172 @@ Follow these steps to get the pushbutton yes/no example working on Apollo 3: ### Additional Apollo3 Instructions -To flash a part with JFlash Lite, do the following: 1. At the command line: -JFlashLiteExe 2. Device = AMA3B1KK-KBR 3. Interface = SWD at 1000 kHz 4. Data -file = -tensorflow/lite/experimental/micro/tools/make/gen/apollo3evb_cortex-m4/bin/pushbutton_cmsis_speech_test.bin +To flash a part with JFlash Lite, do the following: + +1. At the command line: JFlashLiteExe +2. Device = AMA3B1KK-KBR +3. Interface = SWD at 1000 kHz +4. Data file = `tensorflow/lite/experimental/micro/tools/make/gen/apollo3evb_cortex-m4/bin/pushbutton_cmsis_speech_test.bin` 5. Prog Addr = 0x0000C000 +## Building for the Eta Compute ECM3531 EVB using Make + +1. Follow the instructions at + [Tensorflow Micro Speech](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro/examples/micro_speech#getting-started) + to down load the Tensorflow source code and the support libraries \(but do + not run the make command shown there.\) +2. Download the Eta Compute SDK, version 0.0.17. Contact info@etacompute.com +3. You will need the the Arm compiler arm-none-eabi-gcc, version 7.3.1 + 20180622, release ARM/embedded-7-branch revision 261907, 7-2018-q2-update. + This compiler is downloaded when you run the + tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh + script. +4. Edit the file + tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc + so that the variables ETA_SDK and GCC_ARM point to the correct directories. +5. Compile the code with the command \ +     make -f + tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=ecm3531 + TAGS="CMSIS" test \ + This will produce a set of executables in the + tensorflow/lite/experimental/micro/tools/make/gen/ecm3531_cortex-m3/bin + directory. +6. To load an executable into SRAM \ +     Start ocd \ +     cd + tensorflow/lite/experimental/micro/tools/make/targets/ecm3531 \ +     ./load_program name_of_executable, for e.g., + ./load_program audio_provider_test \ +     Start PuTTY \(Connection type = Serial, Speed = + 11520, Data bits = 8, Stop bits = 1, Parity = None\) \ + The following output should appear: \ + Testing TestAudioProvider \ + Testing TestTimer \ + 2/2 tests passed \ + \~\~\~ALL TESTS PASSED\~\~\~ \ + Execution time \(msec\) = 7 +7. To load into flash \ +     Edit the variable ETA_LDS_FILE in + tensorflow/lite/experimental/micro/tools/  make/targets/ecm3531_makefile.inc + to point to the ecm3531_flash.lds file \ +     Recompile \( make -f + tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=ecm3531 + TAGS="CMSIS" test\) \ +     cd + tensorflow/lite/experimental/micro/tools/make/targets/ecm3531 \ +     ./flash_program executable_name to load into flash. + +## Goals + +The design goals are for the framework to be: + +- **Readable**: We want embedded software engineers to be able to understand + what's required to run ML inference without having to study research papers. + We've tried to keep the code base small, modular, and have reference + implementations of all operations to help with this. + +- **Easy to modify**: We know that there are a lot of different platforms and + requirements in the embedded world, and we don't expect to cover all of them + in one framework. Instead, we're hoping that it can be a good starting point + for developers to build on top of to meet their own needs. For example, we + tried to make it easy to replace the implementations of key computational + operators that are often crucial for performance, without having to touch + the data flow and other runtime code. We want it to make more sense to use + our workflow to handle things like model import and less-important + operations, and customize the parts that matter, rather than having to + reimplement everything in your own engine. + +- **Well-tested**: If you're modifying code, you need to know if your changes + are correct. Having an easy way to test lets you develop much faster. To + help there, we've written tests for all the components, and we've made sure + that the tests can be run on almost any platform, with no dependencies apart + from the ability to log text to a debug console somewhere. We also provide + an easy way to run all the tests on-device as part of an automated test + framework, and we use qemu/Renode emulation so that tests can be run even + without physical devices present. + +- **Easy to integrate**: We want to be as open a system as possible, and use + the best code available for each platform. To do that, we're going to rely + on projects like + [CMSIS-NN](https://www.keil.com/pack/doc/CMSIS/NN/html/index.html), + [uTensor](https://github.com/uTensor/uTensor), and other vendor libraries to + handle as much performance-critical code as possible. We know that there are + an increasing number of options to accelerate neural networks on + microcontrollers, so we're aiming to be a good host for deploying those + hardware technologies too. + +- **Compatible**: We're using the same file schema, interpreter API, and + kernel interface as regular TensorFlow Lite, so we leverage the large + existing set of tools, documentation, and examples for the project. The + biggest barrier to deploying ML models is getting them from a training + environment into a form that's easy to run inference on, so we see reusing + this rich ecosystem as being crucial to being easily usable. We also hope to + integrate this experimental work back into the main codebase in the future. + +To meet those goals, we've made some tradeoffs: + +- **Simple C++**: To help with readability, our code is written in a modern + version of C++, but we generally treat it as a "better C", rather relying on + more complex features such as template meta-programming. As mentioned + earlier, we avoid any use of dynamic memory allocation (new/delete) or the + standard C/C++ libraries, so we believe this should still be fairly + portable. It does mean that some older devices with C-only toolchains won't + be supported, but we're hoping that the reference operator implementations + (which are simple C-like functions) can still be useful in those cases. The + interfaces are also designed to be C-only, so it should be possible to + integrate the resulting library with pure C projects. + +- **Interpreted**: Code generation is a popular pattern for embedded code, + because it gives standalone code that's easy to modify and step through, but + we've chosen to go with an interpreted approach. In our internal + microcontroller work we've found that using an extremely stripped-down + interpreter with almost no dependencies gives us a lot of the same + advantages, but is easier to maintain. For example, when new updates come + out for the underlying library, you can just merge your local modifications + in a single step, rather than having to regenerate new code and then patch + in any changes you subsequently made. The coarse granularity of the + interpreted primitives means that each operation call typically takes + hundreds of thousands of instruction cycles at least, so we don't see + noticeable performance gains from avoiding what's essentially a single + switch statement at the interpreter level to call each operation. We're + still working on improving the packaging though, for example we're + considering having the ability to snapshot all the source files and headers + used for a particular model, being able to compile the code and data + together as a library, and then access it through a minimal set of C + interface calls which hide the underlying complexity. + +- **Flatbuffers**: We represent our models using + [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs), + with the difference that we always keep it in read-only program memory + (typically flash) rather than relying on having a file system to read it + from. This is a good fit because flatbuffer's serialized format is designed + to be mapped into memory without requiring any extra memory allocations or + modifications to access it. All of the functions to read model values work + directly on the serialized bytes, and large sections of data like weights + are directly accessible as sequential C-style arrays of their data type, + with no strides or unpacking needed. We do get a lot of value from using + flatbuffers, but there is a cost in complexity. The flat buffer library code + is all inline + [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema_generated.h), + but it isn't straightforward to inspect their implementations, and the model + data structures aren't easy to comprehend from the debugger. The header for + the schema itself also has to be periodically updated when new information + is added to the file format, though we try to handle that transparently for + most developers by checking in a pre-generated version. + +- **Code Duplication**: Some of the code in this prototype largely duplicates + the logic in other parts of the TensorFlow Lite code base, for example the + operator wrappers. We've tried to keep share as much as we can between the + two interpreters, but there are some assumptions built into the original + runtime that make this difficult. We'll be working on modularizing the main + interpreter so that we can move to an entirely shared system. + +This initial preview release is designed to get early feedback, and is not +intended to be a final product. It only includes enough operations to run a +simple keyword recognition model, and the implementations are not optimized. +We're hoping this will be a good way to get feedback and collaborate to improve +the framework. + ## Generating Project Files It's not always easy or convenient to use a makefile-based build process, @@ -384,7 +501,8 @@ This will create a folder in `tensorflow/lite/experimental/micro/tools/make/gen/mbed_cortex-m4/prj/micro_speech_main_test/mbed` that contains the source and header files, some Mbed configuration files, and a README. You should then be able to copy this directory to another machine, and -use it just like any other Mbed project. +use it just like any other Mbed project. There's more information about project +files [below](#working-with-generated-projects). ## How to Port TensorFlow Lite Micro to a New Platform @@ -537,7 +655,9 @@ critical parts of the code with versions specifically tailored to your architecture. The framework has been designed with this in mind, and we hope the combination of small modules and many tests makes it as straightforward as possible to swap in your own code a piece at a time, ensuring you have a working -version at every step. +version at every step. To write specialized implementations for a platform, it's +useful to understand how optional components are handled inside the build +system. ### Code Module Organization @@ -628,7 +748,7 @@ kernel implementations, but with some specific conventions: latest in the ordered list will be chosen. This allows us to express “I’d like generically-optimized fixed point if it’s available, but I’d prefer something using the CMSIS library” using the list 'fixed_point cmsis'. These - tags are passed in as `TAGS="/tensorflow`. + +You'll either need Linux, MacOS, or Windows with something like CygWin installed +to run the next steps, since they involve building a makefile. Run the following +commands from a terminal, inside the root of the source folder: + +``` +tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh +make -f tensorflow/lite/experimental/micro/tools/make/Makefile generate_projects +``` + +This will take a few minutes, since it has to download some large toolchains for +the dependencies. Once it has finished, you should see some folders created +inside a path like +`tensorflow/lite/experimental/micro/tools/make/gen/linux_x86_64/prj/`. The exact +path depends on your host operating system, but you should be able to figure it +out from all the copy commands. These folders contain the generated project and +source files, with +`tensorflow/lite/experimental/micro/tools/make/gen/linux_x86_64/prj/keil` +containing the Keil uVision targets, +`tensorflow/lite/experimental/micro/tools/make/gen/linux_x86_64/prj/mbed` with +the Mbed versions, and so on. + +If you've got this far, you've successfully set up the project generation flow. +Now you need to add your specialized implementation of `DebugLog()`. Start by +creating a folder inside `tensorflow/lite/experimental/micro/` named after the +tag you picked earlier. Put your `debug_log.cc` file inside this folder, and +then run this command, with '' replaced by the actual folder name: + +``` +make -f tensorflow/lite/experimental/micro/tools/make/Makefile TAGS="" generate_projects +``` + +If your tag name actually refers to a whole target architecture, then you'll use +TARGET or TARGET_ARCH instead. For example, here's how a simple RISC-V set of +projects is generated: + +``` +make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET="riscv32_mcu" generate_projects +``` + +The way it works is the same as TAGS though, it just looks for specialized +implementations with the same containing folder name. + +If you look inside the projects that have been created, you should see that the +default `DebugLog()` implementation is no longer present at +`tensorflow/lite/experimental/micro/debug_log.cc`, and instead +`tensorflow/lite/experimental/micro//debug_log.cc` is being used. Copy +over the generated project files and try building them in your own IDE. If +everything works, then you're ready to submit your change. + +To do this, run something like: + +``` +git add tensorflow/lite/experimental/micro//debug_log.cc +git commit -a -m "Added DebugLog() support for " +git push origin master +``` + +Then go back to https://github.com//tensorflow, and choose "New +Pull Request" near the top. You should then be able to go through the standard +TensorFlow PR process to get your change added to the main repository, and +available to the rest of the community! + +### Supporting a Platform with Makefiles + +The changes you've made so far will enable other developers using the generated +projects to use your platform, but TensorFlow's continuous integration process +uses makefiles to build frequently and ensure changes haven't broken the build +process for different systems. If you are able to convert your build procedure +into something that can be expressed by a makefile, then we can integrate your +platform into our CI builds and make sure it continues to work. + +Fully describing how to do this is beyond the scope of this documentation, but +the biggest needs are: + +- A command-line compiler that can be called for every source file. +- A list of the arguments to pass into the compiler to build and link all + files. +- The correct linker map files and startup assembler to ensure `main()` gets + called. + +### Supporting a Platform with Emulation Testing + +Integrating your platform into the makefile process should help us make sure +that it continues to build, but it doesn't guarantee that the results of the +build process will run correctly. Running tests is something we require to be +able to say that TensorFlow officially supports a platform, since otherwise we +can't guarantee that users will have a good experience when they try using it. +Since physically maintaining a full set of all supported hardware devices isn't +feasible, we rely on software emulation to run these tests. A good example is +our +[STM32F4 'Bluepill' support](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/testing/test_bluepill_binary.sh), +which uses [Docker](https://www.docker.com/) and [Renode](https://renode.io/) to +run built binaries in an emulator. You can use whatever technologies you want, +the only requirements are that they capture the debug log output of the tests +being run in the emulator, and parse them for the string that indicates the test +was successful. These scripts need to run on Ubuntu 18.04, in a bash +environment, though Docker is available if you need to install extra software or +have other dependencies. + +### Implementing More Optimizations + +Clearly, getting debug logging support is only the beginning of the work you'll +need to do on a particular platform. It's very likely that you'll want to +optimize the core deep learning operations that take up the most time when +running models you care about. The good news is that the process for providing +optimized implementations is the same as the one you just went through to +provide your own logging. You'll need to identify parts of the code that are +bottlenecks, and then add specialized implementations in their own folders. +These don't need to be platform specific, they can also be broken out by which +library they rely on for example. [Here's where we do that for the CMSIS +implementation of integer fast-fourier +transforms](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/preprocessor.cc). +This more complex case shows that you can also add helper source files alongside +the main implementation, as long as you +[mention them in the platform-specific makefile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc). +You can also do things like update the list of libraries that need to be linked +in, or add include paths to required headers. diff --git a/tensorflow/lite/experimental/micro/ecm3531/debug_log.cc b/tensorflow/lite/experimental/micro/ecm3531/debug_log.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d961963969039c75232b91bba12b54870225605 --- /dev/null +++ b/tensorflow/lite/experimental/micro/ecm3531/debug_log.cc @@ -0,0 +1,20 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/debug_log.h" + +#include "eta_csp_io.h" + +extern "C" void DebugLog(const char* s) { EtaCspIoPrintf("%s", s); } diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD index 171d07cf569a366aad8edf0715c6c1c8ea4a6d87..218b5de86a742f974520f82c8628fa9e3d08e210 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD @@ -11,34 +11,34 @@ load( ) cc_library( - name = "model_settings", + name = "simple_model_settings", srcs = [ - "model_settings.cc", + "simple_features/simple_model_settings.cc", ], hdrs = [ - "model_settings.h", + "simple_features/simple_model_settings.h", ], ) cc_library( - name = "tiny_conv_model_data", + name = "tiny_conv_simple_features_model_data", srcs = [ - "tiny_conv_model_data.cc", + "simple_features/tiny_conv_simple_features_model_data.cc", ], hdrs = [ - "tiny_conv_model_data.h", + "simple_features/tiny_conv_simple_features_model_data.h", ], ) cc_library( - name = "features_test_data", + name = "simple_features_test_data", srcs = [ - "no_features_data.cc", - "yes_features_data.cc", + "simple_features/no_simple_features_data.cc", + "simple_features/yes_simple_features_data.cc", ], hdrs = [ - "no_features_data.h", - "yes_features_data.h", + "simple_features/no_simple_features_data.h", + "simple_features/yes_simple_features_data.h", ], ) @@ -48,10 +48,10 @@ tflite_micro_cc_test( "micro_speech_test.cc", ], deps = [ - ":features_test_data", - ":tiny_conv_model_data", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_features_test_data", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:tiny_conv_micro_features_model_data", "//tensorflow/lite/experimental/micro/kernels:all_ops_resolver", "//tensorflow/lite/experimental/micro/kernels:micro_ops", "//tensorflow/lite/experimental/micro/testing:micro_test", @@ -60,45 +60,66 @@ tflite_micro_cc_test( ) cc_library( - name = "preprocessor_test_data", + name = "audio_sample_test_data", srcs = [ "no_30ms_sample_data.cc", - "no_power_spectrum_data.cc", "yes_30ms_sample_data.cc", - "yes_power_spectrum_data.cc", ], hdrs = [ "no_30ms_sample_data.h", - "no_power_spectrum_data.h", "yes_30ms_sample_data.h", - "yes_power_spectrum_data.h", ], ) cc_library( - name = "preprocessor_reference", + name = "audio_large_sample_test_data", srcs = [ - "preprocessor.cc", + "no_1000ms_sample_data.cc", + "yes_1000ms_sample_data.cc", ], hdrs = [ - "preprocessor.h", + "no_1000ms_sample_data.h", + "yes_1000ms_sample_data.h", + ], +) + +cc_library( + name = "simple_features_generator_test_data", + srcs = [ + "simple_features/no_power_spectrum_data.cc", + "simple_features/yes_power_spectrum_data.cc", + ], + hdrs = [ + "simple_features/no_power_spectrum_data.h", + "simple_features/yes_power_spectrum_data.h", + ], +) + +cc_library( + name = "simple_features_generator_reference", + srcs = [ + "simple_features/simple_features_generator.cc", + ], + hdrs = [ + "simple_features/simple_features_generator.h", ], deps = [ - ":model_settings", + ":simple_model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", ], ) tflite_micro_cc_test( - name = "preprocessor_reference_test", + name = "simple_features_generator_reference_test", srcs = [ - "preprocessor_test.cc", + "simple_features/simple_features_generator_test.cc", ], deps = [ - ":model_settings", - ":preprocessor_reference", - ":preprocessor_test_data", + ":audio_sample_test_data", + ":simple_features_generator_reference", + ":simple_features_generator_test_data", + ":simple_model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/experimental/micro/testing:micro_test", @@ -106,29 +127,30 @@ tflite_micro_cc_test( ) cc_library( - name = "preprocessor_fixed", + name = "simple_features_generator_fixed", srcs = [ - "fixed_point/preprocessor.cc", + "simple_features/fixed_point/simple_features_generator.cc", ], hdrs = [ - "preprocessor.h", + "simple_features/simple_features_generator.h", ], deps = [ - ":model_settings", + ":simple_model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", ], ) tflite_micro_cc_test( - name = "preprocessor_fixed_test", + name = "simple_features_generator_fixed_test", srcs = [ - "preprocessor_test.cc", + "simple_features/simple_features_generator_test.cc", ], deps = [ - ":model_settings", - ":preprocessor_fixed", - ":preprocessor_test_data", + ":audio_sample_test_data", + ":simple_features_generator_fixed", + ":simple_features_generator_test_data", + ":simple_model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/experimental/micro/testing:micro_test", @@ -144,9 +166,25 @@ cc_library( "audio_provider.h", ], deps = [ - ":model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", + ], +) + +cc_library( + name = "audio_provider_mock", + srcs = [ + "audio_provider_mock.cc", + ], + hdrs = [ + "audio_provider.h", + ], + deps = [ + ":audio_large_sample_test_data", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", ], ) @@ -157,9 +195,24 @@ tflite_micro_cc_test( ], deps = [ ":audio_provider", - ":model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "audio_provider_mock_test", + srcs = [ + "audio_provider_mock_test.cc", + ], + deps = [ + ":audio_large_sample_test_data", + ":audio_provider_mock", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) @@ -174,10 +227,10 @@ cc_library( ], deps = [ ":audio_provider", - ":model_settings", - ":preprocessor_reference", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_features_generator", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", ], ) @@ -189,9 +242,41 @@ tflite_micro_cc_test( deps = [ ":audio_provider", ":feature_provider", - ":model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +cc_library( + name = "feature_provider_mock", + srcs = [ + "feature_provider.cc", + ], + hdrs = [ + "feature_provider.h", + ], + deps = [ + ":audio_provider_mock", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_features_generator", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", + ], +) + +tflite_micro_cc_test( + name = "feature_provider_mock_test", + srcs = [ + "feature_provider_mock_test.cc", + ], + deps = [ + ":feature_provider_mock", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_features_test_data", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) @@ -205,9 +290,9 @@ cc_library( "recognize_commands.h", ], deps = [ - ":model_settings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", ], ) @@ -227,6 +312,33 @@ tflite_micro_cc_test( ], ) +cc_library( + name = "command_responder", + srcs = [ + "command_responder.cc", + ], + hdrs = [ + "command_responder.h", + ], + deps = [ + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +tflite_micro_cc_test( + name = "command_responder_test", + srcs = [ + "command_responder_test.cc", + ], + deps = [ + ":command_responder", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + cc_binary( name = "micro_speech", srcs = [ @@ -234,16 +346,33 @@ cc_binary( ], deps = [ ":audio_provider", + ":command_responder", ":feature_provider", - ":features_test_data", - ":model_settings", - ":preprocessor_reference", ":recognize_commands", - ":tiny_conv_model_data", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:tiny_conv_micro_features_model_data", + "//tensorflow/lite/experimental/micro/kernels:all_ops_resolver", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_binary( + name = "micro_speech_mock", + srcs = [ + "main.cc", + ], + deps = [ + ":audio_provider_mock", + ":command_responder", + ":feature_provider", + ":recognize_commands", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:micro_model_settings", + "//tensorflow/lite/experimental/micro/examples/micro_speech/micro_features:tiny_conv_micro_features_model_data", "//tensorflow/lite/experimental/micro/kernels:all_ops_resolver", - "//tensorflow/lite/experimental/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", ], ) diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc b/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc index 3d560510ad140ff0bba84ebcf790a0fda90e72fa..22134152afb27b558b14589a0f9a22ae117a93e4 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc @@ -5,6 +5,11 @@ ifneq ($(filter CMSIS,$(ALL_TAGS)),) -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Include/ \ -I$(MAKEFILE_DIR)/downloads/CMSIS_ext/ + GENERATED_PROJECT_INCLUDES += \ + -isystemthird_party/cmsis/CMSIS/Core/Include/ \ + -isystemthird_party/cmsis/CMSIS/DSP/Include/ \ + -Ithird_party/CMSIS_ext/ + CMSIS_PREPROCESSOR_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/hanning.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/sin_1k.cc \ @@ -12,6 +17,7 @@ ifneq ($(filter CMSIS,$(ALL_TAGS)),) CMSIS_PREPROCESSOR_HDRS := \ tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/hanning.h \ tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/sin_1k.h \ + third_party/CMSIS_ext/README.md \ third_party/CMSIS_ext/arm_cmplx_mag_squared_q10p6.h PREPROCESSOR_TEST_SRCS += $(CMSIS_PREPROCESSOR_SRCS) @@ -20,24 +26,34 @@ ifneq ($(filter CMSIS,$(ALL_TAGS)),) FEATURE_PROVIDER_TEST_SRCS += $(CMSIS_PREPROCESSOR_SRCS) FEATURE_PROVIDER_TEST_HDRS += $(CMSIS_PREPROCESSOR_HDRS) + SIMPLE_FEATURES_GENERATOR_TEST_SRCS += $(CMSIS_PREPROCESSOR_SRCS) + SIMPLE_FEATURES_GENERATOR_TEST_HDRS += $(CMSIS_PREPROCESSOR_HDRS) + MICRO_SPEECH_SRCS += $(CMSIS_PREPROCESSOR_SRCS) MICRO_SPEECH_HDRS += $(CMSIS_PREPROCESSOR_HDRS) THIRD_PARTY_CC_SRCS += \ - third_party/CMSIS_ext/arm_cmplx_mag_squared_q10p6.c \ - third_party/cmsis/CMSIS/DSP/Source/BasicMathFunctions/arm_mult_q15.c \ - third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_init_q15.c \ - third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_q15.c \ - third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_q15.c \ - third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_radix4_q15.c \ - third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_bitreversal2.S \ - third_party/cmsis/CMSIS/DSP/Source/CommonTables/arm_const_structs.c \ - third_party/cmsis/CMSIS/DSP/Source/CommonTables/arm_common_tables.c \ - third_party/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_mean_q15.c \ - third_party/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_max_q7.c + $(MAKEFILE_DIR)/downloads/CMSIS_ext/arm_cmplx_mag_squared_q10p6.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/BasicMathFunctions/arm_mult_q15.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_bitreversal.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_init_q15.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_q15.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_q15.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_radix4_q15.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_bitreversal2.S \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/CommonTables/arm_const_structs.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/CommonTables/arm_common_tables.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_mean_q15.c \ + $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_max_q7.c THIRD_PARTY_CC_HDRS += \ + third_party/cmsis/LICENSE.txt \ + third_party/cmsis/CMSIS/Core/Include/cmsis_compiler.h \ + third_party/cmsis/CMSIS/Core/Include/cmsis_gcc.h \ + third_party/cmsis/CMSIS/Core/Include/cmsis_version.h \ + third_party/cmsis/CMSIS/Core/Include/core_cm3.h \ third_party/cmsis/CMSIS/DSP/Include/arm_common_tables.h \ - third_party/cmsis/CMSIS/DSP/Include/arm_const_structs.h + third_party/cmsis/CMSIS/DSP/Include/arm_const_structs.h \ + third_party/cmsis/CMSIS/DSP/Include/arm_math.h endif diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc b/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc index 49aace3d7d05ba1d7010d3d834c66dc13e488c96..6015d0d63614d9deeb42d78e3f954c4403c457c2 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc @@ -1,106 +1,373 @@ +INCLUDES += \ + -I$(MAKEFILE_DIR)/downloads/kissfft + +GENERATED_PROJECT_INCLUDES += \ +-I./third_party/kissfft + +PROJECT_INCLUDES += \ +third_party/kissfft + +KISSFFT_LIB_SRCS := \ +$(MAKEFILE_DIR)/downloads/kissfft/kiss_fft.c \ +$(MAKEFILE_DIR)/downloads/kissfft/tools/kiss_fftr.c + +KISSFFT_LIB_HDRS := \ +$(MAKEFILE_DIR)/downloads/kissfft/COPYING \ +$(MAKEFILE_DIR)/downloads/kissfft/kiss_fft.h \ +$(MAKEFILE_DIR)/downloads/kissfft/_kiss_fft_guts.h \ +$(MAKEFILE_DIR)/downloads/kissfft/tools/kiss_fftr.h + +THIRD_PARTY_CC_HDRS += \ +third_party/kissfft/COPYING \ +third_party/kissfft/kiss_fft.h \ +third_party/kissfft/_kiss_fft_guts.h \ +third_party/kissfft/tools/kiss_fftr.h + MICRO_SPEECH_TEST_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc MICRO_SPEECH_TEST_HDRS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h \ -PREPROCESSOR_TEST_SRCS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc \ +SIMPLE_FEATURES_GENERATOR_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator_test.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.cc -PREPROCESSOR_TEST_HDRS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h \ +SIMPLE_FEATURES_GENERATOR_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h \ tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h \ tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.h + +MICRO_FEATURES_LIB_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.cc \ +$(KISSFFT_LIB_SRCS) + +MICRO_FEATURES_LIB_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h \ +$(KISSFFT_LIB_HDRS) + +MICRO_FEATURES_FFT_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.cc \ +$(KISSFFT_LIB_SRCS) + +MICRO_FEATURES_FFT_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h \ +$(KISSFFT_LIB_HDRS) + +MICRO_FEATURES_FILTERBANK_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.cc + +MICRO_FEATURES_FILTERBANK_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h + +MICRO_FEATURES_FRONTEND_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_test.cc \ +$(MICRO_FEATURES_LIB_SRCS) + +MICRO_FEATURES_FRONTEND_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h \ +$(MICRO_FEATURES_LIB_HDRS) + +MICRO_FEATURES_LOG_SCALE_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.cc + +MICRO_FEATURES_LOG_SCALE_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h + +MICRO_FEATURES_NOISE_REDUCTION_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.cc + +MICRO_FEATURES_NOISE_REDUCTION_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h + +MICRO_FEATURES_PCAN_GAIN_CONTROL_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.cc + +MICRO_FEATURES_PCAN_GAIN_CONTROL_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h + +MICRO_FEATURES_WINDOW_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.cc + +MICRO_FEATURES_WINDOW_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h + +MICRO_FEATURES_GENERATOR_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc \ +$(MICRO_FEATURES_LIB_SRCS) + +MICRO_FEATURES_GENERATOR_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h \ +$(MICRO_FEATURES_LIB_HDRS) + +MICRO_FEATURES_GENERATOR_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.cc \ +$(MICRO_FEATURES_GENERATOR_SRCS) + +MICRO_FEATURES_GENERATOR_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h \ +$(MICRO_FEATURES_GENERATOR_HDRS) AUDIO_PROVIDER_TEST_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc AUDIO_PROVIDER_TEST_HDRS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h \ + +AUDIO_PROVIDER_MOCK_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock.cc + +AUDIO_PROVIDER_MOCK_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h \ FEATURE_PROVIDER_TEST_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc \ +$(MICRO_FEATURES_GENERATOR_SRCS) FEATURE_PROVIDER_TEST_HDRS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h \ +$(MICRO_FEATURES_GENERATOR_HDRS) + +FEATURE_PROVIDER_MOCK_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc \ +$(MICRO_FEATURES_GENERATOR_SRCS) + +FEATURE_PROVIDER_MOCK_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h \ +$(MICRO_FEATURES_GENERATOR_HDRS) RECOGNIZE_COMMANDS_TEST_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc RECOGNIZE_COMMANDS_TEST_HDRS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h \ tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h +COMMAND_RESPONDER_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.cc + +COMMAND_RESPONDER_TEST_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h + MICRO_SPEECH_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.cc \ +$(MICRO_FEATURES_GENERATOR_SRCS) MICRO_SPEECH_HDRS := \ -tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h \ tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h \ -tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h \ +$(MICRO_FEATURES_GENERATOR_HDRS) + +MICRO_SPEECH_MOCK_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.cc \ +$(MICRO_FEATURES_GENERATOR_SRCS) + +MICRO_SPEECH_MOCK_HDRS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h \ +tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h \ +$(MICRO_FEATURES_GENERATOR_HDRS) # Find any platform-specific rules for this example. include $(wildcard tensorflow/lite/experimental/micro/examples/micro_speech/*/Makefile.inc) +$(eval $(call microlite_test,micro_features_fft_test,\ +$(MICRO_FEATURES_FFT_TEST_SRCS),$(MICRO_FEATURES_FFT_TEST_HDRS))) + +$(eval $(call microlite_test,micro_features_filterbank_test,\ +$(MICRO_FEATURES_FILTERBANK_TEST_SRCS),$(MICRO_FEATURES_FILTERBANK_TEST_HDRS))) + +$(eval $(call microlite_test,micro_features_frontend_test,\ +$(MICRO_FEATURES_FRONTEND_TEST_SRCS),$(MICRO_FEATURES_FRONTEND_TEST_HDRS))) + +$(eval $(call microlite_test,micro_features_log_scale_test,\ +$(MICRO_FEATURES_LOG_SCALE_TEST_SRCS),$(MICRO_FEATURES_LOG_SCALE_TEST_HDRS))) + +$(eval $(call microlite_test,micro_features_noise_reduction_test,\ +$(MICRO_FEATURES_NOISE_REDUCTION_TEST_SRCS),$(MICRO_FEATURES_NOISE_REDUCTION_TEST_HDRS))) + +$(eval $(call microlite_test,micro_features_pcan_gain_control_test,\ +$(MICRO_FEATURES_PCAN_GAIN_CONTROL_TEST_SRCS),$(MICRO_FEATURES_PCAN_GAIN_CONTROL_TEST_HDRS))) + +$(eval $(call microlite_test,micro_features_window_test,\ +$(MICRO_FEATURES_WINDOW_TEST_SRCS),$(MICRO_FEATURES_WINDOW_TEST_HDRS))) + +# Test the code for feature generation. +$(eval $(call microlite_test,micro_features_generator_test,\ +$(MICRO_FEATURES_GENERATOR_TEST_SRCS), $(MICRO_FEATURES_GENERATOR_TEST_HDRS))) + # Tests loading and running a speech model. $(eval $(call microlite_test,micro_speech_test,\ $(MICRO_SPEECH_TEST_SRCS),$(MICRO_SPEECH_TEST_HDRS))) # Test the code for feature generation. -$(eval $(call microlite_test,preprocessor_test,\ -$(PREPROCESSOR_TEST_SRCS), $(PREPROCESSOR_TEST_HDRS))) +$(eval $(call microlite_test,simple_features_generator_test,\ +$(SIMPLE_FEATURES_GENERATOR_TEST_SRCS), $(SIMPLE_FEATURES_GENERATOR_TEST_HDRS))) # Tests the audio provider module. $(eval $(call microlite_test,audio_provider_test,\ $(AUDIO_PROVIDER_TEST_SRCS),$(AUDIO_PROVIDER_TEST_HDRS))) +# Tests the audio provider mock module. +$(eval $(call microlite_test,audio_provider_mock_test,\ +$(AUDIO_PROVIDER_MOCK_TEST_SRCS),$(AUDIO_PROVIDER_MOCK_TEST_HDRS))) + # Tests the feature provider module. $(eval $(call microlite_test,feature_provider_test,\ $(FEATURE_PROVIDER_TEST_SRCS),$(FEATURE_PROVIDER_TEST_HDRS))) -# Tests the feature provider module. +# Tests the feature provider module using the mock audio provider. +$(eval $(call microlite_test,feature_provider_mock_test,\ +$(FEATURE_PROVIDER_MOCK_TEST_SRCS),$(FEATURE_PROVIDER_MOCK_TEST_HDRS))) + +# Tests the command recognizer module. $(eval $(call microlite_test,recognize_commands_test,\ $(RECOGNIZE_COMMANDS_TEST_SRCS),$(RECOGNIZE_COMMANDS_TEST_HDRS))) +# Tests responding to a command. +$(eval $(call microlite_test,command_responder_test,\ +$(COMMAND_RESPONDER_TEST_SRCS),$(COMMAND_RESPONDER_TEST_HDRS))) + # Builds a standalone speech command recognizer binary. $(eval $(call microlite_test,micro_speech,\ $(MICRO_SPEECH_SRCS),$(MICRO_SPEECH_HDRS))) + +# Builds a standalone speech command recognizer binary using fake audio input. +$(eval $(call microlite_test,micro_speech_mock,\ +$(MICRO_SPEECH_MOCK_SRCS),$(MICRO_SPEECH_MOCK_HDRS))) diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/README.md b/tensorflow/lite/experimental/micro/examples/micro_speech/README.md index 500eed33bab0187f9b2cf9647c046f4a541b9e2c..4a88de1b5becb2037a7f1454ad22b39a6388e7c3 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/README.md +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/README.md @@ -31,6 +31,9 @@ To understand how TensorFlow Lite does this, you can look at the `TestInvoke()` Once you have downloaded the dependencies and got the x86/Linux build working, you can try building a version for the STM32F103 'bluepill' device. The following command will build the test and then run it on an emulator, assuming you have Docker installed: +*On Mac OS you need to have ARM compiler installed, one way of doing so is with +brew: brew install caskroom/cask/gcc-arm-embedded* + ``` make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test_micro_speech ``` @@ -83,7 +86,8 @@ If you see a compiling error on older machines, try leaving out the `--copt` arg ``` bazel run tensorflow/examples/speech_commands:freeze -- \ --model_architecture=tiny_conv --window_stride=20 --preprocess=average \ ---wanted_words="yes,no" --quantize=1 --output_file=/tmp/tiny_conv.pb +--wanted_words="yes,no" --quantize=1 --output_file=/tmp/tiny_conv.pb \ +--start_checkpoint=/tmp/speech_commands_train/tiny_conv.ckpt-18000 ``` The next step is to create a TensorFlow Lite file from the frozen graph: @@ -99,5 +103,59 @@ bazel run tensorflow/lite/toco:toco -- \ Finally, convert the file into a C source file that can be compiled into an embedded system: ``` -xxd -i /tmp/tiny_conv.tflite > /tmp/tiny_conv_model_data.cc +xxd -i /tmp/tiny_conv.tflite > /tmp/tiny_conv_simple_features_model_data.cc +``` + +### Creating Your Own Model With Google Cloud + +If want to train your model in Google Cloud you can do so by using +pre-configured Deep Learning images. + +First create the VM: + +``` +export IMAGE_FAMILY="tf-latest-cpu" +export ZONE="us-west1-b" # Or any other required region +export INSTANCE_NAME="model-trainer" +export INSTANCE_TYPE="n1-standard-8" # or any other instance type +gcloud compute instances create $INSTANCE_NAME \ + --zone=$ZONE \ + --image-family=$IMAGE_FAMILY \ + --image-project=deeplearning-platform-release \ + --machine-type=$INSTANCE_TYPE \ + --boot-disk-size=120GB \ + --min-cpu-platform=Intel\ Skylake +``` + +As soon as instance has been created you can SSH to it(as a jupyter user!): + +``` +gcloud compute ssh "jupyter@${INSTANCE_NAME}" +``` + +now install Bazel: + +``` +wget https://github.com/bazelbuild/bazel/releases/download/0.15.0/bazel-0.15.0-installer-linux-x86_64.sh +sudo bash ./bazel-0.15.0-installer-linux-x86_64.sh +source /usr/local/lib/bazel/bin/bazel-complete.bash +sudo ln /usr/local/bin/bazel /usr/bin/bazel +``` + +and finally run the build: + +``` +# TensorFlow already pre-baked on the image +cd src/tensorflow +bazel run -c opt --copt=-mavx2 --copt=-mfma \ +tensorflow/examples/speech_commands:train -- \ +--model_architecture=tiny_conv --window_stride=20 --preprocess=average \ +--wanted_words="yes,no" --silence_percentage=25 --unknown_percentage=25 --quantize=1 +``` + +After build is over follow the rest of the instrucitons from this tutorial. And +finally do not forget to remove the instance when training is done: + +``` +gcloud compute instances delete "${INSTANCE_NAME}" --zone="${ZONE}" ``` diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/Makefile.inc b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/Makefile.inc index 0aa362be0038f8757387a6311021e183dc19dabd..c83090344ba0d82e9f774897577b1eb924e92329 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/Makefile.inc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/Makefile.inc @@ -5,7 +5,7 @@ ifeq ($(TARGET), apollo3evb) $(AP3_MICRO_DIR)/../preprocessor.cc \ $(AP3_MICRO_DIR)/pushbutton_main.c \ $(AP3_MICRO_DIR)/pushbutton_test.cc \ - $(AP3_MICRO_DIR)/../tiny_conv_model_data.cc \ + $(AP3_MICRO_DIR)/../simple_features/tiny_conv_simple_features_model_data.cc \ $(APOLLO3_SDK)/devices/am_devices_led.c ALL_SRCS += $(PUSHBUTTON_MICRO_SPEECH_TEST_SRCS) PUSHBUTTON_MICRO_SPEECH_TEST_OBJS := $(addprefix $(OBJDIR), \ @@ -24,8 +24,8 @@ ifeq ($(TARGET), apollo3evb) PUSHBUTTON_CMSIS_SPEECH_TEST_SRCS := \ $(AP3_MICRO_DIR)/pushbutton_main.c \ $(AP3_MICRO_DIR)/pushbutton_test.cc \ - $(AP3_MICRO_DIR)/../tiny_conv_model_data.cc \ - $(CMSIS_DIR)/preprocessor.cc \ + $(AP3_MICRO_DIR)/../simple_features/tiny_conv_simple_features_model_data.cc \ + $(CMSIS_DIR)/simple_features_generator.cc \ $(CMSIS_EXT_DIR)/arm_cmplx_mag_squared_q10p6.c \ $(CMSIS_DIR)/hanning.c \ $(APOLLO3_SDK)/devices/am_devices_led.c \ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/captured_data_to_wav.py b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/captured_data_to_wav.py index 10a05b6dcf1bbd5c779f7ee7bdf4d01ebde76017..52604f5f2feaaf535a7ed9dc24020e48f4982308 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/captured_data_to_wav.py +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/captured_data_to_wav.py @@ -28,7 +28,7 @@ def new_data_to_array(fn): vals = [] with open(fn) as f: for n, line in enumerate(f): - if n is not 0: + if n != 0: vals.extend([int(v, 16) for v in line.split()]) b = ''.join(map(chr, vals)) y = struct.unpack('<' + 'h' * int(len(b) / 2), b) diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/compare_1k.py b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/compare_1k.py index 52352bad94a1e5627a9ca35d07a5082b6d79e6a6..fab178b3176cb680d739b53b0f6e3171e32ef721 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/compare_1k.py +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/compare_1k.py @@ -30,7 +30,7 @@ def new_data_to_array(fn, datatype='int16'): vals = [] with open(fn) as f: for n, line in enumerate(f): - if n is not 0: + if n != 0: vals.extend([int(v, 16) for v in line.split()]) b = ''.join(map(chr, vals)) diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_main.c b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_main.c index afee38343b3fac81de945dcd01b53ad35e8be270..4f70d47c3ea9b6f7df884ceabeca245a2a5e55ce 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_main.c +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_main.c @@ -157,7 +157,7 @@ void pdm_data_get(void) { // PDM interrupt handler. // //***************************************************************************** -void am_pdm_isr(void) { +void am_pdm0_isr(void) { uint32_t ui32Status; // // Read the interrupt status. diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_test.cc index 95043f857b34b953c91a762bc1a54e9489431bff..d4583dbf4a6dcb083e4d9cd2818e63a116debd7f 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/apollo3/pushbutton_test.cc @@ -17,8 +17,8 @@ limitations under the License. * micro_speech_test.cc */ #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/micro_interpreter.h" @@ -32,20 +32,36 @@ uint8_t g_unknown_score = 0; uint8_t g_yes_score = 0; uint8_t g_no_score = 0; +namespace { + +TfLiteStatus GenerateSimpleFeatures_1sec(tflite::ErrorReporter* error_reporter, + const int16_t* input, + uint8_t* output) { + int i; + for (i = 0; i < 49; i++) { + GenerateSimpleFeatures(error_reporter, input + i * 320, 480, 43, + output + i * 43); + } + return kTfLiteOk; +} + +} // namespace + TF_LITE_MICRO_TESTS_BEGIN -TF_LITE_MICRO_TEST(TestPreprocessor) { +TF_LITE_MICRO_TEST(TestSimpleFeaturesGenerator) { tflite::MicroErrorReporter micro_error_reporter; tflite::ErrorReporter* error_reporter = µ_error_reporter; uint8_t preprocessed_data[43 * 49]; - TfLiteStatus preprocess_1sec_status = - Preprocess_1sec(error_reporter, captured_data, preprocessed_data); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, preprocess_1sec_status); + TfLiteStatus generate_1sec_status = GenerateSimpleFeatures_1sec( + error_reporter, captured_data, preprocessed_data); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, generate_1sec_status); // Map the model into a usable data structure. This doesn't involve any // copying or parsing, it's a very lightweight operation. - const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data); + const tflite::Model* model = + ::tflite::GetModel(g_tiny_conv_simple_features_model_data); if (model->version() != TFLITE_SCHEMA_VERSION) { error_reporter->Report( "Model provided is schema version %d not equal " diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc index 52db18e6868371afc0b7cd39f6f41d0d60b91689..08811c83b437e66bf1e77a1a1f32d1cb5be02c43 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" namespace { int16_t g_dummy_audio_data[kMaxAudioSampleSize]; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c9792510b055e243ab4f6e804717647afa0b418 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h" + +namespace { +int16_t g_dummy_audio_data[kMaxAudioSampleSize]; +int32_t g_latest_audio_timestamp = 0; +} // namespace + +TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter, + int start_ms, int duration_ms, + int* audio_samples_size, int16_t** audio_samples) { + const int yes_start = (0 * kAudioSampleFrequency) / 1000; + const int yes_end = (1000 * kAudioSampleFrequency) / 1000; + const int no_start = (4000 * kAudioSampleFrequency) / 1000; + const int no_end = (5000 * kAudioSampleFrequency) / 1000; + const int wraparound = (8000 * kAudioSampleFrequency) / 1000; + const int start_sample = (start_ms * kAudioSampleFrequency) / 1000; + for (int i = 0; i < kMaxAudioSampleSize; ++i) { + const int sample_index = (start_sample + i) % wraparound; + int16_t sample; + if ((sample_index >= yes_start) && (sample_index < yes_end)) { + sample = g_yes_1000ms_sample_data[sample_index - yes_start]; + } else if ((sample_index >= no_start) && (sample_index < no_end)) { + sample = g_no_1000ms_sample_data[sample_index - no_start]; + } else { + sample = 0; + } + g_dummy_audio_data[i] = sample; + } + *audio_samples_size = kMaxAudioSampleSize; + *audio_samples = g_dummy_audio_data; + return kTfLiteOk; +} + +int32_t LatestAudioTimestamp() { + g_latest_audio_timestamp += 100; + return g_latest_audio_timestamp; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b73d436ad638277d3f052715c506668e7f163f17 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_mock_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestAudioProviderMock) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + int audio_samples_size = 0; + int16_t* audio_samples = nullptr; + TfLiteStatus get_status = + GetAudioSamples(error_reporter, 0, kFeatureSliceDurationMs, + &audio_samples_size, &audio_samples); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, get_status); + TF_LITE_MICRO_EXPECT_LE(audio_samples_size, kMaxAudioSampleSize); + TF_LITE_MICRO_EXPECT_NE(audio_samples, nullptr); + for (int i = 0; i < audio_samples_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_yes_1000ms_sample_data[i], audio_samples[i]); + } + + get_status = GetAudioSamples(error_reporter, 500, kFeatureSliceDurationMs, + &audio_samples_size, &audio_samples); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, get_status); + TF_LITE_MICRO_EXPECT_LE(audio_samples_size, kMaxAudioSampleSize); + TF_LITE_MICRO_EXPECT_NE(audio_samples, nullptr); + for (int i = 0; i < audio_samples_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_yes_1000ms_sample_data[i + 8000], + audio_samples[i]); + } + + get_status = GetAudioSamples(error_reporter, 1500, kFeatureSliceDurationMs, + &audio_samples_size, &audio_samples); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, get_status); + TF_LITE_MICRO_EXPECT_LE(audio_samples_size, kMaxAudioSampleSize); + TF_LITE_MICRO_EXPECT_NE(audio_samples, nullptr); + for (int i = 0; i < audio_samples_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(0, audio_samples[i]); + } + + get_status = GetAudioSamples(error_reporter, 12250, kFeatureSliceDurationMs, + &audio_samples_size, &audio_samples); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, get_status); + TF_LITE_MICRO_EXPECT_LE(audio_samples_size, kMaxAudioSampleSize); + TF_LITE_MICRO_EXPECT_NE(audio_samples, nullptr); + for (int i = 0; i < audio_samples_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_no_1000ms_sample_data[i + 4000], + audio_samples[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc index 85fbbb80a6c5b330230c1d1d0186de795edc4754..f9212aa3491e99104c2a3f1f5e315e9e96481345 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.cc new file mode 100644 index 0000000000000000000000000000000000000000..afff5109d9d1252481304a895045b450b9be85bb --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.cc @@ -0,0 +1,28 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h" + +// The default implementation writes out the name of the recognized command +// to the error console. Real applications will want to take some custom +// action instead, and should implement their own versions of this function. +void RespondToCommand(tflite::ErrorReporter* error_reporter, + int32_t current_time, const char* found_command, + uint8_t score, bool is_new_command) { + if (is_new_command) { + error_reporter->Report("Heard %s (%d) @%dms", found_command, score, + current_time); + } +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h b/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h new file mode 100644 index 0000000000000000000000000000000000000000..3d91209de0aa035738cee8b52afe46f348c796bd --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides an interface to take an action based on an audio command. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_COMMAND_RESPONDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_COMMAND_RESPONDER_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +// Called every time the results of an audio recognition run are available. The +// human-readable name of any recognized command is in the `found_command` +// argument, `score` has the numerical confidence, and `is_new_command` is set +// if the previous command was different to this one. +void RespondToCommand(tflite::ErrorReporter* error_reporter, + int32_t current_time, const char* found_command, + uint8_t score, bool is_new_command); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_COMMAND_RESPONDER_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8acf4552f59b0b6dffa3023cc86b44d422a28fd5 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/command_responder_test.cc @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestCallability) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + // This will have external side-effects (like printing to the debug console + // or lighting an LED) that are hard to observe, so the most we can do is + // make sure the call doesn't crash. + RespondToCommand(error_reporter, 0, "foo", 0, true); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/disco_f746ng/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/disco_f746ng/audio_provider.cc index 06647d0c536564c26d72cb73396ca36efb3aeb25..49fea826759956d479e9171e2ba7a41331e31023 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/disco_f746ng/audio_provider.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/disco_f746ng/audio_provider.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" #include "AUDIO_DISCO_F746NG.h" #include "SDRAM_DISCO_F746NG.h" diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc index 7f9ece41dd3f013ae328ffd1bdc98f197855a131..b5dfa3d944076a21cde2dfafc6ce1ed39f15164d 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data) : feature_size_(feature_size), @@ -48,6 +48,10 @@ TfLiteStatus FeatureProvider::PopulateFeatureData( int slices_needed = current_step - last_step; // If this is the first call, make sure we don't use any cached information. if (is_first_run_) { + TfLiteStatus init_status = InitializeMicroFeatures(error_reporter); + if (init_status != kTfLiteOk) { + return init_status; + } is_first_run_ = false; slices_needed = kFeatureSliceCount; } @@ -94,16 +98,17 @@ TfLiteStatus FeatureProvider::PopulateFeatureData( GetAudioSamples(error_reporter, slice_start_ms, kFeatureSliceDurationMs, &audio_samples_size, &audio_samples); if (audio_samples_size < kMaxAudioSampleSize) { - error_reporter->Report("Audio data size %d too small, want %d", + error_reporter->Report("Audio data size %d too small, want %d", audio_samples_size, kMaxAudioSampleSize); return kTfLiteError; } uint8_t* new_slice_data = feature_data_ + (new_slice * kFeatureSliceSize); - TfLiteStatus preprocess_status = - Preprocess(error_reporter, audio_samples, audio_samples_size, - kFeatureSliceSize, new_slice_data); - if (preprocess_status != kTfLiteOk) { - return preprocess_status; + size_t num_samples_read; + TfLiteStatus generate_status = GenerateMicroFeatures( + error_reporter, audio_samples, audio_samples_size, kFeatureSliceSize, + new_slice_data, &num_samples_read); + if (generate_status != kTfLiteOk) { + return generate_status; } } } diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_mock_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_mock_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b05912e26f8a7f5e89b9f45766adf4270c033ed5 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_mock_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestFeatureProviderMockYes) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + uint8_t feature_data[kFeatureElementCount]; + FeatureProvider feature_provider(kFeatureElementCount, feature_data); + + int how_many_new_slices = 0; + TfLiteStatus populate_status = feature_provider.PopulateFeatureData( + error_reporter, /* last_time_in_ms= */ 0, /* time_in_ms= */ 970, + &how_many_new_slices); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status); + TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices); + + for (int i = 0; i < kFeatureElementCount; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_yes_micro_f2e59fea_nohash_1_data[i], + feature_data[i]); + } +} + +TF_LITE_MICRO_TEST(TestFeatureProviderMockNo) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + uint8_t feature_data[kFeatureElementCount]; + FeatureProvider feature_provider(kFeatureElementCount, feature_data); + + int how_many_new_slices = 0; + TfLiteStatus populate_status = feature_provider.PopulateFeatureData( + error_reporter, /* last_time_in_ms= */ 4000, /* time_in_ms= */ 4970, + &how_many_new_slices); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status); + TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices); + + for (int i = 0; i < kFeatureElementCount; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_no_micro_f9643d42_nohash_4_data[i], + feature_data[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc index 556cbfe799bd9adf2df8f584a4f10b4a1c834bd4..e7655a3be53ae6a032195dd4ca991f740bb19537 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h" #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc index 3a9a5a4df1bf8239950dd2c79a1048706004e1f5..1bf96bc5a7e04bd05b85dd286bf2e9d06774e224 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/micro_interpreter.h" @@ -31,7 +32,8 @@ int main(int argc, char* argv[]) { // Map the model into a usable data structure. This doesn't involve any // copying or parsing, it's a very lightweight operation. - const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data); + const tflite::Model* model = + ::tflite::GetModel(g_tiny_conv_micro_features_model_data); if (model->version() != TFLITE_SCHEMA_VERSION) { error_reporter->Report( "Model provided is schema version %d not equal " @@ -122,9 +124,11 @@ int main(int argc, char* argv[]) { "RecognizeCommands::ProcessLatestResults() failed"); return 1; } - if (is_new_command) { - error_reporter->Report("Heard %s (%d)", found_command, score); - } + // Do something based on the recognized command. The default implementation + // just prints to the error console, but you should replace this with your + // own function for a real application. + RespondToCommand(error_reporter, current_time, found_command, score, + is_new_command); } return 0; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/BUILD b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1e684e1efd0bfbc676635e8c3233ef6284e6954d --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/BUILD @@ -0,0 +1,300 @@ +# Library for generating feature vectors from audio data + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +cc_library( + name = "micro_model_settings", + srcs = [ + "micro_model_settings.cc", + ], + hdrs = [ + "micro_model_settings.h", + ], +) + +cc_library( + name = "tiny_conv_micro_features_model_data", + srcs = [ + "tiny_conv_micro_features_model_data.cc", + ], + hdrs = [ + "tiny_conv_micro_features_model_data.h", + ], +) + +cc_library( + name = "micro_features_test_data", + srcs = [ + "no_micro_features_data.cc", + "yes_micro_features_data.cc", + ], + hdrs = [ + "no_micro_features_data.h", + "yes_micro_features_data.h", + ], +) + +cc_library( + name = "bits", + hdrs = ["bits.h"], +) + +cc_library( + name = "static_alloc", + hdrs = ["static_alloc.h"], +) + +cc_library( + name = "fft", + srcs = [ + "fft.cc", + "fft_util.cc", + ], + hdrs = [ + "fft.h", + "fft_util.h", + ], + deps = [ + ":micro_model_settings", + ":static_alloc", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "@kissfft//:kiss_fftr_16", + ], +) + +cc_library( + name = "filterbank", + srcs = [ + "filterbank.cc", + "filterbank_util.cc", + ], + hdrs = [ + "filterbank.h", + "filterbank_util.h", + ], + deps = [ + ":bits", + ":fft", + ":micro_model_settings", + ":static_alloc", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "frontend", + srcs = [ + "frontend.cc", + "frontend_util.cc", + ], + hdrs = [ + "frontend.h", + "frontend_util.h", + ], + deps = [ + ":bits", + ":fft", + ":filterbank", + ":log_scale", + ":micro_model_settings", + ":noise_reduction", + ":pcan_gain_control", + ":window", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "log_scale", + srcs = [ + "log_lut.cc", + "log_scale.cc", + "log_scale_util.cc", + ], + hdrs = [ + "log_lut.h", + "log_scale.h", + "log_scale_util.h", + ], + deps = [ + ":bits", + ":micro_model_settings", + ":static_alloc", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "noise_reduction", + srcs = [ + "noise_reduction.cc", + "noise_reduction_util.cc", + ], + hdrs = [ + "noise_reduction.h", + "noise_reduction_util.h", + ], + deps = [ + ":micro_model_settings", + ":static_alloc", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "pcan_gain_control", + srcs = [ + "pcan_gain_control.cc", + "pcan_gain_control_util.cc", + ], + hdrs = [ + "pcan_gain_control.h", + "pcan_gain_control_util.h", + ], + deps = [ + ":bits", + ":micro_model_settings", + ":static_alloc", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "window", + srcs = [ + "window.cc", + "window_util.cc", + ], + hdrs = [ + "window.h", + "window_util.h", + ], + deps = [ + ":micro_model_settings", + ":static_alloc", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "micro_features_generator", + srcs = [ + "micro_features_generator.cc", + ], + hdrs = [ + "micro_features_generator.h", + ], + deps = [ + ":frontend", + ":micro_model_settings", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "micro_features_generator_test_data", + srcs = [ + "no_feature_data_slice.cc", + "yes_feature_data_slice.cc", + ], + hdrs = [ + "no_feature_data_slice.h", + "yes_feature_data_slice.h", + ], +) + +tflite_micro_cc_test( + name = "fft_test", + srcs = ["fft_test.cc"], + deps = [ + ":fft", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "filterbank_test", + srcs = ["filterbank_test.cc"], + deps = [ + ":filterbank", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "frontend_test", + srcs = ["frontend_test.cc"], + deps = [ + ":frontend", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "log_scale_test", + srcs = ["log_scale_test.cc"], + deps = [ + ":log_scale", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "noise_reduction_test", + srcs = ["noise_reduction_test.cc"], + deps = [ + ":noise_reduction", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "pcan_gain_control_test", + srcs = ["pcan_gain_control_test.cc"], + deps = [ + ":pcan_gain_control", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "window_test", + srcs = ["window_test.cc"], + deps = [ + ":window", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "micro_features_generator_test", + srcs = [ + "micro_features_generator_test.cc", + ], + deps = [ + ":micro_features_generator", + ":micro_features_generator_test_data", + ":micro_model_settings", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/examples/micro_speech:audio_sample_test_data", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h new file mode 100644 index 0000000000000000000000000000000000000000..3b19ee6f030ae9fa8a931c6693cfe490747e336a --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_BITS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_BITS_H_ + +#include + +static inline int CountLeadingZeros32Slow(uint64_t n) { + int zeroes = 28; + if (n >> 16) zeroes -= 16, n >>= 16; + if (n >> 8) zeroes -= 8, n >>= 8; + if (n >> 4) zeroes -= 4, n >>= 4; + return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[n] + zeroes; +} + +static inline int CountLeadingZeros32(uint32_t n) { +#if defined(_MSC_VER) + unsigned long result = 0; // NOLINT(runtime/int) + if (_BitScanReverse(&result, n)) { + return 31 - result; + } + return 32; +#elif defined(__GNUC__) + + // Handle 0 as a special case because __builtin_clz(0) is undefined. + if (n == 0) { + return 32; + } + return __builtin_clz(n); +#else + return CountLeadingZeros32Slow(n); +#endif +} + +static inline int MostSignificantBit32(uint32_t n) { + return 32 - CountLeadingZeros32(n); +} + +static inline int CountLeadingZeros64Slow(uint64_t n) { + int zeroes = 60; + if (n >> 32) zeroes -= 32, n >>= 32; + if (n >> 16) zeroes -= 16, n >>= 16; + if (n >> 8) zeroes -= 8, n >>= 8; + if (n >> 4) zeroes -= 4, n >>= 4; + return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[n] + zeroes; +} + +static inline int CountLeadingZeros64(uint64_t n) { +#if defined(_MSC_VER) && defined(_M_X64) + // MSVC does not have __buitin_clzll. Use _BitScanReverse64. + unsigned long result = 0; // NOLINT(runtime/int) + if (_BitScanReverse64(&result, n)) { + return 63 - result; + } + return 64; +#elif defined(_MSC_VER) + // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse + unsigned long result = 0; // NOLINT(runtime/int) + if ((n >> 32) && _BitScanReverse(&result, n >> 32)) { + return 31 - result; + } + if (_BitScanReverse(&result, n)) { + return 63 - result; + } + return 64; +#elif defined(__GNUC__) + + // Handle 0 as a special case because __builtin_clzll(0) is undefined. + if (n == 0) { + return 64; + } + return __builtin_clzll(n); +#else + return CountLeadingZeros64Slow(n); +#endif +} + +static inline int MostSignificantBit64(uint64_t n) { + return 64 - CountLeadingZeros64(n); +} + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_BITS_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde4e38740e65cf56cd179d577528263177a649e --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h" + +#include + +#define FIXED_POINT 16 +#include "kiss_fft.h" +// Internal test dependency placeholder1 +// Internal test dependency placeholder2 +#include "tools/kiss_fftr.h" +// Internal test dependency placeholder3 + +void FftCompute(struct FftState* state, const int16_t* input, + int input_scale_shift) { + const size_t input_size = state->input_size; + const size_t fft_size = state->fft_size; + + int16_t* fft_input = state->input; + // First, scale the input by the given shift. + int i; + for (i = 0; i < input_size; ++i) { + *fft_input++ = (*input++) << input_scale_shift; + } + // Zero out whatever else remains in the top part of the input. + for (; i < fft_size; ++i) { + *fft_input++ = 0; + } + + // Apply the FFT. + kiss_fftr(reinterpret_cast(state->scratch), state->input, + reinterpret_cast(state->output)); +} + +void FftInit(struct FftState* state) { + // All the initialization is done in FftPopulateState() +} + +void FftReset(struct FftState* state) { + memset(state->input, 0, state->fft_size * sizeof(*state->input)); + memset(state->output, 0, (state->fft_size / 2 + 1) * sizeof(*state->output)); +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h new file mode 100644 index 0000000000000000000000000000000000000000..d5d29f68a2cc5688f9644a2b556abb4787e3bb93 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FFT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FFT_H_ + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +struct complex_int16_t { + int16_t real; + int16_t imag; +}; + +struct FftState { + int16_t input[kMaxAudioSampleSize]; + struct complex_int16_t output[kMaxAudioSampleSize + 2]; + size_t fft_size; + size_t input_size; + // This magic number was derived from KissFFT's estimate of how much space it + // will need to process the particular lengths and datatypes we need to for + // these model settings. This size will need to be recalculated for different + // models, but you will see a runtime error if it's not large enough. + char scratch[2848]; + size_t scratch_size; +}; + +void FftCompute(struct FftState* state, const int16_t* input, + int input_scale_shift); + +void FftInit(struct FftState* state); + +void FftReset(struct FftState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FFT_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b89b01445a641c8152aaff8165495688ab6861b2 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int16_t kFakeWindow[] = { + 0, 1151, 0, -5944, 0, 13311, 0, -21448, 0, 28327, 0, -32256, 0, 32255, + 0, -28328, 0, 21447, 0, -13312, 0, 5943, 0, -1152, 0}; +const int kScaleShift = 0; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FftTest_CheckOutputValues) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + struct FftState state; + TF_LITE_MICRO_EXPECT(FftPopulateState( + error_reporter, &state, sizeof(kFakeWindow) / sizeof(kFakeWindow[0]))); + + FftInit(&state); + FftCompute(&state, kFakeWindow, kScaleShift); + + const struct complex_int16_t expected[] = { + {0, 0}, {-10, 9}, {-20, 0}, {-9, -10}, {0, 25}, {-119, 119}, + {-887, 0}, {3000, 3000}, {0, -6401}, {-3000, 3000}, {886, 0}, {118, 119}, + {0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}}; + TF_LITE_MICRO_EXPECT_EQ(state.fft_size / 2 + 1, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i <= state.fft_size / 2; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(state.output[i].real, expected[i].real, 2); + TF_LITE_MICRO_EXPECT_NEAR(state.output[i].imag, expected[i].imag, 2); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..ab742893197e6fda9ec2266e85997b555e0a4fc0 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h" + +#define FIXED_POINT 16 +#include "kiss_fft.h" +#include "tools/kiss_fftr.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h" + +int FftPopulateState(tflite::ErrorReporter* error_reporter, + struct FftState* state, size_t input_size) { + state->input_size = input_size; + state->fft_size = 1; + while (state->fft_size < state->input_size) { + state->fft_size <<= 1; + } + + STATIC_ALLOC_ENSURE_ARRAY_SIZE(state->input, + (state->fft_size * sizeof(*state->input))); + + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->output, ((state->fft_size / 2 + 1) * sizeof(*state->output) * 2)); + + // Ask kissfft how much memory it wants. + size_t scratch_size = 0; + kiss_fftr_cfg kfft_cfg = + kiss_fftr_alloc(state->fft_size, 0, nullptr, &scratch_size); + if (kfft_cfg != nullptr) { + error_reporter->Report("Kiss memory sizing failed."); + return 0; + } + STATIC_ALLOC_ENSURE_ARRAY_SIZE(state->scratch, scratch_size); + state->scratch_size = scratch_size; + // Let kissfft configure the scratch space we just allocated + kfft_cfg = kiss_fftr_alloc(state->fft_size, 0, state->scratch, &scratch_size); + if (reinterpret_cast(kfft_cfg) != state->scratch) { + error_reporter->Report("Kiss memory preallocation strategy failed."); + return 0; + } + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h new file mode 100644 index 0000000000000000000000000000000000000000..1dea097bc996e194cef7987431c67be3c976ed2b --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FFT_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FFT_UTIL_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +// Prepares and FFT for the given input size. +int FftPopulateState(tflite::ErrorReporter* error_reporter, + struct FftState* state, size_t input_size); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FFT_UTIL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.cc new file mode 100644 index 0000000000000000000000000000000000000000..67f69dd67581ff6c15063b2467810f6c212ed1e5 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h" + +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h" + +void FilterbankConvertFftComplexToEnergy(struct FilterbankState* state, + struct complex_int16_t* fft_output, + int32_t* energy) { + const int end_index = state->end_index; + int i; + energy += state->start_index; + fft_output += state->start_index; + for (i = state->start_index; i < end_index; ++i) { + const int32_t real = fft_output->real; + const int32_t imag = fft_output->imag; + fft_output++; + const uint32_t mag_squared = (real * real) + (imag * imag); + *energy++ = mag_squared; + } +} + +void FilterbankAccumulateChannels(struct FilterbankState* state, + const int32_t* energy) { + uint64_t* work = state->work; + uint64_t weight_accumulator = 0; + uint64_t unweight_accumulator = 0; + + const int16_t* channel_frequency_starts = state->channel_frequency_starts; + const int16_t* channel_weight_starts = state->channel_weight_starts; + const int16_t* channel_widths = state->channel_widths; + + int num_channels_plus_1 = state->num_channels + 1; + int i; + for (i = 0; i < num_channels_plus_1; ++i) { + const int32_t* magnitudes = energy + *channel_frequency_starts++; + const int16_t* weights = state->weights + *channel_weight_starts; + const int16_t* unweights = state->unweights + *channel_weight_starts++; + const int width = *channel_widths++; + int j; + for (j = 0; j < width; ++j) { + weight_accumulator += *weights++ * (static_cast(*magnitudes)); + unweight_accumulator += + *unweights++ * (static_cast(*magnitudes)); + ++magnitudes; + } + *work++ = weight_accumulator; + weight_accumulator = unweight_accumulator; + unweight_accumulator = 0; + } +} + +static uint16_t Sqrt32(uint32_t num) { + if (num == 0) { + return 0; + } + uint32_t res = 0; + int max_bit_number = 32 - MostSignificantBit32(num); + max_bit_number |= 1; + uint32_t bit = 1U << (31 - max_bit_number); + int iterations = (31 - max_bit_number) / 2 + 1; + while (iterations--) { + if (num >= res + bit) { + num -= res + bit; + res = (res >> 1U) + bit; + } else { + res >>= 1U; + } + bit >>= 2U; + } + // Do rounding - if we have the bits. + if (num > res && res != 0xFFFF) { + ++res; + } + return res; +} + +static uint32_t Sqrt64(uint64_t num) { + // Take a shortcut and just use 32 bit operations if the upper word is all + // clear. This will cause a slight off by one issue for numbers close to 2^32, + // but it probably isn't going to matter (and gives us a big performance win). + if ((num >> 32) == 0) { + return Sqrt32(static_cast(num)); + } + uint64_t res = 0; + int max_bit_number = 64 - MostSignificantBit64(num); + max_bit_number |= 1; + uint64_t bit = 1ULL << (63 - max_bit_number); + int iterations = (63 - max_bit_number) / 2 + 1; + while (iterations--) { + if (num >= res + bit) { + num -= res + bit; + res = (res >> 1U) + bit; + } else { + res >>= 1U; + } + bit >>= 2U; + } + // Do rounding - if we have the bits. + if (num > res && res != 0xFFFFFFFFLL) { + ++res; + } + return res; +} + +uint32_t* FilterbankSqrt(struct FilterbankState* state, int scale_down_shift) { + const int num_channels = state->num_channels; + const int64_t* work = reinterpret_cast(state->work + 1); + // Reuse the work buffer since we're fine clobbering it at this point to hold + // the output. + uint32_t* output = reinterpret_cast(state->work); + int i; + for (i = 0; i < num_channels; ++i) { + *output++ = Sqrt64(*work++) >> scale_down_shift; + } + return reinterpret_cast(state->work); +} + +void FilterbankReset(struct FilterbankState* state) { + memset(state->work, 0, (state->num_channels + 1) * sizeof(*state->work)); +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h new file mode 100644 index 0000000000000000000000000000000000000000..f7b479d4899a985482710dfcb12b908f50ec1690 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FILTERBANK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FILTERBANK_H_ + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +#define kFilterbankBits 12 + +struct FilterbankState { + int num_channels; + int start_index; + int end_index; + int16_t channel_frequency_starts[kFeatureSliceSize + 1]; + int16_t channel_weight_starts[kFeatureSliceSize + 1]; + int16_t channel_widths[kFeatureSliceSize + 1]; + int16_t weights[316]; + int16_t unweights[316]; + uint64_t work[kFeatureSliceSize + 1]; +}; + +// Converts the relevant complex values of an FFT output into energy (the +// square magnitude). +void FilterbankConvertFftComplexToEnergy(struct FilterbankState* state, + struct complex_int16_t* fft_output, + int32_t* energy); + +// Computes the mel-scale filterbank on the given energy array. Output is cached +// internally - to fetch it, you need to call FilterbankSqrt. +void FilterbankAccumulateChannels(struct FilterbankState* state, + const int32_t* energy); + +// Applies an integer square root to the 64 bit intermediate values of the +// filterbank, and returns a pointer to them. Memory will be invalidated the +// next time FilterbankAccumulateChannels is called. +uint32_t* FilterbankSqrt(struct FilterbankState* state, int scale_down_shift); + +void FilterbankReset(struct FilterbankState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FILTERBANK_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..682b216ace37fa0a809db3c06386b5b4b3ca94c8 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h" + +#include + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int kSampleRate = 1000; +const int kSpectrumSize = 17; +const int kStartIndex = 1; +const int kEndIndex = 15; +const int32_t kEnergy[] = {-1, 181, 400, 181, 625, 28322, + 786769, 18000000, 40972801, 18000000, 784996, 28085, + 625, 181, 361, -1, -1}; +const uint64_t kWork[] = {1835887, 61162970173, 258694800000}; +const int kScaleShift = 0; + +// Test filterbank generation using scaled-down defaults. +class FilterbankTestConfig { + public: + FilterbankTestConfig() { + config_.num_channels = 2; + config_.lower_band_limit = 8.0; + config_.upper_band_limit = 450.0; + } + + struct FilterbankConfig config_; +}; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FilterbankTest_CheckStartIndex) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + TF_LITE_MICRO_EXPECT_EQ(state.start_index, kStartIndex); +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckEndIndex) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + TF_LITE_MICRO_EXPECT_EQ(state.end_index, kEndIndex); +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckChannelFrequencyStarts) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 4, 8}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i <= state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.channel_frequency_starts[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckChannelWeightStarts) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 8, 16}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i <= state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.channel_weight_starts[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckChannelWidths) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {8, 8, 8}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i <= state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.channel_widths[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckWeights) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 3277, 2217, 1200, 222, 0, 0, 0, + 0, 3376, 2468, 1591, 744, 0, 0, 0, + 0, 4020, 3226, 2456, 1708, 983, 277, 0}; + TF_LITE_MICRO_EXPECT_EQ(state.channel_weight_starts[state.num_channels] + + state.channel_widths[state.num_channels], + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.weights[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckUnweights) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 819, 1879, 2896, 3874, 0, 0, 0, + 0, 720, 1628, 2505, 3352, 0, 0, 0, + 0, 76, 870, 1640, 2388, 3113, 3819, 0}; + TF_LITE_MICRO_EXPECT_EQ(state.channel_weight_starts[state.num_channels] + + state.channel_widths[state.num_channels], + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.unweights[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckConvertFftComplexToEnergy) { + struct FilterbankState state; + state.start_index = kStartIndex; + state.end_index = kEndIndex; + + struct complex_int16_t fake_fft[] = { + {0, 0}, {-10, 9}, {-20, 0}, {-9, -10}, {0, 25}, {-119, 119}, + {-887, 0}, {3000, 3000}, {0, -6401}, {-3000, 3000}, {886, 0}, {118, 119}, + {0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}}; + int32_t* energy = reinterpret_cast(fake_fft); + FilterbankConvertFftComplexToEnergy(&state, fake_fft, energy); + + int i; + for (i = state.start_index; i < state.end_index; ++i) { + TF_LITE_MICRO_EXPECT_EQ(energy[i], kEnergy[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckAccumulateChannels) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + + FilterbankAccumulateChannels(&state, kEnergy); + + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(kWork) / sizeof(kWork[0])); + int i; + for (i = 0; i <= state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.work[i], kWork[i]); + } +} + +TF_LITE_MICRO_TEST(FilterbankTest_CheckSqrt) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FilterbankTestConfig config; + struct FilterbankState state; + TF_LITE_MICRO_EXPECT(FilterbankPopulateState( + error_reporter, &config.config_, &state, kSampleRate, kSpectrumSize)); + std::memcpy(state.work, kWork, sizeof(kWork)); + + uint32_t* scaled_filterbank = FilterbankSqrt(&state, kScaleShift); + + const uint32_t expected[] = {247311, 508620}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(scaled_filterbank[i], expected[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea4aa5179bc6e1d875127152b596f220a076191e --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.cc @@ -0,0 +1,212 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h" + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h" + +#define kFilterbankIndexAlignment 4 +#define kFilterbankChannelBlockSize 4 + +void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config) { + config->num_channels = 32; + config->lower_band_limit = 125.0f; + config->upper_band_limit = 7500.0f; + config->output_scale_shift = 7; +} + +static float FreqToMel(float freq) { + return 1127.0 * log(1.0 + (freq / 700.0)); +} + +static void CalculateCenterFrequencies(const int num_channels, + const float lower_frequency_limit, + const float upper_frequency_limit, + float* center_frequencies) { + assert(lower_frequency_limit >= 0.0f); + assert(upper_frequency_limit > lower_frequency_limit); + + const float mel_low = FreqToMel(lower_frequency_limit); + const float mel_hi = FreqToMel(upper_frequency_limit); + const float mel_span = mel_hi - mel_low; + const float mel_spacing = mel_span / (static_cast(num_channels)); + int i; + for (i = 0; i < num_channels; ++i) { + center_frequencies[i] = mel_low + (mel_spacing * (i + 1)); + } +} + +static void QuantizeFilterbankWeights(const float float_weight, int16_t* weight, + int16_t* unweight) { + *weight = floor(float_weight * (1 << kFilterbankBits) + 0.5); + *unweight = floor((1.0 - float_weight) * (1 << kFilterbankBits) + 0.5); +} + +int FilterbankPopulateState(tflite::ErrorReporter* error_reporter, + const struct FilterbankConfig* config, + struct FilterbankState* state, int sample_rate, + int spectrum_size) { + state->num_channels = config->num_channels; + const int num_channels_plus_1 = config->num_channels + 1; + + // How should we align things to index counts given the byte alignment? + const int index_alignment = + (kFilterbankIndexAlignment < sizeof(int16_t) + ? 1 + : kFilterbankIndexAlignment / sizeof(int16_t)); + + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->channel_frequency_starts, + (num_channels_plus_1 * sizeof(*state->channel_frequency_starts))); + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->channel_weight_starts, + (num_channels_plus_1 * sizeof(*state->channel_weight_starts))); + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->channel_widths, + (num_channels_plus_1 * sizeof(*state->channel_widths))); + STATIC_ALLOC_ENSURE_ARRAY_SIZE(state->work, + (num_channels_plus_1 * sizeof(*state->work))); + + float center_mel_freqs[kFeatureSliceSize + 1]; + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + center_mel_freqs, (num_channels_plus_1 * sizeof(*center_mel_freqs))); + + int16_t actual_channel_starts[kFeatureSliceSize + 1]; + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + actual_channel_starts, + (num_channels_plus_1 * sizeof(*actual_channel_starts))); + + int16_t actual_channel_widths[kFeatureSliceSize + 1]; + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + actual_channel_widths, + (num_channels_plus_1 * sizeof(*actual_channel_widths))); + + CalculateCenterFrequencies(num_channels_plus_1, config->lower_band_limit, + config->upper_band_limit, center_mel_freqs); + + // Always exclude DC. + const float hz_per_sbin = + 0.5 * sample_rate / (static_cast(spectrum_size) - 1); + state->start_index = 1.5 + config->lower_band_limit / hz_per_sbin; + state->end_index = 0; // Initialized to zero here, but actually set below. + + // For each channel, we need to figure out what frequencies belong to it, and + // how much padding we need to add so that we can efficiently multiply the + // weights and unweights for accumulation. To simplify the multiplication + // logic, all channels will have some multiplication to do (even if there are + // no frequencies that accumulate to that channel) - they will be directed to + // a set of zero weights. + int chan_freq_index_start = state->start_index; + int weight_index_start = 0; + int needs_zeros = 0; + + int chan; + for (chan = 0; chan < num_channels_plus_1; ++chan) { + // Keep jumping frequencies until we overshoot the bound on this channel. + int freq_index = chan_freq_index_start; + while (FreqToMel((freq_index)*hz_per_sbin) <= center_mel_freqs[chan]) { + ++freq_index; + } + + const int width = freq_index - chan_freq_index_start; + actual_channel_starts[chan] = chan_freq_index_start; + actual_channel_widths[chan] = width; + + if (width == 0) { + // This channel doesn't actually get anything from the frequencies, it's + // always zero. We need then to insert some 'zero' weights into the + // output, and just redirect this channel to do a single multiplication at + // this point. For simplicity, the zeros are placed at the beginning of + // the weights arrays, so we have to go and update all the other + // weight_starts to reflect this shift (but only once). + state->channel_frequency_starts[chan] = 0; + state->channel_weight_starts[chan] = 0; + state->channel_widths[chan] = kFilterbankChannelBlockSize; + if (!needs_zeros) { + needs_zeros = 1; + int j; + for (j = 0; j < chan; ++j) { + state->channel_weight_starts[j] += kFilterbankChannelBlockSize; + } + weight_index_start += kFilterbankChannelBlockSize; + } + } else { + // How far back do we need to go to ensure that we have the proper + // alignment? + const int aligned_start = + (chan_freq_index_start / index_alignment) * index_alignment; + const int aligned_width = (chan_freq_index_start - aligned_start + width); + const int padded_width = + (((aligned_width - 1) / kFilterbankChannelBlockSize) + 1) * + kFilterbankChannelBlockSize; + + state->channel_frequency_starts[chan] = aligned_start; + state->channel_weight_starts[chan] = weight_index_start; + state->channel_widths[chan] = padded_width; + weight_index_start += padded_width; + } + chan_freq_index_start = freq_index; + } + + // Allocate the two arrays to store the weights - weight_index_start contains + // the index of what would be the next set of weights that we would need to + // add, so that's how many weights we need to allocate. + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->weights, (weight_index_start * sizeof(*state->weights))); + for (int i = 0; i < weight_index_start; ++i) { + state->weights[i] = 0; + } + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->unweights, (weight_index_start * sizeof(*state->unweights))); + for (int i = 0; i < weight_index_start; ++i) { + state->unweights[i] = 0; + } + + // Next pass, compute all the weights. Since everything has been memset to + // zero, we only need to fill in the weights that correspond to some frequency + // for a channel. + const float mel_low = FreqToMel(config->lower_band_limit); + for (chan = 0; chan < num_channels_plus_1; ++chan) { + int frequency = actual_channel_starts[chan]; + const int num_frequencies = actual_channel_widths[chan]; + const int frequency_offset = + frequency - state->channel_frequency_starts[chan]; + const int weight_start = state->channel_weight_starts[chan]; + const float denom_val = (chan == 0) ? mel_low : center_mel_freqs[chan - 1]; + + int j; + for (j = 0; j < num_frequencies; ++j, ++frequency) { + const float weight = + (center_mel_freqs[chan] - FreqToMel(frequency * hz_per_sbin)) / + (center_mel_freqs[chan] - denom_val); + + // Make the float into an integer for the weights (and unweights). + const int weight_index = weight_start + frequency_offset + j; + QuantizeFilterbankWeights(weight, state->weights + weight_index, + state->unweights + weight_index); + } + if (frequency > state->end_index) { + state->end_index = frequency; + } + } + + if (state->end_index >= spectrum_size) { + error_reporter->Report("Filterbank end_index is above spectrum size."); + return 0; + } + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0bf0c8494ce11753c2f2c2185e6c1141d0adbc74 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FILTERBANK_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FILTERBANK_UTIL_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +struct FilterbankConfig { + // number of frequency channel buckets for filterbank + int num_channels; + // maximum frequency to include + float upper_band_limit; + // minimum frequency to include + float lower_band_limit; + // unused + int output_scale_shift; +}; + +// Fills the frontendConfig with "sane" defaults. +void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config); + +// Allocates any buffers. +int FilterbankPopulateState(tflite::ErrorReporter* error_reporter, + const struct FilterbankConfig* config, + struct FilterbankState* state, int sample_rate, + int spectrum_size); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FILTERBANK_UTIL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.cc new file mode 100644 index 0000000000000000000000000000000000000000..c609190e4db90adb3ec79e2a794b61923cea0978 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h" + +struct FrontendOutput FrontendProcessSamples(struct FrontendState* state, + const int16_t* samples, + size_t num_samples, + size_t* num_samples_read) { + struct FrontendOutput output; + output.values = nullptr; + output.size = 0; + + // Try to apply the window - if it fails, return and wait for more data. + if (!WindowProcessSamples(&state->window, samples, num_samples, + num_samples_read)) { + return output; + } + + // Apply the FFT to the window's output (and scale it so that the fixed point + // FFT can have as much resolution as possible). + int input_shift = + 15 - MostSignificantBit32(state->window.max_abs_output_value); + FftCompute(&state->fft, state->window.output, input_shift); + + // We can re-ruse the fft's output buffer to hold the energy. + int32_t* energy = reinterpret_cast(state->fft.output); + FilterbankConvertFftComplexToEnergy(&state->filterbank, state->fft.output, + energy); + FilterbankAccumulateChannels(&state->filterbank, energy); + uint32_t* scaled_filterbank = FilterbankSqrt(&state->filterbank, input_shift); + + // Apply noise reduction. + NoiseReductionApply(&state->noise_reduction, scaled_filterbank); + + if (state->pcan_gain_control.enable_pcan) { + PcanGainControlApply(&state->pcan_gain_control, scaled_filterbank); + } + + // Apply the log and scale. + int correction_bits = + MostSignificantBit32(state->fft.fft_size) - 1 - (kFilterbankBits / 2); + uint16_t* logged_filterbank = + LogScaleApply(&state->log_scale, scaled_filterbank, + state->filterbank.num_channels, correction_bits); + + output.size = state->filterbank.num_channels; + output.values = logged_filterbank; + return output; +} + +void FrontendReset(struct FrontendState* state) { + WindowReset(&state->window); + FftReset(&state->fft); + FilterbankReset(&state->filterbank); + NoiseReductionReset(&state->noise_reduction); +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h new file mode 100644 index 0000000000000000000000000000000000000000..3221d283e8740cd95fa39ece4ad9533d059018d7 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FRONTEND_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FRONTEND_H_ + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h" + +struct FrontendState { + struct WindowState window; + struct FftState fft; + struct FilterbankState filterbank; + struct NoiseReductionState noise_reduction; + struct PcanGainControlState pcan_gain_control; + struct LogScaleState log_scale; +}; + +struct FrontendOutput { + const uint16_t* values; + size_t size; +}; + +// Main entry point to processing frontend samples. Updates num_samples_read to +// contain the number of samples that have been consumed from the input array. +// Returns a struct containing the generated output. If not enough samples were +// added to generate a feature vector, the returned size will be 0 and the +// values pointer will be NULL. Note that the output pointer will be invalidated +// as soon as FrontendProcessSamples is called again, so copy the contents +// elsewhere if you need to use them later. +struct FrontendOutput FrontendProcessSamples(struct FrontendState* state, + const int16_t* samples, + size_t num_samples, + size_t* num_samples_read); + +void FrontendReset(struct FrontendState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FRONTEND_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d9f86d48d86b1795fa45b9bf24db4dd75fb0a20 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int kSampleRate = 1000; +const int kWindowSamples = 25; +const int kStepSamples = 10; +const int16_t kFakeAudioData[] = { + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768}; + +// Test end-to-end frontend behaviors. +class FrontendTestConfig { + public: + FrontendTestConfig() { + config_.window.size_ms = 25; + config_.window.step_size_ms = 10; + config_.noise_reduction.smoothing_bits = 10; + config_.filterbank.num_channels = 2; + config_.filterbank.lower_band_limit = 8.0; + config_.filterbank.upper_band_limit = 450.0; + config_.noise_reduction.smoothing_bits = 10; + config_.noise_reduction.even_smoothing = 0.025; + config_.noise_reduction.odd_smoothing = 0.06; + config_.noise_reduction.min_signal_remaining = 0.05; + config_.pcan_gain_control.enable_pcan = true; + config_.pcan_gain_control.strength = 0.95; + config_.pcan_gain_control.offset = 80.0; + config_.pcan_gain_control.gain_bits = 21; + config_.log_scale.enable_log = true; + config_.log_scale.scale_shift = 6; + } + + struct FrontendConfig config_; +}; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FrontendTest_CheckOutputValues) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FrontendTestConfig config; + struct FrontendState state; + TF_LITE_MICRO_EXPECT(FrontendPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + struct FrontendOutput output = FrontendProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read); + + const uint16_t expected[] = {479, 425}; + TF_LITE_MICRO_EXPECT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < output.size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(output.values[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FrontendTest_CheckConsecutiveWindow) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FrontendTestConfig config; + struct FrontendState state; + TF_LITE_MICRO_EXPECT(FrontendPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + FrontendProcessSamples(&state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), + &num_samples_read); + struct FrontendOutput output = FrontendProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read); + + const int16_t expected[] = {436, 378}; + TF_LITE_MICRO_EXPECT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < output.size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(output.values[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(FrontendTest_CheckNotEnoughSamples) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + FrontendTestConfig config; + struct FrontendState state; + TF_LITE_MICRO_EXPECT(FrontendPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + FrontendProcessSamples(&state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), + &num_samples_read); + FrontendProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read); + struct FrontendOutput output = FrontendProcessSamples( + &state, kFakeAudioData + kWindowSamples + kStepSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples - + kStepSamples, + &num_samples_read); + + TF_LITE_MICRO_EXPECT_EQ(output.size, 0); + TF_LITE_MICRO_EXPECT_EQ(output.values, nullptr); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..220bc130fb9332e4afbe02a4432b61c8a4bcd544 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h" + +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h" + +void FrontendFillConfigWithDefaults(struct FrontendConfig* config) { + WindowFillConfigWithDefaults(&config->window); + FilterbankFillConfigWithDefaults(&config->filterbank); + NoiseReductionFillConfigWithDefaults(&config->noise_reduction); + PcanGainControlFillConfigWithDefaults(&config->pcan_gain_control); + LogScaleFillConfigWithDefaults(&config->log_scale); +} + +int FrontendPopulateState(tflite::ErrorReporter* error_reporter, + const struct FrontendConfig* config, + struct FrontendState* state, int sample_rate) { + memset(state, 0, sizeof(*state)); + + if (!WindowPopulateState(error_reporter, &config->window, &state->window, + sample_rate)) { + error_reporter->Report("Failed to populate window state"); + return 0; + } + + if (!FftPopulateState(error_reporter, &state->fft, state->window.size)) { + error_reporter->Report("Failed to populate fft state"); + return 0; + } + FftInit(&state->fft); + + if (!FilterbankPopulateState(error_reporter, &config->filterbank, + &state->filterbank, sample_rate, + state->fft.fft_size / 2 + 1)) { + error_reporter->Report("Failed to populate filterbank state"); + return 0; + } + + if (!NoiseReductionPopulateState(error_reporter, &config->noise_reduction, + &state->noise_reduction, + state->filterbank.num_channels)) { + error_reporter->Report("Failed to populate noise reduction state"); + return 0; + } + + int input_correction_bits = + MostSignificantBit32(state->fft.fft_size) - 1 - (kFilterbankBits / 2); + if (!PcanGainControlPopulateState( + error_reporter, &config->pcan_gain_control, &state->pcan_gain_control, + state->noise_reduction.estimate, state->filterbank.num_channels, + state->noise_reduction.smoothing_bits, input_correction_bits)) { + error_reporter->Report("Failed to populate pcan gain control state"); + return 0; + } + + if (!LogScalePopulateState(error_reporter, &config->log_scale, + &state->log_scale)) { + error_reporter->Report("Failed to populate log scale state"); + return 0; + } + + FrontendReset(state); + + // All good, return a true value. + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h new file mode 100644 index 0000000000000000000000000000000000000000..f7267644ae053e1f816cb22bb5e0ecd04e4de0ef --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FRONTEND_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FRONTEND_UTIL_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/fft_util.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/filterbank_util.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +struct FrontendConfig { + struct WindowConfig window; + struct FilterbankConfig filterbank; + struct NoiseReductionConfig noise_reduction; + struct PcanGainControlConfig pcan_gain_control; + struct LogScaleConfig log_scale; +}; + +// Fills the frontendConfig with "sane" defaults. +void FrontendFillConfigWithDefaults(struct FrontendConfig* config); + +// Prepares any buffers. +int FrontendPopulateState(tflite::ErrorReporter* error_reporter, + const struct FrontendConfig* config, + struct FrontendState* state, int sample_rate); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_FRONTEND_UTIL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.cc new file mode 100644 index 0000000000000000000000000000000000000000..c651caad8c67773f20a485eea5519f286b1b0253 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.cc @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h" +const uint16_t kLogLut[] +#ifndef _MSC_VER + __attribute__((aligned(4))) +#endif // _MSV_VER + = {0, 224, 442, 654, 861, 1063, 1259, 1450, 1636, 1817, 1992, 2163, + 2329, 2490, 2646, 2797, 2944, 3087, 3224, 3358, 3487, 3611, 3732, 3848, + 3960, 4068, 4172, 4272, 4368, 4460, 4549, 4633, 4714, 4791, 4864, 4934, + 5001, 5063, 5123, 5178, 5231, 5280, 5326, 5368, 5408, 5444, 5477, 5507, + 5533, 5557, 5578, 5595, 5610, 5622, 5631, 5637, 5640, 5641, 5638, 5633, + 5626, 5615, 5602, 5586, 5568, 5547, 5524, 5498, 5470, 5439, 5406, 5370, + 5332, 5291, 5249, 5203, 5156, 5106, 5054, 5000, 4944, 4885, 4825, 4762, + 4697, 4630, 4561, 4490, 4416, 4341, 4264, 4184, 4103, 4020, 3935, 3848, + 3759, 3668, 3575, 3481, 3384, 3286, 3186, 3084, 2981, 2875, 2768, 2659, + 2549, 2437, 2323, 2207, 2090, 1971, 1851, 1729, 1605, 1480, 1353, 1224, + 1094, 963, 830, 695, 559, 421, 282, 142, 0, 0}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h new file mode 100644 index 0000000000000000000000000000000000000000..d5ed9339bd02e23cc134992badce5cdb72a74771 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_LUT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_LUT_H_ + +#include + +// Number of segments in the log lookup table. The table will be kLogSegments+1 +// in length (with some padding). +#define kLogSegments 128 +#define kLogSegmentsLog2 7 + +// Scale used by lookup table. +#define kLogScale 65536 +#define kLogScaleLog2 16 +#define kLogCoeff 45426 + +extern const uint16_t kLogLut[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_LUT_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.cc new file mode 100644 index 0000000000000000000000000000000000000000..f85e9c1a2f5a0056deaffad7b99b774c772ce562 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_lut.h" + +#define kuint16max 0x0000FFFF + +// The following functions implement integer logarithms of various sizes. The +// approximation is calculated according to method described in +// www.inti.gob.ar/electronicaeinformatica/instrumentacion/utic/ +// publicaciones/SPL2007/Log10-spl07.pdf +// It first calculates log2 of the input and then converts it to natural +// logarithm. + +static uint32_t Log2FractionPart(const uint32_t x, const uint32_t log2x) { + // Part 1 + int32_t frac = x - (1LL << log2x); + if (log2x < kLogScaleLog2) { + frac <<= kLogScaleLog2 - log2x; + } else { + frac >>= log2x - kLogScaleLog2; + } + // Part 2 + const uint32_t base_seg = frac >> (kLogScaleLog2 - kLogSegmentsLog2); + const uint32_t seg_unit = + ((static_cast(1)) << kLogScaleLog2) >> kLogSegmentsLog2; + + const int32_t c0 = kLogLut[base_seg]; + const int32_t c1 = kLogLut[base_seg + 1]; + const int32_t seg_base = seg_unit * base_seg; + const int32_t rel_pos = ((c1 - c0) * (frac - seg_base)) >> kLogScaleLog2; + return frac + c0 + rel_pos; +} + +static uint32_t Log(const uint32_t x, const uint32_t scale_shift) { + const uint32_t integer = MostSignificantBit32(x) - 1; + const uint32_t fraction = Log2FractionPart(x, integer); + const uint32_t log2 = (integer << kLogScaleLog2) + fraction; + const uint32_t round = kLogScale / 2; + const uint32_t loge = + ((static_cast(kLogCoeff)) * log2 + round) >> kLogScaleLog2; + // Finally scale to our output scale + const uint32_t loge_scaled = ((loge << scale_shift) + round) >> kLogScaleLog2; + return loge_scaled; +} + +uint16_t* LogScaleApply(struct LogScaleState* state, uint32_t* signal, + int signal_size, int correction_bits) { + const int scale_shift = state->scale_shift; + uint16_t* output = reinterpret_cast(signal); + uint16_t* ret = output; + int i; + for (i = 0; i < signal_size; ++i) { + uint32_t value = *signal++; + if (state->enable_log) { + if (correction_bits < 0) { + value >>= -correction_bits; + } else { + value <<= correction_bits; + } + if (value > 1) { + value = Log(value, scale_shift); + } else { + value = 0; + } + } + *output++ = (value < kuint16max) ? value : kuint16max; + } + return ret; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h new file mode 100644 index 0000000000000000000000000000000000000000..d90b87fb6d6fe181158b209a87a42f6d075ba457 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_SCALE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_SCALE_H_ + +#include +#include + +struct LogScaleState { + int enable_log; + int scale_shift; +}; + +// Applies a fixed point logarithm to the signal and converts it to 16 bit. Note +// that the signal array will be modified. +uint16_t* LogScaleApply(struct LogScaleState* state, uint32_t* signal, + int signal_size, int correction_bits); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_SCALE_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5b23323a273314a347f25e691d538781558980a --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int kScaleShift = 6; +const int kCorrectionBits = -1; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(LogScaleTest_CheckOutputValues) { + struct LogScaleState state; + state.enable_log = true; + state.scale_shift = kScaleShift; + + uint32_t fake_signal[] = {3578, 1533}; + uint16_t* output = LogScaleApply(&state, fake_signal, + sizeof(fake_signal) / sizeof(fake_signal[0]), + kCorrectionBits); + + const uint16_t expected[] = {479, 425}; + int i; + for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + TF_LITE_MICRO_EXPECT_EQ(output[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(LogScaleTest_CheckOutputValuesNoLog) { + struct LogScaleState state; + state.enable_log = false; + state.scale_shift = kScaleShift; + + uint32_t fake_signal[] = {85964, 45998}; + uint16_t* output = LogScaleApply(&state, fake_signal, + sizeof(fake_signal) / sizeof(fake_signal[0]), + kCorrectionBits); + + const uint16_t expected[] = {65535, 45998}; + int i; + for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + TF_LITE_MICRO_EXPECT_EQ(output[i], expected[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..09adc09c3511fdedeb7246b6717fa4bfb4c83ba1 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h" + +void LogScaleFillConfigWithDefaults(struct LogScaleConfig* config) { + config->enable_log = 1; + config->scale_shift = 6; +} + +int LogScalePopulateState(tflite::ErrorReporter* error_reporter, + const struct LogScaleConfig* config, + struct LogScaleState* state) { + state->enable_log = config->enable_log; + state->scale_shift = config->scale_shift; + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h new file mode 100644 index 0000000000000000000000000000000000000000..3caf207f2693756783b6c1dc64246d2522388d3b --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale_util.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_SCALE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_SCALE_UTIL_H_ + +#include +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/log_scale.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +struct LogScaleConfig { + // set to false (0) to disable this module + int enable_log; + // scale results by 2^(scale_shift) + int scale_shift; +}; + +// Populates the LogScaleConfig with "sane" default values. +void LogScaleFillConfigWithDefaults(struct LogScaleConfig* config); + +// Allocates any buffers. +int LogScalePopulateState(tflite::ErrorReporter* error_reporter, + const struct LogScaleConfig* config, + struct LogScaleState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_LOG_SCALE_UTIL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.cc new file mode 100644 index 0000000000000000000000000000000000000000..6657c6f3205903c178b2aa4314551f5b4fee1101 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.cc @@ -0,0 +1,99 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h" + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/frontend_util.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +namespace { + +FrontendState g_micro_features_state; +bool g_is_first_time = true; + +} // namespace + +TfLiteStatus InitializeMicroFeatures(tflite::ErrorReporter* error_reporter) { + FrontendConfig config; + config.window.size_ms = kFeatureSliceDurationMs; + config.window.step_size_ms = kFeatureSliceStrideMs; + config.noise_reduction.smoothing_bits = 10; + config.filterbank.num_channels = kFeatureSliceSize; + config.filterbank.lower_band_limit = 125.0; + config.filterbank.upper_band_limit = 7500.0; + config.noise_reduction.smoothing_bits = 10; + config.noise_reduction.even_smoothing = 0.025; + config.noise_reduction.odd_smoothing = 0.06; + config.noise_reduction.min_signal_remaining = 0.05; + config.pcan_gain_control.enable_pcan = 1; + config.pcan_gain_control.strength = 0.95; + config.pcan_gain_control.offset = 80.0; + config.pcan_gain_control.gain_bits = 21; + config.log_scale.enable_log = 1; + config.log_scale.scale_shift = 6; + if (!FrontendPopulateState(error_reporter, &config, &g_micro_features_state, + kAudioSampleFrequency)) { + error_reporter->Report("FrontendPopulateState() failed"); + return kTfLiteError; + } + g_is_first_time = true; + return kTfLiteOk; +} + +// This is not exposed in any header, and is only used for testing, to ensure +// that the state is correctly set up before generating results. +void SetMicroFeaturesNoiseEstimates(const uint32_t* estimate_presets) { + for (int i = 0; i < g_micro_features_state.filterbank.num_channels; ++i) { + g_micro_features_state.noise_reduction.estimate[i] = estimate_presets[i]; + } +} + +TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, + int output_size, uint8_t* output, + size_t* num_samples_read) { + const int16_t* frontend_input; + if (g_is_first_time) { + frontend_input = input; + g_is_first_time = false; + } else { + frontend_input = input + 160; + } + FrontendOutput frontend_output = FrontendProcessSamples( + &g_micro_features_state, frontend_input, input_size, num_samples_read); + + for (int i = 0; i < frontend_output.size; ++i) { + // These scaling values are derived from those used in input_data.py in the + // training pipeline. + constexpr int32_t value_scale = (10 * 255); + constexpr int32_t value_div = (256 * 26); + int32_t value = + ((frontend_output.values[i] * value_scale) + (value_div / 2)) / + value_div; + if (value < 0) { + value = 0; + } + if (value > 255) { + value = 255; + } + output[i] = value; + } + + return kTfLiteOk; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..46fa55d62ff7a8032cb94e512d4e856fb5960276 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +// Sets up any resources needed for the feature generation pipeline. +TfLiteStatus InitializeMicroFeatures(tflite::ErrorReporter* error_reporter); + +// Converts audio sample data into a more compact form that's appropriate for +// feeding into a neural network. +TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, + int output_size, uint8_t* output, + size_t* num_samples_read); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c2da7a799ce398ba7faf31d577d79bb96b2072a --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_features_generator.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +// This is a test-only API, not exposed in any public headers, so declare it. +void SetMicroFeaturesNoiseEstimates(const uint32_t* estimate_presets); + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestMicroFeaturesGeneratorYes) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, InitializeMicroFeatures(error_reporter)); + + // The micro features pipeline retains state from previous calls to help + // estimate the background noise. Unfortunately this makes it harder to + // exactly reproduce results in a test environment, so use a known snapshot + // of the parameters at the point that the golden feature values were + // created. + const uint32_t yes_estimate_presets[] = { + 1062898, 2644477, 1257642, 1864718, 412722, 725703, 395721, 474082, + 173046, 255856, 158966, 153736, 69181, 199100, 144493, 227740, + 110573, 164330, 79666, 144650, 122947, 476799, 398553, 497493, + 322152, 1140005, 566716, 690605, 308902, 347481, 109891, 170457, + 73901, 100975, 42963, 72325, 34183, 20207, 6640, 9468, + }; + SetMicroFeaturesNoiseEstimates(yes_estimate_presets); + + uint8_t yes_calculated_data[g_yes_feature_data_slice_size]; + size_t num_samples_read; + TfLiteStatus yes_status = GenerateMicroFeatures( + error_reporter, g_yes_30ms_sample_data, g_yes_30ms_sample_data_size, + g_yes_feature_data_slice_size, yes_calculated_data, &num_samples_read); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, yes_status); + + for (int i = 0; i < g_yes_feature_data_slice_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_yes_feature_data_slice[i], + yes_calculated_data[i]); + if (g_yes_feature_data_slice[i] != yes_calculated_data[i]) { + error_reporter->Report("Expected value %d but found %d", + g_yes_feature_data_slice[i], + yes_calculated_data[i]); + } + } +} + +TF_LITE_MICRO_TEST(TestMicroFeaturesGeneratorNo) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, InitializeMicroFeatures(error_reporter)); + // As we did for the previous features, set known good noise state + // parameters. + const uint32_t no_estimate_presets[] = { + 2563964, 1909393, 559801, 538670, 203643, 175959, 75088, 139491, + 59691, 95307, 43865, 129263, 52517, 80058, 51330, 100731, + 76674, 76262, 15497, 22598, 13778, 21460, 8946, 17806, + 10023, 18810, 8002, 10842, 7578, 9983, 6267, 10759, + 8946, 18488, 9691, 39785, 9939, 17835, 9671, 18512, + }; + SetMicroFeaturesNoiseEstimates(no_estimate_presets); + + uint8_t no_calculated_data[g_no_feature_data_slice_size]; + size_t num_samples_read; + TfLiteStatus no_status = GenerateMicroFeatures( + error_reporter, g_no_30ms_sample_data, g_no_30ms_sample_data_size, + g_no_feature_data_slice_size, no_calculated_data, &num_samples_read); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, no_status); + + for (int i = 0; i < g_no_feature_data_slice_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_no_feature_data_slice[i], no_calculated_data[i]); + if (g_no_feature_data_slice[i] != no_calculated_data[i]) { + error_reporter->Report("Expected value %d but found %d", + g_no_feature_data_slice[i], no_calculated_data[i]); + } + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc similarity index 95% rename from tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc index b9b8fb37b19d384fe92edf8ce2292aee19b99b7f..09f65ca24b3cd03485a5a79599dc0143ca83329c 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" const char* kCategoryLabels[kCategoryCount] = { "silence", diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h new file mode 100644 index 0000000000000000000000000000000000000000..b74a4d01ca49d37d62daf3710c878cfc6d9940f0 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_MODEL_SETTINGS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_MODEL_SETTINGS_H_ + +// Keeping these as constant expressions allow us to allocate fixed-sized arrays +// on the stack for our working memory. + +// The size of the input time series data we pass to the FFT to produce the +// frequency information. This has to be a power of two, and since we're dealing +// with 30ms of 16KHz inputs, which means 480 samples, this is the next value. +constexpr int kMaxAudioSampleSize = 512; +constexpr int kAudioSampleFrequency = 16000; + +// All of these values are derived from the values used during model training, +// if you change your model you'll need to update these constants. +constexpr int kFeatureSliceSize = 40; +constexpr int kFeatureSliceCount = 49; +constexpr int kFeatureElementCount = (kFeatureSliceSize * kFeatureSliceCount); +constexpr int kFeatureSliceStrideMs = 20; +constexpr int kFeatureSliceDurationMs = 30; + +constexpr int kCategoryCount = 4; +constexpr int kSilenceIndex = 0; +constexpr int kUnknownIndex = 1; +extern const char* kCategoryLabels[kCategoryCount]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_MODEL_SETTINGS_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dbb606e184e70e0fa97d417bcbab6010b8a88a5 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc @@ -0,0 +1,24 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.h" + +const uint8_t g_no_feature_data_slice[g_no_feature_data_slice_size] = { + 216, 195, 223, 211, 238, 223, 243, 215, 226, 204, 232, 211, 232, 213, + 240, 218, 235, 214, 238, 205, 207, 173, 149, 201, 215, 200, 230, 213, + 208, 195, 175, 151, 195, 175, 182, 163, 235, 217, 218, 190, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..72ea2bf6a23e83bff5dea771931e585d74c757ec --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_feature_data_slice.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was extracted from the larger feature data held in +// no_features_data.cc and consists of the 29th spectrogram slice of 43 values. +// This is the expected result of running the sample data in +// no_30ms_sample_data.cc through through the preprocessing pipeline. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_FEATURE_DATA_SLICE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_FEATURE_DATA_SLICE_H_ + +#include + +constexpr int g_no_feature_data_slice_size = 40; +extern const uint8_t g_no_feature_data_slice[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_FEATURE_DATA_SLICE_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..865209b01df7e8f77139bcd5b6a37537a6f674f4 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.cc @@ -0,0 +1,165 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h" + +/* File automatically created by + * tensorflow/examples/speech_commands/wav_to_features.py \ + * --sample_rate=16000 \ + * --clip_duration_ms=1000 \ + * --window_size_ms=30 \ + * --window_stride_ms=20 \ + * --feature_bin_count=40 \ + * --quantize=1 \ + * --preprocess="micro" \ + * --input_wav="speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav" \ + * --output_c_file="/tmp/no_micro_features_data.cc" \ + */ + +const int g_no_micro_f9643d42_nohash_4_width = 40; +const int g_no_micro_f9643d42_nohash_4_height = 49; +const unsigned char g_no_micro_f9643d42_nohash_4_data[] = { + 230, 205, 191, 203, 202, 181, 180, 194, 205, 187, 183, 197, 203, 198, 196, + 186, 202, 159, 151, 126, 110, 138, 141, 142, 137, 148, 133, 120, 110, 126, + 117, 110, 117, 116, 137, 134, 95, 116, 123, 110, 184, 144, 183, 189, 197, + 172, 188, 164, 194, 179, 175, 174, 182, 173, 184, 174, 200, 145, 154, 148, + 147, 135, 143, 122, 127, 138, 116, 99, 122, 105, 110, 125, 127, 133, 131, + 123, 116, 119, 127, 114, 193, 176, 185, 170, 175, 146, 166, 167, 185, 185, + 185, 183, 195, 185, 176, 178, 197, 155, 137, 144, 164, 132, 153, 132, 138, + 137, 134, 95, 120, 116, 131, 122, 99, 120, 120, 110, 116, 110, 126, 127, + 128, 159, 187, 119, 178, 187, 197, 167, 199, 184, 180, 165, 194, 176, 144, + 134, 187, 136, 142, 134, 145, 132, 145, 105, 119, 123, 125, 116, 125, 102, + 129, 138, 130, 99, 99, 90, 120, 123, 134, 95, 194, 172, 187, 123, 191, + 179, 195, 182, 201, 137, 167, 142, 185, 161, 187, 146, 167, 152, 154, 107, + 152, 112, 134, 144, 117, 116, 105, 85, 105, 105, 99, 90, 123, 112, 112, + 68, 107, 105, 117, 99, 116, 143, 139, 90, 154, 142, 188, 172, 178, 135, + 175, 149, 177, 110, 173, 160, 169, 162, 173, 119, 132, 110, 85, 85, 117, + 129, 117, 112, 117, 51, 112, 95, 139, 102, 105, 90, 128, 119, 112, 99, + 170, 168, 195, 152, 174, 173, 180, 0, 157, 130, 169, 149, 149, 123, 170, + 130, 170, 133, 159, 102, 134, 90, 85, 105, 126, 119, 130, 90, 78, 68, + 127, 120, 95, 51, 122, 110, 112, 78, 116, 95, 180, 135, 179, 146, 179, + 162, 197, 153, 172, 135, 154, 0, 149, 95, 145, 114, 166, 0, 114, 110, + 145, 107, 114, 90, 136, 68, 95, 95, 95, 85, 116, 99, 116, 0, 95, + 68, 102, 51, 102, 78, 185, 157, 138, 158, 180, 117, 173, 142, 145, 117, + 169, 130, 159, 99, 138, 123, 169, 90, 78, 0, 123, 85, 107, 51, 114, + 102, 95, 0, 116, 85, 119, 95, 95, 68, 85, 51, 116, 68, 102, 78, + 167, 105, 164, 163, 178, 126, 164, 154, 154, 51, 177, 120, 156, 85, 134, + 139, 168, 90, 161, 102, 114, 116, 122, 95, 112, 102, 107, 51, 114, 85, + 119, 78, 114, 90, 102, 51, 102, 51, 114, 99, 177, 68, 152, 102, 184, + 166, 179, 129, 177, 129, 180, 110, 158, 105, 139, 0, 145, 85, 148, 102, + 117, 102, 116, 0, 78, 68, 90, 51, 107, 85, 78, 0, 51, 0, 51, + 0, 95, 51, 107, 68, 180, 117, 90, 0, 138, 0, 187, 146, 119, 140, + 164, 90, 136, 0, 131, 51, 159, 99, 141, 138, 116, 51, 90, 51, 90, + 68, 105, 0, 85, 78, 112, 51, 122, 95, 128, 68, 85, 0, 112, 68, + 147, 126, 178, 146, 171, 130, 190, 147, 188, 123, 170, 78, 132, 0, 130, + 125, 159, 95, 102, 0, 110, 0, 95, 85, 120, 68, 78, 51, 99, 51, + 105, 0, 112, 102, 105, 68, 90, 51, 90, 0, 127, 95, 166, 175, 187, + 133, 135, 0, 171, 139, 132, 128, 140, 51, 126, 107, 161, 0, 95, 51, + 119, 0, 114, 0, 95, 110, 116, 51, 112, 0, 90, 0, 116, 51, 68, + 0, 105, 68, 105, 0, 164, 78, 173, 0, 194, 166, 145, 114, 116, 51, + 107, 122, 151, 0, 156, 102, 148, 51, 122, 95, 129, 0, 85, 0, 127, + 78, 90, 0, 78, 0, 95, 0, 110, 0, 68, 119, 120, 68, 68, 0, + 122, 99, 147, 127, 200, 167, 85, 114, 161, 85, 161, 125, 143, 99, 156, + 85, 147, 68, 99, 0, 107, 102, 132, 51, 112, 68, 95, 78, 99, 0, + 68, 0, 51, 0, 90, 78, 128, 51, 95, 0, 166, 136, 174, 138, 189, + 144, 130, 129, 138, 134, 132, 120, 134, 0, 51, 78, 147, 51, 51, 0, + 51, 0, 78, 0, 68, 68, 95, 78, 90, 0, 0, 0, 68, 0, 90, + 68, 110, 0, 95, 51, 165, 151, 157, 0, 0, 0, 112, 0, 112, 95, + 149, 107, 119, 68, 126, 68, 138, 0, 78, 0, 78, 0, 99, 51, 112, + 0, 102, 0, 78, 51, 85, 0, 0, 0, 78, 0, 95, 0, 95, 78, + 105, 0, 152, 0, 0, 51, 132, 105, 159, 0, 129, 102, 114, 0, 138, + 51, 123, 0, 129, 78, 119, 51, 51, 51, 105, 0, 78, 85, 95, 0, + 85, 0, 0, 0, 85, 0, 78, 0, 0, 0, 172, 142, 141, 0, 137, + 0, 148, 128, 157, 120, 146, 120, 120, 0, 95, 78, 141, 68, 68, 0, + 68, 0, 90, 0, 85, 0, 107, 0, 78, 0, 85, 51, 102, 0, 68, + 78, 68, 0, 51, 0, 125, 0, 141, 51, 102, 138, 175, 51, 120, 51, + 173, 85, 116, 141, 164, 68, 150, 123, 133, 51, 114, 0, 117, 68, 150, + 51, 116, 68, 78, 0, 68, 0, 68, 0, 85, 0, 78, 0, 51, 78, + 155, 90, 161, 0, 132, 99, 123, 78, 107, 0, 134, 90, 95, 0, 78, + 0, 162, 143, 85, 0, 107, 78, 125, 90, 90, 51, 51, 0, 85, 0, + 0, 0, 132, 102, 102, 154, 128, 0, 99, 68, 162, 102, 151, 0, 99, + 51, 147, 141, 156, 0, 112, 120, 158, 127, 145, 139, 187, 171, 135, 138, + 146, 0, 95, 68, 127, 0, 85, 0, 105, 0, 0, 0, 187, 170, 162, + 188, 165, 51, 51, 78, 243, 215, 225, 196, 205, 181, 205, 168, 176, 134, + 157, 110, 126, 114, 133, 139, 193, 163, 159, 116, 160, 126, 122, 127, 171, + 99, 114, 68, 123, 85, 90, 0, 157, 146, 166, 179, 136, 0, 116, 90, + 242, 219, 240, 204, 216, 164, 188, 171, 176, 164, 154, 158, 190, 157, 190, + 141, 182, 177, 169, 128, 172, 145, 105, 129, 157, 90, 78, 51, 119, 68, + 137, 68, 116, 78, 141, 132, 151, 122, 156, 140, 234, 206, 229, 201, 216, + 174, 191, 144, 162, 85, 122, 157, 194, 167, 204, 149, 180, 166, 166, 139, + 122, 133, 156, 126, 145, 85, 128, 0, 99, 51, 145, 0, 126, 51, 166, + 162, 166, 162, 177, 157, 228, 198, 221, 197, 214, 177, 173, 166, 173, 139, + 185, 191, 202, 163, 205, 172, 206, 189, 135, 68, 166, 134, 149, 134, 135, + 90, 127, 107, 175, 90, 136, 117, 135, 140, 172, 167, 166, 149, 177, 152, + 221, 191, 215, 194, 211, 0, 156, 147, 182, 178, 208, 163, 190, 157, 208, + 200, 195, 164, 179, 154, 181, 150, 143, 99, 132, 137, 185, 143, 163, 85, + 51, 107, 132, 134, 164, 127, 167, 159, 175, 141, 216, 195, 223, 211, 238, + 223, 243, 215, 226, 204, 232, 211, 232, 213, 240, 218, 235, 214, 238, 205, + 207, 173, 149, 201, 215, 200, 230, 213, 208, 195, 175, 151, 195, 175, 182, + 163, 235, 217, 218, 190, 211, 191, 215, 191, 217, 220, 241, 215, 229, 206, + 236, 210, 227, 216, 236, 188, 183, 149, 202, 189, 208, 172, 191, 201, 220, + 193, 221, 207, 216, 208, 201, 131, 170, 187, 229, 197, 211, 194, 226, 201, + 205, 184, 206, 177, 221, 210, 226, 184, 204, 197, 218, 198, 212, 209, 213, + 141, 172, 110, 175, 167, 180, 156, 213, 188, 192, 179, 213, 205, 204, 174, + 200, 147, 162, 181, 203, 167, 198, 187, 210, 164, 196, 169, 189, 168, 224, + 198, 213, 204, 198, 195, 230, 211, 221, 197, 208, 0, 0, 0, 85, 90, + 167, 130, 175, 173, 203, 164, 193, 144, 170, 145, 185, 148, 154, 139, 198, + 159, 180, 171, 216, 174, 178, 161, 166, 136, 216, 184, 215, 197, 199, 190, + 228, 195, 208, 51, 117, 0, 0, 0, 0, 0, 140, 51, 135, 154, 188, + 155, 168, 0, 90, 0, 156, 85, 110, 0, 174, 90, 172, 154, 179, 99, + 142, 166, 179, 157, 177, 95, 192, 142, 204, 198, 217, 147, 173, 0, 112, + 0, 0, 0, 0, 0, 0, 0, 110, 0, 107, 0, 160, 0, 148, 95, + 172, 0, 0, 0, 116, 0, 122, 114, 170, 0, 0, 0, 0, 0, 179, + 110, 196, 85, 205, 183, 169, 0, 99, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 141, 0, 112, 0, 0, 0, 134, 0, 0, 0, 0, + 0, 0, 0, 139, 0, 0, 0, 0, 112, 186, 78, 163, 0, 169, 128, + 174, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 95, + 0, 105, 0, 0, 0, 105, 0, 0, 0, 0, 0, 0, 0, 95, 0, + 0, 0, 0, 0, 0, 0, 119, 0, 164, 78, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, 68, + 117, 0, 0, 0, 0, 0, 0, 0, 148, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, + 0, 0, 0, 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 78, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h new file mode 100644 index 0000000000000000000000000000000000000000..178323eeba6669d247edfe9cb675b37fe5c7d526 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_MICRO_FEATURES_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_MICRO_FEATURES_DATA_H_ + +extern const int g_no_micro_f9643d42_nohash_4_width; +extern const int g_no_micro_f9643d42_nohash_4_height; +extern const unsigned char g_no_micro_f9643d42_nohash_4_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_MICRO_FEATURES_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b3aa19cda5d1c8151de5e8bf5aad45df09259a0 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.cc @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h" + +#include + +void NoiseReductionApply(struct NoiseReductionState* state, uint32_t* signal) { + int i; + for (i = 0; i < state->num_channels; ++i) { + const uint32_t smoothing = + ((i & 1) == 0) ? state->even_smoothing : state->odd_smoothing; + const uint32_t one_minus_smoothing = (1 << kNoiseReductionBits) - smoothing; + + // Update the estimate of the noise. + const uint32_t signal_scaled_up = signal[i] << state->smoothing_bits; + uint32_t estimate = + ((static_cast(signal_scaled_up) * smoothing) + + (static_cast(state->estimate[i]) * one_minus_smoothing)) >> + kNoiseReductionBits; + state->estimate[i] = estimate; + + // Make sure that we can't get a negative value for the signal - estimate. + if (estimate > signal_scaled_up) { + estimate = signal_scaled_up; + } + + const uint32_t floor = + (static_cast(signal[i]) * state->min_signal_remaining) >> + kNoiseReductionBits; + const uint32_t subtracted = + (signal_scaled_up - estimate) >> state->smoothing_bits; + const uint32_t output = subtracted > floor ? subtracted : floor; + signal[i] = output; + } +} + +void NoiseReductionReset(struct NoiseReductionState* state) { + memset(state->estimate, 0, sizeof(*state->estimate) * state->num_channels); +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..699144345d5751f27f7adcafec551180f82725d1 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NOISE_REDUCTION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NOISE_REDUCTION_H_ + +#define kNoiseReductionBits 14 + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +struct NoiseReductionState { + int smoothing_bits; + uint16_t even_smoothing; + uint16_t odd_smoothing; + uint16_t min_signal_remaining; + int num_channels; + uint32_t estimate[kFeatureSliceSize]; +}; + +// Removes stationary noise from each channel of the signal using a low pass +// filter. +void NoiseReductionApply(struct NoiseReductionState* state, uint32_t* signal); + +void NoiseReductionReset(struct NoiseReductionState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NOISE_REDUCTION_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..de7181d710bc00938e411869bf071b91e22f2044 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int kNumChannels = 2; + +// Test noise reduction using default config values. +class NoiseReductionTestConfig { + public: + NoiseReductionTestConfig() { + config_.smoothing_bits = 10; + config_.even_smoothing = 0.025; + config_.odd_smoothing = 0.06; + config_.min_signal_remaining = 0.05; + } + + struct NoiseReductionConfig config_; +}; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(NoiseReductionTest_TestNoiseReductionEstimate) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + NoiseReductionTestConfig config; + struct NoiseReductionState state; + TF_LITE_MICRO_EXPECT(NoiseReductionPopulateState( + error_reporter, &config.config_, &state, kNumChannels)); + + uint32_t signal[] = {247311, 508620}; + NoiseReductionApply(&state, signal); + + const uint32_t expected[] = {6321887, 31248341}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.estimate[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(NoiseReductionTest_TestNoiseReduction) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + NoiseReductionTestConfig config; + struct NoiseReductionState state; + TF_LITE_MICRO_EXPECT(NoiseReductionPopulateState( + error_reporter, &config.config_, &state, kNumChannels)); + + uint32_t signal[] = {247311, 508620}; + NoiseReductionApply(&state, signal); + + const uint32_t expected[] = {241137, 478104}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(signal[i], expected[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..42a5c2136f2a85b0ddd7e3a620bb879d13eeb258 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h" + +void NoiseReductionFillConfigWithDefaults(struct NoiseReductionConfig* config) { + config->smoothing_bits = 10; + config->even_smoothing = 0.025; + config->odd_smoothing = 0.06; + config->min_signal_remaining = 0.05; +} + +int NoiseReductionPopulateState(tflite::ErrorReporter* error_reporter, + const struct NoiseReductionConfig* config, + struct NoiseReductionState* state, + int num_channels) { + state->smoothing_bits = config->smoothing_bits; + state->odd_smoothing = config->odd_smoothing * (1 << kNoiseReductionBits); + state->even_smoothing = config->even_smoothing * (1 << kNoiseReductionBits); + state->min_signal_remaining = + config->min_signal_remaining * (1 << kNoiseReductionBits); + state->num_channels = num_channels; + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->estimate, (state->num_channels * sizeof(*state->estimate))); + for (int i = 0; i < state->num_channels; ++i) { + state->estimate[i] = 0; + } + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h new file mode 100644 index 0000000000000000000000000000000000000000..60f9de5067c606158bc0e29771d1e83a495cd4c1 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction_util.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NOISE_REDUCTION_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NOISE_REDUCTION_UTIL_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/noise_reduction.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +struct NoiseReductionConfig { + // scale the signal up by 2^(smoothing_bits) before reduction + int smoothing_bits; + // smoothing coefficient for even-numbered channels + float even_smoothing; + // smoothing coefficient for odd-numbered channels + float odd_smoothing; + // fraction of signal to preserve (1.0 disables this module) + float min_signal_remaining; +}; + +// Populates the NoiseReductionConfig with "sane" default values. +void NoiseReductionFillConfigWithDefaults(struct NoiseReductionConfig* config); + +// Prepares any buffers. +int NoiseReductionPopulateState(tflite::ErrorReporter* error_reporter, + const struct NoiseReductionConfig* config, + struct NoiseReductionState* state, + int num_channels); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NOISE_REDUCTION_UTIL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.cc new file mode 100644 index 0000000000000000000000000000000000000000..50656758d722844b8aeb6a32c04d3df36f0e5242 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h" + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/bits.h" + +int16_t WideDynamicFunction(const uint32_t x, const int16_t* lut) { + if (x <= 2) { + return lut[x]; + } + + const int16_t interval = MostSignificantBit32(x); + lut += 4 * interval - 6; + + const int16_t frac = + ((interval < 11) ? (x << (11 - interval)) : (x >> (interval - 11))) & + 0x3FF; + + int32_t result = (static_cast(lut[2]) * frac) >> 5; + result += (static_cast(lut[1])) << 5; + result *= frac; + result = (result + (1 << 14)) >> 15; + result += lut[0]; + return static_cast(result); +} + +uint32_t PcanShrink(const uint32_t x) { + if (x < (2 << kPcanSnrBits)) { + return (x * x) >> (2 + 2 * kPcanSnrBits - kPcanOutputBits); + } else { + return (x >> (kPcanSnrBits - kPcanOutputBits)) - (1 << kPcanOutputBits); + } +} + +void PcanGainControlApply(struct PcanGainControlState* state, + uint32_t* signal) { + int i; + for (i = 0; i < state->num_channels; ++i) { + const uint32_t gain = + WideDynamicFunction(state->noise_estimate[i], state->gain_lut); + const uint32_t snr = + (static_cast(signal[i]) * gain) >> state->snr_shift; + signal[i] = PcanShrink(snr); + } +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h new file mode 100644 index 0000000000000000000000000000000000000000..06d6fc990385cd74ccba8510765eb7ad8da4eeca --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_PCAN_GAIN_CONTROL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_PCAN_GAIN_CONTROL_H_ + +#include +#include + +#define kPcanSnrBits 12 +#define kPcanOutputBits 6 + +#define kWideDynamicFunctionBits 32 +#define kWideDynamicFunctionLUTSize (4 * kWideDynamicFunctionBits - 3) + +struct PcanGainControlState { + int enable_pcan; + uint32_t* noise_estimate; + int num_channels; + int16_t gain_lut[kWideDynamicFunctionLUTSize]; + int32_t snr_shift; +}; + +int16_t WideDynamicFunction(const uint32_t x, const int16_t* lut); + +uint32_t PcanShrink(const uint32_t x); + +void PcanGainControlApply(struct PcanGainControlState* state, uint32_t* signal); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_PCAN_GAIN_CONTROL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7dee66746f381ea50127e416fe90f063353eca89 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int kNumChannels = 2; +const int kSmoothingBits = 10; +const int kCorrectionBits = -1; + +// Test pcan auto gain control using default config values. +class PcanGainControlTestConfig { + public: + PcanGainControlTestConfig() { + config_.enable_pcan = 1; + config_.strength = 0.95; + config_.offset = 80.0; + config_.gain_bits = 21; + } + + struct PcanGainControlConfig config_; +}; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(PcanGainControlTest_TestPcanGainControl) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + uint32_t estimate[] = {6321887, 31248341}; + PcanGainControlTestConfig config; + struct PcanGainControlState state; + TF_LITE_MICRO_EXPECT(PcanGainControlPopulateState( + error_reporter, &config.config_, &state, estimate, kNumChannels, + kSmoothingBits, kCorrectionBits)); + + uint32_t signal[] = {241137, 478104}; + PcanGainControlApply(&state, signal); + + const uint32_t expected[] = {3578, 1533}; + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.num_channels; ++i) { + TF_LITE_MICRO_EXPECT_EQ(signal[i], expected[i]); + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7867ac6284d519ba6dd35f601bc3cb40e2f95fe --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.cc @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h" + +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h" + +#define kint16max 0x00007FFF + +void PcanGainControlFillConfigWithDefaults( + struct PcanGainControlConfig* config) { + config->enable_pcan = 0; + config->strength = 0.95; + config->offset = 80.0; + config->gain_bits = 21; +} + +int16_t PcanGainLookupFunction(const struct PcanGainControlConfig* config, + int32_t input_bits, uint32_t x) { + const float x_as_float = + (static_cast(x)) / (static_cast(1) << input_bits); + const float gain_as_float = + (static_cast(1) << config->gain_bits) * + powf(x_as_float + config->offset, -config->strength); + + if (gain_as_float > kint16max) { + return kint16max; + } + return static_cast(gain_as_float + 0.5f); +} + +int PcanGainControlPopulateState(tflite::ErrorReporter* error_reporter, + const struct PcanGainControlConfig* config, + struct PcanGainControlState* state, + uint32_t* noise_estimate, + const int num_channels, + const uint16_t smoothing_bits, + const int32_t input_correction_bits) { + state->enable_pcan = config->enable_pcan; + if (!state->enable_pcan) { + return 1; + } + state->noise_estimate = noise_estimate; + state->num_channels = num_channels; + STATIC_ALLOC_ENSURE_ARRAY_SIZE( + state->gain_lut, (kWideDynamicFunctionLUTSize * sizeof(int16_t))); + state->snr_shift = config->gain_bits - input_correction_bits - kPcanSnrBits; + + const int32_t input_bits = smoothing_bits - input_correction_bits; + state->gain_lut[0] = PcanGainLookupFunction(config, input_bits, 0); + state->gain_lut[1] = PcanGainLookupFunction(config, input_bits, 1); + int16_t* temp_gain_lut = state->gain_lut - 6; + int interval; + for (interval = 2; interval <= kWideDynamicFunctionBits; ++interval) { + const uint32_t x0 = static_cast(1) << (interval - 1); + const uint32_t x1 = x0 + (x0 >> 1); + const uint32_t x2 = + (interval == kWideDynamicFunctionBits) ? x0 + (x0 - 1) : 2 * x0; + + const int16_t y0 = PcanGainLookupFunction(config, input_bits, x0); + const int16_t y1 = PcanGainLookupFunction(config, input_bits, x1); + const int16_t y2 = PcanGainLookupFunction(config, input_bits, x2); + + const int32_t diff1 = static_cast(y1) - y0; + const int32_t diff2 = static_cast(y2) - y0; + const int32_t a1 = 4 * diff1 - diff2; + const int32_t a2 = diff2 - a1; + + temp_gain_lut[4 * interval] = y0; + temp_gain_lut[4 * interval + 1] = static_cast(a1); + temp_gain_lut[4 * interval + 2] = static_cast(a2); + } + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h new file mode 100644 index 0000000000000000000000000000000000000000..4cc1de7bb25db509f8271d12f053e61554d07680 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_PCAN_GAIN_CONTROL_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_PCAN_GAIN_CONTROL_UTIL_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/pcan_gain_control.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +struct PcanGainControlConfig { + // set to false (0) to disable this module + int enable_pcan; + // gain normalization exponent (0.0 disables, 1.0 full strength) + float strength; + // positive value added in the normalization denominator + float offset; + // number of fractional bits in the gain + int gain_bits; +}; + +void PcanGainControlFillConfigWithDefaults( + struct PcanGainControlConfig* config); + +int16_t PcanGainLookupFunction(const struct PcanGainControlConfig* config, + int32_t input_bits, uint32_t x); + +int PcanGainControlPopulateState(tflite::ErrorReporter* error_reporter, + const struct PcanGainControlConfig* config, + struct PcanGainControlState* state, + uint32_t* noise_estimate, + const int num_channels, + const uint16_t smoothing_bits, + const int32_t input_correction_bits); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_PCAN_GAIN_CONTROL_UTIL_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h new file mode 100644 index 0000000000000000000000000000000000000000..e2af862de7590323819c99de3a6702d1bd046681 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_STATIC_ALLOC_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_STATIC_ALLOC_H_ + +// Checks to ensure that the C-style array passed in has a compile-time size of +// at least the number of bytes requested. This doesn't work with raw pointers +// since sizeof() doesn't know their actual length, so only use this to check +// statically-allocated arrays with known sizes. +#define STATIC_ALLOC_ENSURE_ARRAY_SIZE(A, N) \ + do { \ + if (sizeof(A) < (N)) { \ + error_reporter->Report(#A " too small (%d bytes, wanted %d) at %s:%d", \ + sizeof(A), (N), __FILE__, __LINE__); \ + return 0; \ + } \ + } while (0) + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_STATIC_ALLOC_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..57a32c3595da7ae17c2328bb4c98fb005fd253ef --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.cc @@ -0,0 +1,1541 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Automatically created from a TensorFlow Lite flatbuffer using the command: +// xxd -i tiny_conv.tflite > tiny_conv_simple_features_model_data.cc +// See the README for a full description of the creation process. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h" + +const unsigned char g_tiny_conv_micro_features_model_data[] = { + 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xd0, 0x46, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xb4, 0x41, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x54, 0x4f, 0x43, 0x4f, 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74, + 0x65, 0x64, 0x2e, 0x00, 0x09, 0x00, 0x00, 0x00, 0x94, 0x41, 0x00, 0x00, + 0x74, 0x41, 0x00, 0x00, 0x44, 0x41, 0x00, 0x00, 0xb4, 0x3e, 0x00, 0x00, + 0xac, 0x3e, 0x00, 0x00, 0xa4, 0x3e, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf0, 0xb9, 0xff, 0xff, + 0xf4, 0xb9, 0xff, 0xff, 0x52, 0xba, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x80, 0x3e, 0x00, 0x00, 0x68, 0x95, 0x91, 0x7d, 0x9b, 0x85, 0x85, 0x81, + 0x77, 0x85, 0x99, 0x89, 0x7e, 0x8a, 0x85, 0x92, 0xa5, 0x7e, 0x93, 0x97, + 0x97, 0x91, 0xa3, 0x97, 0x88, 0x8b, 0xa6, 0x71, 0x77, 0x85, 0x95, 0x86, + 0x6b, 0x93, 0xcb, 0x96, 0x7a, 0x9a, 0x7f, 0x85, 0x7a, 0x8e, 0xac, 0x98, + 0x6d, 0x9d, 0x9b, 0x70, 0x9a, 0x90, 0xba, 0x99, 0x7b, 0x93, 0x6e, 0x68, + 0x75, 0x86, 0xc4, 0x8b, 0x66, 0x5d, 0x96, 0x7f, 0x92, 0x91, 0xb6, 0x7b, + 0x96, 0x95, 0x9a, 0x77, 0x9a, 0x96, 0xce, 0x80, 0x88, 0x65, 0x8e, 0x80, + 0x88, 0x85, 0xb7, 0x9c, 0x7b, 0x93, 0x9d, 0x95, 0x83, 0x92, 0xd0, 0x7e, + 0x68, 0x88, 0x6c, 0x78, 0x98, 0x81, 0xac, 0x95, 0x9e, 0x98, 0xa2, 0x99, + 0x8d, 0x7d, 0xb8, 0x81, 0x6e, 0x68, 0xa1, 0x81, 0x9d, 0x99, 0xb4, 0x7d, + 0x92, 0x86, 0x9d, 0x93, 0xa3, 0xb0, 0xd6, 0x79, 0x93, 0x76, 0x8d, 0x84, + 0x91, 0x9d, 0xbe, 0x94, 0xb0, 0x70, 0x84, 0x80, 0x85, 0x99, 0x9e, 0xa2, + 0x86, 0x8a, 0x7a, 0x76, 0x91, 0x8d, 0xa6, 0x76, 0x8d, 0x82, 0x98, 0x8c, + 0x92, 0x8f, 0x8c, 0xb3, 0x78, 0x75, 0xa5, 0x88, 0x73, 0x8c, 0x91, 0x7c, + 0x82, 0x7d, 0x93, 0x9e, 0x8b, 0x97, 0x7c, 0x90, 0x84, 0x95, 0x7e, 0x9e, + 0xa4, 0x52, 0x8a, 0xb4, 0x97, 0x65, 0x7d, 0xb6, 0x83, 0x7d, 0x99, 0x80, + 0x97, 0x85, 0x96, 0x5f, 0x8e, 0x87, 0x95, 0x6d, 0x76, 0x84, 0x97, 0x8c, + 0x66, 0x97, 0xae, 0x6b, 0x93, 0xb3, 0xa8, 0x8b, 0xa1, 0x79, 0xa3, 0x94, + 0x7e, 0xa8, 0x8d, 0xad, 0x78, 0x82, 0xa2, 0x7b, 0x90, 0xa4, 0x7d, 0xb3, + 0xa0, 0x7b, 0x94, 0x85, 0x9a, 0x8d, 0x76, 0x82, 0x65, 0x73, 0xab, 0xa4, + 0xaa, 0x74, 0x93, 0x9c, 0x83, 0x66, 0xbf, 0x7a, 0xaa, 0x81, 0x92, 0x89, + 0x7e, 0x88, 0xa6, 0x66, 0xaf, 0x92, 0x9f, 0x97, 0x6c, 0x89, 0x9c, 0x74, + 0x7e, 0x82, 0x8e, 0x88, 0xb2, 0x85, 0xba, 0x96, 0x90, 0x78, 0x8d, 0xa7, + 0x9e, 0x87, 0xbc, 0x7f, 0xb2, 0x8b, 0x77, 0x9b, 0xab, 0x8f, 0xa4, 0x7d, + 0x6f, 0x77, 0x8c, 0x98, 0x6f, 0x89, 0xb1, 0x9f, 0xa7, 0x94, 0x7d, 0xae, + 0x88, 0x8a, 0xa9, 0x75, 0x7d, 0x7c, 0x88, 0x99, 0x90, 0x9d, 0x97, 0xa7, + 0x8d, 0x7f, 0x73, 0xa1, 0xa3, 0x87, 0xa9, 0x92, 0x98, 0x7e, 0x9c, 0x88, + 0x73, 0x6b, 0x78, 0x8e, 0x7d, 0x86, 0x6c, 0x7c, 0x92, 0x40, 0x86, 0xa7, + 0x65, 0x93, 0x67, 0x91, 0x67, 0x71, 0x6c, 0xa8, 0x81, 0x70, 0x8e, 0xa8, + 0x7b, 0x63, 0x89, 0x76, 0x69, 0x90, 0x73, 0x5e, 0x92, 0x78, 0x7e, 0x9d, + 0x87, 0x86, 0x89, 0x64, 0x66, 0xa9, 0x92, 0x8d, 0x72, 0x7c, 0x63, 0x7f, + 0x94, 0x5c, 0x92, 0x89, 0x87, 0x9d, 0x8b, 0x75, 0x93, 0x8c, 0x94, 0x68, + 0x97, 0x87, 0x78, 0x7d, 0x7f, 0x84, 0x84, 0x77, 0x6b, 0x8e, 0x83, 0xab, + 0x7e, 0x62, 0x90, 0x83, 0x8e, 0x71, 0x7e, 0x9b, 0x96, 0x6d, 0x83, 0x6a, + 0x76, 0x68, 0x71, 0x90, 0x98, 0x90, 0x9b, 0x68, 0x89, 0x89, 0x95, 0x85, + 0x6e, 0x75, 0x8e, 0x95, 0x83, 0x7a, 0x91, 0x7f, 0x8b, 0x71, 0x90, 0x7d, + 0xad, 0x91, 0x6f, 0x74, 0x98, 0x8a, 0xb0, 0xa8, 0x80, 0xa3, 0x8e, 0x7c, + 0xa5, 0x67, 0xa4, 0x66, 0xa9, 0x7b, 0x85, 0x9d, 0x88, 0xab, 0x7d, 0x81, + 0x6e, 0x7f, 0x8f, 0x97, 0x97, 0x84, 0x89, 0x74, 0x9d, 0x5f, 0x9c, 0x88, + 0x6f, 0x74, 0x96, 0x9e, 0x7e, 0x7e, 0xa4, 0x85, 0x94, 0x91, 0xaf, 0x99, + 0x7a, 0xaa, 0x8c, 0x92, 0x85, 0x9d, 0x6c, 0x79, 0x57, 0x7a, 0x80, 0x84, + 0x79, 0x79, 0x74, 0xa5, 0x55, 0xab, 0x73, 0x8c, 0x72, 0x9d, 0x72, 0xa9, + 0x90, 0x73, 0x8f, 0xa0, 0x89, 0x6d, 0x68, 0x66, 0x61, 0x6f, 0x57, 0x7d, + 0x66, 0x8c, 0x65, 0x87, 0x62, 0x76, 0x83, 0x77, 0x89, 0xa4, 0x73, 0x89, + 0x7f, 0x70, 0x79, 0x6b, 0x86, 0x6f, 0x8d, 0x96, 0x65, 0x89, 0x66, 0x53, + 0x73, 0xae, 0x6a, 0x72, 0x88, 0x97, 0x7a, 0x7f, 0x5d, 0xa1, 0x86, 0x88, + 0x5f, 0x9f, 0x9b, 0x8a, 0x74, 0x9a, 0x7a, 0x7e, 0x8b, 0x71, 0x58, 0x74, + 0x8f, 0x9b, 0x9b, 0x8d, 0x6b, 0x83, 0x60, 0x7f, 0x75, 0x91, 0x79, 0x93, + 0x7a, 0x92, 0x8c, 0x7e, 0x7a, 0x95, 0x84, 0x69, 0x8f, 0x8c, 0x7c, 0x6e, + 0x8b, 0x87, 0x82, 0x62, 0xa6, 0x97, 0x91, 0x65, 0xa2, 0xa4, 0x9b, 0x8b, + 0x85, 0xa4, 0x84, 0x7b, 0x67, 0x93, 0x96, 0x84, 0x85, 0x75, 0x6d, 0x9e, + 0x80, 0x80, 0x73, 0x8c, 0x81, 0x70, 0x8a, 0x68, 0x9c, 0x8e, 0x63, 0x91, + 0x89, 0x79, 0x8d, 0x79, 0xa4, 0x9a, 0x96, 0xa0, 0x83, 0x63, 0x88, 0x8f, + 0x76, 0xb4, 0xa8, 0x8e, 0x68, 0x8d, 0x8e, 0x95, 0x78, 0xae, 0x5d, 0x89, + 0x66, 0x7e, 0x7b, 0x8a, 0x75, 0x86, 0x71, 0x97, 0x6d, 0xb3, 0x67, 0x76, + 0x82, 0x7d, 0x70, 0x79, 0x8a, 0x9c, 0x82, 0xa7, 0x82, 0xab, 0x58, 0x86, + 0x5c, 0x70, 0x8c, 0x71, 0x61, 0xa6, 0x74, 0xa8, 0x65, 0x78, 0x72, 0x9d, + 0x6c, 0x92, 0x70, 0x88, 0x88, 0x79, 0x96, 0x6f, 0x68, 0xa4, 0x7a, 0x7b, + 0x96, 0xac, 0x6d, 0x76, 0x6a, 0xab, 0x82, 0x7d, 0x71, 0x8d, 0x6b, 0x81, + 0x6c, 0x9d, 0x71, 0x59, 0x5c, 0x71, 0x77, 0x6d, 0x6a, 0x96, 0x76, 0x69, + 0x80, 0x83, 0x88, 0x70, 0x97, 0xb4, 0x8a, 0x6c, 0xa5, 0x6e, 0x64, 0x75, + 0x73, 0xa2, 0x7f, 0x97, 0x9e, 0x75, 0x8f, 0x86, 0x68, 0xbb, 0x6b, 0x86, + 0x8d, 0x80, 0x8e, 0x58, 0x6d, 0xb2, 0x76, 0x99, 0x8f, 0x70, 0x6c, 0x86, + 0x78, 0x9e, 0x91, 0x90, 0xa2, 0x7c, 0x8c, 0x81, 0x80, 0xb4, 0x77, 0x7a, + 0x8c, 0x5f, 0x85, 0x56, 0x7a, 0x93, 0x6b, 0x5c, 0x74, 0x59, 0x7e, 0x86, + 0x8c, 0xae, 0x76, 0x7d, 0x76, 0x7e, 0x81, 0x5f, 0x81, 0x8e, 0x7b, 0x90, + 0xaa, 0x99, 0x79, 0x89, 0x93, 0xbc, 0x86, 0x91, 0xa2, 0x88, 0x79, 0x82, + 0x80, 0xb6, 0x4a, 0x93, 0x7b, 0x89, 0x75, 0x8d, 0x7a, 0x8d, 0x66, 0x7c, + 0x81, 0x9f, 0x6e, 0x86, 0x4d, 0x82, 0x66, 0x88, 0x73, 0x89, 0x7d, 0xac, + 0x89, 0x9f, 0x58, 0x7f, 0x6b, 0x8c, 0x6a, 0x82, 0x59, 0xb8, 0x83, 0x67, + 0x8b, 0x8a, 0x84, 0x7b, 0x7f, 0xb5, 0x44, 0x57, 0x5a, 0x73, 0x8b, 0x6d, + 0x7c, 0x9e, 0x71, 0x72, 0x8d, 0x93, 0x80, 0x60, 0x7f, 0xc5, 0x69, 0x5c, + 0x67, 0x92, 0x6c, 0x75, 0x66, 0x8f, 0x91, 0x5a, 0x6c, 0x70, 0x90, 0x84, + 0x88, 0xab, 0x90, 0x66, 0x9c, 0x64, 0x6e, 0x68, 0x92, 0x9e, 0x89, 0x8d, + 0x82, 0x97, 0x77, 0x75, 0x7f, 0xa7, 0x91, 0x75, 0x8c, 0x89, 0xa4, 0x6b, + 0x98, 0x99, 0x80, 0x7d, 0x6b, 0x7f, 0x7d, 0x88, 0x79, 0xa1, 0x87, 0x90, + 0x81, 0x8e, 0x94, 0x96, 0x7d, 0xa8, 0x86, 0x84, 0x86, 0x79, 0x97, 0x6e, + 0xaa, 0x95, 0x8a, 0x9f, 0x8c, 0x72, 0x99, 0x77, 0x81, 0x94, 0x91, 0x9f, + 0x6e, 0x67, 0x87, 0x70, 0x7d, 0xad, 0x58, 0x7f, 0x6d, 0x96, 0x8e, 0x82, + 0x7d, 0xa6, 0x77, 0x99, 0x87, 0x95, 0x89, 0x7e, 0xa6, 0x9e, 0x86, 0xac, + 0x78, 0x9f, 0x9b, 0x85, 0x76, 0x99, 0x6a, 0x92, 0x66, 0x7b, 0x9a, 0x99, + 0x83, 0x8b, 0x57, 0x65, 0x75, 0x9f, 0xa6, 0x8a, 0x8d, 0x96, 0x6f, 0x80, + 0x65, 0x8f, 0x80, 0x9f, 0x82, 0x85, 0x55, 0x75, 0x5c, 0x84, 0x91, 0x86, + 0x76, 0x96, 0x5a, 0x6c, 0x62, 0x7b, 0x92, 0x88, 0x61, 0xca, 0x75, 0x66, + 0x70, 0x70, 0x8e, 0x7a, 0x75, 0xb2, 0x66, 0x81, 0x5b, 0x79, 0x92, 0x97, + 0x94, 0xaf, 0x72, 0x8a, 0x9b, 0x5f, 0x65, 0x96, 0x81, 0xb6, 0x8a, 0x6f, + 0x94, 0x7a, 0x96, 0x92, 0x79, 0x94, 0x8e, 0x53, 0x9a, 0x73, 0x6a, 0x9d, + 0xa1, 0xa3, 0xa4, 0x8f, 0x6b, 0xa4, 0x8b, 0x82, 0x96, 0xb1, 0x8c, 0x92, + 0x7f, 0x91, 0x5f, 0x98, 0x8a, 0xa4, 0x7e, 0x80, 0x97, 0x86, 0x86, 0x86, + 0x8f, 0xa6, 0x77, 0x9a, 0x82, 0x80, 0x6e, 0x73, 0x83, 0xaf, 0x87, 0x6d, + 0x77, 0x9a, 0x83, 0x9f, 0x7c, 0xa4, 0x71, 0x6f, 0x7d, 0x75, 0x9d, 0x82, + 0x83, 0xaf, 0x85, 0x80, 0x8d, 0x7f, 0xa4, 0xa2, 0x88, 0xba, 0x76, 0x76, + 0x94, 0x6b, 0x76, 0x83, 0x77, 0x96, 0x78, 0x8c, 0xb0, 0x8e, 0x83, 0x87, + 0xa0, 0xcc, 0x7f, 0xa4, 0x8c, 0x77, 0x84, 0x8c, 0x80, 0xa0, 0x57, 0x76, + 0x76, 0x71, 0x86, 0x9c, 0x7f, 0x88, 0x57, 0x95, 0x4d, 0x8c, 0x7f, 0x80, + 0x66, 0x9e, 0x42, 0x8d, 0x6a, 0x8e, 0x8c, 0x80, 0x89, 0x9d, 0x4f, 0x83, + 0x54, 0x8a, 0x5e, 0x64, 0x70, 0x94, 0x78, 0x90, 0x7d, 0x78, 0x8d, 0x71, + 0x56, 0x9a, 0x8c, 0x65, 0x8b, 0x62, 0x88, 0x9a, 0x6c, 0x8e, 0x7b, 0x78, + 0x68, 0x86, 0x64, 0x6b, 0x67, 0xaa, 0x8c, 0x7b, 0x67, 0x75, 0x58, 0x7e, + 0x6b, 0x97, 0x92, 0x87, 0x9c, 0x79, 0x71, 0x76, 0x7d, 0xbb, 0x89, 0x75, + 0x83, 0x57, 0x74, 0x98, 0xa1, 0x8f, 0xb0, 0x89, 0x76, 0x88, 0x69, 0x9c, + 0x74, 0xb0, 0x86, 0x9c, 0x79, 0x6f, 0x84, 0x70, 0x94, 0xa1, 0x6e, 0x7a, + 0xa3, 0x88, 0xa0, 0x7a, 0x94, 0xa1, 0x82, 0x93, 0x99, 0x95, 0x7f, 0xab, + 0x97, 0x9d, 0x6e, 0x68, 0x79, 0x73, 0x76, 0x83, 0x76, 0xbd, 0x87, 0x87, + 0x86, 0x74, 0x8f, 0x6e, 0x65, 0xba, 0x6a, 0x78, 0x91, 0x62, 0x72, 0x67, + 0x75, 0xbd, 0x8c, 0x5e, 0x85, 0x6d, 0x72, 0x85, 0x7d, 0x96, 0x8f, 0xb9, + 0x9f, 0x97, 0xa2, 0x8a, 0xa1, 0xc1, 0x8d, 0xbc, 0x85, 0x78, 0x93, 0x97, + 0x99, 0x9f, 0x3a, 0x98, 0x65, 0x8d, 0x6a, 0x6c, 0x92, 0x85, 0x49, 0x7e, + 0x6a, 0xaa, 0x8a, 0x94, 0x6b, 0x93, 0x40, 0x8a, 0x8c, 0x9c, 0x6f, 0xad, + 0x72, 0xb0, 0x58, 0x88, 0x60, 0x8c, 0x86, 0x84, 0x74, 0x96, 0x8f, 0x97, + 0x5e, 0x6c, 0x79, 0x92, 0x51, 0xa8, 0x92, 0x58, 0x62, 0x6f, 0x6c, 0x76, + 0x5f, 0x9e, 0x86, 0x71, 0x9c, 0x69, 0x7e, 0x80, 0x8a, 0x97, 0x6f, 0x79, + 0x8b, 0x6f, 0x6c, 0x88, 0x73, 0x9c, 0x6d, 0x91, 0x77, 0x73, 0x7f, 0x97, + 0x86, 0xa9, 0xac, 0x71, 0x82, 0x90, 0x83, 0x8a, 0x80, 0x9d, 0xa8, 0x85, + 0x78, 0x7f, 0x94, 0x99, 0x8e, 0xa3, 0x89, 0x70, 0x87, 0x62, 0x82, 0x87, + 0x8c, 0x98, 0x7a, 0x88, 0x72, 0x7e, 0x78, 0xa0, 0x78, 0x95, 0x97, 0x8f, + 0x7b, 0x7c, 0x83, 0x94, 0x93, 0xa7, 0x77, 0x97, 0x90, 0x5e, 0x76, 0x7c, + 0x68, 0xaa, 0x69, 0x67, 0x76, 0x84, 0x7e, 0x64, 0xa3, 0xbe, 0x7e, 0x8b, + 0x82, 0x50, 0x8a, 0x82, 0x89, 0xc0, 0x79, 0x78, 0x68, 0x7c, 0x6b, 0x77, + 0x82, 0x99, 0x7b, 0x83, 0x80, 0x90, 0x96, 0x96, 0x87, 0xb7, 0xa5, 0x94, + 0x82, 0x99, 0x95, 0x91, 0x7e, 0xa2, 0x49, 0x95, 0x6d, 0x8e, 0xa9, 0x89, + 0x8e, 0x8f, 0x3d, 0x95, 0x6a, 0x8c, 0x8b, 0x8c, 0x7e, 0x88, 0x63, 0x94, + 0x69, 0x94, 0x88, 0x92, 0x79, 0xa7, 0x68, 0x60, 0x76, 0x85, 0xa1, 0x6f, + 0x54, 0x96, 0x63, 0x7a, 0x5c, 0x73, 0x74, 0x6e, 0x53, 0x99, 0x69, 0x76, + 0x69, 0x57, 0x6a, 0x82, 0x55, 0x93, 0x82, 0x80, 0x65, 0x7f, 0x7b, 0x76, + 0x72, 0x87, 0x8d, 0x97, 0x98, 0x78, 0x7e, 0x6d, 0x7a, 0x95, 0x78, 0x70, + 0x90, 0x83, 0x89, 0x80, 0x7f, 0x9d, 0x73, 0x73, 0x84, 0x77, 0x8e, 0x77, + 0x8e, 0x75, 0x9e, 0xa5, 0x86, 0x68, 0x89, 0x7d, 0x8d, 0x99, 0x79, 0x8f, + 0x8e, 0x87, 0x87, 0x97, 0x8c, 0x91, 0xa1, 0x96, 0x83, 0x73, 0x87, 0xa9, + 0x8c, 0xa6, 0x85, 0x8c, 0x96, 0x7d, 0x7f, 0x8e, 0x7e, 0xb0, 0x85, 0x8f, + 0x7f, 0x7d, 0x95, 0x7d, 0x9c, 0xb3, 0x71, 0x86, 0x81, 0x69, 0x7b, 0x69, + 0x76, 0xb6, 0x5d, 0x67, 0x8a, 0x68, 0x9c, 0xa6, 0x70, 0xbf, 0x79, 0x60, + 0x8b, 0x7f, 0x7a, 0x7b, 0x8b, 0xaf, 0x8c, 0xa1, 0x86, 0x92, 0x76, 0x8d, + 0x89, 0xa2, 0xa8, 0xa3, 0xa0, 0xa2, 0x96, 0x9d, 0x7c, 0x92, 0x3f, 0x9b, + 0x6d, 0x8a, 0x80, 0x81, 0xa0, 0x92, 0x50, 0x7c, 0x82, 0x99, 0x80, 0xa6, + 0x8e, 0x8d, 0x4f, 0x8d, 0x65, 0x71, 0x77, 0x81, 0x51, 0xa6, 0x3f, 0x5c, + 0x63, 0x6f, 0x61, 0x93, 0x5c, 0xaa, 0x77, 0x8f, 0x5d, 0x53, 0x79, 0x74, + 0x6b, 0x94, 0x86, 0x81, 0x85, 0x48, 0x81, 0x80, 0x6b, 0x85, 0x6c, 0x91, + 0x92, 0x6a, 0x74, 0x78, 0x72, 0x87, 0x6c, 0x82, 0x88, 0x7b, 0x93, 0x71, + 0x91, 0x8d, 0x67, 0x83, 0x86, 0x5b, 0x86, 0x79, 0x81, 0x9f, 0x95, 0x8a, + 0x70, 0x66, 0x9e, 0x6b, 0x72, 0x98, 0x97, 0x95, 0x72, 0x93, 0x84, 0x92, + 0x8c, 0x96, 0xa2, 0x65, 0x80, 0x75, 0xa2, 0xa7, 0x7d, 0x97, 0x71, 0x8f, + 0x69, 0x65, 0x8f, 0xae, 0x9c, 0x97, 0x5d, 0xb3, 0x98, 0x83, 0x98, 0xa0, + 0x5f, 0x7e, 0x7a, 0x7a, 0x87, 0x7c, 0x92, 0xa0, 0x81, 0xa6, 0x71, 0x8e, + 0x88, 0x52, 0xa3, 0x88, 0x6a, 0x9d, 0x84, 0x82, 0x7c, 0x78, 0x9f, 0x92, + 0x66, 0xa4, 0x53, 0x6a, 0x7e, 0x84, 0x60, 0x84, 0x92, 0xb0, 0x93, 0x9d, + 0xa0, 0x5f, 0x95, 0x8c, 0x77, 0xa1, 0x8c, 0x90, 0xa0, 0x9c, 0x9a, 0x95, + 0x85, 0xa1, 0x22, 0x8f, 0x57, 0x80, 0x96, 0x7d, 0x92, 0x8b, 0x41, 0xa6, + 0x61, 0xa2, 0x6f, 0x80, 0x5d, 0x91, 0x66, 0xab, 0x6d, 0x7e, 0x88, 0x93, + 0x5c, 0xa5, 0x75, 0x6e, 0x6c, 0x86, 0x69, 0x73, 0x4e, 0x8e, 0x77, 0x6b, + 0x6c, 0x60, 0x67, 0x91, 0x75, 0x91, 0x6c, 0x7c, 0x53, 0x6e, 0x75, 0x8e, + 0x79, 0x8c, 0x8b, 0x74, 0x6b, 0x57, 0x71, 0xa1, 0x7f, 0x83, 0x6c, 0x6b, + 0x93, 0x99, 0x7a, 0x78, 0x71, 0x8c, 0x78, 0x88, 0x9f, 0x85, 0x77, 0x7b, + 0x86, 0x85, 0xa1, 0x61, 0x78, 0x65, 0x61, 0x75, 0x82, 0x7d, 0xa9, 0xa2, + 0x84, 0x82, 0x94, 0x95, 0x90, 0x9f, 0x83, 0x97, 0x76, 0x95, 0x8a, 0x83, + 0x9b, 0x87, 0x8b, 0x7a, 0x6c, 0x6e, 0x75, 0x95, 0x85, 0x95, 0x84, 0x9e, + 0x96, 0x74, 0x7d, 0xa5, 0x85, 0x8e, 0x7e, 0x73, 0x85, 0x8d, 0x87, 0x80, + 0x8a, 0x96, 0x65, 0x87, 0x7c, 0x73, 0x80, 0x96, 0x73, 0x8d, 0x5e, 0x79, + 0x7e, 0x8d, 0x79, 0x85, 0x63, 0xa0, 0x62, 0x89, 0x9d, 0x8c, 0x74, 0x7b, + 0x9c, 0xa5, 0x71, 0x8c, 0x83, 0x91, 0x8e, 0x8d, 0x89, 0x8b, 0x8b, 0xa4, + 0x78, 0x88, 0x9e, 0x85, 0x8b, 0x94, 0x38, 0x84, 0x7b, 0x86, 0x7d, 0xa2, + 0x73, 0x8f, 0x47, 0x7b, 0x69, 0xb4, 0x85, 0x71, 0x61, 0x9d, 0x59, 0x95, + 0x74, 0x93, 0x6a, 0x88, 0x62, 0xa2, 0x56, 0x93, 0x8d, 0x68, 0x7e, 0x80, + 0x6b, 0xb7, 0x63, 0x90, 0x5d, 0x54, 0x6c, 0x90, 0x5a, 0x8e, 0x7e, 0x7d, + 0x82, 0x73, 0x7f, 0x89, 0x94, 0x8e, 0x7a, 0x70, 0x6c, 0x79, 0x88, 0x88, + 0x9b, 0x8b, 0x70, 0x81, 0x83, 0x83, 0x8b, 0x86, 0x64, 0x93, 0x82, 0x66, + 0x66, 0x79, 0x74, 0x91, 0x92, 0x94, 0x7c, 0x87, 0x72, 0x79, 0x8d, 0xaa, + 0xa2, 0x9e, 0xaf, 0x95, 0xb1, 0x8a, 0x95, 0x8b, 0x94, 0x7e, 0x79, 0x8e, + 0x99, 0x98, 0x97, 0x9e, 0x94, 0x87, 0x74, 0x72, 0x63, 0x92, 0x92, 0x95, + 0xb0, 0x94, 0x86, 0x91, 0x77, 0x8f, 0x91, 0x7e, 0x83, 0x88, 0x90, 0xa5, + 0x79, 0x70, 0x85, 0x8f, 0x67, 0x90, 0x98, 0x8d, 0x8a, 0x5d, 0x8c, 0x9c, + 0x94, 0x91, 0x80, 0x95, 0x6e, 0x95, 0x73, 0x8d, 0x63, 0x8e, 0x53, 0x8a, + 0x77, 0x88, 0x8f, 0x6f, 0x87, 0x9e, 0x8b, 0xb7, 0x99, 0xb2, 0x85, 0x82, + 0xa1, 0x89, 0x9b, 0xa7, 0x80, 0x81, 0xa0, 0x8e, 0x84, 0xa9, 0x27, 0x73, + 0x5e, 0x85, 0x5f, 0x92, 0x8c, 0xa2, 0x34, 0x8e, 0x6e, 0xb2, 0x7b, 0x8c, + 0x69, 0x93, 0x47, 0x9e, 0x58, 0x7e, 0x94, 0x86, 0x47, 0xa3, 0x53, 0x6b, + 0x6e, 0x6a, 0x7f, 0x73, 0x5b, 0x8c, 0x7a, 0x99, 0x6c, 0x5d, 0x82, 0x82, + 0x62, 0x8a, 0x7a, 0x8e, 0x88, 0x62, 0xa0, 0x8e, 0x5c, 0x9a, 0x72, 0x79, + 0x66, 0x6b, 0x75, 0x78, 0x82, 0x8a, 0x59, 0x91, 0x93, 0x68, 0x78, 0xb4, + 0x86, 0x7e, 0x8c, 0x6e, 0x88, 0x7f, 0x96, 0x8e, 0x6e, 0x8b, 0x8c, 0x73, + 0xab, 0x79, 0x88, 0xa6, 0x86, 0x81, 0x9a, 0x80, 0x9a, 0x9e, 0x8b, 0x6d, + 0x9a, 0x70, 0x8e, 0x8a, 0x84, 0x7a, 0xaf, 0xb8, 0x9e, 0x90, 0x89, 0xb3, + 0x9b, 0x85, 0x94, 0xb6, 0x87, 0x8c, 0x6e, 0xa3, 0xac, 0x9e, 0x8c, 0x7c, + 0x81, 0x83, 0x70, 0x8d, 0x7c, 0x81, 0x77, 0x82, 0x69, 0x8e, 0x5e, 0x80, + 0x8a, 0x8e, 0x7c, 0x8a, 0x89, 0x90, 0x58, 0x59, 0x85, 0x88, 0x7a, 0x86, + 0x73, 0x9c, 0x4a, 0x81, 0x8d, 0x89, 0x91, 0x95, 0x72, 0x83, 0x9d, 0x99, + 0x8d, 0x6b, 0x95, 0x7e, 0x70, 0x94, 0x8c, 0x9f, 0x8a, 0x8f, 0xa7, 0x84, + 0x87, 0xb6, 0x42, 0x81, 0x63, 0x8a, 0x79, 0x77, 0x74, 0x90, 0x23, 0x85, + 0x74, 0x8f, 0x87, 0x80, 0x50, 0xa1, 0x4d, 0x9b, 0x55, 0x82, 0x74, 0x8e, + 0x4a, 0xa7, 0x52, 0x4d, 0x77, 0x67, 0x77, 0x9e, 0x62, 0xa5, 0x7d, 0x96, + 0x6f, 0x45, 0x80, 0x8c, 0x6c, 0x92, 0x99, 0x6f, 0x5d, 0x56, 0x93, 0xac, + 0x94, 0x9c, 0x95, 0x92, 0x6e, 0x71, 0x87, 0x8c, 0x7b, 0xa9, 0x7f, 0x7a, + 0x69, 0x6b, 0x7d, 0x90, 0x6f, 0x81, 0x9f, 0x80, 0x83, 0x67, 0x78, 0x85, + 0x85, 0x91, 0x8a, 0x80, 0xaa, 0x86, 0x8c, 0x88, 0x8c, 0x8f, 0x9b, 0x85, + 0x8b, 0x7e, 0x83, 0x82, 0x95, 0x75, 0x6b, 0x8f, 0x85, 0x8b, 0xb0, 0x9f, + 0xa7, 0x8e, 0x61, 0x9d, 0x72, 0xac, 0x92, 0x87, 0x94, 0x96, 0x68, 0x8f, + 0x63, 0x85, 0x9c, 0xa8, 0x82, 0x9b, 0x85, 0x9b, 0x6b, 0x72, 0x83, 0x85, + 0x90, 0x87, 0x74, 0xa4, 0x88, 0x57, 0x63, 0x90, 0x8e, 0x7b, 0x80, 0x81, + 0x94, 0x74, 0x68, 0x8a, 0x7f, 0x86, 0x78, 0x72, 0x75, 0x67, 0x7a, 0x8a, + 0x7a, 0x74, 0x8c, 0xad, 0x75, 0xa2, 0x7d, 0x9a, 0x9e, 0x83, 0x92, 0xa2, + 0xa3, 0x98, 0xa5, 0x91, 0x84, 0xb0, 0x21, 0x9a, 0x5f, 0x8c, 0x7e, 0x86, + 0x80, 0xa0, 0x16, 0x9b, 0x5b, 0x9c, 0x76, 0x8d, 0x77, 0x9f, 0x62, 0x86, + 0x6a, 0x6c, 0x6e, 0x8f, 0x4e, 0xc1, 0x61, 0x6f, 0x74, 0x79, 0x80, 0x5f, + 0x59, 0x9e, 0x7c, 0x87, 0x7f, 0x4b, 0x6c, 0x8b, 0x5a, 0x8f, 0x65, 0x8a, + 0x62, 0x58, 0x66, 0x8d, 0x83, 0x97, 0x8a, 0x7a, 0x77, 0x79, 0x6c, 0x83, + 0x8c, 0x93, 0x82, 0x5e, 0x61, 0x8c, 0x82, 0x80, 0x88, 0x88, 0x85, 0x87, + 0x77, 0x70, 0x8d, 0x7f, 0x7a, 0x89, 0x72, 0x7e, 0xa3, 0x99, 0x6b, 0xaa, + 0x81, 0x87, 0x90, 0x6f, 0x7f, 0x77, 0x96, 0x83, 0x89, 0x89, 0x6a, 0x77, + 0xa4, 0x6c, 0x97, 0x7e, 0x95, 0xa4, 0x63, 0x8d, 0x71, 0x96, 0x8a, 0xa4, + 0x9f, 0x7c, 0x54, 0x94, 0x7a, 0x89, 0x8a, 0x90, 0x7e, 0x9d, 0x53, 0x7c, + 0x9d, 0x83, 0x90, 0x84, 0xa1, 0x8e, 0x80, 0x74, 0x69, 0x7a, 0x69, 0x93, + 0x8a, 0x90, 0x83, 0x76, 0x8b, 0x6f, 0x8e, 0x93, 0x82, 0x84, 0x7d, 0x94, + 0xa1, 0x78, 0x7d, 0x68, 0x79, 0x83, 0x85, 0x9d, 0x89, 0xa0, 0x8a, 0x93, + 0x90, 0x8c, 0x82, 0x86, 0x80, 0x71, 0xb3, 0xa1, 0x90, 0xb2, 0x27, 0xa3, + 0x5e, 0xa3, 0xa6, 0x64, 0x75, 0xa0, 0x23, 0x8c, 0x7c, 0xc4, 0x7a, 0x8c, + 0x4d, 0xa3, 0x4c, 0x93, 0x71, 0x7b, 0x71, 0x8b, 0x34, 0xa5, 0x47, 0x7f, + 0x4e, 0x73, 0x51, 0x8a, 0x67, 0xa0, 0x9d, 0x7f, 0x65, 0x38, 0x61, 0x70, + 0x71, 0x8d, 0x6a, 0x7e, 0x7e, 0x4c, 0x7d, 0x8d, 0x81, 0x80, 0xa5, 0x84, + 0x6f, 0x57, 0x70, 0x91, 0x8b, 0x99, 0x9d, 0x84, 0x77, 0x7f, 0x6b, 0x7f, + 0x76, 0x8f, 0x90, 0x72, 0x6c, 0x58, 0x6b, 0x85, 0xa6, 0x8a, 0xa2, 0x6d, + 0x8a, 0x71, 0x71, 0x95, 0x92, 0x7c, 0x88, 0x67, 0x86, 0x6d, 0x8d, 0x95, + 0x79, 0x8e, 0x65, 0x71, 0x71, 0x91, 0x85, 0x99, 0xa9, 0x87, 0x80, 0x88, + 0x74, 0x86, 0x75, 0x83, 0x8b, 0x7f, 0x78, 0xb1, 0x90, 0xa8, 0x7b, 0x98, + 0x8a, 0x7b, 0x5b, 0x99, 0x6f, 0x7f, 0xa0, 0x79, 0xa5, 0x93, 0x8b, 0x7b, + 0x7e, 0x7a, 0x61, 0x9d, 0x98, 0x8b, 0x82, 0x7c, 0x76, 0x73, 0x81, 0x8a, + 0x7e, 0x8d, 0x6e, 0x71, 0xa0, 0x65, 0x80, 0x62, 0x7d, 0x8d, 0x5e, 0x9b, + 0x8f, 0x85, 0x89, 0xad, 0x71, 0x73, 0x7f, 0x89, 0x8d, 0x89, 0xb3, 0xa1, + 0x7c, 0xaf, 0x43, 0x82, 0x49, 0x92, 0x62, 0x7f, 0x79, 0xa6, 0x23, 0x99, + 0x6c, 0x9a, 0x8a, 0x90, 0x6c, 0xb9, 0x6f, 0x8a, 0x61, 0x7f, 0x8f, 0x8a, + 0x57, 0xb9, 0x55, 0x65, 0x4b, 0x51, 0x66, 0x6e, 0x4a, 0xa1, 0x83, 0x8a, + 0x73, 0x23, 0x8a, 0x6d, 0x46, 0xa7, 0x87, 0x64, 0x84, 0x5f, 0x6f, 0x6f, + 0x9b, 0x9d, 0x76, 0x83, 0x60, 0x6e, 0x76, 0x8a, 0x9a, 0xa6, 0x75, 0x73, + 0x86, 0x5b, 0x97, 0x88, 0x7b, 0x8e, 0x82, 0x5c, 0x97, 0x71, 0x74, 0x85, + 0x83, 0x91, 0x89, 0x6f, 0x93, 0x94, 0x8b, 0xa9, 0x7d, 0x84, 0x80, 0x89, + 0x97, 0x80, 0x65, 0x92, 0x9a, 0x85, 0x5a, 0x6a, 0x6b, 0x58, 0x6f, 0x8c, + 0x9a, 0x8b, 0x6e, 0x81, 0x9d, 0xae, 0x8c, 0x86, 0x8d, 0x90, 0x6c, 0xb8, + 0x91, 0x89, 0x98, 0xbd, 0x8b, 0x78, 0x7d, 0x87, 0x9c, 0x72, 0x73, 0x80, + 0x9e, 0x92, 0x5d, 0x77, 0x78, 0x4f, 0x87, 0x7b, 0x7a, 0x9e, 0x74, 0x67, + 0x6a, 0x58, 0x95, 0x80, 0x75, 0x97, 0x81, 0x75, 0x94, 0x75, 0x73, 0x92, + 0x83, 0x7b, 0x6b, 0x8e, 0x82, 0x6e, 0x7d, 0x9b, 0x91, 0x7f, 0x9e, 0xaa, + 0x8c, 0xa3, 0xa8, 0x8c, 0x9a, 0xc1, 0x28, 0xac, 0x49, 0x9b, 0x59, 0x8a, + 0x60, 0xa7, 0x39, 0xa7, 0x75, 0x9b, 0x95, 0x94, 0x76, 0xb3, 0x4a, 0x6b, + 0x60, 0x6c, 0xa5, 0x71, 0x40, 0xc4, 0x4c, 0x7c, 0x76, 0x7b, 0x67, 0x76, + 0x76, 0xa4, 0x7b, 0x83, 0x67, 0x4d, 0x87, 0x87, 0x6e, 0x93, 0x84, 0x70, + 0x78, 0x41, 0x87, 0x9f, 0x7a, 0x8c, 0x87, 0x69, 0x73, 0x6c, 0x93, 0x73, + 0x77, 0xa2, 0x52, 0x72, 0x5c, 0x75, 0x6c, 0x8f, 0x65, 0x92, 0x87, 0x52, + 0x67, 0x54, 0x54, 0x75, 0x90, 0x9c, 0x91, 0x6f, 0xa3, 0x86, 0x87, 0x9c, + 0x99, 0x86, 0x9f, 0x71, 0x8a, 0x7a, 0x7a, 0x97, 0x7a, 0x86, 0x6c, 0x99, + 0x89, 0x7e, 0x9c, 0x83, 0x98, 0x78, 0x73, 0x7f, 0x91, 0x96, 0x9a, 0x8d, + 0xb0, 0x9e, 0x6a, 0x80, 0x92, 0x86, 0x95, 0x83, 0x94, 0x92, 0x6f, 0x86, + 0x8a, 0x52, 0x6e, 0x82, 0x84, 0x8b, 0x77, 0x88, 0x70, 0x54, 0x8f, 0x7f, + 0x7d, 0x7e, 0x57, 0x89, 0x6d, 0x6f, 0x9c, 0x93, 0x90, 0x93, 0x52, 0x70, + 0x75, 0x92, 0x73, 0x88, 0x93, 0x77, 0x77, 0x91, 0x89, 0xa2, 0x9d, 0xa6, + 0xae, 0x84, 0x7d, 0xab, 0x92, 0x7e, 0x9c, 0x98, 0x7b, 0xc3, 0x38, 0x98, + 0x4f, 0x97, 0x8f, 0x93, 0x62, 0xb8, 0x23, 0xa4, 0x6d, 0x9c, 0x81, 0x8e, + 0x6f, 0x9d, 0x56, 0x89, 0x50, 0x94, 0x70, 0x77, 0x5d, 0xb7, 0x60, 0x5b, + 0x72, 0x45, 0x81, 0x8c, 0x66, 0xbc, 0x8f, 0x7f, 0x57, 0x43, 0x85, 0x96, + 0x5a, 0xb2, 0x91, 0x7d, 0x6c, 0x3a, 0x73, 0x92, 0x63, 0x93, 0x89, 0x90, + 0x7f, 0x52, 0x7f, 0x7b, 0xa1, 0xa6, 0x8f, 0x60, 0x78, 0x51, 0x5f, 0xac, + 0x7b, 0x89, 0x88, 0x97, 0x7e, 0x64, 0x57, 0x72, 0x6c, 0x96, 0x74, 0x78, + 0xab, 0x66, 0x62, 0x8d, 0x6f, 0x86, 0x91, 0x93, 0x7d, 0x74, 0x82, 0x80, + 0x73, 0x84, 0x9c, 0x8e, 0x68, 0x69, 0x9e, 0xa1, 0x8a, 0x83, 0x7a, 0x87, + 0x94, 0x8c, 0x83, 0x7e, 0x91, 0x92, 0x82, 0x7b, 0xa0, 0x8e, 0x73, 0x86, + 0xa9, 0x95, 0x7c, 0xa5, 0x6c, 0x6f, 0x8c, 0x87, 0xa6, 0x8a, 0x77, 0x86, + 0x7d, 0x79, 0x89, 0x75, 0x8f, 0x82, 0x54, 0x61, 0x82, 0x8e, 0x80, 0x84, + 0x7b, 0x8e, 0x61, 0x82, 0x86, 0x77, 0x7d, 0x7c, 0x7e, 0x6c, 0x7b, 0xad, + 0x7b, 0x90, 0x88, 0x80, 0x64, 0x83, 0x7e, 0xa7, 0x83, 0x7e, 0xb5, 0xbb, + 0x88, 0xd9, 0x21, 0x9a, 0x4d, 0x9f, 0x91, 0x97, 0x64, 0xb5, 0x1c, 0x8a, + 0x5f, 0xaf, 0x7e, 0x7b, 0x67, 0xad, 0x48, 0x7f, 0x4e, 0x87, 0x8f, 0x7c, + 0x46, 0xab, 0x70, 0x7f, 0x4b, 0x4e, 0x48, 0x8c, 0x63, 0xc5, 0xa2, 0x7f, + 0x68, 0x3b, 0x59, 0x7f, 0x53, 0xa1, 0x8e, 0x6e, 0x7a, 0x4a, 0x5f, 0x62, + 0x5b, 0xa1, 0x62, 0x78, 0x74, 0x57, 0x78, 0x91, 0x7b, 0x9b, 0x75, 0x73, + 0x73, 0x72, 0x94, 0x92, 0x79, 0xaa, 0x94, 0x75, 0x86, 0x58, 0x8c, 0x71, + 0x77, 0x91, 0xa5, 0x74, 0x8f, 0x73, 0x89, 0x77, 0x68, 0x8e, 0x90, 0x96, + 0x9f, 0x79, 0x77, 0x7d, 0x89, 0x9b, 0x8c, 0x94, 0x81, 0x88, 0x91, 0x8f, + 0x9b, 0x91, 0x78, 0x87, 0x82, 0x72, 0xa7, 0xa2, 0x85, 0x98, 0xa3, 0x91, + 0x83, 0x75, 0x72, 0x93, 0x80, 0x8f, 0x85, 0x70, 0x97, 0x58, 0x9f, 0x72, + 0x91, 0x8e, 0x93, 0x74, 0x97, 0x73, 0x74, 0x91, 0x80, 0x84, 0x96, 0x94, + 0x76, 0x69, 0x66, 0x9e, 0x81, 0x8a, 0x8b, 0x63, 0x65, 0x7c, 0xa1, 0x9a, + 0x72, 0x84, 0x9e, 0x89, 0x9a, 0x86, 0x98, 0x7f, 0x77, 0x85, 0x82, 0xaa, + 0xa3, 0x88, 0xac, 0x9e, 0x76, 0xca, 0x2b, 0xa0, 0x40, 0xad, 0x6f, 0x6c, + 0x66, 0xc8, 0x07, 0x9e, 0x3e, 0x9f, 0x85, 0x9f, 0x5e, 0xb7, 0x53, 0x91, + 0x56, 0x6d, 0x62, 0x95, 0x4c, 0xc7, 0x46, 0x56, 0x4b, 0x5d, 0x6f, 0x52, + 0x4d, 0xa3, 0x8c, 0x90, 0x78, 0x4d, 0x58, 0x8d, 0x53, 0x93, 0x8e, 0x68, + 0x6f, 0x3b, 0x49, 0x86, 0x6e, 0x9d, 0x76, 0x74, 0x5b, 0x44, 0x7b, 0x8c, + 0x89, 0xb0, 0x64, 0x62, 0x6a, 0x6d, 0x7a, 0xae, 0x84, 0x95, 0x8c, 0x71, + 0x8b, 0x60, 0x82, 0x9e, 0x8c, 0xa8, 0x90, 0x66, 0xa1, 0x7b, 0x65, 0x82, + 0x8f, 0x7d, 0x8d, 0x78, 0x8e, 0x5f, 0x75, 0x88, 0x5d, 0x93, 0xa1, 0x93, + 0x6b, 0x67, 0x7a, 0xa7, 0x92, 0x8c, 0x65, 0x88, 0x95, 0x93, 0x87, 0x81, + 0x9c, 0x97, 0x62, 0x9d, 0x90, 0x62, 0xa1, 0x9f, 0x87, 0x94, 0x94, 0x99, + 0x92, 0x8f, 0x71, 0x80, 0x77, 0x82, 0x92, 0x78, 0x67, 0x69, 0x7e, 0x81, + 0x93, 0x89, 0x80, 0x9b, 0x71, 0x57, 0x63, 0x83, 0x7b, 0x9f, 0x5d, 0x92, + 0x85, 0x96, 0x7e, 0x92, 0x84, 0x7f, 0x81, 0xa3, 0xa8, 0x96, 0x91, 0x8e, + 0x8c, 0x8e, 0x7d, 0xb0, 0x86, 0x72, 0x9d, 0x8e, 0x8e, 0xd0, 0x05, 0x77, + 0x45, 0xad, 0x91, 0x95, 0x71, 0xb8, 0x01, 0x9a, 0x41, 0xb8, 0x94, 0x6e, + 0x63, 0xd3, 0x58, 0x8c, 0x5a, 0x89, 0x85, 0x83, 0x52, 0xc1, 0x7b, 0x6a, + 0x65, 0x6e, 0x73, 0x63, 0x68, 0xba, 0x67, 0x78, 0x79, 0x4a, 0x73, 0x8f, + 0x51, 0xc9, 0x85, 0x8a, 0x6b, 0x45, 0x6a, 0x8f, 0x6c, 0xad, 0x8a, 0x8d, + 0x6a, 0x6e, 0x6b, 0x7f, 0x86, 0xb4, 0x88, 0x7d, 0xaa, 0x71, 0x5c, 0x69, + 0x5d, 0xa8, 0x62, 0x7d, 0x6c, 0x6e, 0x6f, 0x6a, 0x7c, 0x9d, 0x7a, 0x83, + 0x7d, 0x79, 0x7b, 0x9c, 0x73, 0x93, 0x7f, 0x9d, 0x8c, 0x75, 0x78, 0x83, + 0x85, 0x88, 0x81, 0x81, 0x98, 0x79, 0xa3, 0xae, 0x5b, 0x90, 0x89, 0x9d, + 0x6d, 0x90, 0xa3, 0x8e, 0x87, 0x96, 0x60, 0xa7, 0x76, 0x82, 0x81, 0x84, + 0x84, 0x9c, 0x73, 0x8a, 0x6c, 0x58, 0x64, 0x96, 0x89, 0x8b, 0x76, 0x60, + 0x91, 0x72, 0x7f, 0x86, 0x9a, 0x89, 0x67, 0x7d, 0x77, 0x84, 0x73, 0x5c, + 0x67, 0x8a, 0x82, 0x8c, 0x8c, 0x94, 0x8a, 0xa2, 0xaa, 0x7e, 0x5f, 0x7f, + 0x86, 0x90, 0x96, 0xab, 0x8d, 0x91, 0x7c, 0xb6, 0x82, 0x8d, 0xb8, 0xa9, + 0x92, 0xea, 0x1b, 0x74, 0x25, 0xab, 0x8d, 0x61, 0x81, 0xd8, 0x2c, 0x86, + 0x2f, 0xcf, 0xa2, 0x84, 0x7f, 0xa4, 0x36, 0x86, 0x47, 0x8d, 0x60, 0x8a, + 0x62, 0xb1, 0x4a, 0x54, 0x48, 0x73, 0x64, 0x9d, 0x72, 0xb2, 0x76, 0x4c, + 0x8e, 0x4e, 0x76, 0x94, 0x7c, 0xad, 0x74, 0x6c, 0x6c, 0x54, 0x7f, 0x63, + 0x97, 0xb3, 0x74, 0x6c, 0x99, 0x5f, 0x86, 0x6a, 0xa3, 0x94, 0x7c, 0x83, + 0x8d, 0x81, 0x79, 0xac, 0x61, 0x9b, 0x65, 0x7b, 0x66, 0x89, 0x60, 0x76, + 0x8d, 0x93, 0x8d, 0x84, 0x71, 0x65, 0x82, 0x8c, 0x94, 0xa7, 0x59, 0xa1, + 0x8b, 0x72, 0x84, 0x65, 0x75, 0x95, 0x62, 0x71, 0x71, 0x7e, 0x7b, 0x97, + 0x9b, 0x9a, 0x80, 0xb1, 0x77, 0x7a, 0x73, 0x8e, 0x9c, 0x8c, 0x7d, 0x96, + 0x89, 0x7d, 0x7e, 0x80, 0x8e, 0x93, 0x63, 0x72, 0x6b, 0x57, 0x78, 0x8f, + 0x90, 0x86, 0x62, 0x75, 0x7e, 0x54, 0x7d, 0x95, 0x85, 0x84, 0x73, 0x7b, + 0x8f, 0x9e, 0x72, 0x8c, 0x90, 0x96, 0x8e, 0x6c, 0x80, 0x8b, 0x9e, 0x8c, + 0x87, 0x8e, 0x9b, 0x97, 0x8f, 0x94, 0xa3, 0x6b, 0xad, 0x93, 0x8a, 0x96, + 0x8d, 0x91, 0xa6, 0x8a, 0x9e, 0xce, 0x6b, 0x98, 0x6d, 0xa9, 0x92, 0x92, + 0x7c, 0xe2, 0x63, 0x97, 0x42, 0xc8, 0xa3, 0xa0, 0x88, 0xdc, 0x75, 0x9b, + 0x51, 0x7d, 0x5c, 0x80, 0x89, 0xc0, 0x83, 0x5e, 0x5e, 0xa4, 0x3e, 0x74, + 0x9b, 0xb6, 0x7f, 0x63, 0x78, 0x7d, 0x74, 0x57, 0x93, 0xa2, 0x83, 0x70, + 0x5e, 0x7d, 0x60, 0x69, 0x93, 0x9e, 0x79, 0x86, 0x91, 0x67, 0x86, 0x95, + 0xa2, 0xad, 0x62, 0x74, 0x68, 0x7e, 0x7e, 0x82, 0x8c, 0xb0, 0xa0, 0x63, + 0x8b, 0x82, 0x8f, 0x8c, 0xa4, 0xa3, 0x76, 0x6c, 0x8e, 0x87, 0x72, 0x85, + 0xaa, 0xa4, 0x7f, 0x7b, 0x8e, 0x9a, 0x69, 0x91, 0x9d, 0xa0, 0x81, 0x92, + 0x90, 0x85, 0x66, 0x82, 0xa3, 0xa9, 0x7f, 0x8f, 0x83, 0x9d, 0x8b, 0x8d, + 0x96, 0xa3, 0x8f, 0x7a, 0x6d, 0x89, 0x74, 0x8a, 0xa9, 0xa9, 0x7b, 0x77, + 0x93, 0x8b, 0x63, 0x92, 0x99, 0x8b, 0x88, 0x4f, 0x87, 0x7c, 0x67, 0x78, + 0x83, 0xa5, 0xa5, 0x58, 0x8d, 0x70, 0x86, 0x82, 0x9e, 0xa7, 0xa5, 0x96, + 0x8d, 0x7b, 0x96, 0x8c, 0x95, 0xa3, 0x8d, 0x9c, 0x92, 0x95, 0x98, 0x94, + 0x87, 0x90, 0x92, 0x92, 0x95, 0x96, 0xad, 0x6e, 0x97, 0x8c, 0x92, 0x7f, + 0x95, 0x8b, 0x8a, 0x90, 0x9b, 0x87, 0x9e, 0x86, 0x91, 0xa0, 0x68, 0x82, + 0x85, 0x8e, 0x82, 0xa8, 0x9f, 0x68, 0x87, 0x75, 0x9b, 0x70, 0x95, 0x91, + 0x6c, 0x77, 0x8b, 0x7b, 0x95, 0x80, 0x99, 0x65, 0x95, 0x82, 0x92, 0x9a, + 0x8a, 0x65, 0x70, 0x8c, 0x98, 0x9e, 0x80, 0x7b, 0xa5, 0x9b, 0x93, 0x94, + 0x84, 0x6a, 0x69, 0x82, 0x80, 0x7a, 0x75, 0x72, 0x94, 0x79, 0xad, 0xb2, + 0x81, 0x8b, 0x85, 0x6c, 0x86, 0x88, 0x9e, 0x79, 0x86, 0x9e, 0x7e, 0x91, + 0x7b, 0x6d, 0x93, 0x91, 0x82, 0x97, 0x6b, 0xa6, 0xaa, 0x9f, 0xa8, 0x74, + 0x94, 0x7f, 0x63, 0x98, 0x90, 0xa1, 0x8c, 0x7f, 0x71, 0x86, 0x89, 0x95, + 0x88, 0x80, 0x77, 0x67, 0x85, 0x7d, 0x89, 0x6d, 0x9c, 0x76, 0x72, 0x8d, + 0x96, 0x94, 0x88, 0x98, 0x9f, 0x94, 0x8e, 0x84, 0x7a, 0x88, 0x79, 0x9f, + 0x81, 0xa1, 0x7c, 0x8b, 0x71, 0x79, 0x7d, 0x9d, 0x7b, 0x6a, 0x8c, 0x66, + 0x9e, 0x7b, 0x77, 0x7a, 0xb0, 0x74, 0x7f, 0x8d, 0x8d, 0x71, 0x72, 0x84, + 0x90, 0x98, 0x7b, 0x89, 0x9b, 0x8e, 0x85, 0x7a, 0x67, 0x8a, 0x72, 0x84, + 0x82, 0x91, 0x91, 0x7a, 0x85, 0x8a, 0xae, 0x8a, 0x9a, 0x9a, 0x7f, 0x85, + 0x8a, 0x90, 0x69, 0x7b, 0x76, 0x78, 0x98, 0x54, 0x94, 0x7e, 0x6c, 0x72, + 0x89, 0x88, 0x82, 0x96, 0x59, 0x95, 0x76, 0x91, 0x94, 0x96, 0x83, 0x84, + 0x72, 0x8d, 0x97, 0x71, 0x68, 0x8e, 0x88, 0x8b, 0x7c, 0xa9, 0x73, 0x8a, + 0x95, 0x86, 0x87, 0x96, 0x91, 0x77, 0xb1, 0x88, 0x6e, 0x7d, 0x7c, 0x9f, + 0x8f, 0x82, 0x79, 0x83, 0xa6, 0x81, 0x89, 0x83, 0x85, 0x9b, 0x7c, 0x68, + 0x6f, 0x84, 0x7c, 0xa1, 0x8e, 0x80, 0x78, 0x8f, 0x96, 0x77, 0x7e, 0x7b, + 0x8f, 0x81, 0xa5, 0x84, 0x86, 0x91, 0x7b, 0x73, 0x92, 0x85, 0xa3, 0x7e, + 0x80, 0x95, 0x7d, 0x5f, 0x8c, 0x94, 0x95, 0x73, 0x95, 0x78, 0x87, 0xa1, + 0x94, 0x6c, 0xac, 0x6c, 0x77, 0x89, 0x86, 0x9c, 0x82, 0x76, 0x99, 0x93, + 0x92, 0x88, 0x80, 0x80, 0x85, 0x8a, 0xa8, 0x8f, 0x7a, 0x89, 0x9a, 0x7a, + 0x8f, 0x91, 0x86, 0x82, 0x7f, 0x82, 0x91, 0x95, 0x85, 0x71, 0x7d, 0x8f, + 0x83, 0x8c, 0x79, 0x97, 0x7a, 0x9b, 0x91, 0x88, 0xa2, 0x86, 0x8a, 0x80, + 0xa0, 0x96, 0x8b, 0x7d, 0x76, 0x96, 0x9f, 0x8d, 0x95, 0x8a, 0x94, 0xa0, + 0x80, 0x95, 0x9b, 0x96, 0x81, 0xa8, 0x59, 0x89, 0x92, 0xb2, 0x83, 0x89, + 0x85, 0x81, 0x7e, 0x64, 0x77, 0x82, 0x90, 0x96, 0x7e, 0x9f, 0xab, 0x8a, + 0x6e, 0x9b, 0x90, 0x89, 0x6e, 0x7d, 0x81, 0x65, 0x81, 0x86, 0xa1, 0x93, + 0x8b, 0x83, 0x81, 0x89, 0x8b, 0x90, 0x7e, 0x97, 0x8e, 0x75, 0x7e, 0x7e, + 0x7b, 0x81, 0x9a, 0x64, 0x90, 0xab, 0x90, 0x82, 0x8a, 0x82, 0x8d, 0xad, + 0x90, 0x74, 0x7f, 0x9a, 0x88, 0x92, 0x83, 0x97, 0xa6, 0x6e, 0x9d, 0x81, + 0xa2, 0x98, 0x74, 0x84, 0x93, 0x85, 0x84, 0x7d, 0xa2, 0x92, 0x92, 0x87, + 0x73, 0x8b, 0x92, 0x74, 0x96, 0x70, 0x83, 0x86, 0x8a, 0x89, 0x86, 0x88, + 0x87, 0x7c, 0x7d, 0x81, 0x8d, 0x71, 0x8c, 0x89, 0x70, 0x94, 0x8f, 0x9a, + 0x83, 0x9d, 0x99, 0x78, 0x74, 0x88, 0x84, 0x9a, 0x95, 0x8b, 0x8e, 0x7f, + 0xa2, 0xa0, 0x76, 0x93, 0x9b, 0x7c, 0x97, 0x81, 0x83, 0x8c, 0xa1, 0x99, + 0x9d, 0x7f, 0x87, 0x75, 0xa7, 0x75, 0x89, 0x7e, 0x88, 0x80, 0x8f, 0x84, + 0x9a, 0x77, 0x8d, 0x90, 0x9d, 0x6c, 0x88, 0x8d, 0x8e, 0x81, 0x97, 0x6d, + 0x81, 0x88, 0x64, 0x8c, 0x77, 0x8e, 0x91, 0x8a, 0x7f, 0x8a, 0x94, 0x7a, + 0x89, 0x93, 0x8c, 0x69, 0x85, 0x8c, 0x93, 0x61, 0x7e, 0x89, 0x7e, 0x8a, + 0x65, 0x8a, 0xa9, 0x7f, 0x80, 0x86, 0x82, 0x90, 0x66, 0x7a, 0x99, 0x71, + 0x7f, 0x73, 0x8d, 0x94, 0x7d, 0x73, 0x7a, 0x7d, 0x87, 0x7a, 0x97, 0x70, + 0x81, 0x60, 0x61, 0x7a, 0x91, 0x88, 0x93, 0x7a, 0x9e, 0xa6, 0x92, 0x9d, + 0x92, 0x67, 0x99, 0x9a, 0xae, 0x71, 0x89, 0xa5, 0x9f, 0xa6, 0x98, 0x89, + 0x97, 0x90, 0x9b, 0x9a, 0xc0, 0x95, 0x8f, 0x9c, 0x95, 0x93, 0x88, 0x95, + 0x95, 0xa0, 0x8e, 0x8c, 0xa8, 0x94, 0x6e, 0x9e, 0x6f, 0x7b, 0xa5, 0x96, + 0x98, 0x90, 0x91, 0x89, 0x93, 0x8f, 0x84, 0xb2, 0x7f, 0x5e, 0xc2, 0x75, + 0x8f, 0x90, 0x9c, 0xbf, 0x8a, 0x84, 0xa6, 0x85, 0x7d, 0x84, 0x8a, 0xad, + 0x6f, 0x88, 0xac, 0x77, 0x91, 0x8d, 0x94, 0xac, 0x8f, 0x7f, 0xa1, 0xa5, + 0x8e, 0x6d, 0x8a, 0x82, 0x85, 0x80, 0x9b, 0x7a, 0x9f, 0x60, 0x95, 0x97, + 0x90, 0x67, 0x8f, 0x91, 0x86, 0x89, 0x88, 0x89, 0x96, 0x6c, 0x8b, 0x94, + 0x8a, 0x75, 0x84, 0x96, 0x8a, 0x86, 0x7c, 0x91, 0x74, 0x8f, 0x97, 0x89, + 0x8f, 0x8e, 0x6b, 0x97, 0x93, 0x89, 0x6b, 0x7e, 0x65, 0xa4, 0xa5, 0x63, + 0x85, 0x88, 0x81, 0xa3, 0x70, 0x9b, 0x9e, 0x8c, 0x62, 0x73, 0x85, 0xb4, + 0x88, 0x6e, 0x92, 0x6f, 0x91, 0x88, 0x79, 0x91, 0x7f, 0x7d, 0x9a, 0x6b, + 0x78, 0x93, 0x7e, 0x79, 0x93, 0x7a, 0x74, 0x91, 0x8d, 0x92, 0xb3, 0x61, + 0xa3, 0x76, 0x81, 0x99, 0x96, 0x8b, 0x93, 0x8f, 0xa7, 0x6f, 0x8f, 0xa6, + 0xb2, 0x76, 0xa1, 0x83, 0xa8, 0x8b, 0xae, 0x99, 0x90, 0x6a, 0x97, 0x97, + 0xaa, 0x95, 0x85, 0x7d, 0x97, 0x94, 0x86, 0x94, 0x89, 0xa4, 0xa9, 0x81, + 0x89, 0x7c, 0x96, 0xb3, 0x92, 0x7d, 0xa4, 0x6f, 0x6d, 0x92, 0x83, 0xb4, + 0x7b, 0x94, 0x8c, 0x79, 0x61, 0x6f, 0x8f, 0xb7, 0x88, 0x66, 0xaa, 0x7d, + 0x89, 0x7f, 0x90, 0xbd, 0x99, 0xac, 0xb1, 0x96, 0x9c, 0x7c, 0x92, 0xb7, + 0x73, 0x94, 0xad, 0x9d, 0x7c, 0x80, 0x87, 0x96, 0x73, 0x8d, 0xa8, 0x88, + 0xa9, 0x83, 0x7b, 0x84, 0x9d, 0x99, 0x83, 0x89, 0x9d, 0x7f, 0x7e, 0x86, + 0x75, 0x83, 0x77, 0x7d, 0x8b, 0x7d, 0x80, 0x9d, 0xa2, 0x94, 0x72, 0x92, + 0x75, 0x95, 0x99, 0xa0, 0x7b, 0x83, 0x99, 0x89, 0x82, 0x92, 0x5b, 0x9e, + 0x7c, 0x91, 0x95, 0x79, 0x61, 0x86, 0x60, 0xc7, 0x72, 0x91, 0xb5, 0x88, + 0x71, 0x8d, 0x85, 0x91, 0x83, 0x74, 0xa8, 0x67, 0x79, 0x77, 0x7f, 0x79, + 0x68, 0x84, 0x95, 0x69, 0x98, 0x88, 0x74, 0x72, 0x9c, 0x86, 0x87, 0x95, + 0x90, 0x95, 0x9b, 0x8b, 0xc5, 0x7d, 0x81, 0x8f, 0x88, 0x8c, 0xb0, 0x95, + 0xa8, 0x8c, 0x84, 0xa0, 0xb0, 0x89, 0x9a, 0x90, 0xaa, 0x88, 0x96, 0x9b, + 0x88, 0xa9, 0x89, 0x99, 0xb7, 0x82, 0x99, 0xa0, 0x85, 0x70, 0x9c, 0x9a, + 0x94, 0x74, 0x91, 0x81, 0x76, 0x70, 0x8f, 0xc2, 0x8c, 0x91, 0x8f, 0x69, + 0x74, 0x7e, 0x6d, 0x9a, 0x80, 0x77, 0xa5, 0x94, 0x8b, 0x6d, 0x82, 0xcf, + 0x8e, 0x74, 0xc4, 0x86, 0x7f, 0x78, 0x72, 0xb3, 0x78, 0x7a, 0xac, 0x9c, + 0x7d, 0x77, 0x8d, 0xca, 0x67, 0x8c, 0xd5, 0x8f, 0x7f, 0x71, 0x70, 0x82, + 0x7e, 0x9f, 0xb0, 0x7f, 0x75, 0x90, 0x79, 0x7b, 0x8d, 0x7b, 0xa6, 0x87, + 0x98, 0x76, 0x84, 0x96, 0x81, 0x6a, 0x96, 0x86, 0x8e, 0x77, 0xa3, 0x83, + 0x91, 0x83, 0x8a, 0x6c, 0x74, 0x83, 0x99, 0x7d, 0x7c, 0x8a, 0x88, 0x9a, + 0x6b, 0x86, 0x59, 0xa3, 0x8a, 0x8e, 0xbb, 0x8a, 0x75, 0x78, 0x68, 0xb5, + 0x9b, 0x7b, 0xa7, 0x93, 0x5b, 0x6c, 0x6b, 0xa0, 0x74, 0x99, 0xc0, 0x73, + 0x8b, 0x7e, 0x8e, 0x83, 0x64, 0x7c, 0x7d, 0x7a, 0x98, 0x7d, 0x82, 0x7c, + 0x8f, 0x7e, 0x74, 0x86, 0xa9, 0x84, 0xba, 0x8f, 0xc7, 0x6f, 0x87, 0xae, + 0x97, 0x91, 0xad, 0x82, 0xb2, 0x70, 0x8a, 0xa0, 0xb0, 0x7d, 0x95, 0x8d, + 0xc2, 0x85, 0x80, 0xad, 0x9f, 0x85, 0x8b, 0x76, 0xaa, 0xab, 0x8f, 0xa0, + 0x89, 0x9b, 0x8a, 0xb3, 0xa0, 0x72, 0xbe, 0x8c, 0x93, 0x7a, 0xa0, 0xad, + 0x99, 0x6f, 0xa2, 0x79, 0x78, 0x8b, 0x6d, 0xae, 0x75, 0x6f, 0xa1, 0x8d, + 0x68, 0x81, 0x74, 0xb3, 0x8f, 0x81, 0xc6, 0x96, 0x77, 0x68, 0x85, 0xaf, + 0x86, 0x9f, 0xbb, 0x8a, 0x7e, 0x8a, 0x86, 0xab, 0x8b, 0x87, 0x94, 0x96, + 0x99, 0x82, 0x6a, 0xaa, 0x7b, 0x81, 0xa6, 0x9b, 0xb6, 0x73, 0x78, 0x9a, + 0x8f, 0xaa, 0x93, 0x81, 0x97, 0x7a, 0x72, 0x82, 0x79, 0x81, 0x7c, 0x88, + 0x8e, 0x79, 0x9d, 0x81, 0x9a, 0x75, 0x9b, 0x89, 0x73, 0x6a, 0xa6, 0x84, + 0x5c, 0x6f, 0xa0, 0x9d, 0x81, 0x84, 0x3e, 0xaf, 0x94, 0xa1, 0xb8, 0x93, + 0x81, 0x89, 0x68, 0xd4, 0x87, 0x99, 0x99, 0x95, 0x79, 0x72, 0x81, 0xa1, + 0x78, 0x7d, 0x8f, 0x7e, 0x87, 0x78, 0x8e, 0x97, 0x7e, 0x96, 0x86, 0x86, + 0x97, 0x74, 0x6f, 0x7d, 0xa5, 0x81, 0x6f, 0x8e, 0x9e, 0x8b, 0xad, 0xac, + 0xbd, 0x75, 0x84, 0xa2, 0x93, 0x76, 0xc7, 0x9e, 0xb0, 0x75, 0x89, 0xa4, + 0x95, 0x92, 0xb5, 0xaa, 0xb9, 0x7d, 0x79, 0xa5, 0x88, 0x70, 0x84, 0x70, + 0xa3, 0x81, 0xa1, 0xa6, 0x8f, 0x96, 0x96, 0x8d, 0xa5, 0x83, 0xb2, 0x8f, + 0x88, 0x74, 0x96, 0xbc, 0x8b, 0x81, 0xa4, 0x85, 0x7c, 0x87, 0x64, 0xb4, + 0x80, 0x88, 0x92, 0x90, 0x78, 0x79, 0x77, 0xa5, 0x79, 0x8b, 0xbd, 0x7d, + 0x84, 0x8c, 0x96, 0xd4, 0x78, 0x81, 0xa4, 0x8c, 0x97, 0x89, 0x78, 0xc4, + 0x9f, 0x94, 0xb9, 0x83, 0x76, 0x78, 0x89, 0x86, 0x81, 0x8f, 0xbd, 0xa7, + 0x88, 0x79, 0x8e, 0x92, 0x86, 0x88, 0xad, 0x8a, 0x7b, 0x7f, 0x80, 0xad, + 0x7a, 0xaf, 0x8a, 0x93, 0xa6, 0x84, 0x92, 0x8e, 0x84, 0x99, 0x80, 0xae, + 0x74, 0x7c, 0x95, 0x9c, 0x7b, 0x84, 0x84, 0x84, 0xa4, 0x82, 0x57, 0xb5, + 0x95, 0xc1, 0xb7, 0xa0, 0x85, 0x7b, 0x69, 0xc3, 0xb1, 0x8e, 0xa0, 0x8e, + 0x81, 0x88, 0x78, 0x9e, 0x81, 0x97, 0xb2, 0x74, 0x81, 0x84, 0x91, 0x87, + 0x6f, 0x6f, 0x75, 0x78, 0x92, 0x7a, 0x6d, 0x80, 0x9a, 0x7e, 0x81, 0xa1, + 0xa8, 0x6d, 0xb5, 0x98, 0xb4, 0x7f, 0x9a, 0xa4, 0x9d, 0x7b, 0xba, 0xaa, + 0xce, 0x93, 0x79, 0xa5, 0x81, 0x95, 0xa6, 0x7f, 0x8c, 0x8b, 0x96, 0xa4, + 0xa1, 0x8d, 0x91, 0x97, 0xce, 0x8e, 0x8e, 0x9d, 0x86, 0x7f, 0x97, 0xa3, + 0x99, 0x75, 0xa3, 0xa0, 0x69, 0x6a, 0x87, 0xa0, 0x9a, 0x80, 0xa2, 0x72, + 0x6d, 0x85, 0x6b, 0x94, 0x8d, 0x77, 0x9f, 0x84, 0x7f, 0x92, 0x64, 0xaa, + 0x78, 0x82, 0xa7, 0x8f, 0x84, 0x79, 0x84, 0xb9, 0x92, 0x7c, 0xb6, 0x96, + 0x9c, 0x99, 0x8f, 0xab, 0xab, 0x8a, 0xa2, 0xab, 0x6d, 0x97, 0x7b, 0xb1, + 0x9e, 0x6c, 0x9a, 0x99, 0xaa, 0xa3, 0x70, 0x80, 0x81, 0x6f, 0xb6, 0x95, + 0x93, 0x93, 0x8e, 0x80, 0x86, 0xb0, 0x87, 0x91, 0x8f, 0x8c, 0xa4, 0x86, + 0x89, 0x8f, 0x93, 0x83, 0x75, 0x7d, 0x9b, 0x86, 0x7d, 0x5a, 0x9d, 0x67, + 0x9f, 0x78, 0x5c, 0xa5, 0x8e, 0xa2, 0xc1, 0x95, 0x89, 0x84, 0x53, 0xd1, + 0x7d, 0x9b, 0xc0, 0x8f, 0x73, 0x7f, 0x85, 0x9e, 0x8a, 0x7b, 0xa6, 0x84, + 0x6c, 0x74, 0x95, 0x93, 0x7a, 0x7a, 0x81, 0x7d, 0x89, 0x86, 0x76, 0x8a, + 0xad, 0x66, 0x90, 0x90, 0x9d, 0x77, 0xb4, 0xad, 0xac, 0x8e, 0xb3, 0xa5, + 0x9d, 0x91, 0xd7, 0x94, 0xba, 0x8b, 0x72, 0xa4, 0x93, 0x7e, 0xa7, 0x86, + 0xae, 0x83, 0x63, 0xa6, 0xa0, 0x78, 0x81, 0x8b, 0xc4, 0x82, 0x8f, 0x98, + 0xa1, 0x8f, 0x79, 0x9a, 0x92, 0x85, 0x9d, 0x91, 0x92, 0x84, 0x8f, 0x84, + 0x91, 0x6d, 0x7b, 0x69, 0x75, 0x87, 0x5d, 0x99, 0x92, 0x83, 0xab, 0x8f, + 0x53, 0x90, 0x7b, 0xa0, 0x71, 0x89, 0xc2, 0x7f, 0x6a, 0x7c, 0x86, 0xb2, + 0x8d, 0x89, 0xaf, 0x9c, 0x81, 0x8c, 0x84, 0xbe, 0x93, 0x9c, 0xa8, 0x97, + 0x68, 0x9b, 0x84, 0xa3, 0x8a, 0x77, 0xa5, 0x79, 0x7b, 0x87, 0x86, 0xa5, + 0x80, 0x83, 0x9e, 0x8d, 0xb1, 0x94, 0x7a, 0x8b, 0xa6, 0xa8, 0x80, 0x98, + 0x8c, 0x73, 0xa9, 0x7b, 0x91, 0x8f, 0x71, 0x82, 0x68, 0x84, 0xa5, 0x96, + 0x67, 0x63, 0xa6, 0x71, 0xa7, 0x85, 0x57, 0x9f, 0x91, 0xb2, 0xa6, 0x87, + 0x80, 0x8f, 0x6a, 0xba, 0x9d, 0xb7, 0xb9, 0x8b, 0x75, 0x7c, 0x6f, 0x9f, + 0x74, 0x8d, 0xaf, 0x6e, 0x7c, 0x65, 0x6c, 0x8a, 0x7c, 0x81, 0x89, 0x77, + 0x8b, 0x74, 0x65, 0x9b, 0xa5, 0x6b, 0x92, 0x71, 0xbb, 0x70, 0x99, 0xbf, + 0xb0, 0x7b, 0x92, 0xb4, 0xa4, 0x84, 0xc4, 0x92, 0xa8, 0x94, 0x7e, 0xcd, + 0x83, 0x87, 0xaf, 0xa0, 0xa5, 0x94, 0x72, 0xb9, 0x90, 0xa6, 0x9e, 0x9e, + 0x9b, 0x7a, 0x68, 0xc0, 0x8f, 0x89, 0x72, 0x94, 0x9b, 0x81, 0x81, 0x91, + 0x88, 0x90, 0xa8, 0x8d, 0x90, 0x78, 0x7c, 0x67, 0x64, 0x8e, 0x55, 0xa1, + 0x6d, 0x86, 0xa3, 0x6f, 0x5c, 0x7d, 0x79, 0xa3, 0x64, 0x71, 0xd4, 0x87, + 0x73, 0x85, 0x76, 0xc7, 0x72, 0x86, 0xb2, 0x8c, 0x7b, 0x8d, 0x96, 0xc3, + 0xad, 0x87, 0xac, 0xa8, 0x84, 0x94, 0x7b, 0xbf, 0x83, 0x74, 0x8e, 0x8c, + 0x9c, 0x99, 0x88, 0x8e, 0x86, 0x88, 0xae, 0x7f, 0x70, 0x96, 0x6f, 0x74, + 0x8f, 0x85, 0x7c, 0x86, 0x97, 0x83, 0xa0, 0x6a, 0x8b, 0x82, 0x88, 0x90, + 0x72, 0x84, 0x9b, 0xa1, 0x6f, 0x72, 0xa4, 0x95, 0xa6, 0x7d, 0x65, 0xbd, + 0x90, 0xb6, 0x9e, 0x98, 0xa1, 0x94, 0x66, 0xb3, 0x9c, 0xb3, 0xa7, 0x7f, + 0x91, 0x69, 0x6e, 0xb1, 0x68, 0x7a, 0xaa, 0x91, 0x7c, 0x71, 0x9f, 0x95, + 0x83, 0x86, 0x76, 0x69, 0x9b, 0x7f, 0x8c, 0x94, 0x9c, 0x89, 0x86, 0x93, + 0xc1, 0x79, 0x98, 0x9e, 0xb1, 0x90, 0x9b, 0xb7, 0xab, 0x86, 0xc6, 0xa1, + 0xa9, 0xaa, 0x86, 0xb0, 0x8b, 0x79, 0xb9, 0x85, 0xbe, 0x92, 0x60, 0xc0, + 0x9f, 0x9a, 0x90, 0x8d, 0xb5, 0x77, 0x95, 0xad, 0x8b, 0x93, 0x8a, 0x93, + 0x93, 0x7e, 0x86, 0xa6, 0x7d, 0x89, 0x6b, 0x81, 0x93, 0x75, 0x7f, 0x86, + 0x66, 0x8f, 0x56, 0x8f, 0x84, 0x75, 0x9e, 0x77, 0x78, 0x89, 0x62, 0xb3, + 0x78, 0x76, 0xb5, 0x92, 0x7f, 0x80, 0x7a, 0xb9, 0x7d, 0x80, 0xc2, 0xb9, + 0x7d, 0x8f, 0x8f, 0x8c, 0xa0, 0x78, 0xa2, 0xaf, 0x68, 0x98, 0x77, 0xac, + 0x96, 0x77, 0x96, 0x99, 0x84, 0xb1, 0x72, 0x8e, 0x96, 0xa4, 0xa9, 0x8e, + 0x84, 0x7b, 0x85, 0x8d, 0x8f, 0x83, 0x83, 0x7f, 0x85, 0x6e, 0xa4, 0x98, + 0xab, 0x83, 0x90, 0x8e, 0x77, 0x8e, 0xab, 0x9c, 0x73, 0x79, 0x8d, 0x6e, + 0xa0, 0x97, 0x68, 0xa7, 0x8a, 0xbd, 0x95, 0x96, 0x96, 0x8b, 0x72, 0xc7, + 0x8d, 0x8c, 0xa5, 0x83, 0x9b, 0x8b, 0x6c, 0xac, 0x62, 0x78, 0xae, 0x78, + 0x71, 0x7a, 0x8d, 0xae, 0x91, 0x87, 0x90, 0x82, 0x9b, 0x83, 0x90, 0x97, + 0xb0, 0x96, 0x82, 0xa5, 0xa9, 0x76, 0xa5, 0xa0, 0xac, 0xa1, 0x93, 0x94, + 0xb7, 0x91, 0xbb, 0x9b, 0xa4, 0xa5, 0x8c, 0xb5, 0x95, 0x7b, 0x92, 0x91, + 0xb0, 0x97, 0x73, 0xb9, 0x86, 0xa7, 0x92, 0x98, 0x9e, 0x70, 0x77, 0xba, + 0x96, 0x7b, 0xa6, 0x86, 0x97, 0x85, 0x8e, 0xaa, 0x93, 0x97, 0x8f, 0x8b, + 0x8d, 0x79, 0x84, 0x7e, 0x70, 0x95, 0x52, 0x8f, 0x62, 0x75, 0x8b, 0x8b, + 0x7b, 0x8b, 0x79, 0xaf, 0x90, 0x6d, 0xc8, 0x8d, 0x84, 0x8c, 0x72, 0xaf, + 0x70, 0x8d, 0xa5, 0x8a, 0x76, 0x97, 0x87, 0x8e, 0xa9, 0x83, 0xb2, 0x8d, + 0x7e, 0x9b, 0x76, 0xc2, 0xa2, 0x72, 0xc5, 0x87, 0x75, 0xb7, 0x92, 0x95, + 0x9e, 0xa0, 0xc3, 0x82, 0x8d, 0x8f, 0x7d, 0x85, 0x90, 0x99, 0x7b, 0x82, + 0x87, 0x87, 0xa0, 0x87, 0x9a, 0x8b, 0xa2, 0xa4, 0x67, 0x93, 0xa5, 0xbb, + 0x73, 0x5f, 0x8c, 0x60, 0xa5, 0x7d, 0x6c, 0xb3, 0xb2, 0xb3, 0xa9, 0xa9, + 0x8d, 0x8d, 0x67, 0xd7, 0x63, 0x99, 0xaa, 0x83, 0x88, 0x6a, 0x6f, 0x9e, + 0x5e, 0x9e, 0x9d, 0x81, 0x84, 0x6e, 0x98, 0x90, 0x89, 0x7c, 0x95, 0x7d, + 0x81, 0x8a, 0xa2, 0x8c, 0x92, 0x85, 0x80, 0x92, 0xac, 0x80, 0x9b, 0x9b, + 0xc3, 0x8c, 0x95, 0xbc, 0xaa, 0x7c, 0xb5, 0x8d, 0xa1, 0xb8, 0x70, 0xb6, + 0x8c, 0x92, 0xa8, 0x8e, 0xa3, 0x76, 0x6c, 0xbe, 0xa0, 0x8c, 0x92, 0x8e, + 0xa1, 0x83, 0x76, 0xb2, 0x91, 0x7b, 0x8e, 0x87, 0x7f, 0x89, 0x8a, 0xa1, + 0x91, 0xa0, 0x7a, 0x95, 0x7b, 0x86, 0x99, 0x92, 0x78, 0x8a, 0x62, 0x9e, + 0x7b, 0x7b, 0x89, 0x79, 0x78, 0x87, 0x82, 0x94, 0x7d, 0x91, 0x96, 0x79, + 0x7b, 0x8d, 0x80, 0xa7, 0x88, 0x95, 0xa6, 0x8f, 0x7d, 0x95, 0x79, 0xa2, + 0x91, 0x9b, 0x9d, 0x90, 0x79, 0xa4, 0x88, 0x98, 0x9b, 0x7a, 0xa5, 0x7f, + 0x71, 0x9c, 0x87, 0x96, 0x8c, 0x8f, 0xbc, 0x74, 0x95, 0x99, 0x7f, 0x78, + 0x8c, 0x63, 0x7c, 0x7a, 0x92, 0x8c, 0xa8, 0x78, 0xa8, 0x89, 0x9a, 0x86, + 0x69, 0x7e, 0xa1, 0xc3, 0x57, 0x68, 0x84, 0x89, 0xa9, 0x8d, 0x6f, 0xa9, + 0x8a, 0xab, 0xa5, 0xad, 0x94, 0x83, 0x6b, 0xa7, 0x7e, 0x95, 0x9b, 0x7f, + 0x8b, 0x78, 0x73, 0x90, 0x65, 0x8d, 0xb1, 0x91, 0x84, 0x65, 0x90, 0xb4, + 0x8c, 0x89, 0x94, 0x7c, 0x99, 0x8b, 0x98, 0xb7, 0xb0, 0x91, 0x9e, 0x88, + 0xbd, 0xa0, 0xa4, 0xb9, 0xad, 0x96, 0x97, 0xa3, 0xb6, 0x81, 0xba, 0x9b, + 0xbc, 0xa9, 0x94, 0xb9, 0xa0, 0x85, 0x8e, 0xa1, 0xac, 0x87, 0x65, 0xa6, + 0x98, 0x8e, 0xaa, 0xa3, 0xa3, 0x7f, 0x79, 0xb4, 0x93, 0x76, 0x90, 0x99, + 0x8b, 0x90, 0x84, 0xa6, 0x90, 0x8f, 0x88, 0xa6, 0x89, 0x83, 0x86, 0x7a, + 0x5d, 0x96, 0x71, 0xa5, 0x64, 0x94, 0x9a, 0x85, 0x7c, 0xa1, 0x96, 0x9d, + 0x76, 0x8f, 0x95, 0xa0, 0x7f, 0x8c, 0x80, 0xc7, 0x6c, 0x7d, 0xb7, 0xb2, + 0x82, 0x8e, 0x82, 0xbd, 0xb3, 0x82, 0x99, 0x9b, 0x80, 0x94, 0x8c, 0x94, + 0x94, 0x6b, 0xc6, 0xa9, 0x81, 0x9f, 0x8c, 0x7e, 0x87, 0x88, 0xb3, 0x7d, + 0x88, 0x8c, 0x81, 0x81, 0x7e, 0x7e, 0x86, 0x87, 0x96, 0x85, 0xb4, 0x87, + 0xab, 0x91, 0x8f, 0xa1, 0x72, 0x83, 0xa4, 0x89, 0x6b, 0x75, 0x85, 0x7c, + 0x94, 0x85, 0x6f, 0xad, 0x91, 0xae, 0xa4, 0xa5, 0xa7, 0x8e, 0x6c, 0xb2, + 0x73, 0x99, 0x96, 0x92, 0x89, 0x81, 0x7d, 0x88, 0x60, 0x8d, 0x94, 0x83, + 0x99, 0x68, 0x86, 0xa2, 0x94, 0x8e, 0x82, 0x76, 0x89, 0x8d, 0x98, 0x86, + 0x94, 0x90, 0x83, 0x7d, 0xad, 0x94, 0xa6, 0x90, 0xcb, 0x96, 0xa2, 0xb2, + 0xb6, 0x89, 0xc4, 0x9d, 0xc7, 0xa5, 0x75, 0xc3, 0x92, 0x8c, 0x8e, 0xad, + 0x96, 0x94, 0x8e, 0xab, 0x94, 0x90, 0xa8, 0x84, 0xb5, 0x84, 0x66, 0xce, + 0x74, 0x8c, 0x93, 0x8d, 0x8f, 0x95, 0x8b, 0xa1, 0x7b, 0xa1, 0x79, 0x9e, + 0x81, 0xa4, 0xa0, 0x98, 0x5f, 0x78, 0x8e, 0x97, 0x6f, 0x81, 0x96, 0x8d, + 0x70, 0x93, 0x72, 0x9c, 0x7b, 0x98, 0x8b, 0x8a, 0x8f, 0x8b, 0x6c, 0xa9, + 0x81, 0x99, 0xb3, 0xa3, 0x71, 0x9c, 0x8b, 0x94, 0xa6, 0x8a, 0xb8, 0xa0, + 0x7b, 0x98, 0x74, 0x9f, 0x92, 0x92, 0xb2, 0x89, 0x81, 0xa8, 0x87, 0x97, + 0x96, 0x86, 0xa4, 0x7b, 0x63, 0x8e, 0x86, 0x7d, 0x76, 0x81, 0x93, 0x94, + 0x98, 0x8b, 0xaf, 0x6d, 0xab, 0x9b, 0x85, 0x9b, 0x91, 0x86, 0x95, 0x95, + 0x65, 0x89, 0x9e, 0x6b, 0xa4, 0x82, 0x68, 0xb5, 0x8b, 0xd1, 0x9d, 0x93, + 0x7d, 0x67, 0x5e, 0xba, 0x9b, 0x94, 0x93, 0x8d, 0x88, 0x73, 0x7c, 0x8e, + 0x7d, 0x83, 0x9a, 0x82, 0xa4, 0x62, 0x9a, 0x8d, 0x86, 0xa0, 0x7b, 0x72, + 0xa9, 0x84, 0xa7, 0x94, 0xb2, 0x98, 0x8f, 0x81, 0xbe, 0x84, 0x9d, 0x94, + 0x9c, 0x9a, 0x94, 0x8f, 0xb1, 0x82, 0xb1, 0x82, 0xb1, 0xb2, 0x78, 0xa7, + 0x95, 0x99, 0x8b, 0x8c, 0xb1, 0x81, 0x5b, 0xbb, 0x88, 0x7a, 0x90, 0xa3, + 0x8d, 0x78, 0x6f, 0xbf, 0x8c, 0x93, 0xa1, 0x8e, 0x9f, 0x98, 0x88, 0xb3, + 0x7e, 0x82, 0x8a, 0x8e, 0x7d, 0x8a, 0x96, 0x6a, 0x6c, 0x7b, 0x91, 0x94, + 0x6f, 0x89, 0x9a, 0x84, 0x73, 0x8b, 0x8c, 0x91, 0x7d, 0x8e, 0x9e, 0x80, + 0x88, 0x81, 0x78, 0xaf, 0x86, 0xa5, 0xa2, 0x8d, 0x6a, 0x8a, 0x75, 0xa1, + 0x83, 0x87, 0xaf, 0x7d, 0x6c, 0xa3, 0x65, 0x77, 0x89, 0x91, 0x9a, 0xa1, + 0xa1, 0xaf, 0x78, 0x94, 0x93, 0xb2, 0xaf, 0x92, 0x74, 0x7a, 0xa7, 0x7b, + 0x8f, 0x9c, 0x86, 0x8d, 0x8f, 0x79, 0xb0, 0xb3, 0x97, 0x82, 0x8e, 0x92, + 0x92, 0x81, 0xa7, 0xbc, 0x6e, 0x6e, 0x89, 0xa5, 0x9a, 0x8d, 0x84, 0xb6, + 0x83, 0xae, 0xa5, 0xa7, 0xae, 0x86, 0x6b, 0xb9, 0x89, 0xb0, 0x8f, 0x82, + 0x8f, 0x6f, 0x83, 0x98, 0x6a, 0x98, 0x9a, 0x85, 0x9f, 0x78, 0x93, 0x8d, + 0x83, 0x88, 0x88, 0x7e, 0x97, 0x99, 0x8a, 0x9b, 0xb0, 0x90, 0x86, 0x88, + 0xb5, 0x90, 0xb3, 0xaa, 0xad, 0x96, 0x93, 0xa3, 0x9d, 0x81, 0xa3, 0x9a, + 0x9f, 0x99, 0x90, 0x9c, 0x9e, 0x8e, 0x88, 0x93, 0xa8, 0x94, 0x62, 0xa6, + 0x94, 0x92, 0xa1, 0x86, 0xb7, 0x8a, 0x6a, 0xa6, 0x81, 0x7e, 0x7b, 0x80, + 0x89, 0x8f, 0x74, 0xa6, 0x72, 0x91, 0xa6, 0x9b, 0x73, 0x97, 0x7e, 0x6f, + 0x70, 0x8d, 0x73, 0x98, 0x80, 0x90, 0x8f, 0x7e, 0x83, 0x77, 0x84, 0x92, + 0x7f, 0x8c, 0x91, 0xa6, 0x99, 0x90, 0x9d, 0xb1, 0x88, 0x85, 0x89, 0x85, + 0x7c, 0x9f, 0x7e, 0xb0, 0xaa, 0x84, 0xa0, 0x8e, 0x74, 0x93, 0x78, 0x90, + 0x9a, 0x8b, 0x8e, 0x97, 0x8f, 0x9f, 0x7c, 0x83, 0x8a, 0x88, 0xa5, 0x8f, + 0x8b, 0x74, 0x84, 0x9a, 0x7f, 0x91, 0x88, 0x77, 0x9c, 0x91, 0xbc, 0x93, + 0x9c, 0x82, 0x89, 0x9b, 0x8a, 0x7d, 0xb7, 0xb8, 0x6f, 0x68, 0xb5, 0x8e, + 0xb4, 0x86, 0x8c, 0xb3, 0x94, 0xb6, 0xa4, 0x93, 0x98, 0x8b, 0x70, 0xb3, + 0x96, 0xaa, 0x87, 0x89, 0x99, 0x68, 0x74, 0xa4, 0x69, 0x9e, 0x8e, 0x6b, + 0x9f, 0x6b, 0x95, 0x9c, 0x88, 0x89, 0x8a, 0x86, 0x8d, 0x75, 0x94, 0x88, + 0xa0, 0x94, 0x77, 0x8c, 0x9c, 0x8d, 0x8e, 0xa4, 0xac, 0xa7, 0x8a, 0x9b, + 0xa9, 0x81, 0xab, 0xac, 0xaf, 0xaf, 0x87, 0xbb, 0x9b, 0x95, 0x8e, 0x9e, + 0x9f, 0xa1, 0x6c, 0xb4, 0x98, 0x8f, 0x81, 0x8d, 0x98, 0x8f, 0x78, 0x96, + 0x89, 0x86, 0x6c, 0x91, 0x8d, 0x9f, 0x95, 0x9f, 0x6b, 0x7f, 0x93, 0x7c, + 0x96, 0x8e, 0x8a, 0x58, 0x80, 0x8e, 0x7a, 0x93, 0x8b, 0x78, 0x99, 0x92, + 0x62, 0x8e, 0x83, 0x8e, 0x87, 0x83, 0x86, 0x99, 0x93, 0x92, 0x80, 0x95, + 0xa2, 0x72, 0xa2, 0x97, 0x78, 0x87, 0x7b, 0xa3, 0x99, 0x78, 0x98, 0x9c, + 0x80, 0x9b, 0x5e, 0x8a, 0x9c, 0x99, 0xa6, 0x7a, 0x8e, 0x99, 0x7a, 0x8e, + 0x8b, 0x76, 0x9b, 0x89, 0x80, 0x8e, 0x83, 0x8a, 0x80, 0x7c, 0x80, 0x74, + 0x95, 0x8c, 0xbf, 0x7e, 0xa8, 0x7a, 0x99, 0x7d, 0x7d, 0x73, 0xb4, 0xae, + 0x88, 0x76, 0xae, 0x78, 0xaa, 0x65, 0x94, 0xbe, 0x97, 0xaf, 0xa4, 0x91, + 0x9c, 0x95, 0x6c, 0xbe, 0x82, 0xb1, 0x9b, 0x91, 0x85, 0x7d, 0x66, 0x9c, + 0x99, 0xbd, 0xa3, 0x88, 0xa8, 0x73, 0x81, 0x94, 0x92, 0x8e, 0x90, 0x8d, + 0xaf, 0x75, 0x86, 0x9b, 0x8b, 0x8b, 0x8d, 0x74, 0xbd, 0x85, 0x97, 0x8b, + 0x9d, 0xba, 0x90, 0xa8, 0x9d, 0x72, 0xa5, 0xa8, 0xbf, 0xbb, 0x7b, 0xb6, + 0xad, 0x94, 0x6f, 0x9a, 0xa7, 0x97, 0x78, 0x9c, 0x98, 0x8d, 0x8c, 0x93, + 0xb8, 0xa8, 0x7f, 0x9d, 0x98, 0x7f, 0x8f, 0x8a, 0x8d, 0xa8, 0x86, 0x7b, + 0x5d, 0x89, 0x8a, 0x83, 0x8c, 0x8b, 0x81, 0x56, 0x7c, 0x87, 0x89, 0xa6, + 0x75, 0x7c, 0x92, 0x74, 0x96, 0x92, 0x78, 0x8d, 0x8d, 0x98, 0xae, 0x7a, + 0x95, 0x8f, 0x8b, 0x9c, 0x95, 0x9f, 0xae, 0x93, 0x7b, 0x93, 0x8c, 0x9a, + 0x79, 0x74, 0x94, 0x6e, 0x7e, 0x8f, 0x64, 0x9f, 0x9c, 0x88, 0x8f, 0x8e, + 0x84, 0x8d, 0x89, 0x95, 0x96, 0x8f, 0x9d, 0x60, 0x85, 0x86, 0x7c, 0x93, + 0x8d, 0x68, 0x83, 0x7c, 0x94, 0x87, 0xb8, 0xa2, 0x9d, 0x82, 0x8e, 0x84, + 0x6c, 0x73, 0xa8, 0xbc, 0x84, 0x85, 0xa2, 0x79, 0x92, 0x64, 0x69, 0xa9, + 0x82, 0xa7, 0x9d, 0x95, 0x8e, 0x6f, 0x9f, 0xa7, 0x97, 0xb1, 0x9d, 0x8e, + 0xa1, 0x70, 0x80, 0x9e, 0x8e, 0x91, 0xa0, 0xaa, 0x81, 0x5b, 0x98, 0x8f, + 0xa0, 0xaa, 0x83, 0x7a, 0x91, 0x7a, 0x73, 0x80, 0xa6, 0x9a, 0x80, 0x7d, + 0x9e, 0x75, 0x7b, 0xa3, 0xad, 0x92, 0x98, 0xc0, 0xa1, 0x80, 0x88, 0xa2, + 0xa5, 0xa4, 0x7e, 0x9b, 0xa0, 0x80, 0x6e, 0xa0, 0x9f, 0xa3, 0x8a, 0x8f, + 0xa2, 0x93, 0x86, 0x8d, 0x8f, 0x93, 0x7e, 0x90, 0x98, 0x83, 0x7d, 0x9b, + 0x9f, 0x9a, 0x97, 0x83, 0x6e, 0x8d, 0x94, 0x6c, 0x7b, 0x7f, 0x73, 0x65, + 0x6a, 0x93, 0x8a, 0x94, 0x83, 0x89, 0x7d, 0x7b, 0x77, 0x8a, 0x7a, 0x9b, + 0x8e, 0x8d, 0x94, 0x89, 0x86, 0x83, 0x7c, 0x8e, 0x8b, 0x90, 0xab, 0x99, + 0x81, 0x8e, 0x77, 0x9c, 0x8c, 0x82, 0x97, 0x8f, 0x78, 0x91, 0x5f, 0xa1, + 0x8b, 0x83, 0xa9, 0x8d, 0x7b, 0x97, 0x77, 0x80, 0x84, 0x7e, 0x9e, 0x75, + 0xa3, 0x86, 0x67, 0x7c, 0x80, 0x6d, 0x77, 0x75, 0x88, 0x75, 0xad, 0x7a, + 0x93, 0x89, 0x8c, 0x87, 0x7a, 0x79, 0xb2, 0xa1, 0x69, 0x80, 0xb5, 0x7a, + 0xa6, 0x7b, 0x95, 0xac, 0x95, 0xa9, 0x98, 0xa4, 0xad, 0x83, 0x8d, 0xbe, + 0xa4, 0x98, 0xad, 0x7d, 0x8b, 0x65, 0x65, 0xad, 0x6a, 0xae, 0xa3, 0xa8, + 0x9c, 0x63, 0x90, 0x91, 0x6d, 0x9a, 0x81, 0x98, 0x86, 0x6a, 0x83, 0x84, + 0x94, 0x9c, 0x77, 0x86, 0xc2, 0x7f, 0x9b, 0xa9, 0xad, 0xae, 0xa7, 0xa6, + 0xd4, 0x70, 0x9d, 0xb5, 0xaa, 0xdb, 0x8f, 0xa3, 0xa5, 0x87, 0x88, 0x9e, + 0xa9, 0x9f, 0x62, 0xa7, 0xa2, 0x8e, 0x7d, 0x8a, 0x9d, 0xa2, 0x6b, 0xa7, + 0x96, 0x6d, 0x76, 0x8c, 0x9b, 0x8c, 0x86, 0x86, 0x93, 0x7c, 0x9d, 0x7c, + 0x7e, 0x93, 0x5c, 0x79, 0x76, 0x8c, 0x8a, 0x87, 0x79, 0x97, 0x9a, 0x7a, + 0x85, 0x8c, 0x7f, 0x85, 0x7a, 0xa1, 0xa7, 0x72, 0x87, 0x7f, 0x96, 0x9e, + 0x92, 0x92, 0x9e, 0xa0, 0x72, 0x99, 0x7a, 0xb0, 0x8c, 0x8d, 0xa3, 0x9b, + 0x91, 0xa6, 0x63, 0x94, 0x8b, 0x81, 0xbb, 0x94, 0x79, 0x95, 0x99, 0x9a, + 0xa0, 0x7a, 0x96, 0x72, 0x82, 0x9a, 0x83, 0x7f, 0x72, 0x7f, 0x6d, 0x75, + 0x91, 0x7f, 0xbc, 0x84, 0x9a, 0x81, 0x95, 0x69, 0x7d, 0x6d, 0xa2, 0xa8, + 0x7e, 0x64, 0xac, 0x86, 0x85, 0x6d, 0x99, 0xaa, 0x7e, 0x79, 0x9c, 0xa0, + 0xa4, 0x77, 0x99, 0xac, 0xa8, 0x8d, 0xb7, 0xa2, 0xa3, 0x61, 0x82, 0x98, + 0x84, 0x8e, 0xa1, 0x8c, 0x88, 0x82, 0x6f, 0x7d, 0x88, 0x80, 0x7a, 0x8a, + 0x8c, 0x6d, 0x87, 0x6f, 0xab, 0x8f, 0x8b, 0x76, 0xa0, 0x7d, 0x9f, 0xab, + 0xb0, 0xb8, 0x9c, 0x8d, 0xb8, 0x81, 0x89, 0x94, 0xa8, 0xc8, 0x92, 0x9b, + 0x8d, 0x83, 0x7b, 0xaf, 0x97, 0x94, 0x6e, 0xa5, 0x9b, 0x97, 0x89, 0x8d, + 0xaa, 0x8a, 0x66, 0x88, 0x93, 0x84, 0xa1, 0x88, 0xa0, 0x99, 0x85, 0x89, + 0x7d, 0x84, 0x8b, 0x6a, 0x92, 0xa1, 0x74, 0x76, 0x73, 0x87, 0x7a, 0x9a, + 0x77, 0x86, 0x89, 0x5f, 0x7f, 0x8b, 0x7f, 0x8d, 0x7e, 0x81, 0x95, 0x8a, + 0x7d, 0x85, 0x74, 0x9a, 0x87, 0x8c, 0x9e, 0xae, 0x80, 0x88, 0x7d, 0x8b, + 0xaa, 0x79, 0x7c, 0x97, 0x79, 0x90, 0x7b, 0x97, 0x97, 0x9f, 0xa1, 0xa2, + 0xab, 0x97, 0x69, 0x7a, 0x8d, 0x9f, 0x9f, 0x89, 0x90, 0x8c, 0x66, 0x98, + 0x6e, 0x86, 0x7b, 0x6e, 0x86, 0x8a, 0xb2, 0xa6, 0x93, 0x7d, 0x8c, 0x81, + 0x7e, 0x84, 0xa6, 0xb6, 0x83, 0x92, 0xa0, 0x88, 0x90, 0x5f, 0x7c, 0x92, + 0x98, 0x94, 0x92, 0x98, 0xa7, 0x65, 0x90, 0xa2, 0xa2, 0x9b, 0xa6, 0x7d, + 0x8b, 0x5a, 0x94, 0x95, 0x9b, 0xa5, 0x99, 0xa5, 0x7e, 0x61, 0x9a, 0x7a, + 0x8b, 0x77, 0x87, 0x76, 0x9d, 0x72, 0x9a, 0x84, 0x98, 0x94, 0x92, 0x73, + 0xae, 0x78, 0x8e, 0xaa, 0xa0, 0xc3, 0x7a, 0xa4, 0xa0, 0x75, 0xa9, 0xae, + 0x8c, 0xd6, 0x87, 0x8f, 0x9f, 0x8c, 0x9b, 0x90, 0x99, 0x97, 0x73, 0x8f, + 0x9b, 0x9c, 0x8c, 0x89, 0xa5, 0x84, 0x8f, 0x7b, 0x8b, 0x7f, 0x97, 0x98, + 0x8d, 0x7b, 0x94, 0x9d, 0x9c, 0x8e, 0x92, 0x89, 0x88, 0x8d, 0x6c, 0x63, + 0x73, 0x81, 0x72, 0x8a, 0x88, 0x8a, 0x9f, 0x79, 0x81, 0x82, 0x9a, 0xa9, + 0x7a, 0x92, 0x7d, 0x76, 0x7b, 0x7a, 0x6a, 0xbe, 0x91, 0x7d, 0x86, 0xad, + 0x84, 0x86, 0x6c, 0x91, 0x91, 0x9f, 0x92, 0x6b, 0x95, 0x98, 0x84, 0xa0, + 0x8f, 0x8b, 0x9e, 0x7f, 0x9f, 0x97, 0x7e, 0x87, 0x80, 0x9e, 0x79, 0x8d, + 0x68, 0x87, 0x88, 0x7d, 0x89, 0x81, 0x6d, 0x85, 0x80, 0x82, 0xa0, 0x97, + 0xa3, 0x72, 0x94, 0x74, 0x8e, 0x56, 0x96, 0x98, 0x91, 0x6f, 0xa0, 0xae, + 0x7c, 0x6e, 0x8e, 0xa9, 0x7c, 0x80, 0x87, 0xa3, 0x9e, 0x57, 0x8e, 0xb5, + 0x87, 0xa6, 0x87, 0x79, 0x8f, 0x55, 0x8a, 0x81, 0x97, 0x6c, 0x9b, 0x99, + 0x78, 0x5c, 0x82, 0x80, 0x91, 0x76, 0x80, 0x91, 0x8b, 0x65, 0x89, 0x7d, + 0xa9, 0x95, 0x89, 0x97, 0x96, 0x6a, 0x89, 0xad, 0x92, 0x9f, 0xb6, 0x82, + 0x88, 0x79, 0x9d, 0xa5, 0x9c, 0xae, 0x9a, 0x93, 0x77, 0x8e, 0x8a, 0xb5, + 0x84, 0xb0, 0x76, 0xa2, 0x89, 0xa0, 0x96, 0x7a, 0xa5, 0x8e, 0x7e, 0x74, + 0x8d, 0x89, 0x89, 0x9e, 0x93, 0x95, 0x90, 0x78, 0x93, 0x8f, 0xa5, 0x7c, + 0x9d, 0x7c, 0x77, 0x85, 0x81, 0x92, 0x7c, 0x87, 0x92, 0x82, 0x98, 0xa3, + 0x63, 0x76, 0x9b, 0x91, 0x7b, 0x8e, 0x97, 0x7e, 0x66, 0x90, 0x63, 0xb4, + 0x71, 0x88, 0x86, 0x8e, 0x6f, 0x89, 0x7a, 0x88, 0x93, 0x7f, 0x96, 0xa8, + 0x7d, 0x88, 0x88, 0x86, 0x7b, 0x91, 0x88, 0x6b, 0xa6, 0x8b, 0x69, 0x78, + 0x82, 0x80, 0x83, 0x6b, 0xaf, 0x81, 0x7b, 0x64, 0x8f, 0x78, 0x6e, 0x7f, + 0x86, 0x91, 0x92, 0xa3, 0xa0, 0x97, 0x82, 0x88, 0x92, 0x90, 0x9e, 0x89, + 0x9d, 0x7b, 0x96, 0x82, 0xa3, 0x8c, 0x7f, 0x84, 0x7a, 0x6c, 0x60, 0x85, + 0xa9, 0x74, 0x83, 0xa2, 0x89, 0x87, 0x9b, 0x77, 0x9b, 0x9a, 0x99, 0x84, + 0x7c, 0x9c, 0x8d, 0x90, 0x8d, 0x7b, 0x74, 0x77, 0x93, 0x8c, 0x6c, 0x8b, + 0x85, 0x78, 0x7f, 0x7d, 0x75, 0x7f, 0x7e, 0x85, 0x8f, 0x7d, 0x62, 0x8c, + 0x7c, 0xad, 0x7f, 0x83, 0xa1, 0xa1, 0x97, 0x7b, 0x72, 0x82, 0x9d, 0x81, + 0x94, 0x81, 0x8d, 0x9f, 0x6f, 0x8f, 0x9d, 0x89, 0x6a, 0x7e, 0x7f, 0x7f, + 0x8d, 0x7e, 0x91, 0x86, 0x7d, 0x8a, 0x7e, 0x70, 0x7b, 0x9b, 0x6e, 0x5f, + 0xa8, 0x7a, 0x73, 0x8a, 0x7a, 0x71, 0x90, 0x95, 0x8d, 0x78, 0x7b, 0x72, + 0x5e, 0x89, 0x62, 0xa1, 0x87, 0x7f, 0x83, 0x75, 0x98, 0x7f, 0x76, 0x72, + 0x8f, 0x9b, 0x7a, 0x8b, 0xa1, 0x7f, 0x60, 0x99, 0x96, 0x6e, 0x67, 0x76, + 0x88, 0x98, 0x6c, 0x7b, 0x9b, 0x8d, 0x5f, 0x89, 0x7c, 0x81, 0x79, 0x86, + 0x69, 0x9e, 0x83, 0x65, 0x8e, 0x82, 0x83, 0x89, 0x85, 0x7f, 0x90, 0x80, + 0xa2, 0x81, 0x85, 0x83, 0x8e, 0x94, 0x94, 0x75, 0x86, 0x87, 0x9a, 0xb2, + 0x82, 0x99, 0x85, 0x7f, 0x8c, 0x7e, 0x81, 0x9a, 0x81, 0x7d, 0x87, 0x81, + 0xa3, 0x8c, 0x8d, 0x85, 0x8d, 0x96, 0x86, 0x7c, 0xa7, 0x87, 0x7e, 0x9d, + 0x63, 0xa8, 0x7c, 0x97, 0xa2, 0xa4, 0x7e, 0x87, 0x93, 0x9e, 0x89, 0x8d, + 0x6b, 0x6d, 0x9d, 0x9b, 0x78, 0x8a, 0x8e, 0x7f, 0x7b, 0xa5, 0x6e, 0x8c, + 0x89, 0x88, 0x73, 0x7e, 0x77, 0x9d, 0xa6, 0xa7, 0x77, 0x87, 0x7e, 0x7e, + 0x97, 0x84, 0x6b, 0x59, 0x60, 0x90, 0x85, 0x76, 0x8f, 0x61, 0x7f, 0x94, + 0x8f, 0x84, 0x8b, 0x7f, 0x73, 0x77, 0x73, 0x71, 0x8a, 0x9b, 0x7b, 0x89, + 0x97, 0x8f, 0x76, 0x63, 0xa3, 0xa1, 0x6b, 0x7c, 0x62, 0x95, 0x8e, 0xa3, + 0x9f, 0x89, 0x8f, 0x7f, 0x92, 0x7c, 0xa2, 0xa4, 0xa6, 0x92, 0x89, 0x93, + 0x74, 0x73, 0x73, 0x96, 0xad, 0x9b, 0x87, 0xac, 0x91, 0x8a, 0xa0, 0x70, + 0x70, 0x7e, 0x8f, 0x74, 0x75, 0xaf, 0x8d, 0x82, 0x8e, 0x82, 0x96, 0x7d, + 0x69, 0x9c, 0x64, 0xa2, 0x82, 0x89, 0x83, 0x9d, 0x83, 0x88, 0x62, 0x92, + 0x72, 0x89, 0x6d, 0x7f, 0x92, 0x70, 0x8e, 0x80, 0x7e, 0x8d, 0x91, 0x85, + 0x8d, 0x89, 0x83, 0x96, 0x90, 0x96, 0x9c, 0xa6, 0x8a, 0x73, 0x89, 0x79, + 0xa9, 0x70, 0x80, 0x78, 0x96, 0x80, 0x7b, 0x85, 0xa5, 0x80, 0x93, 0x95, + 0xc5, 0x74, 0x81, 0x88, 0xa2, 0x93, 0x86, 0x9c, 0xa3, 0x6d, 0x92, 0x8a, + 0x92, 0x99, 0x98, 0x65, 0xad, 0x63, 0x9d, 0x95, 0x99, 0x89, 0x7f, 0x7a, + 0x99, 0x91, 0x7f, 0x78, 0x90, 0x8f, 0x80, 0x85, 0xa1, 0x68, 0x9d, 0x6c, + 0x83, 0x8f, 0x7c, 0x5e, 0x99, 0x7b, 0x80, 0x91, 0x66, 0x8a, 0x92, 0xb3, + 0x7a, 0x99, 0x91, 0x7e, 0x7d, 0x96, 0x69, 0x9e, 0x7c, 0x89, 0xad, 0x8f, + 0x9d, 0x90, 0x85, 0x8e, 0x72, 0xa9, 0x89, 0x83, 0x7c, 0x82, 0x70, 0x82, + 0x6b, 0x79, 0x75, 0x8d, 0x77, 0x9b, 0x7c, 0x8f, 0x8a, 0x95, 0x87, 0x9f, + 0x7c, 0x90, 0x87, 0x70, 0x83, 0x83, 0x98, 0x9f, 0x85, 0x86, 0x8d, 0x81, + 0x87, 0x87, 0x87, 0x9d, 0x8f, 0x9d, 0x7c, 0x98, 0xa2, 0xac, 0x88, 0x93, + 0x88, 0x7d, 0x9b, 0x76, 0x82, 0x67, 0x69, 0x7f, 0x8c, 0x8d, 0x94, 0x7d, + 0x7b, 0xae, 0x8c, 0x85, 0x8b, 0xa7, 0x8c, 0x87, 0x96, 0x7d, 0x8b, 0x90, + 0x90, 0x7c, 0x92, 0xa8, 0x81, 0x87, 0xa4, 0xa4, 0x82, 0x8b, 0x8d, 0x89, + 0x8f, 0x70, 0x9d, 0x7f, 0xa0, 0x84, 0x99, 0x65, 0x99, 0x78, 0x94, 0x8b, + 0xc5, 0x8d, 0x8d, 0x55, 0xb3, 0x8d, 0x78, 0x93, 0xb4, 0x6d, 0x84, 0x90, + 0xd5, 0x76, 0x7a, 0x9e, 0xc8, 0x8f, 0x86, 0x8a, 0xaa, 0x8b, 0x7f, 0x90, + 0xaa, 0x95, 0x9c, 0x81, 0xb4, 0x6b, 0x64, 0x8a, 0x99, 0x84, 0x74, 0x6e, + 0x95, 0x75, 0x98, 0x92, 0x9a, 0x91, 0x8c, 0x7d, 0x88, 0x6e, 0x89, 0x7d, + 0x87, 0x80, 0x8e, 0x86, 0x78, 0x9f, 0x96, 0x75, 0x76, 0x82, 0x84, 0xaf, + 0x8a, 0xb3, 0x93, 0x97, 0x86, 0x7c, 0x7e, 0x96, 0x7c, 0x6d, 0x90, 0x8e, + 0x85, 0x88, 0x8a, 0x9f, 0x70, 0x89, 0x9f, 0x99, 0x95, 0x87, 0x91, 0x9d, + 0x80, 0x74, 0x88, 0x7c, 0x7f, 0xa8, 0x93, 0x77, 0x66, 0xa6, 0x80, 0xa2, + 0x88, 0xa0, 0xaf, 0x6f, 0x76, 0x70, 0x82, 0x9a, 0x73, 0x89, 0x9a, 0x75, + 0x75, 0x8e, 0x5f, 0x85, 0x6a, 0x76, 0x98, 0x66, 0x87, 0xa3, 0x7a, 0x73, + 0x9d, 0xa1, 0x98, 0x8e, 0x78, 0x91, 0x83, 0x8c, 0x82, 0x9e, 0x90, 0x87, + 0x8f, 0x9b, 0x8b, 0x8f, 0x89, 0x62, 0x74, 0x82, 0x7b, 0x7f, 0x8a, 0x9d, + 0x89, 0x93, 0x8c, 0x7a, 0x99, 0x77, 0xac, 0x75, 0x9b, 0x7f, 0x7f, 0x56, + 0x8c, 0x96, 0x70, 0x79, 0xc2, 0x7d, 0x90, 0x64, 0xe9, 0x79, 0x68, 0xb2, + 0xc2, 0xa6, 0xa7, 0x7e, 0xd9, 0x98, 0x79, 0x87, 0xc0, 0x97, 0x87, 0x66, + 0xd0, 0x9f, 0x92, 0x82, 0xa4, 0xa8, 0x8d, 0x78, 0xa6, 0xa1, 0x76, 0x7d, + 0xa4, 0x87, 0x89, 0x51, 0xae, 0x88, 0x5b, 0x76, 0x7d, 0x70, 0x74, 0x93, + 0x89, 0x74, 0x9e, 0x7a, 0x79, 0x64, 0x9a, 0x94, 0x65, 0x93, 0xb0, 0x8d, + 0x88, 0x7e, 0x8e, 0xa5, 0x63, 0x94, 0x94, 0x7d, 0x91, 0x87, 0x84, 0x95, + 0x75, 0x9e, 0x81, 0x99, 0x65, 0x76, 0x82, 0x9c, 0x6a, 0xab, 0x84, 0x85, + 0x88, 0x72, 0x92, 0x83, 0x82, 0xaf, 0x6d, 0x9d, 0x9e, 0x73, 0x98, 0x7f, + 0x91, 0xb4, 0x62, 0x8d, 0x74, 0x6e, 0xb4, 0x94, 0x97, 0x9e, 0x6f, 0x9a, + 0x83, 0x7b, 0xa9, 0x7d, 0x87, 0x97, 0x60, 0xa9, 0x7a, 0x75, 0xad, 0x6c, + 0x77, 0xa4, 0x88, 0x82, 0x6f, 0x8a, 0x83, 0x74, 0x9a, 0xa7, 0x83, 0x91, + 0x7c, 0x7c, 0x78, 0x77, 0x83, 0x92, 0x7a, 0x83, 0x90, 0x6f, 0x79, 0x6b, + 0x9b, 0x8d, 0x99, 0x95, 0x7b, 0x89, 0x8e, 0x6c, 0x8e, 0x6c, 0x9b, 0x91, + 0x97, 0x80, 0x83, 0x6f, 0xaa, 0x91, 0x66, 0x76, 0xc9, 0x77, 0x82, 0x4d, + 0xd7, 0x5f, 0x58, 0x9a, 0xb1, 0x7a, 0xb1, 0x6b, 0xe5, 0x9d, 0x76, 0x89, + 0xb6, 0x94, 0x90, 0x5b, 0xb8, 0x92, 0x7d, 0x90, 0xbd, 0x9a, 0x85, 0x4e, + 0xb4, 0x84, 0x61, 0x82, 0x94, 0x8e, 0x70, 0x57, 0x90, 0x89, 0x6f, 0x60, + 0x78, 0x90, 0x78, 0x85, 0x8e, 0x7c, 0x76, 0x74, 0x71, 0x5d, 0x94, 0x93, + 0x71, 0x8f, 0xc2, 0x80, 0x75, 0x7d, 0x77, 0xa8, 0x70, 0x8f, 0xa6, 0x83, + 0x74, 0x6b, 0x79, 0x97, 0x76, 0xa2, 0xad, 0x93, 0x5b, 0x8c, 0x7c, 0x7e, + 0x82, 0x9b, 0xa0, 0x76, 0x71, 0x7a, 0xa3, 0x80, 0x87, 0x90, 0x92, 0xa6, + 0x85, 0x71, 0x99, 0x91, 0x91, 0x8c, 0x99, 0x9b, 0x92, 0x74, 0xb2, 0x79, + 0x9c, 0x7c, 0x7b, 0xa8, 0x8c, 0x6f, 0xb5, 0x69, 0x7a, 0x8a, 0x68, 0x9f, + 0x82, 0x7d, 0xbd, 0x5f, 0xa1, 0x92, 0x83, 0x9f, 0x6f, 0xa1, 0x88, 0x61, + 0x7b, 0x94, 0x89, 0x83, 0x6f, 0x6e, 0x92, 0x9d, 0x65, 0x7f, 0x97, 0x83, + 0x87, 0x75, 0x92, 0x8a, 0x82, 0x82, 0x79, 0x92, 0x78, 0x89, 0x92, 0x7a, + 0x91, 0x64, 0x8a, 0x93, 0x9d, 0x74, 0x78, 0x64, 0xab, 0x57, 0x7a, 0x84, + 0xcf, 0x7d, 0x95, 0x4f, 0xde, 0x63, 0x78, 0x9a, 0xb7, 0x7a, 0x8b, 0x5b, + 0xda, 0xa3, 0x94, 0x99, 0xbd, 0x88, 0xa4, 0x53, 0xad, 0x8b, 0x81, 0x96, + 0xca, 0x8f, 0x76, 0x5e, 0xbd, 0x9d, 0x70, 0x81, 0x9b, 0x7d, 0x8a, 0x44, + 0xa0, 0x77, 0x52, 0x6e, 0x82, 0x62, 0x6a, 0x6b, 0x9d, 0xaa, 0x81, 0x85, + 0x7d, 0x5f, 0x7f, 0x9c, 0x65, 0x99, 0x97, 0x81, 0x7f, 0x65, 0x65, 0xa4, + 0x84, 0x8c, 0xa1, 0x6d, 0x7a, 0x70, 0x79, 0x90, 0x98, 0xaa, 0x76, 0x95, + 0x7f, 0x91, 0x95, 0x96, 0x6e, 0xa5, 0x95, 0xa2, 0x7d, 0x7e, 0x93, 0x87, + 0x7d, 0x9b, 0x85, 0x9b, 0x85, 0x79, 0x96, 0x6b, 0x9d, 0x9d, 0x61, 0x99, + 0x9c, 0x74, 0xcc, 0x7e, 0x9a, 0x83, 0x83, 0x98, 0x6f, 0x6d, 0xc5, 0x69, + 0xb0, 0xa5, 0x5c, 0x91, 0x6c, 0x7b, 0xcc, 0x72, 0x9a, 0x9d, 0x7e, 0xa3, + 0x8a, 0x96, 0x8e, 0x74, 0x7b, 0x80, 0x6b, 0x85, 0x84, 0x56, 0x92, 0x83, + 0x64, 0x90, 0x86, 0x86, 0x88, 0x79, 0x8b, 0xa0, 0x86, 0x72, 0xab, 0x95, + 0x80, 0x81, 0x96, 0x8f, 0x75, 0x7f, 0x71, 0x92, 0x9e, 0x75, 0x62, 0x5e, + 0xc3, 0x7a, 0x6c, 0x84, 0xba, 0x81, 0x8f, 0x49, 0xc9, 0x76, 0x54, 0x89, + 0xc2, 0x8c, 0xa2, 0x54, 0xd8, 0xa4, 0x72, 0x90, 0xb1, 0x91, 0xa0, 0x7a, + 0xbf, 0x9a, 0x6f, 0x82, 0xbb, 0x81, 0x6a, 0x52, 0xc2, 0x82, 0x52, 0x65, + 0x8d, 0x8a, 0x84, 0x46, 0xa2, 0x90, 0x45, 0x52, 0x82, 0x61, 0x8c, 0x77, + 0x92, 0x6d, 0x87, 0x5b, 0x5e, 0x72, 0x76, 0x97, 0x73, 0x8d, 0x8d, 0x70, + 0x7a, 0x66, 0x76, 0x89, 0x72, 0xbf, 0xb0, 0x84, 0x7d, 0x80, 0x71, 0x8f, + 0x85, 0xa9, 0xa3, 0x7d, 0x7b, 0x84, 0x83, 0xa1, 0x97, 0xa7, 0xaf, 0x84, + 0x86, 0x7d, 0x94, 0x78, 0x80, 0x98, 0x71, 0x84, 0x94, 0x73, 0xb0, 0x74, + 0x99, 0xa2, 0x68, 0xa7, 0x8b, 0x86, 0xe0, 0x75, 0x9e, 0x93, 0x5c, 0xb2, + 0xa2, 0x68, 0xb8, 0x61, 0x92, 0xa3, 0x68, 0xa4, 0x89, 0x59, 0xd0, 0x77, + 0x97, 0xa9, 0x6a, 0x9b, 0x7d, 0x69, 0x9b, 0x79, 0x8c, 0x7c, 0x68, 0x8b, + 0x7a, 0x53, 0x99, 0x9c, 0x7e, 0x8d, 0x89, 0x96, 0x9e, 0x83, 0x89, 0x74, + 0x7f, 0x94, 0x92, 0x8f, 0x85, 0x8a, 0x8a, 0x80, 0x99, 0x87, 0x7a, 0x7d, + 0xac, 0x93, 0x74, 0x68, 0xba, 0x87, 0x6a, 0x98, 0xc7, 0x79, 0x91, 0x54, + 0xeb, 0x80, 0x45, 0x80, 0xc4, 0xb4, 0x94, 0x61, 0xd2, 0xa6, 0x7b, 0x95, + 0xa4, 0xaa, 0x93, 0x7b, 0xb1, 0x74, 0x53, 0x7c, 0xaa, 0x91, 0x64, 0x51, + 0xa9, 0x6e, 0x5e, 0x7c, 0x79, 0x82, 0x8b, 0x2e, 0x9d, 0x66, 0x61, 0x5e, + 0x72, 0x7f, 0x6e, 0x6d, 0x8c, 0x79, 0x7d, 0x60, 0x76, 0x79, 0x68, 0x84, + 0x4d, 0x8e, 0xa8, 0x8f, 0x78, 0x74, 0x69, 0xa4, 0x6e, 0xa9, 0xb9, 0x59, + 0x83, 0x7f, 0x7a, 0x93, 0x90, 0x9b, 0x8d, 0x93, 0x78, 0x80, 0x77, 0x8b, + 0x72, 0xa3, 0x97, 0x73, 0x91, 0x6c, 0x9a, 0x97, 0xa3, 0xad, 0x89, 0x96, + 0x9e, 0x6d, 0xb5, 0x7c, 0xa4, 0x98, 0x61, 0x8a, 0x93, 0x5f, 0xdc, 0x63, + 0xba, 0x92, 0x84, 0x94, 0xab, 0x6f, 0xbf, 0x66, 0x98, 0x93, 0x74, 0x85, + 0x96, 0x63, 0xb8, 0x60, 0x94, 0xbb, 0x79, 0x94, 0x7b, 0x67, 0x8a, 0x64, + 0x99, 0xac, 0x60, 0x98, 0xb0, 0x65, 0xa2, 0x73, 0x8f, 0x94, 0x8c, 0x92, + 0x84, 0x84, 0x9b, 0x8f, 0x84, 0x8d, 0x9f, 0x90, 0x91, 0x85, 0x93, 0x74, + 0x97, 0x66, 0x7f, 0x78, 0xa2, 0x95, 0x73, 0x6b, 0xc5, 0x6f, 0x62, 0x79, + 0xbd, 0x81, 0x89, 0x4a, 0xbd, 0x93, 0x57, 0x81, 0xba, 0xb0, 0x9b, 0x4c, + 0xe8, 0xa2, 0x85, 0xa2, 0x96, 0x92, 0x93, 0x62, 0xbe, 0x7a, 0x71, 0x8b, + 0x8d, 0x97, 0x53, 0x56, 0xb1, 0x5f, 0x67, 0x60, 0x7a, 0x8e, 0x8a, 0x3a, + 0x86, 0x67, 0x6d, 0x53, 0x6e, 0x91, 0x7b, 0x60, 0x99, 0x6d, 0x71, 0x5d, + 0x67, 0x65, 0x63, 0x87, 0x71, 0x8a, 0x92, 0x6d, 0x8f, 0x6f, 0x6f, 0xae, + 0x6c, 0xa2, 0x87, 0x6f, 0x99, 0x88, 0x78, 0x94, 0x8a, 0xb2, 0x93, 0x89, + 0x90, 0x8d, 0x8c, 0x98, 0x81, 0x86, 0x90, 0x6d, 0xa2, 0x82, 0xa2, 0xa3, + 0x9d, 0x8f, 0x7a, 0x9f, 0x87, 0x70, 0xbd, 0x8e, 0xa5, 0x99, 0x5d, 0x70, + 0x8c, 0x60, 0xc7, 0x78, 0x97, 0xb0, 0x6f, 0x94, 0x92, 0x5a, 0xc3, 0x6e, + 0x8b, 0x9f, 0x79, 0xa3, 0x8c, 0x5e, 0xbf, 0x79, 0x8e, 0x98, 0x76, 0x8e, + 0x67, 0x31, 0x9b, 0x85, 0x8e, 0x85, 0x71, 0x99, 0x72, 0x77, 0x84, 0x81, + 0x91, 0x95, 0x80, 0x98, 0x82, 0x6f, 0x90, 0xa0, 0x91, 0x91, 0x8e, 0x75, + 0x8a, 0x89, 0x93, 0x69, 0x95, 0x7f, 0x9a, 0xa0, 0x9e, 0x9b, 0x88, 0x4e, + 0xc3, 0x8d, 0x65, 0x74, 0xba, 0x8d, 0x97, 0x4d, 0xd6, 0x94, 0x73, 0xa0, + 0xb1, 0xb3, 0x8c, 0x67, 0xdd, 0x9f, 0x7f, 0xaa, 0xaf, 0x9a, 0x88, 0x67, + 0xc2, 0x8f, 0x71, 0x7b, 0x8f, 0x9f, 0x47, 0x52, 0x93, 0x72, 0x5a, 0x52, + 0x97, 0x9d, 0x67, 0x3c, 0xa9, 0x59, 0x59, 0x5b, 0x88, 0x92, 0x82, 0x57, + 0x83, 0x67, 0x94, 0x77, 0x52, 0x74, 0x60, 0x9e, 0x52, 0x84, 0xa2, 0x69, + 0x71, 0x96, 0x73, 0xb0, 0x5e, 0xb0, 0x89, 0x71, 0x94, 0x8a, 0x66, 0xa0, + 0x75, 0xc1, 0x99, 0x8e, 0x83, 0x8a, 0x91, 0x89, 0x6b, 0xa5, 0x79, 0x82, + 0x8b, 0x73, 0x95, 0xb0, 0x77, 0x9b, 0x82, 0x7d, 0x8f, 0x60, 0xb9, 0x78, + 0x8b, 0x8f, 0x7b, 0x74, 0x84, 0x6d, 0xbf, 0x76, 0x8f, 0xa3, 0x91, 0xa1, + 0x81, 0x59, 0xcb, 0x69, 0xac, 0x90, 0x98, 0x92, 0xa7, 0x5d, 0xb4, 0x8b, + 0xaa, 0xb1, 0x98, 0x8c, 0xa2, 0x4d, 0xa1, 0x69, 0x7f, 0xa0, 0x7d, 0x8a, + 0x9b, 0x77, 0x8e, 0x71, 0x82, 0x8a, 0x78, 0x8d, 0x98, 0x78, 0x90, 0x91, + 0x7e, 0x7f, 0x78, 0x85, 0x97, 0x8a, 0x97, 0x6d, 0xb3, 0x94, 0x89, 0xa3, + 0xa5, 0x9a, 0x76, 0x6b, 0xbd, 0x79, 0x71, 0x95, 0xce, 0xab, 0x93, 0x1f, + 0xe9, 0x97, 0x4c, 0x84, 0xd5, 0x9f, 0x98, 0x6e, 0xdd, 0x8d, 0x80, 0x9c, + 0xa8, 0x9e, 0x8d, 0x75, 0xbc, 0x8c, 0x80, 0x89, 0xa1, 0x89, 0x74, 0x58, + 0x92, 0x86, 0x55, 0x87, 0x91, 0x8d, 0x70, 0x33, 0xb8, 0x50, 0x63, 0x6b, + 0x79, 0x99, 0x76, 0x71, 0x75, 0x59, 0x73, 0x6b, 0x62, 0x62, 0x74, 0x85, + 0x73, 0xa3, 0xac, 0x78, 0x77, 0x88, 0x64, 0xa0, 0x73, 0xa1, 0xa8, 0x73, + 0x91, 0x8e, 0x5f, 0x9a, 0x68, 0xc9, 0xa1, 0x92, 0x7a, 0x7c, 0x69, 0x77, + 0x7d, 0x9e, 0x8f, 0x76, 0x88, 0x80, 0x92, 0x93, 0x91, 0x99, 0x8c, 0x85, + 0x9f, 0x69, 0xa8, 0x9b, 0x9f, 0x9a, 0x64, 0x7a, 0x99, 0x70, 0xc4, 0x6d, + 0x9a, 0x99, 0x82, 0xa0, 0x8b, 0x59, 0xc8, 0x61, 0x8f, 0x95, 0x72, 0x8c, + 0x90, 0x63, 0xa9, 0x7e, 0x88, 0x8c, 0x85, 0x78, 0x76, 0x58, 0x8e, 0x72, + 0xa3, 0x9a, 0x7c, 0xa0, 0x7f, 0x6d, 0xa6, 0x83, 0x7e, 0x8d, 0x83, 0x88, + 0x86, 0x68, 0x8d, 0x96, 0xaa, 0x78, 0x90, 0xa5, 0x9c, 0x9d, 0x99, 0x88, + 0xb0, 0x82, 0x6f, 0x7e, 0xad, 0xa9, 0x7b, 0x6a, 0xba, 0x6c, 0x6d, 0x89, + 0xc1, 0x9e, 0x8e, 0x2f, 0xf2, 0x77, 0x50, 0x73, 0xdb, 0xc4, 0x9c, 0x6c, + 0xd0, 0x90, 0x88, 0xbe, 0x97, 0xb9, 0x9e, 0x6e, 0xbe, 0x8e, 0x83, 0x8e, + 0x96, 0x98, 0x4c, 0x4e, 0xa7, 0x8d, 0x43, 0x92, 0x8f, 0x92, 0x6d, 0x27, + 0x94, 0x73, 0x5f, 0x42, 0x7c, 0xa7, 0x8a, 0x5a, 0x81, 0x60, 0x85, 0x66, + 0x73, 0x72, 0x74, 0x9d, 0x5a, 0x9e, 0xa3, 0x71, 0x75, 0x91, 0x4f, 0xa2, + 0x67, 0xa6, 0x91, 0x64, 0x92, 0x7e, 0x95, 0x8d, 0x6e, 0xbe, 0x9b, 0x57, + 0x9b, 0x82, 0x89, 0x70, 0x6f, 0x9e, 0x7e, 0x86, 0x97, 0x81, 0x85, 0x8e, + 0x70, 0x96, 0x6c, 0x72, 0xab, 0x6d, 0x9c, 0x91, 0xa0, 0x8a, 0x8d, 0x88, + 0x9e, 0x75, 0xc6, 0x76, 0x7c, 0xa7, 0x6b, 0xa8, 0x94, 0x72, 0xb6, 0x78, + 0x8d, 0x90, 0x7b, 0x8c, 0xa6, 0x65, 0xad, 0x9b, 0xaa, 0x94, 0x89, 0x7d, + 0x90, 0x69, 0xaa, 0x7e, 0x9e, 0xad, 0x7f, 0x94, 0x81, 0x7d, 0xa1, 0x7b, + 0x6c, 0x65, 0x83, 0x95, 0x89, 0x75, 0x93, 0x87, 0x94, 0x87, 0xa8, 0x92, + 0x8d, 0xa6, 0x9f, 0x78, 0xaa, 0x72, 0x95, 0x94, 0xac, 0xa6, 0x91, 0x5a, + 0xdb, 0x82, 0x55, 0xb6, 0xc1, 0xa3, 0x84, 0x4f, 0xc9, 0x88, 0x53, 0x8f, + 0xbb, 0xae, 0x9b, 0x8a, 0xd8, 0xa9, 0x68, 0xc2, 0xa0, 0xa9, 0x87, 0x6b, + 0xbd, 0x99, 0x7e, 0x86, 0x88, 0xa7, 0x5e, 0x53, 0xa4, 0x84, 0x6b, 0x6e, + 0x89, 0x95, 0x84, 0x2d, 0xb5, 0x43, 0x3e, 0x50, 0x71, 0x96, 0x9a, 0x5b, + 0xa1, 0x60, 0x80, 0x70, 0x6a, 0x73, 0x8f, 0x95, 0x52, 0x9b, 0xae, 0x71, + 0x76, 0x7d, 0x61, 0x99, 0x5b, 0xc3, 0xa8, 0x76, 0x98, 0x72, 0x7f, 0x8a, + 0x66, 0xc7, 0xa3, 0x7b, 0x8e, 0x8f, 0x70, 0x74, 0x6a, 0xae, 0x85, 0x83, + 0x96, 0x7d, 0x98, 0xa7, 0x8f, 0x94, 0x7e, 0x84, 0x96, 0x7a, 0xab, 0x7d, + 0x83, 0xb1, 0x6f, 0x7d, 0x9f, 0x80, 0xca, 0x8f, 0x9b, 0xa9, 0x69, 0x7a, + 0x92, 0x73, 0xaa, 0x74, 0x88, 0x98, 0x87, 0x8f, 0xa7, 0x68, 0xa0, 0x74, + 0x97, 0x95, 0x6e, 0x6f, 0x83, 0x53, 0x9b, 0x79, 0x71, 0x87, 0x7d, 0x8b, + 0x79, 0x87, 0xa3, 0x75, 0x68, 0x73, 0x7e, 0x89, 0x8f, 0x81, 0x98, 0x7a, + 0x9a, 0x83, 0x9d, 0x95, 0x90, 0x98, 0x97, 0x57, 0x93, 0x7e, 0xa2, 0x9a, + 0xa8, 0x8a, 0x85, 0x53, 0xbd, 0x7a, 0x61, 0x8b, 0xca, 0xac, 0x9b, 0x2e, + 0xe8, 0xa5, 0x66, 0x86, 0xca, 0xa7, 0xa0, 0x85, 0xcf, 0xa4, 0x6a, 0xc2, + 0xb0, 0xaa, 0x76, 0x76, 0xb6, 0xa2, 0x72, 0xa9, 0xa1, 0xa1, 0x67, 0x67, + 0xac, 0x90, 0x70, 0x6d, 0x8f, 0xb5, 0x6d, 0x3b, 0x85, 0x64, 0x4a, 0x6e, + 0x72, 0x9f, 0x98, 0x5b, 0x97, 0x3e, 0x8a, 0x6a, 0x6c, 0x7d, 0x77, 0x98, + 0x5a, 0x92, 0xa3, 0x81, 0x6f, 0x91, 0x7b, 0xa6, 0x6e, 0x9c, 0x9b, 0x5f, + 0x9e, 0x7e, 0x77, 0x9d, 0x88, 0xc6, 0x81, 0x5a, 0x93, 0x8b, 0x6c, 0x71, + 0x63, 0x9e, 0x78, 0x79, 0x70, 0x90, 0x95, 0x9f, 0x71, 0xa9, 0x90, 0x73, + 0x98, 0x8a, 0xa5, 0x8e, 0x87, 0xb0, 0x79, 0x79, 0x92, 0x7d, 0xcc, 0xa8, + 0x7a, 0x92, 0x82, 0x91, 0x90, 0x69, 0xa4, 0x9b, 0x97, 0x8f, 0x75, 0x7c, + 0xa3, 0x69, 0xb5, 0x87, 0x8d, 0x88, 0x7b, 0x94, 0x8b, 0x55, 0xa2, 0x6d, + 0x89, 0x8e, 0x81, 0x8a, 0x9e, 0x87, 0x86, 0x83, 0x8b, 0x84, 0x87, 0xa7, + 0x8e, 0x79, 0xa4, 0x9c, 0x99, 0x82, 0xa3, 0x8f, 0x91, 0x9a, 0x95, 0x5b, + 0x9f, 0x6e, 0x85, 0x93, 0xa6, 0x9a, 0x91, 0x4c, 0xd8, 0x6b, 0x6d, 0x85, + 0xde, 0xaa, 0x97, 0x51, 0xcf, 0x8c, 0x5f, 0x9a, 0xc2, 0x9d, 0x9a, 0x7c, + 0xc6, 0xb1, 0x84, 0xac, 0xba, 0xa5, 0x7c, 0x76, 0xbd, 0x93, 0x7f, 0xa0, + 0x86, 0xae, 0x47, 0x41, 0x88, 0x82, 0x62, 0x62, 0x73, 0xad, 0x6b, 0x23, + 0xa0, 0x48, 0x5a, 0x5a, 0x8f, 0x98, 0xbd, 0x5c, 0x9c, 0x72, 0x7c, 0x68, + 0x50, 0x78, 0x91, 0xab, 0x5c, 0xc1, 0xc6, 0x66, 0x87, 0x86, 0x60, 0x99, + 0x65, 0xac, 0x94, 0x91, 0x7e, 0x8c, 0x7d, 0x9b, 0x70, 0xb2, 0x9a, 0x7d, + 0x82, 0x91, 0x6b, 0x86, 0x6f, 0xbb, 0x7f, 0x66, 0x7a, 0x79, 0x94, 0x96, + 0x71, 0xa5, 0x75, 0x73, 0x95, 0x81, 0xa4, 0x8b, 0x87, 0xaa, 0x8e, 0x92, + 0xa9, 0x82, 0xb0, 0x92, 0x89, 0xa7, 0x83, 0x81, 0x8c, 0x6d, 0xc4, 0x7a, + 0x89, 0xa5, 0xa1, 0xa2, 0xa4, 0x6b, 0xa4, 0x82, 0x90, 0xb2, 0x8d, 0x72, + 0x83, 0x60, 0xa7, 0x7a, 0x80, 0x97, 0x65, 0x90, 0x87, 0x85, 0xae, 0x71, + 0x7d, 0x71, 0x98, 0xa8, 0x90, 0x75, 0xa9, 0x96, 0xa2, 0x91, 0x7b, 0x6b, + 0xa0, 0x9d, 0x8d, 0x5d, 0xa4, 0x79, 0x8c, 0xa4, 0xad, 0x94, 0x7e, 0x77, + 0xb6, 0x92, 0x74, 0xaf, 0xb5, 0x9b, 0x99, 0x67, 0xe7, 0x8e, 0x6a, 0x87, + 0xc1, 0x98, 0x9b, 0x7e, 0xd7, 0x9b, 0x5b, 0xae, 0xc9, 0x94, 0x7a, 0x6d, + 0x9e, 0xb4, 0x86, 0x8e, 0xa3, 0xa1, 0x5e, 0x5d, 0x8e, 0x8f, 0x6b, 0x59, + 0xa5, 0xa9, 0x69, 0x20, 0xa4, 0x64, 0x35, 0x61, 0x83, 0x9d, 0x8a, 0x4e, + 0x8b, 0x6c, 0x5e, 0x5b, 0x68, 0x76, 0x89, 0x94, 0x5f, 0x87, 0x98, 0x7a, + 0x5d, 0x81, 0x89, 0xa6, 0x54, 0xa3, 0xb4, 0x7b, 0x83, 0x8a, 0x90, 0x8b, + 0x86, 0xbc, 0x86, 0x59, 0x91, 0x79, 0x71, 0x6b, 0x7c, 0x94, 0x98, 0x7f, + 0x81, 0x76, 0x85, 0xad, 0x69, 0xa8, 0x83, 0x8c, 0x8f, 0x70, 0x9a, 0x91, + 0x78, 0xb3, 0x8f, 0x6d, 0x90, 0x86, 0xbd, 0x97, 0x7f, 0xaf, 0x7e, 0x90, + 0x8f, 0x63, 0xa2, 0x93, 0x6e, 0xab, 0x75, 0x72, 0x8d, 0x74, 0xa1, 0x72, + 0x82, 0xaa, 0x70, 0x82, 0x8d, 0x67, 0x94, 0x91, 0x92, 0xa5, 0x7f, 0xa5, + 0x6f, 0x6d, 0xaf, 0x80, 0x89, 0x7d, 0x92, 0x99, 0x92, 0x72, 0x9d, 0x7d, + 0x92, 0x78, 0xa9, 0x89, 0xa9, 0x9b, 0xa3, 0x73, 0x98, 0x71, 0x98, 0x86, + 0x9e, 0x97, 0x9e, 0x6a, 0xb9, 0x6a, 0x6e, 0x90, 0xde, 0x94, 0x9a, 0x52, + 0xdd, 0xa9, 0x6a, 0x79, 0xb9, 0xa3, 0xaa, 0x95, 0xba, 0xa2, 0x75, 0xc2, + 0xbf, 0xb5, 0x6d, 0x8d, 0xae, 0x9b, 0x8d, 0x9a, 0x92, 0xb4, 0x5e, 0x4b, + 0x8b, 0x99, 0x4f, 0x65, 0x94, 0xb6, 0x5d, 0x3a, 0xa3, 0x77, 0x51, 0x4e, + 0x6d, 0xa3, 0x94, 0x59, 0x80, 0x56, 0x8c, 0x67, 0x67, 0x74, 0x99, 0x85, + 0x57, 0x7b, 0x9e, 0x7e, 0x84, 0x85, 0x94, 0x96, 0x71, 0xbf, 0x97, 0x5f, + 0x7d, 0x80, 0x93, 0x87, 0x6b, 0xb9, 0x7d, 0x8b, 0x84, 0x84, 0x6b, 0x8c, + 0x6c, 0xc4, 0x85, 0x82, 0x87, 0x8d, 0x64, 0x90, 0x80, 0xb6, 0x9a, 0x70, + 0x9c, 0x68, 0xa0, 0x88, 0x81, 0x9d, 0x83, 0x75, 0x9d, 0x84, 0xbf, 0x8f, + 0x83, 0x9b, 0x75, 0x82, 0x9c, 0x76, 0xa4, 0x9d, 0x8a, 0xa7, 0x8e, 0x96, + 0x9c, 0x64, 0xc0, 0x95, 0x88, 0xa5, 0x6f, 0x74, 0x7e, 0x5d, 0x9f, 0x7d, + 0x89, 0x81, 0x71, 0xa8, 0x82, 0x6e, 0x9b, 0x9a, 0x6f, 0xa5, 0x88, 0x89, + 0xa4, 0x7e, 0xa4, 0x90, 0xa1, 0x83, 0x8b, 0x9c, 0x9a, 0x89, 0xa2, 0x89, + 0x9d, 0x5d, 0x86, 0xa5, 0xc4, 0x96, 0x9c, 0x85, 0xd6, 0x7c, 0x69, 0x88, + 0xc9, 0xa5, 0x9b, 0x60, 0xea, 0xab, 0x62, 0x9f, 0xd1, 0xa5, 0x86, 0x7e, + 0xb3, 0xbd, 0x7a, 0xa1, 0xbd, 0xa0, 0x7c, 0x92, 0xa6, 0xa3, 0x7d, 0xa9, + 0x98, 0xa6, 0x71, 0x5c, 0x9b, 0x9b, 0x58, 0x6f, 0x8f, 0xaa, 0x5e, 0x3b, + 0xa6, 0x5f, 0x3a, 0x79, 0x94, 0xa5, 0x84, 0x6f, 0x83, 0x5d, 0x75, 0x65, + 0x6c, 0x77, 0x86, 0xad, 0x4a, 0x92, 0x8e, 0x8a, 0x8f, 0x7b, 0x72, 0x96, + 0x79, 0xa6, 0xa8, 0x6d, 0x7b, 0x7b, 0x98, 0xa9, 0x79, 0xb9, 0x9e, 0x8f, + 0x90, 0x6d, 0x76, 0x82, 0x81, 0xc1, 0x95, 0x7c, 0x97, 0x8d, 0x95, 0xa2, + 0x7c, 0xa4, 0x7b, 0x9b, 0x7f, 0x6f, 0xac, 0x83, 0x7e, 0xa1, 0x7c, 0x7c, + 0xa1, 0x7a, 0xa1, 0x6d, 0x95, 0x86, 0x77, 0x98, 0x8e, 0x58, 0xa2, 0x76, + 0x8e, 0xa8, 0x94, 0x90, 0xa7, 0x62, 0xb8, 0x8a, 0x9f, 0xac, 0x87, 0x91, + 0x88, 0x50, 0xa7, 0x83, 0x88, 0x65, 0x7a, 0x92, 0x9d, 0x70, 0xa9, 0x99, + 0x7c, 0x87, 0x8c, 0x96, 0x8e, 0x73, 0xa4, 0xa7, 0x9b, 0x70, 0x99, 0x96, + 0x8f, 0x88, 0xb4, 0x85, 0xa8, 0x6a, 0x9e, 0x78, 0xb0, 0x82, 0x9f, 0x89, + 0xc9, 0x8d, 0x71, 0x7f, 0xc0, 0x98, 0xa0, 0x6d, 0xd2, 0x8e, 0x64, 0x9e, + 0xb2, 0xa9, 0x93, 0x6e, 0xcc, 0xbb, 0x89, 0xb1, 0xc1, 0x9b, 0x86, 0x94, + 0xb5, 0xb5, 0x95, 0xa0, 0x9c, 0x9b, 0x62, 0x5f, 0x7b, 0x91, 0x69, 0x74, + 0x9e, 0xa3, 0x81, 0x30, 0x85, 0x59, 0x49, 0x5e, 0x83, 0x85, 0x7d, 0x6a, + 0x90, 0x51, 0x80, 0x5e, 0x64, 0x6f, 0x99, 0x93, 0x75, 0x9a, 0xa7, 0x72, + 0x6c, 0x5d, 0xa3, 0x93, 0x87, 0xa7, 0xbd, 0x6f, 0x92, 0x6d, 0x85, 0x98, + 0x6f, 0xc7, 0xb6, 0x7c, 0x80, 0x71, 0x8a, 0x9f, 0x71, 0xb5, 0x8c, 0x6d, + 0xac, 0x7b, 0x72, 0xb7, 0x69, 0xa6, 0x9d, 0x66, 0xab, 0x7a, 0x8b, 0x70, + 0x8c, 0x9e, 0x86, 0x75, 0x96, 0x7b, 0xa3, 0x93, 0x8f, 0xb7, 0x84, 0x8c, + 0x87, 0x56, 0xae, 0x82, 0x71, 0xa3, 0x8d, 0x93, 0xaf, 0x59, 0xb3, 0x8a, + 0x97, 0x99, 0x75, 0x73, 0x8e, 0x51, 0xae, 0x84, 0x8b, 0x7a, 0x76, 0x77, + 0x6e, 0x75, 0xa4, 0x8a, 0x75, 0x8e, 0x8f, 0xa2, 0x96, 0x76, 0x9a, 0x80, + 0x96, 0x7d, 0x94, 0x71, 0x8a, 0x90, 0xac, 0x82, 0xa5, 0x61, 0xa3, 0x84, + 0xac, 0x8f, 0x74, 0x5c, 0xb6, 0x77, 0x8b, 0x9b, 0xb5, 0x8b, 0xb6, 0x52, + 0xd7, 0xaa, 0x4b, 0x8c, 0xbf, 0xb8, 0x9f, 0x6d, 0xcb, 0xa3, 0x6e, 0x97, + 0xaa, 0x8d, 0x7c, 0x99, 0xc0, 0xd0, 0x9e, 0xb7, 0x93, 0xaa, 0x5a, 0x6a, + 0x7d, 0x9a, 0x63, 0x71, 0x78, 0x8c, 0x67, 0x43, 0x87, 0x52, 0x64, 0x68, + 0x68, 0x9c, 0x65, 0x60, 0x7a, 0x35, 0x68, 0x66, 0x63, 0x69, 0x8d, 0x8f, + 0x72, 0x9b, 0x99, 0x5b, 0x80, 0x67, 0x93, 0xa2, 0x97, 0x9d, 0x8c, 0x68, + 0x80, 0x86, 0x96, 0x91, 0x64, 0xbf, 0x98, 0x63, 0x83, 0x85, 0x61, 0x97, + 0x6a, 0xac, 0xb4, 0x99, 0x8d, 0x7b, 0x7b, 0xad, 0x8b, 0xb2, 0x9e, 0x7f, + 0x9a, 0x73, 0x91, 0x84, 0x89, 0x9f, 0x8a, 0x87, 0x8b, 0x72, 0x8e, 0x79, + 0x86, 0xa7, 0x77, 0x84, 0x90, 0x58, 0xb2, 0x90, 0x93, 0xa0, 0x7f, 0x8a, + 0x91, 0x5a, 0xb1, 0x80, 0x99, 0xc1, 0x80, 0x7d, 0x97, 0x5c, 0x9a, 0x8c, + 0x71, 0x96, 0x7e, 0x7f, 0xad, 0x7b, 0xb9, 0x8a, 0x84, 0x84, 0x81, 0x97, + 0x94, 0x64, 0x9f, 0x7e, 0x9b, 0x8d, 0x7d, 0x8d, 0x9a, 0x9e, 0xac, 0x72, + 0xb2, 0x73, 0x81, 0x84, 0xc8, 0x81, 0x88, 0x72, 0xbe, 0x85, 0x86, 0x97, + 0xd3, 0x8a, 0xc7, 0x75, 0xce, 0x9c, 0x69, 0xa6, 0xb0, 0xa1, 0x8e, 0x64, + 0xb1, 0xa6, 0x67, 0xaa, 0xcd, 0x95, 0x97, 0xa2, 0xb2, 0xb2, 0x85, 0x9a, + 0x9d, 0xa3, 0x5e, 0x73, 0x6e, 0xae, 0x50, 0x83, 0x8c, 0xab, 0x92, 0x43, + 0x6b, 0x66, 0x43, 0x5c, 0x8f, 0x8a, 0x9a, 0x6c, 0x84, 0x48, 0x80, 0x6b, + 0x8d, 0x82, 0xaf, 0x89, 0x71, 0x9f, 0xa4, 0x9a, 0x7b, 0x68, 0x91, 0xaa, + 0x6b, 0xa3, 0x9c, 0x62, 0x8d, 0x6d, 0x87, 0x87, 0x81, 0x9a, 0x97, 0x6c, + 0x9c, 0x76, 0x63, 0xbc, 0x62, 0xbc, 0xb0, 0x97, 0xa7, 0x81, 0x70, 0x8f, + 0x7d, 0xb2, 0xa6, 0x98, 0xa1, 0x7b, 0x8e, 0x83, 0x8c, 0xa2, 0x7e, 0x73, + 0x99, 0x65, 0xc1, 0x77, 0x8e, 0xbc, 0x72, 0xa6, 0x8c, 0x55, 0xab, 0x8e, + 0x7d, 0xa3, 0x79, 0x80, 0x9e, 0x6b, 0xa9, 0x6c, 0x80, 0xb6, 0x81, 0xa6, + 0x92, 0x5b, 0xb7, 0x99, 0x81, 0x7e, 0x8e, 0x89, 0x97, 0x86, 0x93, 0x86, + 0x7b, 0x9a, 0x7f, 0x9a, 0x8e, 0x69, 0xa3, 0xa4, 0x9f, 0x8b, 0x96, 0x6f, + 0x8b, 0x97, 0xb4, 0x74, 0x96, 0x53, 0x99, 0x91, 0xa7, 0xa8, 0x69, 0x72, + 0xc9, 0x85, 0x99, 0x93, 0xc0, 0x90, 0xaa, 0x7f, 0xc7, 0x71, 0x74, 0x8d, + 0xb7, 0xab, 0x91, 0x69, 0xb4, 0x9b, 0x7d, 0x95, 0xc3, 0xb0, 0x9b, 0xa9, + 0xb3, 0x9f, 0x79, 0xa5, 0x9f, 0xad, 0x6b, 0x85, 0x90, 0xad, 0x69, 0x62, + 0x7e, 0xa6, 0x69, 0x4e, 0x80, 0x7e, 0x52, 0x57, 0x5f, 0x95, 0x72, 0x4c, + 0x87, 0x4e, 0x5a, 0x62, 0x7d, 0x70, 0x92, 0x98, 0x76, 0x8e, 0x99, 0x7d, + 0x73, 0x6d, 0x86, 0x8e, 0x6b, 0x80, 0xa7, 0x9d, 0x91, 0x73, 0x95, 0x70, + 0x80, 0xc3, 0x9f, 0x8b, 0x72, 0x86, 0x6b, 0xad, 0x76, 0xbe, 0xad, 0x8e, + 0x9c, 0x78, 0x6a, 0xbf, 0x7d, 0xa8, 0x88, 0x8a, 0x8b, 0x8c, 0x9c, 0x8c, + 0x8a, 0x85, 0x73, 0x92, 0xa2, 0x7b, 0xa5, 0x96, 0x9b, 0xa3, 0x6c, 0x80, + 0xa6, 0x63, 0xac, 0x98, 0xa3, 0x9a, 0x83, 0x8a, 0x8c, 0x63, 0xb9, 0x8c, + 0x99, 0xa1, 0x7a, 0x6c, 0x9e, 0x59, 0x90, 0x84, 0x8a, 0x93, 0x8f, 0x87, + 0x98, 0x84, 0x99, 0xa4, 0x72, 0x6d, 0x95, 0xa2, 0x95, 0x72, 0xc3, 0x88, + 0x8f, 0x6a, 0x77, 0x7d, 0x8b, 0xae, 0xa3, 0x7c, 0xa8, 0x5d, 0x7c, 0xa8, + 0xa1, 0x85, 0x7e, 0x8c, 0xac, 0x8d, 0x73, 0x88, 0xc1, 0x89, 0xaa, 0x89, + 0xb2, 0x92, 0x75, 0x9a, 0x9c, 0x8e, 0xb9, 0xaa, 0xaa, 0xac, 0x78, 0x85, + 0xbc, 0x9f, 0x6d, 0xb7, 0x89, 0xa6, 0xb3, 0x8e, 0xa5, 0xbb, 0x6b, 0x9d, + 0x8f, 0x8b, 0x69, 0x7a, 0x82, 0x99, 0x8c, 0x49, 0x87, 0x74, 0x37, 0x63, + 0x5d, 0x92, 0x77, 0x66, 0x63, 0x56, 0x77, 0x5d, 0x7f, 0x68, 0x97, 0x74, + 0x84, 0x94, 0x7d, 0x7d, 0x91, 0x78, 0x87, 0x96, 0x7f, 0x97, 0x94, 0x6f, + 0x89, 0x6c, 0x96, 0x71, 0x83, 0x8f, 0x8a, 0x89, 0x7d, 0x84, 0x8a, 0xa6, + 0x7b, 0x95, 0x89, 0x77, 0x94, 0x80, 0x7f, 0x93, 0x5e, 0xbb, 0x9c, 0xa8, + 0xa2, 0x7e, 0xa6, 0x86, 0x7d, 0x8b, 0x92, 0x73, 0xac, 0x78, 0xaa, 0x98, + 0xb1, 0x94, 0x79, 0x8b, 0x8f, 0x70, 0xa7, 0xae, 0x92, 0xad, 0xb1, 0x8b, + 0xb0, 0x78, 0xbc, 0xa9, 0xa4, 0xa3, 0x9e, 0x76, 0x89, 0x67, 0xab, 0x98, + 0x75, 0x8c, 0x86, 0x95, 0x9e, 0x77, 0x96, 0x85, 0x8c, 0x8e, 0x8b, 0x8a, + 0x8a, 0x4b, 0x71, 0x8a, 0x9b, 0x6d, 0x6e, 0x89, 0x81, 0x82, 0xa7, 0x98, + 0xa5, 0x66, 0x72, 0x8b, 0x99, 0x9a, 0x8b, 0x8b, 0x9f, 0x87, 0x79, 0x84, + 0x99, 0x6d, 0x90, 0x7d, 0x9d, 0xa7, 0x81, 0xa3, 0x9d, 0x96, 0x82, 0x86, + 0xa2, 0x8e, 0x8d, 0x7f, 0x84, 0x8c, 0x98, 0xbc, 0x83, 0xb4, 0xb5, 0x78, + 0x7d, 0xab, 0x8d, 0x87, 0x71, 0x8d, 0x6e, 0x8f, 0x89, 0xaa, 0x7c, 0x6f, + 0x71, 0x69, 0x65, 0x60, 0x81, 0x91, 0x94, 0x6d, 0x76, 0x66, 0x74, 0x5e, + 0x77, 0x7c, 0xa2, 0xa6, 0x70, 0x90, 0xa3, 0x68, 0x83, 0x69, 0x71, 0x72, + 0x6c, 0xa9, 0x85, 0x71, 0x88, 0x60, 0x90, 0x84, 0x8a, 0xba, 0x8b, 0x8c, + 0x72, 0x8f, 0x98, 0x84, 0x8b, 0x8a, 0xb1, 0xa2, 0x93, 0x8d, 0x86, 0x99, + 0xa2, 0x99, 0xb0, 0xa6, 0x92, 0x78, 0x86, 0x87, 0x9c, 0x9d, 0x6f, 0x92, + 0x9a, 0x8a, 0xbf, 0xaa, 0xa3, 0xa2, 0x71, 0x8d, 0x93, 0x70, 0xb5, 0x9c, + 0xa8, 0x97, 0xb4, 0x93, 0xa6, 0x75, 0xbb, 0xa3, 0x92, 0x95, 0x95, 0x94, + 0x90, 0x5b, 0xbf, 0x92, 0x8a, 0x95, 0xa0, 0xa1, 0x68, 0x7e, 0x9a, 0x7f, + 0x88, 0xa7, 0x93, 0xa1, 0x7a, 0x93, 0x95, 0x8b, 0x96, 0x94, 0x70, 0xa0, + 0x70, 0x8f, 0x9d, 0x96, 0x8e, 0x9c, 0x90, 0x9f, 0x7e, 0x83, 0x84, 0x9e, + 0x7f, 0x65, 0x72, 0x84, 0x64, 0x94, 0x75, 0xa7, 0x62, 0xa3, 0x8a, 0x9b, + 0x82, 0x99, 0x87, 0x70, 0x81, 0x6d, 0xac, 0x7b, 0x74, 0x68, 0x5d, 0x95, + 0xa0, 0x6e, 0x84, 0xab, 0x79, 0x8e, 0x8b, 0x79, 0x7b, 0x83, 0xa0, 0x7b, + 0x96, 0x71, 0x5d, 0xad, 0xa4, 0x82, 0x79, 0x96, 0x73, 0x84, 0x7d, 0x98, + 0x87, 0x93, 0x86, 0xa6, 0x7f, 0x7c, 0x71, 0x9d, 0xa4, 0x9b, 0x8a, 0x7c, + 0x87, 0x6a, 0x7f, 0x8d, 0x97, 0x92, 0xa0, 0x88, 0x77, 0x7d, 0x70, 0x9c, + 0x9f, 0xa0, 0x71, 0xa3, 0x73, 0x95, 0x76, 0x79, 0x94, 0x95, 0x83, 0x8b, + 0x8d, 0x82, 0x7a, 0x77, 0xa6, 0x88, 0x72, 0x7a, 0x90, 0x76, 0x7f, 0x95, + 0x83, 0x90, 0x9e, 0x7c, 0x8e, 0x9a, 0x6b, 0xa4, 0x98, 0x9f, 0x86, 0x8c, + 0x76, 0x70, 0x74, 0x97, 0x7e, 0xa4, 0x5f, 0xa3, 0xa7, 0x7f, 0x67, 0x8d, + 0x82, 0x95, 0x93, 0x99, 0x82, 0x70, 0x75, 0xa8, 0xa1, 0xaf, 0x8a, 0x8a, + 0xb0, 0x89, 0x88, 0x6b, 0x98, 0xaf, 0x75, 0x7f, 0x86, 0x90, 0x8f, 0x8c, + 0x84, 0x8d, 0x7f, 0x8b, 0x94, 0x9f, 0x80, 0x8b, 0x93, 0xa2, 0x98, 0xa5, + 0x83, 0x81, 0x8a, 0xaa, 0x86, 0xa3, 0xb0, 0xac, 0x64, 0x9c, 0x7c, 0x93, + 0xac, 0x85, 0x7f, 0x88, 0x7a, 0xa5, 0x75, 0x69, 0x94, 0xa8, 0x95, 0xa9, + 0x6f, 0x9f, 0x85, 0x8a, 0xa5, 0x97, 0x98, 0xa9, 0x76, 0x80, 0x7e, 0x95, + 0x89, 0xaf, 0x68, 0x7b, 0xb4, 0x8a, 0x6b, 0xa4, 0x7b, 0x90, 0x79, 0xba, + 0x9f, 0x82, 0x7d, 0x89, 0x85, 0x82, 0x94, 0xa5, 0x78, 0x8f, 0x6f, 0x71, + 0x62, 0x66, 0x73, 0x98, 0x8c, 0x7d, 0x81, 0xa2, 0x69, 0x7c, 0x76, 0xa4, + 0x94, 0x8f, 0x6f, 0x8a, 0x94, 0x8e, 0x8a, 0x88, 0x8c, 0xa3, 0x6f, 0xa2, + 0x7d, 0x90, 0x8f, 0x96, 0x6c, 0x76, 0x6e, 0x8e, 0x82, 0x85, 0x7f, 0x93, + 0x81, 0x83, 0x7b, 0x9f, 0x91, 0x89, 0x75, 0x9c, 0x9f, 0x86, 0x7a, 0x8c, + 0x7a, 0x7b, 0x82, 0xae, 0x6a, 0x7d, 0x82, 0x82, 0xa0, 0x85, 0x99, 0x9f, + 0x88, 0x8b, 0x8c, 0x8f, 0x90, 0x96, 0x8e, 0x98, 0xa3, 0x87, 0x7f, 0x9b, + 0x94, 0x73, 0x96, 0x86, 0x72, 0x7c, 0x75, 0x7c, 0x90, 0x79, 0x83, 0x80, + 0x79, 0x9e, 0x9c, 0x8e, 0x99, 0x8c, 0x7a, 0x9c, 0x8d, 0x99, 0x9d, 0x84, + 0xa5, 0x93, 0x85, 0x96, 0x88, 0x94, 0x80, 0x90, 0x73, 0xa3, 0x7c, 0xa1, + 0x88, 0xa4, 0x98, 0x9f, 0x9e, 0x92, 0x6c, 0xa0, 0x84, 0x87, 0x8a, 0x83, + 0x7b, 0x91, 0x8c, 0x9e, 0x73, 0xa6, 0x93, 0xa0, 0x8d, 0x98, 0x74, 0xa1, + 0x83, 0x9a, 0x80, 0xbc, 0x62, 0x70, 0x9e, 0xad, 0x9e, 0x8f, 0x8f, 0x9e, + 0x7e, 0xac, 0xb0, 0xa9, 0x79, 0x6f, 0x79, 0x8f, 0x7e, 0x71, 0x8d, 0xab, + 0x97, 0x76, 0x86, 0xa2, 0x98, 0x95, 0x8b, 0x9b, 0x75, 0x7a, 0x71, 0x85, + 0x7f, 0x61, 0x76, 0x8e, 0x99, 0x91, 0x88, 0x73, 0x71, 0x65, 0x82, 0xa0, + 0x9b, 0x8f, 0x79, 0x70, 0x78, 0x66, 0x85, 0x94, 0x8b, 0x91, 0x75, 0x80, + 0x9c, 0x94, 0x7f, 0xa5, 0x82, 0x91, 0x7d, 0x76, 0x80, 0x78, 0x83, 0x82, + 0x79, 0x98, 0x83, 0x87, 0x94, 0x71, 0x73, 0x77, 0x71, 0x94, 0x6a, 0xa8, + 0x9e, 0x8d, 0x90, 0x78, 0x7a, 0x81, 0x9c, 0x91, 0x96, 0x80, 0x79, 0x83, + 0x92, 0x9f, 0x8a, 0x84, 0x8e, 0x97, 0x8c, 0x81, 0x87, 0x74, 0x8b, 0x8e, + 0xa7, 0x86, 0x8b, 0x8a, 0x8e, 0x8f, 0x9b, 0x6b, 0x82, 0x8a, 0x9f, 0x7a, + 0x96, 0x80, 0x91, 0x94, 0xa6, 0x8e, 0x7a, 0x97, 0x8a, 0x6c, 0xad, 0xa1, + 0x78, 0x95, 0x9d, 0x9d, 0x88, 0x94, 0x99, 0x86, 0x80, 0x9b, 0x7c, 0x9c, + 0x87, 0x7a, 0xa0, 0xa8, 0x83, 0x74, 0x8e, 0x9b, 0x65, 0x95, 0x83, 0xc2, + 0x69, 0x88, 0x87, 0xa7, 0x86, 0x98, 0x9f, 0xc6, 0x5c, 0x7f, 0xb9, 0x9c, + 0x8b, 0x6e, 0x95, 0xbd, 0x72, 0x83, 0xbf, 0xb1, 0x89, 0x6d, 0x89, 0x8e, + 0x9d, 0x87, 0x95, 0x92, 0x76, 0x8d, 0x7f, 0x7f, 0x6d, 0x9d, 0x7b, 0x95, + 0x86, 0x69, 0x90, 0xa0, 0x62, 0x7c, 0x56, 0xa0, 0x9c, 0x8b, 0x81, 0x79, + 0xa6, 0x73, 0x69, 0xaa, 0x7b, 0x87, 0x8b, 0x7e, 0xa1, 0x9f, 0x6d, 0xa6, + 0x7e, 0x7e, 0x87, 0x7c, 0xa5, 0x84, 0x7b, 0xa2, 0xae, 0x92, 0x8e, 0x67, + 0x93, 0x88, 0x8b, 0xa2, 0x8d, 0x96, 0x92, 0x8e, 0x71, 0x7a, 0x82, 0x80, + 0x9e, 0x8b, 0x7b, 0x87, 0x96, 0xa0, 0xa4, 0x92, 0x88, 0x7e, 0x77, 0x8e, + 0x91, 0x7e, 0x81, 0x77, 0x79, 0x93, 0x8d, 0x9d, 0x8a, 0x71, 0x8d, 0x88, + 0x9d, 0x89, 0x85, 0x94, 0x99, 0x80, 0x89, 0x8f, 0x87, 0x81, 0x83, 0x74, + 0x8a, 0x89, 0x68, 0x7e, 0x99, 0x82, 0x8c, 0x76, 0xc6, 0x8f, 0x90, 0x7d, + 0x6c, 0x68, 0xbd, 0x90, 0x78, 0x9d, 0x7b, 0xa3, 0x99, 0x76, 0xaf, 0x8d, + 0x7d, 0x84, 0x7f, 0x9f, 0x8b, 0x7a, 0xaa, 0xa8, 0x79, 0x89, 0x8f, 0x8f, + 0x71, 0x80, 0x7f, 0xaa, 0x85, 0x70, 0xa8, 0x96, 0x6c, 0x8c, 0xaf, 0xeb, + 0x57, 0x7e, 0xcf, 0x8d, 0x93, 0x72, 0xa6, 0xd2, 0x52, 0xab, 0xbb, 0xa8, + 0x8d, 0x82, 0x7a, 0xbc, 0x72, 0x95, 0xa3, 0xa7, 0x8b, 0x74, 0x84, 0x85, + 0x6a, 0x85, 0x92, 0x9f, 0x91, 0x6b, 0x9b, 0x73, 0x77, 0xa2, 0x7f, 0x81, + 0x8e, 0x8b, 0x71, 0x8c, 0x7f, 0x60, 0x86, 0x81, 0x9c, 0x86, 0x93, 0x65, + 0x84, 0x84, 0x89, 0xa2, 0x98, 0x67, 0x88, 0x71, 0x92, 0x80, 0x65, 0xa2, + 0xa5, 0x99, 0x85, 0x95, 0x8f, 0x85, 0x8f, 0x82, 0x7e, 0x9a, 0x8a, 0x74, + 0x9d, 0x75, 0x88, 0x7e, 0xa2, 0x77, 0x82, 0x9e, 0x78, 0xa1, 0x74, 0x79, + 0x7f, 0x87, 0x91, 0x8d, 0x7a, 0x73, 0x96, 0xa2, 0xa3, 0x81, 0x7d, 0x8a, + 0x85, 0x75, 0x84, 0x81, 0x8b, 0x7f, 0x6c, 0x86, 0x8d, 0x7b, 0x79, 0x78, + 0x89, 0x85, 0x8c, 0x9a, 0xa6, 0x96, 0x7a, 0x78, 0xa2, 0x85, 0x9b, 0x89, + 0xc8, 0x97, 0xa3, 0x82, 0x8b, 0x7f, 0xe7, 0x8f, 0x8f, 0x74, 0x75, 0x83, + 0x87, 0x79, 0xb3, 0xab, 0x70, 0x9a, 0x9a, 0xa6, 0x81, 0x7e, 0xb8, 0x91, + 0x8b, 0x8d, 0x93, 0xa1, 0x79, 0x7d, 0x81, 0xb4, 0x79, 0x94, 0xa5, 0x89, + 0x8e, 0x7c, 0x9b, 0xe2, 0x50, 0x94, 0xdf, 0xa0, 0x53, 0x5d, 0x90, 0xde, + 0x67, 0x90, 0xaf, 0x8a, 0x8f, 0x73, 0x7b, 0xcb, 0x64, 0x9f, 0x91, 0x86, + 0x95, 0x84, 0x83, 0x88, 0x76, 0x8b, 0x8a, 0x8f, 0x9c, 0x9a, 0x92, 0x96, + 0x7f, 0x8e, 0x79, 0x80, 0x91, 0x6d, 0x86, 0x59, 0x74, 0x8a, 0x53, 0x88, + 0xae, 0x7b, 0x80, 0x70, 0x87, 0x74, 0x75, 0x91, 0xa4, 0x74, 0x8d, 0x5a, + 0x83, 0x95, 0x65, 0xa1, 0xb3, 0x74, 0x87, 0x7d, 0xaa, 0x82, 0x79, 0x78, + 0x9b, 0x7c, 0x78, 0x74, 0x9e, 0x74, 0x92, 0x92, 0xa3, 0x6e, 0x75, 0x92, + 0x6a, 0x6f, 0xa3, 0x7c, 0x9e, 0x7f, 0x92, 0x6b, 0x96, 0x79, 0x9a, 0x87, + 0x83, 0x8c, 0x72, 0x79, 0x6a, 0xa3, 0x79, 0x7d, 0x6d, 0x6c, 0x81, 0x96, + 0x98, 0x7f, 0x94, 0x81, 0x8a, 0x8a, 0xa7, 0x8c, 0x9a, 0x84, 0xa7, 0x89, + 0x9d, 0x85, 0xa6, 0xa8, 0xd0, 0x92, 0x97, 0x9f, 0x76, 0x86, 0xe6, 0x6f, + 0x7c, 0x84, 0x98, 0x8d, 0x80, 0x75, 0xc5, 0x86, 0x6b, 0x8d, 0x9e, 0x9e, + 0x7f, 0x71, 0x97, 0xa1, 0x75, 0x92, 0xa9, 0x9e, 0x91, 0x5e, 0xa2, 0xa2, + 0x68, 0xad, 0xa5, 0xa0, 0x7e, 0x68, 0xac, 0xdc, 0x50, 0xa2, 0xc1, 0x8a, + 0x63, 0x74, 0x7e, 0xd9, 0x3f, 0xbb, 0xba, 0x9d, 0x7f, 0x76, 0x5f, 0xb0, + 0x74, 0x8e, 0xb1, 0x95, 0x9a, 0x81, 0x63, 0x9f, 0x98, 0x74, 0x80, 0x89, + 0x95, 0x8e, 0x9e, 0x78, 0x87, 0x82, 0x57, 0x87, 0x8d, 0x90, 0x79, 0x80, + 0x76, 0x7c, 0x7d, 0x8a, 0xa6, 0x82, 0x98, 0x7a, 0x96, 0x97, 0x84, 0x87, + 0xab, 0x7f, 0x87, 0x57, 0x83, 0x6a, 0x6a, 0x84, 0x9c, 0x8d, 0x74, 0x68, + 0xa2, 0x92, 0x90, 0x98, 0x98, 0x8b, 0x6d, 0x72, 0x90, 0x8c, 0x7c, 0x7d, + 0x9b, 0x6e, 0x71, 0x76, 0x6b, 0x7b, 0x63, 0x81, 0xad, 0x71, 0x78, 0x8e, + 0x74, 0x87, 0x8e, 0x8a, 0xab, 0x8e, 0x83, 0x85, 0x7d, 0xa0, 0x67, 0x7f, + 0x9c, 0x74, 0x6b, 0x88, 0x66, 0x92, 0x7f, 0x83, 0x94, 0x92, 0xa5, 0x82, + 0xa1, 0x7b, 0x6f, 0x70, 0xab, 0x72, 0xb5, 0x91, 0xb7, 0x89, 0x91, 0x77, + 0x77, 0x8a, 0xdb, 0x88, 0x8a, 0x8d, 0x89, 0x6c, 0x7b, 0x83, 0xc8, 0xb5, + 0x4b, 0x96, 0x8b, 0x92, 0x91, 0x76, 0xa9, 0xae, 0x70, 0xa8, 0x74, 0x9d, + 0x96, 0x6d, 0xa1, 0xba, 0x86, 0xbc, 0xbc, 0xa2, 0x8d, 0x6c, 0x96, 0xd8, + 0x71, 0xb1, 0xae, 0xb0, 0x79, 0x7b, 0x71, 0xd8, 0x32, 0xaa, 0xae, 0xa7, + 0x7c, 0x6b, 0x77, 0xc0, 0x7c, 0x9e, 0x9f, 0x89, 0x92, 0x8a, 0x76, 0xae, + 0x97, 0x75, 0x87, 0x8c, 0x7f, 0x86, 0x8b, 0x73, 0x6b, 0x64, 0x87, 0x6d, + 0x99, 0x8f, 0x8d, 0x66, 0x76, 0x87, 0x6d, 0x6e, 0x98, 0x7a, 0x91, 0x92, + 0x8c, 0x7c, 0x89, 0x9b, 0x9e, 0x83, 0x86, 0x62, 0x90, 0x6e, 0x62, 0x82, + 0xa3, 0x7e, 0x86, 0x6a, 0x93, 0x9b, 0x73, 0x6c, 0xa8, 0x99, 0x73, 0x99, + 0x8c, 0x89, 0x85, 0x67, 0x98, 0x78, 0x63, 0x98, 0x77, 0xa6, 0x6e, 0x81, + 0xa4, 0x64, 0x8f, 0x8a, 0x7f, 0x9b, 0x91, 0x91, 0x94, 0x82, 0x8b, 0x8b, + 0x76, 0x66, 0x83, 0x81, 0x94, 0x71, 0x82, 0x9e, 0x93, 0x85, 0x80, 0x8c, + 0xae, 0x94, 0x96, 0x74, 0x91, 0x9a, 0x6f, 0x9e, 0xa9, 0x76, 0xab, 0x8e, + 0xd6, 0x9c, 0x7d, 0x98, 0x83, 0x6e, 0xfe, 0x83, 0x71, 0x82, 0x9f, 0x93, + 0x7b, 0x67, 0xcb, 0xb9, 0x66, 0x89, 0x99, 0x8a, 0xac, 0x8c, 0xa0, 0x9c, + 0x70, 0xaf, 0x81, 0x88, 0x9c, 0x7e, 0xa8, 0xa5, 0x65, 0x8c, 0xa1, 0x8c, + 0x83, 0x85, 0x9d, 0xcb, 0x4b, 0xc1, 0xb5, 0xa2, 0x75, 0x63, 0x75, 0xbd, + 0x34, 0xae, 0xca, 0xa2, 0x89, 0x7a, 0x69, 0xb0, 0x70, 0xae, 0x94, 0x76, + 0x85, 0x93, 0x6a, 0x90, 0x6a, 0x8a, 0xac, 0x71, 0x7e, 0x81, 0xa2, 0x71, + 0x98, 0x86, 0x99, 0x76, 0x8f, 0x6f, 0x90, 0x93, 0x7c, 0x72, 0x81, 0x8c, + 0x78, 0x77, 0x97, 0x84, 0x98, 0x70, 0x96, 0x9a, 0x9b, 0x93, 0x92, 0x5f, + 0xaa, 0x88, 0x5b, 0x74, 0xaa, 0x96, 0x6a, 0x73, 0x87, 0x83, 0x72, 0x89, + 0xab, 0x8a, 0x5f, 0x71, 0xa4, 0x94, 0x92, 0x60, 0x96, 0x7b, 0x53, 0x88, + 0x69, 0x8b, 0x5e, 0x7b, 0xa0, 0x83, 0x70, 0x95, 0x6d, 0x9b, 0x6d, 0x98, + 0x99, 0x86, 0x6e, 0x7a, 0x87, 0x86, 0x68, 0x8a, 0x7e, 0x87, 0x90, 0x7d, + 0x76, 0x93, 0x80, 0x8a, 0x8f, 0x97, 0xac, 0x71, 0xa2, 0x96, 0x7f, 0x8e, + 0xc2, 0x71, 0xab, 0xa9, 0xd1, 0x85, 0x8c, 0x74, 0x70, 0x72, 0xff, 0x77, + 0x6d, 0x77, 0x91, 0x5d, 0x71, 0x5d, 0xb2, 0xb1, 0x38, 0x76, 0xa6, 0x80, + 0x91, 0x86, 0xa3, 0x9c, 0x85, 0x95, 0x99, 0xab, 0x8a, 0x6e, 0x9f, 0xa6, + 0x75, 0xa9, 0xb3, 0x97, 0x69, 0x85, 0xa4, 0xc9, 0x59, 0xb4, 0xca, 0x8d, + 0x5c, 0x67, 0x7d, 0xcd, 0x29, 0xca, 0xdb, 0x8c, 0x86, 0x8c, 0x70, 0xaa, + 0x5c, 0x9e, 0x98, 0x86, 0x92, 0x7e, 0x6b, 0x8e, 0x8f, 0x6a, 0x84, 0x71, + 0x9a, 0x76, 0x87, 0x84, 0x8b, 0x7f, 0x7f, 0x6e, 0xa3, 0x83, 0x85, 0x78, + 0x6f, 0x7c, 0x6f, 0x96, 0x95, 0x8c, 0xa3, 0x72, 0x92, 0x66, 0x7b, 0x99, + 0x9c, 0x9c, 0x9a, 0x63, 0xaa, 0x81, 0x7f, 0x90, 0x8c, 0xa0, 0x7e, 0x67, + 0x94, 0x96, 0x7f, 0x8a, 0x95, 0x91, 0x5c, 0x73, 0x88, 0x9b, 0x85, 0x70, + 0x87, 0x79, 0x56, 0x92, 0x69, 0x95, 0x62, 0x78, 0x93, 0x83, 0x63, 0x98, + 0x7a, 0xa4, 0x95, 0x7c, 0x8e, 0x69, 0x86, 0x92, 0x7d, 0x6b, 0x69, 0x85, + 0xa8, 0x90, 0x7c, 0x7b, 0x9e, 0x87, 0x7b, 0x90, 0x98, 0x7a, 0xa4, 0x92, + 0xad, 0x97, 0xa0, 0x6d, 0xa6, 0x74, 0xb7, 0x7f, 0xb9, 0x94, 0x6c, 0x77, + 0x65, 0x6f, 0xfc, 0x7d, 0x68, 0x74, 0xa1, 0x6c, 0x71, 0x61, 0xc3, 0xb5, + 0x60, 0x86, 0x8b, 0x7d, 0x89, 0x8b, 0x93, 0xa4, 0x68, 0xa0, 0x8f, 0x73, + 0x96, 0x6e, 0x81, 0x99, 0x81, 0x9d, 0xae, 0x93, 0x6a, 0x8b, 0x9a, 0xcb, + 0x68, 0xaf, 0xca, 0x81, 0x73, 0x6e, 0x70, 0xd7, 0x49, 0xb9, 0xc5, 0x9d, + 0x87, 0x8d, 0x61, 0xa8, 0x5e, 0xa4, 0xb7, 0xab, 0x96, 0x84, 0x76, 0x98, + 0x84, 0x99, 0x8f, 0x70, 0x79, 0x94, 0xa5, 0x87, 0x6e, 0x73, 0x63, 0x7e, + 0x83, 0x8c, 0x88, 0x71, 0x7a, 0x81, 0x7d, 0x94, 0x92, 0x89, 0xab, 0x7a, + 0x96, 0x66, 0x7b, 0x8b, 0x8f, 0x8e, 0x94, 0x5b, 0xa0, 0x7f, 0x82, 0x84, + 0x84, 0x80, 0x7d, 0x81, 0x89, 0x7b, 0x97, 0x78, 0x83, 0x93, 0x4c, 0x95, + 0x7f, 0x93, 0x8e, 0x70, 0x89, 0x81, 0x69, 0x87, 0x76, 0x73, 0x9a, 0x74, + 0xa2, 0x88, 0x5e, 0xac, 0x74, 0x8e, 0x74, 0x8e, 0x94, 0x85, 0x7b, 0x7a, + 0x72, 0x82, 0x68, 0x77, 0x96, 0x8a, 0x7b, 0x6c, 0x88, 0x8b, 0x6b, 0x86, + 0xa4, 0x88, 0xac, 0xa1, 0x90, 0x8e, 0x85, 0x6d, 0xb1, 0x69, 0xb1, 0xa2, + 0xbe, 0x9a, 0x7c, 0xb4, 0x63, 0x56, 0xf2, 0x90, 0x5e, 0x71, 0xa3, 0x6a, + 0x8b, 0x67, 0xbe, 0xa8, 0x6e, 0x8b, 0x90, 0x83, 0xa0, 0x78, 0x9f, 0xa5, + 0x65, 0xa3, 0x8b, 0x94, 0x84, 0x6c, 0xa5, 0x97, 0x7d, 0xa7, 0x9f, 0x9c, + 0x62, 0x7d, 0xb5, 0xb1, 0x58, 0x98, 0xba, 0x8d, 0x7f, 0x57, 0x86, 0xc5, + 0x39, 0xb3, 0xc9, 0xa9, 0x89, 0x8e, 0x55, 0xaf, 0x54, 0xb4, 0xb0, 0x8f, + 0x8b, 0x7c, 0x6e, 0x8e, 0x96, 0x90, 0x8a, 0x83, 0x84, 0x8c, 0x96, 0x7f, + 0x89, 0x67, 0x99, 0x60, 0x74, 0x8d, 0x9b, 0x82, 0x6f, 0x61, 0x84, 0x9a, + 0x7c, 0x85, 0x86, 0x7c, 0x9b, 0x5f, 0x81, 0x96, 0x90, 0x9b, 0xa0, 0x58, + 0xaf, 0x78, 0x81, 0x8f, 0x96, 0x81, 0x77, 0x7d, 0xa2, 0x85, 0x74, 0x84, + 0x99, 0x8d, 0x5f, 0x77, 0x8a, 0x8c, 0x85, 0x78, 0x8f, 0x80, 0x5c, 0x6f, + 0x77, 0x73, 0x80, 0x99, 0x83, 0x89, 0x6f, 0x8e, 0x85, 0x7e, 0x6c, 0x81, + 0x99, 0x89, 0x69, 0x70, 0x8c, 0x8f, 0x6b, 0x89, 0x80, 0x7a, 0x83, 0x7a, + 0x96, 0x99, 0x73, 0x76, 0x9c, 0x67, 0xab, 0xab, 0xbd, 0x8b, 0x85, 0x90, + 0xb0, 0x6b, 0xbd, 0x9c, 0xb9, 0xa0, 0x7c, 0x7d, 0x66, 0x78, 0xdb, 0x97, + 0x55, 0x67, 0x96, 0x69, 0x80, 0x49, 0xc1, 0xbb, 0x6c, 0x91, 0x8a, 0x92, + 0x9a, 0x98, 0xa5, 0x98, 0x51, 0xa6, 0x99, 0x8e, 0x73, 0x73, 0x9d, 0x9f, + 0x77, 0xa6, 0xa4, 0x92, 0x64, 0x75, 0xac, 0xb2, 0x5d, 0xa1, 0xab, 0xa4, + 0x5a, 0x5b, 0xb3, 0xb7, 0x2d, 0xca, 0xc8, 0x76, 0x94, 0x8e, 0x59, 0xb0, + 0x52, 0x9d, 0xbd, 0x89, 0x97, 0x84, 0x5d, 0x9a, 0x87, 0x9b, 0x94, 0x6c, + 0x7b, 0xaa, 0x8a, 0x8b, 0x79, 0x5d, 0x90, 0x5c, 0x8b, 0x7b, 0xbe, 0x68, + 0x84, 0x6f, 0x75, 0x72, 0x98, 0x82, 0x92, 0x7a, 0xa2, 0x6e, 0x7b, 0x7d, + 0x9c, 0x99, 0x97, 0x5d, 0x9b, 0x69, 0x80, 0xa3, 0x96, 0x8d, 0x7c, 0x82, + 0xa3, 0x76, 0x95, 0x67, 0x93, 0x8e, 0x62, 0x7b, 0x78, 0x96, 0x69, 0x67, + 0x84, 0x8f, 0x62, 0x80, 0x88, 0x7e, 0x6c, 0x94, 0xab, 0x8b, 0x82, 0x9e, + 0x7e, 0x8c, 0x70, 0x83, 0x9c, 0x9c, 0x80, 0x87, 0x8f, 0xa1, 0x7f, 0x81, + 0x95, 0x83, 0x6d, 0x7a, 0xa0, 0x77, 0x6d, 0x76, 0x91, 0x7e, 0xa3, 0x62, + 0xa0, 0x93, 0x7e, 0x97, 0xb6, 0x6c, 0xad, 0x72, 0xb2, 0x95, 0x73, 0x83, + 0x62, 0x56, 0xe2, 0x99, 0x6e, 0x66, 0xb0, 0x6c, 0x75, 0x4e, 0xb2, 0xc7, + 0x51, 0x98, 0x90, 0x8c, 0x82, 0x63, 0xa8, 0x99, 0x54, 0xc1, 0x87, 0x80, + 0x79, 0x62, 0xad, 0x81, 0x76, 0x99, 0xa9, 0x9b, 0x4e, 0x8c, 0xaf, 0xb6, + 0x5d, 0x9b, 0xb4, 0x9f, 0x6d, 0x60, 0xa5, 0xb5, 0x3e, 0xb2, 0xc4, 0x96, + 0x86, 0x6d, 0x48, 0x99, 0x50, 0xc1, 0xa8, 0x93, 0x8a, 0x92, 0x7d, 0x8f, + 0x74, 0x87, 0x91, 0x71, 0x8c, 0x87, 0x90, 0x80, 0x80, 0x82, 0x7b, 0x85, + 0x81, 0x7f, 0xa7, 0x6a, 0x78, 0x4e, 0x90, 0x85, 0x9f, 0x93, 0x91, 0x91, + 0xa5, 0x6e, 0x9d, 0xa7, 0x9e, 0x7f, 0x9a, 0x66, 0xbe, 0x6f, 0x82, 0x81, + 0x85, 0x86, 0x89, 0x6c, 0x88, 0x92, 0x6d, 0x6a, 0x8c, 0x95, 0x68, 0x70, + 0x91, 0x9b, 0x76, 0x59, 0x87, 0x93, 0x6f, 0x79, 0x7a, 0x99, 0x7d, 0x76, + 0xa3, 0x9c, 0x69, 0x75, 0x8f, 0x8e, 0x7e, 0x7a, 0x80, 0x8b, 0x76, 0x82, + 0x70, 0x71, 0x77, 0x7a, 0x88, 0xa1, 0x79, 0x75, 0x9e, 0x7e, 0x6d, 0x6f, + 0xa5, 0x84, 0xb1, 0x77, 0xad, 0x94, 0x98, 0x90, 0xa7, 0x5c, 0xb6, 0x84, + 0x99, 0x91, 0x71, 0x7b, 0x6d, 0x54, 0xd2, 0x84, 0x5d, 0x75, 0xb4, 0x7e, + 0x7d, 0x53, 0xc5, 0x98, 0x70, 0xaa, 0x9e, 0x81, 0x7d, 0x68, 0xa7, 0x8d, + 0x63, 0xab, 0x9b, 0x96, 0x7e, 0x6b, 0xa3, 0x9e, 0x6d, 0x98, 0xaf, 0x9b, + 0x78, 0x74, 0xae, 0xc7, 0x70, 0x98, 0xd4, 0x9a, 0x6e, 0x75, 0xa2, 0xcd, + 0x42, 0xb0, 0xc9, 0x89, 0x88, 0x77, 0x6a, 0xa4, 0x66, 0xb5, 0xbc, 0x8a, + 0x96, 0x87, 0x5e, 0xa5, 0x87, 0x95, 0x91, 0x5d, 0x85, 0x91, 0xaa, 0x8f, + 0x99, 0x78, 0x79, 0x74, 0x7f, 0x81, 0xa1, 0x74, 0x77, 0x64, 0x6c, 0x94, + 0xa0, 0x8b, 0x9b, 0x8e, 0xac, 0x6a, 0x98, 0x9c, 0x7a, 0x9f, 0xab, 0x7e, + 0xa3, 0x8b, 0x68, 0x7f, 0x84, 0x9f, 0x93, 0x77, 0x90, 0x98, 0x8f, 0x87, + 0x81, 0x8e, 0x76, 0x95, 0x66, 0x78, 0x85, 0x79, 0x95, 0x89, 0x64, 0x8e, + 0x8a, 0x87, 0x6f, 0x65, 0xa4, 0x98, 0x7a, 0x83, 0x85, 0x7e, 0x6b, 0xaa, + 0x81, 0x94, 0x7c, 0x6e, 0x78, 0x85, 0x87, 0x6d, 0x7a, 0x92, 0x67, 0x7a, + 0x8d, 0x95, 0x77, 0x7f, 0x9f, 0x71, 0xb1, 0xa1, 0xb2, 0x91, 0x7f, 0xb0, + 0xac, 0x5c, 0xaf, 0x6a, 0xae, 0x98, 0x63, 0x7e, 0x67, 0x6f, 0xc4, 0x8a, + 0x75, 0x61, 0xac, 0x73, 0x86, 0x54, 0xc3, 0xa8, 0x5d, 0xa9, 0xb4, 0x9b, + 0x80, 0x6d, 0xa1, 0x8d, 0x64, 0xaa, 0x86, 0x96, 0x86, 0x6c, 0x9b, 0x8b, + 0x73, 0x9f, 0x9a, 0x87, 0x64, 0x6c, 0xad, 0xa6, 0x64, 0x8a, 0xbe, 0x88, + 0x67, 0x67, 0xaf, 0xb0, 0x71, 0xae, 0xde, 0x95, 0x9f, 0x7c, 0x7d, 0xa1, + 0x79, 0xb8, 0xaa, 0x9c, 0x84, 0x91, 0x6b, 0xac, 0x74, 0xa1, 0xad, 0x74, + 0x88, 0x93, 0x94, 0x72, 0x97, 0x7a, 0x78, 0x86, 0x76, 0x93, 0xb1, 0x6f, + 0x91, 0x44, 0x96, 0x8e, 0x8e, 0xa5, 0x9a, 0x70, 0x99, 0x79, 0x84, 0x82, + 0x7f, 0x78, 0xac, 0x6f, 0x9c, 0x80, 0x7d, 0x87, 0x7f, 0x9d, 0x6a, 0x71, + 0x7c, 0x92, 0x78, 0x7a, 0x93, 0x90, 0x55, 0x83, 0x7a, 0x8a, 0x9a, 0x65, + 0x86, 0x9b, 0x7c, 0x6b, 0xa3, 0x85, 0x86, 0x71, 0xab, 0x9a, 0x86, 0x90, + 0x86, 0x88, 0x88, 0x88, 0x99, 0x98, 0x77, 0x86, 0x88, 0x90, 0x79, 0x7c, + 0x6e, 0x9f, 0x76, 0x70, 0x84, 0x67, 0x7e, 0x8b, 0xa5, 0x68, 0xa7, 0x9d, + 0xb5, 0x9b, 0x8b, 0x8a, 0xc0, 0x60, 0x9e, 0x83, 0xb0, 0xb7, 0x65, 0x7f, + 0x7a, 0x7e, 0xc3, 0x7b, 0x74, 0x8f, 0xa4, 0x68, 0x5f, 0x47, 0xbb, 0xa4, + 0x74, 0x95, 0xab, 0x80, 0x70, 0x5c, 0x9a, 0x8a, 0x7d, 0xa5, 0x90, 0x7d, + 0x86, 0x68, 0xb1, 0x73, 0x6d, 0xad, 0x93, 0x8d, 0x7b, 0x64, 0xbd, 0xae, + 0x7a, 0x98, 0xcb, 0x97, 0x83, 0x67, 0xab, 0xb0, 0x61, 0xa7, 0xcd, 0x7e, + 0x87, 0x78, 0x76, 0x95, 0x6a, 0xba, 0xa9, 0x84, 0x8f, 0x95, 0x7c, 0x8b, + 0x90, 0x89, 0x8b, 0x81, 0x87, 0x8b, 0x76, 0x73, 0x6f, 0x61, 0x94, 0x73, + 0x83, 0x97, 0xb3, 0x6b, 0x9c, 0x55, 0x7f, 0x96, 0x9a, 0x92, 0x85, 0x52, + 0xc6, 0x73, 0x88, 0x9c, 0x7c, 0x86, 0x98, 0x6d, 0x99, 0x87, 0x80, 0x7c, + 0x7d, 0x98, 0x74, 0x7c, 0x89, 0x8a, 0x7d, 0x7b, 0x83, 0x90, 0x7d, 0x81, + 0x7a, 0xa0, 0x86, 0x5f, 0x74, 0x8e, 0x68, 0x7b, 0x6c, 0x86, 0x90, 0x84, + 0x7e, 0xae, 0x73, 0x6f, 0x8d, 0x81, 0x7c, 0x93, 0xa0, 0xb3, 0x6b, 0x9a, + 0x88, 0xab, 0x8a, 0x94, 0x9c, 0x87, 0x9c, 0x75, 0x7d, 0x8f, 0x7c, 0x7f, + 0x9b, 0x69, 0xa8, 0x99, 0x9d, 0x89, 0x8f, 0x72, 0xba, 0x61, 0xac, 0x91, + 0xb5, 0xa7, 0x84, 0x99, 0x71, 0x7e, 0xd0, 0x7c, 0x6d, 0x66, 0xb6, 0x72, + 0x79, 0x61, 0xb6, 0xab, 0x69, 0xa0, 0xaa, 0x7d, 0x74, 0x61, 0x95, 0xa5, + 0x71, 0xb0, 0x93, 0x95, 0x86, 0x7d, 0x9f, 0x7e, 0x6c, 0x97, 0x85, 0x87, + 0x72, 0x7b, 0xb4, 0xad, 0x84, 0x7b, 0xcd, 0xa9, 0x7e, 0x6d, 0xc8, 0xc7, + 0x7e, 0xb7, 0xcf, 0x98, 0x7b, 0x7c, 0x69, 0xaf, 0x64, 0xa6, 0xc1, 0x8e, + 0x8f, 0x9c, 0x7d, 0x93, 0x7a, 0x96, 0x8a, 0x65, 0x92, 0x95, 0x8d, 0x6f, + 0x9f, 0x7f, 0x65, 0x69, 0x7a, 0x92, 0x9f, 0x5c, 0x90, 0x4e, 0x69, 0x89, + 0x8f, 0x9c, 0xa8, 0x7a, 0xb6, 0x7d, 0x84, 0x97, 0x7f, 0x91, 0x8d, 0x71, + 0xae, 0x86, 0x80, 0x78, 0x81, 0x87, 0x6e, 0x88, 0x87, 0x7f, 0x8f, 0x9d, + 0x78, 0x91, 0x74, 0x91, 0x7f, 0x7a, 0x80, 0x63, 0x93, 0xa0, 0x7f, 0x6f, + 0xa3, 0x88, 0x76, 0x5c, 0x6e, 0xa1, 0x6e, 0x7f, 0x84, 0x8b, 0x87, 0x6d, + 0x87, 0x9f, 0x79, 0x7c, 0x83, 0x89, 0x7e, 0x86, 0xa0, 0x82, 0x80, 0x8e, + 0x8b, 0x6c, 0x6e, 0x69, 0x9f, 0x79, 0xaa, 0x6e, 0xa2, 0x8f, 0x9d, 0x87, + 0xb4, 0x5d, 0xba, 0x6c, 0xaf, 0xa0, 0x84, 0x87, 0x8c, 0x89, 0xcb, 0x6f, + 0x8e, 0x71, 0xae, 0x5d, 0x6c, 0x61, 0xb3, 0xaf, 0x7a, 0x94, 0xb1, 0x8a, + 0x80, 0x65, 0x8a, 0x9d, 0x61, 0xb6, 0x8b, 0x97, 0x8a, 0x73, 0xa8, 0x82, + 0x74, 0x8a, 0x9c, 0x73, 0x61, 0x69, 0xb8, 0x9f, 0x76, 0x90, 0xc5, 0xaa, + 0x6b, 0x5f, 0xb7, 0xce, 0x6d, 0xb7, 0xcc, 0x97, 0x7a, 0x81, 0x95, 0xbe, + 0x78, 0xb1, 0xb4, 0x97, 0x8e, 0x99, 0x70, 0xa2, 0x72, 0x8d, 0x8e, 0x7d, + 0x90, 0x9f, 0x7b, 0x63, 0x87, 0x89, 0x7a, 0x5f, 0x81, 0x97, 0x8d, 0x78, + 0x94, 0x64, 0x95, 0x9d, 0x90, 0x87, 0xb3, 0x6e, 0xc2, 0x80, 0x94, 0x86, + 0x87, 0x93, 0xb3, 0x57, 0xb8, 0x73, 0x8a, 0x81, 0x6f, 0x95, 0x89, 0x82, + 0x94, 0x7a, 0x8e, 0x97, 0x8a, 0x91, 0x7f, 0x77, 0x98, 0x72, 0x67, 0x5f, + 0x7b, 0x8d, 0x78, 0x74, 0x91, 0x82, 0x86, 0x5c, 0x88, 0xa3, 0x73, 0x6f, + 0x92, 0x78, 0x9c, 0x95, 0x99, 0x9d, 0x70, 0x89, 0x8f, 0xa7, 0x74, 0x89, + 0x77, 0x90, 0x72, 0x8d, 0x9c, 0x6f, 0x7a, 0x6c, 0x9f, 0x72, 0xad, 0x6c, + 0xa5, 0x7a, 0x9d, 0x78, 0xa4, 0x52, 0xbd, 0x94, 0xb5, 0x97, 0x75, 0x78, + 0x86, 0x72, 0xdf, 0x6f, 0x98, 0x81, 0xab, 0x5d, 0x62, 0x65, 0x9d, 0xbc, + 0x68, 0x8a, 0xc1, 0x7e, 0x67, 0x7f, 0x88, 0x95, 0x7f, 0xbd, 0x9c, 0x77, + 0x7d, 0x7e, 0x96, 0x7c, 0x7f, 0xa1, 0xa4, 0x90, 0x7c, 0x74, 0xc0, 0xac, + 0x7d, 0xa1, 0xdb, 0x85, 0x85, 0x51, 0xbc, 0xb1, 0x6c, 0xcb, 0xd1, 0xa7, + 0x76, 0x70, 0x7d, 0xba, 0x88, 0xb6, 0xaf, 0xa2, 0x9d, 0x9b, 0x71, 0x96, + 0x80, 0x89, 0xa3, 0x86, 0x89, 0x8f, 0x76, 0x77, 0xa9, 0x82, 0x8f, 0x69, + 0x7f, 0x9d, 0xac, 0x80, 0x98, 0x6c, 0x70, 0x72, 0x81, 0x8b, 0xaf, 0x80, + 0xb1, 0x6f, 0x7c, 0x90, 0x91, 0x82, 0xa5, 0x67, 0x9c, 0x76, 0x8c, 0x6b, + 0x9c, 0x9b, 0x87, 0x8c, 0x8e, 0x8b, 0xb0, 0x9d, 0x89, 0x8f, 0x76, 0x87, + 0x9b, 0x90, 0x8e, 0x74, 0x73, 0x91, 0x85, 0x80, 0x81, 0x72, 0x99, 0x84, + 0x87, 0x95, 0x84, 0x8c, 0x8a, 0x6e, 0x8c, 0x82, 0xad, 0x9d, 0x80, 0x7f, + 0x96, 0x9c, 0x7f, 0x67, 0xb0, 0x98, 0x69, 0x84, 0x94, 0xa9, 0x7e, 0x83, + 0x9d, 0x62, 0x92, 0x6e, 0x95, 0x88, 0xa4, 0x90, 0x97, 0x4d, 0xae, 0x89, + 0xb6, 0xa1, 0x88, 0x9f, 0x7a, 0x70, 0xc2, 0x71, 0x7f, 0x83, 0x90, 0x83, + 0x5e, 0x50, 0xa9, 0x9f, 0x73, 0x8c, 0xb2, 0x80, 0x79, 0x65, 0x7c, 0x90, + 0x6d, 0x9a, 0x91, 0x8d, 0x6f, 0x65, 0x97, 0x87, 0x82, 0xa0, 0xa4, 0x8c, + 0x68, 0x76, 0xa8, 0xa2, 0x7f, 0xa4, 0xcd, 0x91, 0x70, 0x54, 0x95, 0xc6, + 0x6e, 0x9c, 0xe2, 0xa1, 0x86, 0x82, 0x73, 0xbc, 0x89, 0xaa, 0xb2, 0x7d, + 0x82, 0x84, 0x8b, 0x9e, 0x84, 0x94, 0xa0, 0x7a, 0x98, 0x9d, 0x99, 0x7b, + 0x7b, 0x89, 0x8f, 0x66, 0x89, 0x9b, 0xa7, 0x8b, 0x9b, 0x62, 0x9b, 0x78, + 0x8b, 0x95, 0xbd, 0x7a, 0x9e, 0x61, 0x80, 0x84, 0x89, 0x8e, 0xb4, 0x7b, + 0xb8, 0x70, 0x75, 0x8e, 0x7b, 0x9c, 0x9e, 0x9f, 0x89, 0x86, 0x9b, 0x7a, + 0x7b, 0x95, 0x83, 0x95, 0x80, 0x94, 0x85, 0x65, 0x8c, 0x81, 0x67, 0x77, + 0x94, 0x8a, 0x92, 0x74, 0x72, 0x90, 0x6b, 0x74, 0x7e, 0x75, 0x71, 0x84, + 0x9e, 0xa6, 0x64, 0x80, 0x8d, 0x7a, 0x8c, 0x82, 0x98, 0x96, 0x64, 0x7d, + 0x8b, 0x82, 0x6a, 0x7f, 0x97, 0x4e, 0x91, 0x74, 0x94, 0x99, 0x6d, 0x6a, + 0xb3, 0x5a, 0xb8, 0x64, 0xa3, 0x95, 0x5d, 0x95, 0x90, 0x87, 0xcc, 0x72, + 0x85, 0x85, 0x8f, 0x55, 0x6f, 0x65, 0x84, 0xb6, 0x7b, 0x77, 0xce, 0x79, + 0x82, 0x59, 0x8a, 0xa2, 0x68, 0x9b, 0xa3, 0x81, 0x9c, 0x7a, 0x97, 0x87, + 0x6b, 0x8c, 0x9c, 0xaa, 0x5c, 0x69, 0xb8, 0xb7, 0x7c, 0xa0, 0xb5, 0x92, + 0x8d, 0x67, 0x96, 0xd2, 0x77, 0xa6, 0xd9, 0xad, 0xaa, 0x79, 0x90, 0xc9, + 0x81, 0xbf, 0xd0, 0x8d, 0x9d, 0x88, 0x9c, 0x91, 0x90, 0x94, 0x89, 0x8a, + 0x91, 0x9b, 0x89, 0x79, 0x92, 0x80, 0x8f, 0x7b, 0x7e, 0x8b, 0xb1, 0x85, + 0xa4, 0x5a, 0xb4, 0x7a, 0xa7, 0x8c, 0xa4, 0x75, 0xb9, 0x66, 0x93, 0x86, + 0x8a, 0x87, 0xad, 0x64, 0xa2, 0x7e, 0x99, 0x9f, 0x81, 0xa2, 0x9b, 0x88, + 0x9e, 0xa2, 0xb9, 0x8a, 0x78, 0x84, 0x91, 0x8e, 0x8b, 0x90, 0x83, 0x80, + 0x64, 0x93, 0x77, 0x89, 0x81, 0x86, 0x96, 0x7a, 0x81, 0xab, 0x6d, 0x73, + 0x7d, 0x7e, 0xaa, 0x85, 0x95, 0xac, 0x8b, 0x89, 0x8b, 0x77, 0xa3, 0x8b, + 0xa3, 0xa0, 0x87, 0x86, 0x7a, 0x74, 0x6f, 0x7c, 0x90, 0x58, 0xa2, 0x64, + 0x94, 0x8b, 0xa0, 0x88, 0xab, 0x53, 0xce, 0x67, 0xb7, 0x7f, 0x8d, 0x69, + 0x84, 0x74, 0xaf, 0x72, 0xab, 0x70, 0x8f, 0x6e, 0x5d, 0x61, 0x96, 0xa1, + 0x7b, 0x6f, 0xa2, 0x75, 0x8f, 0x5d, 0x93, 0x72, 0x82, 0x97, 0x76, 0x65, + 0x7e, 0x96, 0xb3, 0x8b, 0x8d, 0x89, 0x8f, 0x7b, 0x6f, 0x71, 0xa1, 0x9e, + 0x91, 0x7c, 0xc9, 0x9f, 0x7c, 0x71, 0xa1, 0xba, 0x77, 0xa5, 0xd4, 0xa6, + 0xa0, 0x82, 0x7b, 0x95, 0x9d, 0xb7, 0xaa, 0x8d, 0x71, 0x87, 0x94, 0x7e, + 0x88, 0x7f, 0x8b, 0x6e, 0x93, 0x9f, 0x82, 0x88, 0x94, 0x8a, 0x97, 0x7f, + 0x7d, 0x8c, 0xa0, 0x84, 0xb4, 0x7c, 0x8c, 0x7f, 0x71, 0x8c, 0x8e, 0x7f, + 0xc6, 0x64, 0x81, 0x8d, 0x89, 0x8d, 0xc4, 0x77, 0xaf, 0x75, 0x92, 0x7f, + 0x84, 0xa1, 0x99, 0x94, 0x9e, 0x82, 0x7a, 0x98, 0x7e, 0x8e, 0x93, 0x8c, + 0x6b, 0x93, 0x84, 0xaa, 0x7f, 0x8f, 0x6b, 0x94, 0xa3, 0x8a, 0x78, 0x82, + 0x60, 0x92, 0x8b, 0x8d, 0x75, 0x8c, 0x8e, 0x6e, 0x7e, 0x9d, 0x6d, 0x8e, + 0x79, 0x8d, 0x80, 0x89, 0xaa, 0x99, 0x7e, 0xa3, 0x83, 0x95, 0x83, 0x85, + 0x9c, 0x60, 0x99, 0x78, 0x93, 0x8b, 0x80, 0x82, 0x9d, 0x6b, 0xc2, 0x54, + 0xb9, 0x7a, 0x83, 0x98, 0x88, 0x65, 0xcb, 0x52, 0xa7, 0x8d, 0x7f, 0x81, + 0x6b, 0x6d, 0x9e, 0x92, 0x85, 0x82, 0x9f, 0x67, 0x6f, 0x74, 0xaa, 0x75, + 0x99, 0x9f, 0x8a, 0x8b, 0x88, 0x82, 0xb8, 0x6b, 0x85, 0x99, 0x93, 0x90, + 0x8d, 0x7a, 0xaa, 0x9d, 0x86, 0x7f, 0xbd, 0x91, 0x67, 0x65, 0x8c, 0xb3, + 0x87, 0x94, 0xa3, 0x9a, 0x7e, 0x73, 0x83, 0xaa, 0x7a, 0xba, 0xaa, 0x9e, + 0x9e, 0x86, 0x9a, 0x63, 0x9c, 0x98, 0x5e, 0xa0, 0x9c, 0x9e, 0x8b, 0x85, + 0xa2, 0x74, 0x80, 0x8d, 0x7e, 0x89, 0xc0, 0x75, 0xa5, 0x3f, 0x97, 0xa2, + 0x8c, 0x8c, 0x9d, 0x88, 0xa4, 0x5e, 0x75, 0x5f, 0x87, 0x82, 0xbc, 0x72, + 0xa3, 0x77, 0x83, 0x79, 0x82, 0x95, 0x8d, 0x77, 0x73, 0x81, 0x9d, 0x9b, + 0x6c, 0x87, 0x93, 0x96, 0x83, 0x86, 0x8b, 0x89, 0x72, 0x7d, 0x96, 0x78, + 0x67, 0xa2, 0x8d, 0x81, 0x6a, 0x98, 0x75, 0x80, 0x8a, 0x80, 0x9e, 0x82, + 0x76, 0x9b, 0x6c, 0x94, 0x7a, 0x96, 0x74, 0x92, 0x78, 0x91, 0x7a, 0x7c, + 0x9a, 0x98, 0x70, 0x5d, 0x9c, 0x4b, 0x70, 0x7d, 0xa9, 0x9b, 0x70, 0x96, + 0xad, 0x59, 0xc4, 0x63, 0xbc, 0x8f, 0x5c, 0x86, 0x8e, 0x97, 0xa0, 0x7c, + 0xa6, 0x77, 0xaa, 0x93, 0x68, 0x66, 0x93, 0x91, 0x7b, 0x7e, 0xa2, 0x7a, + 0x98, 0x77, 0x97, 0x59, 0x84, 0x76, 0x9c, 0x7b, 0x8b, 0x76, 0x88, 0x7a, + 0x8c, 0x7b, 0xa4, 0xae, 0x6e, 0x7d, 0xb3, 0x99, 0x8d, 0x68, 0x9e, 0x7e, + 0x77, 0x59, 0x80, 0xbe, 0x80, 0x83, 0xd9, 0x9f, 0x7d, 0x60, 0x8b, 0x98, + 0x7f, 0x9e, 0xa3, 0x8d, 0x7d, 0x81, 0x9e, 0x78, 0x99, 0x94, 0x70, 0x80, + 0x9b, 0x89, 0x8c, 0x6d, 0x9c, 0x95, 0x76, 0x7c, 0x83, 0x87, 0x97, 0x93, + 0x89, 0x6d, 0x77, 0x7e, 0x7e, 0x87, 0x8e, 0x7e, 0x94, 0x61, 0x94, 0xa2, + 0x94, 0x91, 0xa1, 0x64, 0xc1, 0x78, 0x79, 0xaf, 0x67, 0x7a, 0x9b, 0xa1, + 0x95, 0x8e, 0x97, 0x84, 0x7b, 0x85, 0x80, 0xa1, 0x6f, 0x87, 0x79, 0x83, + 0x73, 0x9d, 0x81, 0x64, 0x7a, 0x7f, 0x8f, 0x91, 0x73, 0x97, 0x74, 0x8b, + 0x7e, 0x88, 0x7f, 0x7e, 0x6e, 0xa1, 0x85, 0x8f, 0x77, 0x93, 0x7a, 0x6f, + 0x7b, 0x91, 0x67, 0x73, 0x8b, 0x97, 0x6d, 0x87, 0x84, 0xf8, 0xff, 0xff, + 0x88, 0xf8, 0xff, 0xff, 0xe6, 0xf8, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x80, 0x02, 0x00, 0x00, 0x73, 0x84, 0xbb, 0xa4, 0xa5, 0x44, 0x5c, 0xb1, + 0x8e, 0x50, 0x82, 0x8b, 0x81, 0x86, 0x48, 0x80, 0xa9, 0x61, 0xa3, 0xa8, + 0xca, 0x5a, 0x9d, 0x8a, 0x89, 0x7c, 0x65, 0x91, 0x5e, 0x70, 0x84, 0x71, + 0xbc, 0x36, 0x8e, 0x8b, 0xa6, 0x63, 0xb7, 0x75, 0x92, 0x59, 0x60, 0x7e, + 0x33, 0x8f, 0x90, 0x7a, 0xa9, 0x27, 0x72, 0x80, 0x62, 0x95, 0x93, 0x7b, + 0x60, 0x46, 0x40, 0x55, 0x01, 0x9e, 0x8a, 0x6b, 0x58, 0x8a, 0xa6, 0xb7, + 0x91, 0x39, 0x72, 0xb4, 0x6e, 0x67, 0x83, 0x91, 0x82, 0x7b, 0x64, 0x7a, + 0x87, 0x6e, 0xb0, 0xa0, 0xd3, 0x53, 0xb7, 0x93, 0x76, 0xa6, 0x68, 0x8a, + 0x74, 0x6a, 0x96, 0x6e, 0xb3, 0x53, 0xaa, 0x89, 0xf1, 0x76, 0xb8, 0x75, + 0x8b, 0x66, 0x5f, 0x6e, 0x52, 0x92, 0x6f, 0x82, 0xbe, 0x45, 0x8d, 0x69, + 0x98, 0x98, 0x80, 0x87, 0x73, 0x7d, 0x4d, 0x42, 0x1f, 0xa5, 0x6a, 0x73, + 0x47, 0x87, 0x8a, 0xd1, 0x75, 0x30, 0x91, 0xae, 0x60, 0x82, 0x7a, 0x94, + 0x75, 0x71, 0x6a, 0x7c, 0x74, 0x7a, 0xac, 0xa2, 0xb6, 0x51, 0xc6, 0x97, + 0x63, 0xa0, 0x67, 0x7f, 0x80, 0x69, 0x88, 0x6b, 0xa5, 0x5e, 0xc2, 0x72, + 0xf4, 0x6e, 0xaf, 0x76, 0x7f, 0x7c, 0x55, 0x68, 0x67, 0x97, 0x61, 0x7b, + 0xbe, 0x5e, 0xab, 0x58, 0xca, 0xa2, 0x77, 0x7a, 0x8f, 0x6e, 0x54, 0x33, + 0x4d, 0xa7, 0x5d, 0x66, 0x47, 0x92, 0x6f, 0xd6, 0x5c, 0x25, 0xa9, 0xbc, + 0x5c, 0xb8, 0x64, 0x9b, 0x58, 0x6e, 0x77, 0x76, 0x6a, 0x94, 0xb2, 0xac, + 0x9a, 0x51, 0xd0, 0x94, 0x62, 0xcc, 0x5a, 0x7f, 0x74, 0x6e, 0x7d, 0x71, + 0x9b, 0x69, 0xd3, 0x64, 0xef, 0x76, 0xaa, 0x75, 0x89, 0x84, 0x50, 0x76, + 0x72, 0x97, 0x5f, 0x77, 0xc5, 0x66, 0xce, 0x3a, 0xe5, 0xad, 0x5a, 0x81, + 0x9e, 0x8e, 0x60, 0x3d, 0x6d, 0xa9, 0x46, 0x6b, 0x44, 0x89, 0x4d, 0xd8, + 0x4c, 0x28, 0xb1, 0xb7, 0x60, 0xc7, 0x57, 0xb5, 0x50, 0x68, 0x88, 0x7c, + 0x60, 0x98, 0xac, 0x9a, 0x7f, 0x51, 0xce, 0x8a, 0x5e, 0xd8, 0x51, 0x7d, + 0x68, 0x6e, 0x7f, 0x6e, 0x90, 0x7b, 0xdf, 0x60, 0xda, 0x77, 0x91, 0x6f, + 0x85, 0xa0, 0x58, 0x73, 0x70, 0x93, 0x51, 0x7d, 0xb9, 0x70, 0xf5, 0x31, + 0xe9, 0xa3, 0x47, 0x76, 0xa7, 0x9b, 0x72, 0x3d, 0x90, 0xb2, 0x57, 0x64, + 0x5b, 0x6f, 0x2b, 0xcf, 0x52, 0x28, 0xc1, 0xa7, 0x6a, 0x78, 0x51, 0xad, + 0x49, 0x70, 0x90, 0x81, 0x5c, 0x7e, 0x9e, 0x99, 0x77, 0x50, 0xc0, 0x94, + 0x63, 0xb7, 0x4d, 0x71, 0x58, 0x66, 0x76, 0x6d, 0x78, 0x6a, 0xe1, 0x40, + 0xc7, 0x73, 0x7f, 0x65, 0x7c, 0x7f, 0x4d, 0x80, 0x64, 0x95, 0x57, 0x81, + 0xb1, 0x5e, 0xff, 0x26, 0xd6, 0xa2, 0x3a, 0x73, 0xa7, 0x81, 0x76, 0x5d, + 0x92, 0xb1, 0x58, 0x48, 0x4e, 0x5e, 0x1a, 0xc8, 0x58, 0x2c, 0xb6, 0xa7, + 0x67, 0x89, 0x5e, 0xa0, 0x4f, 0x78, 0x93, 0x8b, 0x57, 0x7b, 0x95, 0x78, + 0x6e, 0x46, 0xb2, 0x98, 0x55, 0xd3, 0x5e, 0x66, 0x56, 0x68, 0x74, 0x7e, + 0x72, 0x74, 0xdd, 0x36, 0xa6, 0x64, 0x65, 0x6b, 0x81, 0x98, 0x56, 0x76, + 0x65, 0x93, 0x58, 0x7d, 0x9b, 0x82, 0xef, 0x44, 0xbf, 0xa4, 0x3d, 0x57, + 0xa0, 0xa7, 0x7a, 0x74, 0x9f, 0xa8, 0x70, 0x52, 0x55, 0x5f, 0x1a, 0x94, + 0x64, 0x37, 0xa7, 0xa6, 0x80, 0x7d, 0x6e, 0x99, 0x5d, 0x81, 0x8a, 0x99, + 0x5c, 0x76, 0x8f, 0x44, 0x68, 0x50, 0x94, 0x97, 0x63, 0xb6, 0x73, 0x56, + 0x5b, 0x70, 0x66, 0x8b, 0x72, 0x78, 0xcc, 0x31, 0x8b, 0x68, 0x4a, 0x74, + 0x7d, 0x99, 0x54, 0x91, 0x6a, 0x90, 0x5d, 0x80, 0x8c, 0x82, 0xcd, 0x4f, + 0xb0, 0x96, 0x63, 0x56, 0x97, 0xb3, 0x7e, 0x97, 0xa4, 0x9d, 0x7a, 0x5d, + 0x49, 0x36, 0x18, 0x64, 0x60, 0x43, 0x89, 0xa2, 0x6a, 0x49, 0x7f, 0x58, + 0x6a, 0x83, 0x77, 0x9d, 0x70, 0x3b, 0x83, 0x21, 0x59, 0x52, 0x6d, 0x95, + 0x48, 0xa8, 0x8a, 0x42, 0x50, 0x6d, 0x44, 0x95, 0x69, 0x50, 0xc1, 0x4b, + 0x7c, 0x59, 0x42, 0x78, 0x77, 0x7f, 0x5b, 0x98, 0x67, 0x89, 0x55, 0x8b, + 0x82, 0x47, 0xb7, 0x64, 0x9d, 0x83, 0x5c, 0x53, 0x89, 0x90, 0x79, 0xb2, + 0x90, 0x98, 0x85, 0x5a, 0x4d, 0x2b, 0x19, 0x1e, 0x52, 0x50, 0x57, 0x8b, + 0x73, 0x3a, 0x88, 0x1e, 0x65, 0x80, 0x4d, 0x9b, 0x6c, 0x3c, 0x86, 0x26, + 0x5b, 0x56, 0x36, 0x98, 0x49, 0x87, 0x9f, 0x2a, 0x40, 0x61, 0x27, 0x9d, + 0x63, 0x40, 0xa8, 0x46, 0x6b, 0x52, 0x52, 0x7f, 0x67, 0x6a, 0x58, 0xa1, + 0x5d, 0x6d, 0x5f, 0x9a, 0x72, 0x3a, 0x99, 0x63, 0x8c, 0x80, 0x68, 0x58, + 0x72, 0x6a, 0x7c, 0xbb, 0x7e, 0x78, 0x94, 0x60, 0x72, 0xfb, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x96, 0xfe, 0xff, 0xff, + 0x8f, 0x00, 0x00, 0x00, 0x8f, 0xfc, 0xff, 0xff, 0xb4, 0xfe, 0xff, 0xff, + 0xc1, 0xfd, 0xff, 0xff, 0x59, 0xff, 0xff, 0xff, 0xbc, 0xfe, 0xff, 0xff, + 0x09, 0xff, 0xff, 0xff, 0x9e, 0xfb, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0xe9, 0x03, 0x00, 0x00, 0x2b, 0xfd, 0xff, 0xff, + 0x3b, 0xfd, 0xff, 0xff, 0x91, 0x01, 0x00, 0x00, 0x60, 0xfb, 0xff, 0xff, + 0x04, 0xfd, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0xf0, 0x03, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x18, 0x03, 0x00, 0x00, 0x78, 0x03, 0x00, 0x00, + 0x88, 0x01, 0x00, 0x00, 0xf8, 0x01, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x64, 0x02, 0x00, 0x00, 0xe0, 0x00, 0x00, 0x00, + 0xb2, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x24, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x5f, 0x73, + 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x84, 0xfd, 0xff, 0xff, + 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3b, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x3f, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x22, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x1c, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x61, 0x64, 0x64, 0x5f, + 0x31, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xec, 0xfd, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x7d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x28, 0x17, 0xb1, 0x3d, + 0x01, 0x00, 0x00, 0x00, 0x84, 0xdb, 0x33, 0x41, 0x01, 0x00, 0x00, 0x00, + 0x9d, 0xf0, 0x2c, 0xc1, 0x8e, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x48, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, + 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x5f, 0x31, 0x2f, + 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, + 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56, 0x61, 0x72, 0x73, 0x2f, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa0, 0x0f, 0x00, 0x00, + 0x84, 0xfe, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x89, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xac, 0x5f, 0xf6, 0x39, 0x01, 0x00, 0x00, 0x00, 0x1d, 0xaf, 0x62, 0x3d, + 0x01, 0x00, 0x00, 0x00, 0x5e, 0x1b, 0x83, 0xbd, 0x22, 0xfe, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x03, 0x1c, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x52, 0x65, 0x6c, 0x75, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0xf4, 0xfe, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x47, 0x0f, 0x72, 0x3d, + 0x01, 0x00, 0x00, 0x00, 0x38, 0x1d, 0x71, 0x41, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x96, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x20, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68, + 0x61, 0x70, 0x65, 0x5f, 0x31, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x6c, 0xff, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xc6, 0xd0, 0xd0, 0x3d, + 0x01, 0x00, 0x00, 0x00, 0xf5, 0xff, 0xcf, 0x41, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x3c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, + 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x2f, 0x46, 0x61, + 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, + 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56, 0x61, 0x72, 0x73, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xf7, 0x5e, 0x6c, 0x3a, + 0x01, 0x00, 0x00, 0x00, 0x30, 0x42, 0xec, 0x3d, 0x01, 0x00, 0x00, 0x00, + 0x42, 0xca, 0xe8, 0xbd, 0xaa, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, + 0x20, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x43, 0x6f, 0x6e, 0x76, + 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x9c, 0xff, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xec, 0xcd, 0xc0, 0x38, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00, + 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, 0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x25, 0xf5, 0xe8, 0x37, 0x03, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, + 0x5c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc0, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x09, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, + 0x14, 0x00, 0x1c, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x18, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, 0x1c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, + 0x00, 0x19, 0x06, 0x00, 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x00, 0x09, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x04}; +const int g_tiny_conv_micro_features_model_data_len = 18208; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h new file mode 100644 index 0000000000000000000000000000000000000000..22c0a970b774299aea629ce034b9dd2e4c04e1ca --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a standard TensorFlow Lite model file that has been converted into a +// C data array, so it can be easily compiled into a binary for devices that +// don't have a file system. It was created using the command: +// xxd -i tiny_conv.tflite > tiny_conv_simple_features_model_data.cc + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_TINY_CONV_MICRO_FEATURES_MODEL_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_TINY_CONV_MICRO_FEATURES_MODEL_DATA_H_ + +extern const unsigned char g_tiny_conv_micro_features_model_data[]; +extern const int g_tiny_conv_micro_features_model_data_len; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_TINY_CONV_MICRO_FEATURES_MODEL_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.cc new file mode 100644 index 0000000000000000000000000000000000000000..51b7d8b35bcbce892fbd891415e3f09725c62faa --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h" + +#include + +int WindowProcessSamples(struct WindowState* state, const int16_t* samples, + size_t num_samples, size_t* num_samples_read) { + const int size = state->size; + + // Copy samples from the samples buffer over to our local input. + size_t max_samples_to_copy = state->size - state->input_used; + if (max_samples_to_copy > num_samples) { + max_samples_to_copy = num_samples; + } + memcpy(state->input + state->input_used, samples, + max_samples_to_copy * sizeof(*samples)); + *num_samples_read = max_samples_to_copy; + state->input_used += max_samples_to_copy; + + if (state->input_used < state->size) { + // We don't have enough samples to compute a window. + return 0; + } + + // Apply the window to the input. + const int16_t* coefficients = state->coefficients; + const int16_t* input = state->input; + int16_t* output = state->output; + int i; + int16_t max_abs_output_value = 0; + for (i = 0; i < size; ++i) { + int16_t new_value = ((static_cast(*input++)) * *coefficients++) >> + kFrontendWindowBits; + *output++ = new_value; + if (new_value < 0) { + new_value = -new_value; + } + if (new_value > max_abs_output_value) { + max_abs_output_value = new_value; + } + } + // Shuffle the input down by the step size, and update how much we have used. + memmove(state->input, state->input + state->step, + sizeof(*state->input) * (state->size - state->step)); + state->input_used -= state->step; + state->max_abs_output_value = max_abs_output_value; + + // Indicate that the output buffer is valid for the next stage. + return 1; +} + +void WindowReset(struct WindowState* state) { + memset(state->input, 0, state->size * sizeof(*state->input)); + memset(state->output, 0, state->size * sizeof(*state->output)); + state->input_used = 0; + state->max_abs_output_value = 0; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h new file mode 100644 index 0000000000000000000000000000000000000000..b32c059d81a8efe68c8a87a250fd733e1849479b --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_WINDOW_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_WINDOW_H_ + +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +#define kFrontendWindowBits 12 + +struct WindowState { + size_t size; + int16_t coefficients[kMaxAudioSampleSize]; + size_t step; + + int16_t input[kMaxAudioSampleSize]; + size_t input_used; + int16_t output[kMaxAudioSampleSize]; + int16_t max_abs_output_value; +}; + +// Applies a window to the samples coming in, stepping forward at the given +// rate. +int WindowProcessSamples(struct WindowState* state, const int16_t* samples, + size_t num_samples, size_t* num_samples_read); + +void WindowReset(struct WindowState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_WINDOW_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..310f84fc60b32e37f7e7d9d79bc2425ce7cddf8a --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +namespace { + +const int kSampleRate = 1000; +const int kWindowSamples = 25; +const int kStepSamples = 10; +const int16_t kFakeAudioData[] = { + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768}; + +// Test window function behaviors using default config values. +class WindowTestConfig { + public: + WindowTestConfig() { + config_.size_ms = 25; + config_.step_size_ms = 10; + } + + struct WindowConfig config_; +}; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(WindowState_CheckCoefficients) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + WindowTestConfig config; + struct WindowState state; + TF_LITE_MICRO_EXPECT(WindowPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + + const int16_t expected[] = {16, 144, 391, 743, 1176, 1664, 2177, + 2681, 3145, 3541, 3843, 4032, 4096, 4032, + 3843, 3541, 3145, 2681, 2177, 1664, 1176, + 743, 391, 144, 16}; + TF_LITE_MICRO_EXPECT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.coefficients[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(WindowState_CheckResidualInput) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + WindowTestConfig config; + struct WindowState state; + TF_LITE_MICRO_EXPECT(WindowPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + + int i; + for (i = kStepSamples; i < kWindowSamples; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.input[i - kStepSamples], kFakeAudioData[i]); + } +} + +TF_LITE_MICRO_TEST(WindowState_CheckOutputValues) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + WindowTestConfig config; + struct WindowState state; + TF_LITE_MICRO_EXPECT(WindowPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + + const int16_t expected[] = { + 0, 1151, 0, -5944, 0, 13311, 0, -21448, 0, 28327, 0, -32256, 0, 32255, + 0, -28328, 0, 21447, 0, -13312, 0, 5943, 0, -1152, 0}; + TF_LITE_MICRO_EXPECT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.output[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(WindowState_CheckMaxAbsValue) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + WindowTestConfig config; + struct WindowState state; + TF_LITE_MICRO_EXPECT(WindowPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + + TF_LITE_MICRO_EXPECT_EQ(state.max_abs_output_value, 32256); +} + +TF_LITE_MICRO_TEST(WindowState_CheckConsecutiveWindow) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + WindowTestConfig config; + struct WindowState state; + TF_LITE_MICRO_EXPECT(WindowPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read)); + + const int16_t expected[] = { + 0, -1152, 0, 5943, 0, -13312, 0, 21447, 0, -28328, 0, 32255, 0, -32256, + 0, 28327, 0, -21448, 0, 13311, 0, -5944, 0, 1151, 0}; + TF_LITE_MICRO_EXPECT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + int i; + for (i = 0; i < state.size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(state.output[i], expected[i]); + } +} + +TF_LITE_MICRO_TEST(WindowState_CheckNotEnoughSamples) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + WindowTestConfig config; + struct WindowState state; + TF_LITE_MICRO_EXPECT(WindowPopulateState(error_reporter, &config.config_, + &state, kSampleRate)); + size_t num_samples_read; + + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + TF_LITE_MICRO_EXPECT(WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read)); + TF_LITE_MICRO_EXPECT_EQ( + false, WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples + kStepSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - + kWindowSamples - kStepSamples, + &num_samples_read)); + + TF_LITE_MICRO_EXPECT_EQ( + state.input_used, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - 2 * kStepSamples); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..618973b39b2ebd2088b4c3756ea6ca1c1f7e8181 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h" + +#include +#include +#include + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/static_alloc.h" + +// Needed because some platforms don't have M_PI defined. +#define WINDOW_PI (3.14159265358979323846f) + +void WindowFillConfigWithDefaults(struct WindowConfig* config) { + config->size_ms = 25; + config->step_size_ms = 10; +} + +int WindowPopulateState(tflite::ErrorReporter* error_reporter, + const struct WindowConfig* config, + struct WindowState* state, int sample_rate) { + state->size = config->size_ms * sample_rate / 1000; + state->step = config->step_size_ms * sample_rate / 1000; + + STATIC_ALLOC_ENSURE_ARRAY_SIZE(state->coefficients, + (state->size * sizeof(*state->coefficients))); + + // Populate the window values. + const float arg = WINDOW_PI * 2.0 / (static_cast(state->size)); + int i; + for (i = 0; i < state->size; ++i) { + float float_value = 0.5 - (0.5 * cos(arg * (i + 0.5))); + // Scale it to fixed point and round it. + state->coefficients[i] = + floor(float_value * (1 << kFrontendWindowBits) + 0.5); + } + + state->input_used = 0; + STATIC_ALLOC_ENSURE_ARRAY_SIZE(state->input, + (state->size * sizeof(*state->input))); + + STATIC_ALLOC_ENSURE_ARRAY_SIZE(state->output, + (state->size * sizeof(*state->output))); + return 1; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h new file mode 100644 index 0000000000000000000000000000000000000000..d0c61c29dc9cd2a91f37ea89ace5e031235dd337 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window_util.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_WINDOW_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_WINDOW_UTIL_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/window.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +struct WindowConfig { + // length of window frame in milliseconds + size_t size_ms; + // length of step for next frame in milliseconds + size_t step_size_ms; +}; + +// Populates the WindowConfig with "sane" default values. +void WindowFillConfigWithDefaults(struct WindowConfig* config); + +// Allocates any buffers. +int WindowPopulateState(tflite::ErrorReporter* error_reporter, + const struct WindowConfig* config, + struct WindowState* state, int sample_rate); + +// Frees any allocated buffers. +void WindowFreeStateContents(struct WindowState* state); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_WINDOW_UTIL_H_ diff --git a/tensorflow/core/kernels/bitcast_op.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc similarity index 53% rename from tensorflow/core/kernels/bitcast_op.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc index 1f3659f303338efc69da56da0a67144e9400844b..48535d12d5db850cf0a497645f9e77d98fbcb8a1 100644 --- a/tensorflow/core/kernels/bitcast_op.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// See docs in ../ops/array_ops.cc. +// See the header for documentation on the meaning of this data. -#ifndef TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ -#define TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h" -#include // for memcpy - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" - -#endif // TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +const uint8_t g_yes_feature_data_slice[g_yes_feature_data_slice_size] = { + 214, 215, 236, 202, 235, 203, 225, 191, 203, 188, 199, 194, 212, 127, + 51, 0, 174, 188, 219, 196, 228, 221, 240, 207, 235, 220, 241, 219, + 237, 207, 212, 142, 95, 0, 139, 78, 162, 177, 197, 183, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..e73a13153b65be78a2a57edce0d09f48a8cb444f --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was extracted from the larger feature data held in +// no_micro_features_data.cc and consists of the 26th spectrogram slice of 40 +// values. This is the expected result of running the sample data in +// yes_30ms_sample_data.cc through through the preprocessing pipeline. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_FEATURE_DATA_SLICE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_FEATURE_DATA_SLICE_H_ + +#include + +constexpr int g_yes_feature_data_slice_size = 40; +extern const uint8_t g_yes_feature_data_slice[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_FEATURE_DATA_SLICE_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c2ee0995c00ee0da1337c86cf9aa18ba726bfe1 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc @@ -0,0 +1,165 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h" + +/* File automatically created by + * tensorflow/examples/speech_commands/wav_to_features.py \ + * --sample_rate=16000 \ + * --clip_duration_ms=1000 \ + * --window_size_ms=30 \ + * --window_stride_ms=20 \ + * --feature_bin_count=40 \ + * --quantize=1 \ + * --preprocess="micro" \ + * --input_wav="speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav" \ + * --output_c_file="yes_micro_features_data.cc" \ + */ + +const int g_yes_micro_f2e59fea_nohash_1_width = 40; +const int g_yes_micro_f2e59fea_nohash_1_height = 49; +const unsigned char g_yes_micro_f2e59fea_nohash_1_data[] = { + 244, 226, 245, 223, 234, 213, 228, 208, 194, 110, 95, 116, 102, 0, 137, + 161, 183, 173, 137, 116, 133, 157, 151, 156, 128, 110, 128, 0, 68, 78, + 78, 90, 68, 68, 78, 102, 95, 78, 95, 78, 210, 188, 209, 183, 204, + 188, 201, 191, 166, 119, 90, 107, 110, 107, 175, 157, 179, 168, 182, 145, + 152, 164, 171, 165, 136, 143, 122, 68, 0, 78, 90, 90, 110, 90, 102, + 99, 90, 68, 78, 68, 223, 186, 179, 123, 182, 110, 196, 171, 159, 110, + 102, 95, 90, 99, 160, 134, 125, 136, 153, 152, 164, 134, 164, 151, 141, + 136, 99, 90, 90, 90, 78, 78, 102, 119, 102, 90, 110, 90, 68, 51, + 177, 175, 211, 172, 183, 0, 95, 68, 129, 102, 68, 85, 114, 105, 110, + 85, 102, 95, 140, 51, 85, 51, 95, 90, 143, 116, 90, 78, 78, 51, + 107, 85, 68, 0, 68, 51, 90, 51, 68, 0, 164, 117, 193, 120, 156, + 0, 138, 51, 90, 0, 51, 0, 51, 85, 0, 0, 51, 0, 0, 0, + 0, 0, 114, 0, 85, 78, 90, 51, 0, 0, 51, 85, 99, 85, 107, + 68, 90, 85, 78, 0, 51, 0, 110, 0, 68, 0, 0, 0, 51, 0, + 51, 0, 0, 0, 68, 90, 107, 0, 68, 0, 0, 0, 68, 0, 51, + 68, 0, 78, 68, 0, 51, 0, 78, 68, 90, 68, 78, 51, 51, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 0, + 0, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 68, + 0, 0, 78, 0, 78, 0, 78, 0, 51, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 51, 0, 51, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 51, + 0, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, + 0, 0, 0, 0, 51, 78, 0, 0, 51, 51, 0, 0, 0, 78, 0, + 213, 170, 192, 180, 196, 188, 173, 131, 173, 116, 137, 105, 159, 127, 0, + 0, 0, 0, 127, 164, 165, 161, 170, 164, 185, 197, 195, 167, 134, 138, + 159, 134, 136, 105, 51, 0, 99, 0, 51, 0, 228, 215, 229, 218, 237, + 215, 228, 210, 237, 222, 239, 211, 208, 211, 234, 218, 220, 209, 225, 219, + 235, 222, 245, 225, 245, 224, 243, 223, 241, 218, 237, 224, 234, 213, 221, + 193, 197, 164, 157, 128, 227, 188, 232, 196, 220, 220, 240, 219, 234, 213, + 234, 211, 231, 218, 233, 213, 239, 215, 228, 207, 229, 206, 224, 208, 226, + 207, 232, 210, 225, 208, 230, 199, 227, 206, 210, 205, 218, 174, 178, 141, + 235, 208, 220, 206, 225, 203, 233, 203, 225, 167, 205, 199, 208, 190, 221, + 204, 223, 207, 225, 188, 225, 197, 215, 188, 199, 183, 225, 195, 224, 200, + 216, 178, 208, 188, 215, 202, 214, 183, 176, 140, 198, 150, 211, 194, 203, + 120, 175, 188, 204, 189, 219, 192, 223, 202, 216, 186, 203, 185, 210, 182, + 214, 183, 204, 170, 204, 125, 184, 187, 206, 185, 198, 182, 210, 161, 202, + 198, 218, 173, 145, 120, 188, 183, 205, 168, 200, 170, 210, 177, 187, 190, + 209, 193, 193, 166, 210, 162, 175, 119, 174, 147, 182, 161, 181, 134, 176, + 143, 187, 165, 186, 149, 185, 141, 192, 181, 202, 123, 170, 143, 144, 78, + 149, 0, 208, 182, 170, 78, 170, 0, 117, 51, 156, 99, 195, 170, 200, + 130, 152, 68, 175, 141, 173, 134, 194, 132, 189, 164, 198, 134, 173, 117, + 171, 149, 183, 181, 185, 99, 153, 117, 125, 0, 166, 0, 173, 117, 144, + 0, 117, 102, 188, 120, 193, 166, 197, 68, 163, 119, 169, 99, 134, 0, + 162, 0, 164, 68, 171, 116, 126, 0, 120, 68, 68, 0, 105, 0, 159, + 95, 150, 51, 90, 85, 0, 0, 131, 0, 105, 0, 145, 51, 170, 51, + 120, 0, 107, 0, 145, 85, 160, 0, 85, 0, 0, 51, 149, 0, 78, + 0, 0, 0, 0, 0, 0, 0, 90, 0, 112, 0, 78, 102, 122, 0, + 0, 0, 0, 0, 105, 0, 0, 0, 0, 0, 0, 0, 0, 0, 112, + 0, 164, 120, 143, 0, 0, 0, 0, 0, 51, 0, 90, 0, 78, 0, + 0, 0, 0, 0, 110, 0, 139, 0, 112, 51, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 102, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 107, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 78, 0, 51, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 127, 110, 133, 0, 167, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 132, 0, 190, + 194, 202, 0, 197, 187, 161, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 214, 213, 223, 203, 218, 189, 200, 122, 78, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 191, 210, 231, 197, 226, 217, 238, 216, 236, 207, + 199, 0, 0, 0, 0, 0, 107, 122, 155, 160, 214, 215, 236, 202, 235, + 203, 225, 191, 203, 188, 199, 194, 212, 127, 51, 0, 174, 188, 219, 196, + 228, 221, 240, 207, 235, 220, 241, 219, 237, 207, 212, 142, 95, 0, 139, + 78, 162, 177, 197, 183, 211, 199, 235, 208, 238, 215, 227, 207, 211, 201, + 224, 213, 226, 192, 213, 170, 223, 205, 234, 221, 245, 225, 242, 220, 245, + 221, 239, 221, 238, 213, 226, 180, 159, 112, 176, 159, 208, 202, 213, 191, + 205, 191, 225, 197, 238, 219, 224, 201, 227, 200, 221, 201, 225, 203, 212, + 195, 229, 210, 228, 210, 239, 216, 226, 212, 233, 205, 225, 200, 229, 207, + 222, 151, 147, 119, 179, 185, 230, 218, 223, 192, 202, 136, 205, 177, 223, + 204, 228, 215, 232, 209, 221, 189, 221, 205, 209, 200, 226, 209, 229, 205, + 235, 192, 209, 198, 228, 190, 206, 185, 207, 187, 214, 175, 177, 184, 220, + 195, 214, 207, 230, 184, 205, 159, 208, 184, 189, 169, 224, 213, 219, 199, + 229, 203, 216, 205, 222, 204, 224, 206, 231, 208, 231, 176, 197, 184, 216, + 193, 211, 139, 212, 195, 231, 164, 166, 195, 217, 182, 208, 190, 217, 179, + 205, 68, 182, 119, 195, 168, 182, 136, 204, 179, 193, 158, 182, 140, 188, + 154, 197, 169, 190, 99, 184, 0, 125, 0, 131, 0, 99, 68, 179, 85, + 190, 184, 213, 203, 223, 202, 212, 190, 209, 138, 178, 0, 159, 51, 128, + 51, 105, 0, 139, 51, 179, 125, 185, 114, 171, 128, 175, 132, 181, 174, + 155, 0, 0, 0, 90, 0, 125, 0, 176, 188, 227, 217, 244, 215, 234, + 221, 239, 192, 224, 210, 0, 0, 134, 0, 51, 0, 105, 0, 105, 0, + 143, 90, 192, 119, 175, 147, 141, 51, 184, 110, 85, 0, 0, 0, 0, + 0, 0, 0, 151, 139, 201, 203, 232, 203, 226, 208, 236, 206, 230, 212, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 169, 0, 119, + 0, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 68, 0, 0, 133, + 200, 180, 220, 197, 228, 201, 221, 184, 213, 193, 110, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 78, 0, 164, 0, 0, 0, 0, 0, 107, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 150, 164, 202, 182, 224, + 197, 211, 179, 212, 193, 134, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 85, 0, 150, 0, 85, 0, 95, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 102, 90, 193, 160, 203, 164, 200, 178, 205, 174, + 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 120, 114, 123, 0, 114, + 0, 145, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 102, 68, 199, 170, 195, 180, 208, 176, 200, 164, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 110, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 142, 102, 172, 110, 186, + 167, 185, 147, 189, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 177, 0, 158, 136, 197, 155, 189, 166, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 85, 0, 155, 90, 175, 117, 175, 138, 202, 165, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 139, + 0, 120, 68, 51, 123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 119, 0, 78, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h new file mode 100644 index 0000000000000000000000000000000000000000..d19bf8f067d7329dcda0b866d0d323b92f175e61 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_MICRO_FEATURES_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_MICRO_FEATURES_DATA_H_ + +extern const int g_yes_micro_f2e59fea_nohash_1_width; +extern const int g_yes_micro_f2e59fea_nohash_1_height; +extern const unsigned char g_yes_micro_f2e59fea_nohash_1_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_MICRO_FEATURES_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc index 4e54ff670eb9badd648aee99cf154c0d3b988bff..6f0c2581771e87e69481726adaea4fab3108640b 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/micro_interpreter.h" @@ -32,7 +32,8 @@ TF_LITE_MICRO_TEST(TestInvoke) { // Map the model into a usable data structure. This doesn't involve any // copying or parsing, it's a very lightweight operation. - const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data); + const tflite::Model* model = + ::tflite::GetModel(g_tiny_conv_micro_features_model_data); if (model->version() != TFLITE_SCHEMA_VERSION) { error_reporter->Report( "Model provided is schema version %d not equal " @@ -61,12 +62,12 @@ TF_LITE_MICRO_TEST(TestInvoke) { TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size); TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); TF_LITE_MICRO_EXPECT_EQ(49, input->dims->data[1]); - TF_LITE_MICRO_EXPECT_EQ(43, input->dims->data[2]); + TF_LITE_MICRO_EXPECT_EQ(40, input->dims->data[2]); TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type); // Copy a spectrogram created from a .wav audio file of someone saying "Yes", // into the memory area used for the input. - const uint8_t* yes_features_data = g_yes_f2e59fea_nohash_1_data; + const uint8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data; for (int i = 0; i < input->bytes; ++i) { input->data.uint8[i] = yes_features_data[i]; } @@ -102,7 +103,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { TF_LITE_MICRO_EXPECT_GT(yes_score, no_score); // Now test with a different input, from a recording of "No". - const uint8_t* no_features_data = g_no_f9643d42_nohash_4_data; + const uint8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data; for (int i = 0; i < input->bytes; ++i) { input->data.uint8[i] = no_features_data[i]; } diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..85113a90dcf610a38f21e17f0b303befd6c1e071 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.cc @@ -0,0 +1,1477 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h" + +const int g_no_1000ms_sample_data_size = 16000; +const int16_t g_no_1000ms_sample_data[16000] = { + 5, 1, -10, -16, -14, -10, -4, -5, -10, -15, -13, + -17, -22, -21, -23, -25, -22, -26, -28, -31, -28, -25, + -20, -24, -21, -13, -7, -1, -1, 3, 3, 4, -4, + -6, -8, -10, -13, -4, -2, 5, 8, 11, 26, 28, + 34, 32, 34, 30, 21, 18, 15, 13, 8, 5, 14, + 13, 7, 8, 4, -5, -7, -4, -9, -13, -17, -21, + -16, -14, -12, -12, -14, -11, -9, -2, 5, -1, 2, + 0, 2, 1, -3, -13, -14, -16, -11, -10, -9, -13, + -17, -19, -25, -21, -21, -20, -13, -5, -3, 0, 3, + 6, 5, 1, 0, -1, -7, -10, -11, -9, -6, -7, + -11, -10, -5, -14, -20, -23, -22, -22, -19, -15, -12, + -6, -5, 3, 13, 16, 17, 25, 26, 28, 34, 34, + 33, 34, 30, 21, 22, 18, 13, 20, 22, 24, 27, + 26, 23, 21, 18, 9, 5, -2, -7, -8, -10, -8, + -8, -4, 2, 2, -1, -7, -10, -8, -12, -13, -15, + -9, -5, -4, -3, -6, -11, -11, -18, -16, -13, -10, + -12, -6, 0, -2, 0, -3, -4, -8, -12, -19, -16, + -17, -19, -23, -30, -33, -36, -38, -39, -40, -36, -37, + -32, -27, -25, -31, -38, -41, -47, -52, -50, -42, -32, + -16, -7, -3, 0, -1, -1, -5, -16, -23, -29, -34, + -33, -27, -17, -11, 1, 4, 10, 18, 21, 24, 24, + 25, 30, 34, 30, 29, 26, 23, 20, 15, 14, 13, + 14, 16, 23, 28, 21, 23, 21, 13, 12, 12, 14, + 17, 21, 26, 27, 30, 30, 26, 20, 15, 15, 9, + 8, 9, 10, 7, 8, 7, 1, -2, -6, -10, -10, + -12, -15, -10, -7, -6, -5, 0, -3, -3, -12, -25, + -35, -49, -53, -49, -51, -48, -46, -48, -39, -33, -31, + -37, -42, -47, -49, -46, -47, -47, -46, -42, -39, -33, + -26, -23, -14, -8, -9, -7, -10, -11, -13, -13, -19, + -20, -16, -11, -9, 7, 16, 21, 29, 27, 29, 28, + 21, 14, 13, 17, 19, 20, 18, 13, 17, 16, 18, + 20, 17, 13, 16, 23, 26, 26, 25, 27, 31, 30, + 31, 34, 32, 35, 32, 36, 31, 26, 23, 27, 27, + 29, 27, 26, 32, 31, 28, 26, 23, 14, 6, 0, + -4, -7, -9, -10, -8, -3, 4, 12, 11, 15, 11, + 8, 2, -3, -3, -4, -6, -11, -14, -20, -28, -32, + -38, -46, -42, -44, -40, -34, -26, -29, -25, -23, -24, + -17, -21, -26, -23, -25, -19, -10, -11, -10, -10, -12, + -9, -3, 0, -3, -7, -10, -13, -10, -14, -13, -17, + -22, -22, -30, -28, -29, -26, -18, -6, -1, -3, -4, + -6, -10, -13, -10, -14, -16, -11, -15, -9, -3, -6, + -1, 2, 3, 4, 6, 6, 3, 4, 12, 14, 17, + 21, 19, 20, 16, 17, 15, 21, 21, 22, 20, 17, + 16, 16, 20, 17, 15, 9, 5, 11, 18, 24, 28, + 26, 23, 23, 26, 22, 18, 21, 23, 26, 27, 25, + 27, 29, 26, 20, 10, 7, 11, 8, 16, 25, 33, + 37, 38, 39, 35, 30, 20, 13, 9, 6, 5, 13, + 13, 14, 15, 12, 8, 3, 3, 3, 2, 9, 11, + 10, 5, 5, 0, -7, -11, -12, -15, -17, -12, -13, + -18, -19, -21, -24, -22, -27, -34, -36, -36, -32, -20, + -16, -15, -5, -5, -9, -10, -9, -17, -19, -20, -14, + -13, -10, -4, -7, -7, -14, -19, -28, -31, -30, -31, + -23, -19, -20, -12, -11, -14, -16, -20, -18, -20, -21, + -24, -29, -30, -30, -34, -31, -25, -21, -18, -11, -4, + 2, 2, 3, 3, 2, 4, -1, -4, -8, -3, -1, + 7, 15, 18, 22, 20, 20, 16, 16, 14, 13, 21, + 25, 26, 35, 28, 28, 28, 25, 21, 19, 18, 21, + 24, 20, 25, 28, 19, 16, 15, 8, 3, -1, 3, + 5, 13, 18, 25, 31, 33, 39, 36, 36, 32, 36, + 37, 39, 42, 36, 32, 27, 30, 24, 18, 15, 10, + 7, 5, 6, -1, -4, -10, -17, -15, -19, -15, -7, + -4, 3, 0, 3, 4, -2, -7, -13, -21, -23, -28, + -27, -26, -25, -15, -10, -4, -6, -5, -9, -5, -3, + 1, 2, -1, 1, -4, -7, -8, -17, -17, -15, -14, + -9, -5, -7, -6, -9, -16, -15, -15, -16, -16, -11, + -15, -15, -6, -6, -5, -2, 0, -9, -10, -12, -13, + -10, -4, 0, 8, 5, 4, 2, 0, -5, -8, -16, + -15, -12, -3, 9, 17, 24, 26, 30, 28, 22, 17, + 14, 9, 8, 9, 8, 11, 12, 12, 15, 14, 18, + 20, 17, 19, 22, 21, 12, 5, 0, 3, -3, -4, + -6, -7, 1, 8, 8, 8, 10, 2, -3, -8, -15, + -20, -24, -22, -23, -13, -6, -7, -5, -10, -8, -15, + -19, -22, -20, -17, -18, -13, -10, -1, 6, 5, 3, + 1, -5, -11, -10, -14, -19, -15, -13, -8, -2, -3, + -4, -3, -4, -1, 1, 0, -3, -4, -8, -18, -21, + -25, -24, -16, -9, -2, 1, 5, 1, 3, -2, -7, + -10, -23, -30, -29, -23, -9, -3, 4, 11, 11, 6, + 2, 0, -12, -20, -28, -24, -22, -17, -22, -19, -14, + -21, -17, -17, -12, -8, -3, 2, 0, -6, -5, -8, + -12, -17, -27, -34, -31, -30, -27, -19, -14, -14, -14, + -14, -19, -22, -21, -19, -14, -1, 5, 9, 8, 6, + 5, -4, -2, -3, -3, -1, -2, -3, 2, 7, 8, + 7, 6, 6, 3, 2, 1, -2, 0, 6, 11, 18, + 18, 19, 17, 14, 9, 4, 3, 3, 0, -1, 3, + -1, -5, 0, -2, 0, 1, 7, 7, 8, 20, 29, + 33, 31, 24, 14, 5, -6, -11, -8, -11, -2, 6, + 10, 12, 16, 26, 26, 24, 18, 12, 10, 4, 7, + 6, -2, -12, -17, -17, -20, -23, -23, -18, -8, 1, + 3, 5, 6, 3, 0, -6, -12, -12, -15, -12, -7, + 3, 3, 8, 7, 7, 7, 1, -1, -1, 4, 11, + 17, 25, 32, 35, 42, 50, 52, 56, 50, 55, 53, + 52, 47, 40, 38, 30, 26, 27, 28, 29, 25, 23, + 23, 28, 30, 25, 26, 21, 19, 14, 9, 16, 22, + 25, 33, 39, 45, 49, 48, 55, 51, 43, 35, 20, + 14, 13, 23, 25, 24, 20, 22, 28, 22, 22, 17, + 16, 13, 10, 10, 10, 9, 9, 14, 11, 10, 10, + 4, 0, 0, -2, -3, -5, -7, -3, 1, -8, -8, + -9, -4, 4, 9, 11, 14, 11, 6, 8, 3, -6, + -10, -19, -22, -24, -27, -22, -16, -21, -25, -33, -33, + -32, -30, -21, -13, -6, -5, 2, 1, 4, 9, 7, + 5, 1, 1, 8, 6, 7, 6, 0, -6, -15, -18, + -23, -22, -23, -25, -22, -21, -19, -17, -13, -10, -10, + -16, -17, -15, -13, -8, -9, -14, -13, -17, -20, -26, + -28, -31, -29, -26, -23, -13, -10, -6, -1, 5, 7, + 2, -3, -7, -20, -18, -16, -21, -27, -33, -25, -27, + -22, -22, -21, -16, -11, -7, -2, 2, 11, 18, 11, + 9, 4, 1, -1, -6, -4, -5, -9, -12, -16, -25, + -29, -37, -37, -38, -37, -33, -23, -16, -14, -7, -1, + -4, -3, -4, -5, -11, -14, -8, -8, -8, -8, -9, + -4, -14, -21, -22, -21, -18, -15, -2, 3, -3, 0, + -2, 0, -4, -7, -1, -2, 3, 3, -3, -10, -13, + -10, -16, -19, -17, -17, -14, -7, 5, 5, 7, 8, + 12, 7, 0, -5, -13, -17, -18, -14, -7, -4, 3, + 11, 11, 12, 11, 8, 4, -5, -5, -11, -15, -17, + -23, -22, -18, -14, -14, -12, -6, -4, -1, 3, 1, + -4, -10, -22, -29, -30, -26, -15, -2, 6, 16, 21, + 28, 32, 25, 24, 20, 9, 5, 0, 3, 7, 10, + 11, 13, 17, 15, 16, 13, 11, 11, 8, 7, 1, + 1, -5, -2, -2, -1, 4, 8, 17, 22, 24, 24, + 26, 23, 20, 17, 16, 9, 4, 6, 5, 8, 2, + -1, -5, -4, -10, -14, -14, -17, -19, -18, -16, -14, + -6, -3, 1, 3, 0, -4, -6, -4, -1, -1, 2, + 5, 3, 8, 7, 7, 14, 13, 20, 24, 29, 24, + 12, 7, -1, -6, -15, -22, -20, -27, -22, -14, -6, + 2, 7, 9, 9, 2, -3, -7, -8, -10, -9, -3, + -6, -11, -12, -8, -5, -4, -5, -3, 0, 3, 6, + 6, 7, 5, -7, -10, -14, -13, -14, -17, -11, -7, + -4, 1, 1, 4, -4, -8, -18, -23, -23, -25, -19, + -16, -15, -9, 3, 10, 19, 25, 30, 31, 26, 27, + 23, 19, 16, 8, 7, 2, 0, -1, -1, 1, 5, + 6, 6, 1, 3, -1, -7, -11, -17, -19, -19, -7, + 0, 3, 11, 12, 18, 20, 16, 9, -2, -7, -14, + -19, -22, -30, -33, -34, -36, -26, -14, -11, -9, -3, + 0, -2, 1, -2, -3, -5, -12, -15, -19, -14, -9, + -8, -2, -6, -13, -15, -19, -22, -25, -26, -21, -20, + -11, -1, 1, 5, 9, 13, 15, 12, 11, 3, 1, + -1, 0, 8, 13, 16, 16, 15, 16, 15, 12, 9, + 7, 8, 4, 6, 4, 3, 3, 7, 0, -4, -8, + -11, -18, -18, -15, -20, -23, -21, -22, -21, -27, -25, + -15, -7, -2, 8, 9, 8, 8, 3, 3, 7, 8, + 8, 8, 12, 11, 12, 4, -1, -7, -11, -15, -18, + -17, -17, -20, -19, -13, -11, -3, -3, -1, 1, -3, + 1, 1, 8, 10, 15, 24, 26, 29, 34, 36, 26, + 20, 12, -2, -6, -9, -7, -6, 1, 10, 13, 19, + 22, 22, 18, 21, 24, 28, 35, 37, 34, 33, 34, + 34, 30, 19, 15, 10, 19, 21, 23, 24, 21, 19, + 18, 21, 22, 22, 27, 30, 31, 32, 33, 32, 32, + 24, 18, 10, 8, 10, 10, 6, 2, -7, -14, -22, + -29, -27, -29, -32, -30, -28, -23, -22, -11, -11, -13, + -3, 2, -1, 1, 1, -3, -7, -5, -7, -11, -17, + -23, -25, -26, -27, -26, -23, -14, -5, -3, -1, -2, + -2, -1, 1, -2, -7, -4, 2, 4, 10, 13, 6, + 3, -2, -6, -7, -11, -17, -21, -15, -7, -2, 11, + 16, 22, 25, 25, 23, 24, 23, 21, 22, 25, 23, + 17, 17, 12, 8, -2, -4, 1, 0, 4, 9, 8, + 10, 9, 9, 15, 13, 10, 8, 1, 1, -3, 1, + 4, 11, 10, 9, 5, 5, 4, 1, -1, -4, 0, + 8, 7, 4, 3, 3, 0, -9, -16, -19, -20, -21, + -18, -16, -11, -10, -9, -13, -12, -19, -25, -21, -15, + -5, 8, 14, 21, 24, 18, 20, 17, 6, 1, -2, + -2, 1, 1, 4, 1, -3, 2, 0, -3, -3, -4, + 1, 0, -5, -11, -17, -21, -20, -20, -20, -14, -9, + -3, 3, 7, 5, 3, 1, -1, -3, -4, -1, 1, + -5, -1, -1, -7, -11, -14, -12, -14, -17, -18, -23, + -29, -24, -27, -19, -12, -13, -2, -3, 4, 4, 0, + -3, -5, -2, -1, -5, -6, -7, -7, -7, -9, -13, + -9, -4, 1, 1, 1, -4, -11, -8, -15, -19, -19, + -12, -5, 1, 7, 12, 8, 10, 10, 10, 11, 11, + 19, 12, 9, 9, 2, -4, -13, -22, -24, -25, -24, + -26, -19, -14, -10, -1, 5, 4, -1, -4, -5, -10, + -14, -11, -8, -10, -8, -9, -7, -8, -6, -1, -5, + -10, -18, -27, -29, -24, -19, -11, -7, 1, 10, 8, + 8, 5, 2, -5, -1, -1, 0, 2, 2, -2, -8, + -8, -14, -26, -25, -23, -18, -9, 2, 2, 7, 13, + 6, 7, 5, 4, 3, 2, 1, 7, 2, -1, 1, + -2, 2, 0, -2, -6, -3, 5, 7, 9, 6, 5, + 4, 2, 0, -1, -3, 3, 7, 6, 14, 18, 22, + 20, 22, 19, 13, 9, 2, -8, -11, -6, -2, -3, + -3, 0, 0, 0, 1, -1, -2, 1, 7, 11, 10, + 11, 17, 17, 11, 11, 4, 6, 6, 13, 19, 22, + 23, 27, 25, 24, 22, 14, 11, 13, 7, 0, -3, + -9, -11, -7, -7, -6, -4, 1, 7, 9, 15, 18, + 18, 10, 5, 3, -3, -6, -5, -8, -5, 4, 8, + 8, 11, 10, 9, 4, 4, 1, -3, -10, -11, -8, + -16, -20, -22, -19, -12, -7, -10, -10, -13, -14, -11, + -11, -13, -18, -21, -19, -17, -22, -18, -22, -22, -16, + -9, -3, 0, 3, 6, 3, 3, -3, -6, -9, -14, + -1, 14, 21, 30, 37, 33, 27, 26, 19, 15, 14, + 11, 20, 12, 9, 10, 19, 20, 19, 22, 20, 22, + 17, 13, 14, 10, 8, 12, 15, 13, 12, 12, 12, + 9, 10, 11, 11, 9, 6, 4, 5, -2, 1, 1, + -1, 5, 1, 8, 6, 3, -1, -4, -15, -24, -27, + -26, -23, -19, -9, -3, -4, -9, -9, -10, -16, -22, + -19, -18, -15, -2, 3, 5, 6, 7, 8, 11, 3, + 1, 2, 1, 1, 0, -4, -13, -18, -19, -19, -20, + -23, -15, -10, -5, -3, -1, -1, -1, 3, -1, 0, + -8, -11, -13, -14, -13, -8, -6, -3, 1, 1, 0, + 0, 5, 4, 5, 5, 5, 4, 0, -1, -4, -13, + -22, -21, -28, -26, -22, -28, -23, -23, -14, -11, -10, + -7, -8, -5, -4, 1, 9, 10, 15, 19, 21, 17, + 18, 19, 16, 13, 16, 21, 27, 29, 22, 22, 13, + 4, 1, 0, -5, -6, -2, 3, 5, 8, 6, 9, + 10, 2, -3, -9, -8, -4, -2, -7, -6, -4, -8, + -6, -8, -11, -8, -8, -6, 2, -2, -2, -1, 2, + 4, 8, 5, -1, -8, -10, -7, -6, -5, -6, -5, + 6, 13, 22, 28, 33, 31, 38, 35, 28, 27, 22, + 22, 23, 26, 23, 21, 28, 28, 23, 23, 22, 21, + 20, 14, 6, -1, -5, -8, -5, -1, 2, 5, 5, + 7, 8, 5, 4, 0, 3, 6, 10, 13, 13, 6, + 4, 4, 0, -2, -3, 0, 3, 5, 7, 9, 7, + 6, 10, 8, 3, 4, -1, -4, -2, 0, -2, -2, + -2, -3, 5, 8, 6, 4, -1, -7, -6, -7, -12, + -18, -11, -2, -1, -1, -1, -2, -7, -7, -3, -3, + -5, -6, -6, -6, -6, -6, -9, -12, -9, -5, 1, + 3, 5, 5, 8, 7, 3, -5, -3, -2, 2, 3, + 5, 5, -1, -2, -4, -8, -9, -9, -7, -12, -13, + -17, -19, -16, -19, -21, -21, -19, -11, -6, -3, 7, + 8, 6, 2, 0, 1, 1, -2, -5, 0, -2, 2, + 1, 2, 0, -2, -1, -10, -21, -25, -24, -21, -19, + -14, -8, -3, -5, 0, 0, -5, -6, -3, -6, -9, + -13, -19, -20, -21, -21, -24, -25, -27, -27, -29, -26, + -19, -14, -14, -13, -8, -5, -10, -10, -6, 1, 4, + 14, 22, 23, 24, 20, 20, 18, 14, 11, 9, 6, + 8, 12, 15, 18, 18, 12, 8, 9, 9, 9, 7, + 4, 9, 5, 6, 5, 3, 3, -1, -1, -6, -10, + -6, -8, -3, 0, -2, -3, -2, -6, -6, -7, -3, + -3, -3, -2, 1, -1, -10, -7, -13, -21, -23, -20, + -19, -18, -18, -19, -15, -16, -7, -6, -9, -13, -12, + -6, -1, 3, 6, 7, 5, 3, -3, -11, -18, -20, + -26, -29, -27, -27, -24, -30, -29, -28, -23, -18, -21, + -18, -15, -9, 1, 9, 17, 21, 23, 18, 14, 5, + -1, -2, -1, 0, 3, 6, 5, 4, 4, 0, -1, + 1, -4, -9, -13, -11, -20, -21, -19, -14, -9, -4, + 1, 6, 10, 16, 24, 30, 35, 31, 38, 37, 35, + 39, 36, 36, 32, 30, 33, 31, 24, 19, 12, 4, + -1, -7, -11, -7, -5, -3, 2, 6, 10, 16, 19, + 21, 21, 16, 10, 14, 12, 14, 13, 12, 12, 5, + 6, 2, 0, 1, 3, 4, 6, 9, 6, 2, -1, + -3, -10, -15, -13, -17, -19, -15, -16, -15, -13, -8, + -8, -7, -10, -5, -2, 1, 5, 5, 11, 10, 12, + 10, 9, 9, 15, 23, 33, 35, 33, 34, 34, 35, + 34, 24, 30, 26, 23, 21, 20, 15, 10, 3, 4, + 0, -7, -8, -9, -9, -8, -4, 0, 5, 5, 2, + 3, -2, 0, 0, -1, 0, -1, 1, 2, 6, 3, + 1, -9, -5, -6, -2, -8, -12, -9, -10, -7, -8, + -8, -6, -2, -2, -1, 0, -2, -1, -8, -18, -19, + -27, -37, -42, -40, -39, -33, -30, -23, -16, -16, -9, + -13, -11, -10, -10, -8, -3, -1, 2, 0, -1, 2, + 6, 4, 8, 10, 17, 21, 28, 31, 33, 28, 20, + 12, 8, -3, -5, -4, -3, 2, 6, 9, 8, 2, + 7, 4, -6, -9, -15, -13, -15, -17, -14, -11, -12, + -5, -6, -4, -6, -11, -11, -7, -4, -6, -8, -13, + -10, -7, -12, -11, -12, -13, -12, -9, -9, -10, -10, + -6, -8, -8, -7, -9, -9, -7, 2, 5, 5, 6, + 3, 4, 6, 3, -1, -2, -2, -2, 1, 5, 3, + 4, 2, -2, -7, -9, -13, -11, -8, 2, 12, 23, + 31, 37, 41, 40, 37, 36, 31, 31, 27, 28, 24, + 13, 16, 14, 15, 9, 4, 4, 5, 4, 7, 12, + 16, 14, 11, 13, 6, -2, -4, -1, -3, 3, 6, + 6, 9, 7, 9, 7, 5, 0, 1, -1, -2, -4, + -1, 0, 0, -4, 0, -4, -9, -15, -16, -18, -15, + -10, -6, -8, -5, -2, -2, 0, 4, 7, 0, -2, + -3, 4, 3, 2, -1, -3, -8, -19, -19, -19, -16, + -8, -5, 0, 1, 2, 1, -1, -2, -10, -12, -10, + -4, 3, 4, 2, 7, 8, 4, 1, -5, -5, -4, + -1, 9, 10, 12, 15, 15, 14, 11, 20, 16, 19, + 18, 26, 29, 21, 23, 16, 16, 3, -3, -4, -10, + -12, -10, -6, -7, -12, -17, -14, -16, -19, -13, -10, + -13, -13, -2, 2, 3, 7, 13, 22, 21, 21, 21, + 24, 27, 23, 22, 20, 17, 17, 16, 13, 11, 5, + 1, 1, 5, 5, 3, 2, -1, 2, -5, -6, -3, + -11, -9, -6, -5, -10, -4, -1, 1, 2, -1, -4, + -4, -9, -9, -7, -3, 3, -2, 1, 1, 4, -4, + -8, -8, -17, -17, -13, -13, -18, -18, -25, -27, -21, + -22, -18, -7, -1, 5, 9, 11, 11, 11, 15, 11, + 4, 1, 6, 8, 17, 12, 10, 5, -2, -3, -14, + -17, -25, -26, -22, -20, -13, -12, -12, -13, -10, -4, + -6, -6, -4, -6, -4, 0, -3, -7, -7, -10, -17, + -14, -9, -3, 4, 4, 6, 1, 0, 0, -6, -3, + -4, -3, -6, -9, -9, -5, 0, 1, 2, -2, 3, + -1, -4, -5, -11, -14, -17, -14, -12, -14, -19, -21, + -25, -35, -40, -39, -31, -24, -13, -4, -1, 0, 0, + 2, -2, -5, -8, -8, -9, -6, -2, 0, -5, -6, + 2, 5, 4, 1, 6, 8, 9, 14, 13, 19, 15, + 19, 13, 14, 20, 16, 16, 14, 14, 17, 13, 12, + 11, 6, -1, -7, -9, -10, -11, -2, 8, 12, 12, + 12, 8, 4, 1, -3, -4, -4, -3, 1, 9, 14, + 16, 10, 12, 9, 6, 4, -1, 8, 6, 3, 6, + 1, -11, -10, -10, -13, -9, -6, -2, -2, 9, 13, + 17, 17, 19, 17, 16, 9, -2, -5, -5, -3, -9, + -8, -8, -12, -17, -16, -18, -15, -9, -7, 1, 10, + 17, 18, 23, 25, 23, 20, 15, 17, 18, 23, 33, + 40, 43, 45, 51, 53, 47, 36, 27, 10, 5, 1, + 4, 5, 4, 0, 0, 6, 7, 8, 9, 3, 2, + 1, 0, -1, 3, 5, 5, 13, 7, 4, 4, 3, + 11, 17, 21, 31, 31, 31, 31, 28, 26, 23, 19, + 16, 17, 16, 10, 10, 12, 9, 7, -1, -7, -12, + -15, -15, -15, -13, -13, -16, -19, -19, -23, -31, -34, + -38, -39, -31, -30, -21, -21, -18, -11, -16, -20, -25, + -22, -18, -14, -7, -8, -3, 2, 10, 13, 12, 10, + 6, 2, 0, 0, 0, -6, -4, -1, 0, 0, -1, + -2, 1, 3, 8, 9, 3, 6, 2, -4, -2, -3, + -7, -4, -3, 2, 6, 8, 10, 12, 15, 11, 15, + 12, 13, 14, 15, 18, 14, 8, 4, 4, 3, -4, + -5, -4, -2, -3, -2, 4, 9, 13, 18, 21, 20, + 18, 15, 11, 6, 7, 10, 8, 6, 3, -3, -7, + -14, -21, -29, -33, -32, -26, -17, -12, -11, -9, -3, + -10, -13, -18, -23, -21, -26, -26, -24, -28, -25, -29, + -30, -30, -27, -17, -7, 2, 10, 13, 16, 16, 17, + 18, 17, 19, 19, 20, 15, 14, 16, 14, 10, 5, + 0, -4, -18, -21, -25, -20, -16, -13, -8, -5, 2, + 6, 11, 12, 18, 16, 18, 15, 13, 17, 18, 22, + 21, 25, 26, 25, 26, 28, 31, 27, 20, 10, 3, + -6, -10, -16, -19, -18, -15, -13, -10, -2, 0, 2, + 4, 3, 5, -1, 0, 1, 2, 0, -2, -1, -6, + -5, -7, -12, -10, -9, -4, -1, 3, 4, 2, 4, + 4, 3, -3, -6, -11, -14, -15, -23, -25, -29, -30, + -28, -25, -22, -19, -21, -19, -11, -7, -7, -3, -3, + -6, -8, -13, -10, -10, -5, 1, 4, 9, 7, 6, + 6, 4, -5, -11, -8, -6, -3, 0, 3, 7, 11, + 7, 3, 5, 6, 10, 12, 14, 16, 8, 5, -1, + -1, 4, 0, 0, -3, -5, -5, -4, -2, -2, 1, + 4, 7, 5, 10, 9, 6, 9, 12, 19, 28, 32, + 32, 33, 31, 29, 20, 17, 16, 14, 15, 6, -2, + -5, -7, -10, -10, -11, -9, -6, -3, 8, 10, 10, + 10, 12, 12, 7, 7, 5, 3, 2, 2, -2, -5, + -4, -7, -2, -6, -5, -6, -11, -14, -13, -10, -11, + -15, -16, -11, -11, -11, -10, -16, -15, -15, -16, -10, + -11, -11, -5, -1, 2, 1, 2, 0, 1, 4, 8, + 5, -4, -2, -4, -12, -18, -24, -20, -25, -14, -3, + 4, 11, 13, 13, 7, 4, -4, -9, -13, -17, -10, + -6, -1, 0, 2, 2, -1, 1, -8, -18, -22, -19, + -19, -22, -20, -22, -20, -17, -12, -9, -4, 3, 9, + 9, 9, 7, 6, 13, 10, 11, 8, 4, -1, 5, + 7, 7, 8, 4, 2, 2, -2, -8, -11, -16, -18, + -12, -12, -9, -2, 3, 3, 5, 5, 6, 9, 11, + 20, 22, 26, 30, 28, 22, 15, 15, 10, 11, 9, + 6, 9, 9, 11, 10, 12, 10, 8, 8, 7, 9, + 4, 3, 9, 5, 1, 2, 0, -3, -3, 0, 3, + 0, -2, 1, 4, 6, 4, 0, 1, -4, -13, -13, + -11, -20, -21, -15, -17, -23, -22, -24, -29, -24, -29, + -32, -21, -13, -11, -9, -9, -8, -13, -11, -11, -11, + -11, -17, -17, -21, -23, -27, -32, -33, -32, -31, -35, + -31, -26, -24, -18, -10, -1, 5, 13, 17, 15, 13, + 8, 4, 6, 9, 10, 13, 11, 12, 13, 9, 5, + 6, 8, 12, 21, 25, 24, 23, 16, 8, 7, 0, + -3, -8, -9, -2, 1, 11, 18, 25, 30, 31, 27, + 21, 19, 19, 18, 18, 22, 24, 16, 14, 8, 2, + -4, -9, -7, -10, -6, -8, -8, -13, -14, -11, -13, + -8, -7, 6, 9, 10, 15, 17, 11, 11, 9, 2, + 2, -2, 2, -6, -6, -7, -14, -11, -12, -13, -17, + -22, -25, -30, -24, -16, -4, 5, 2, 7, 5, 2, + -1, 1, -4, -4, 4, 8, 8, 5, 6, 6, 2, + 1, -2, -9, -14, -17, -16, -15, -14, -12, -11, -6, + -6, -2, -3, -3, 6, 13, 18, 27, 27, 26, 24, + 22, 19, 18, 19, 12, 8, 7, -2, 0, -6, -8, + -6, -4, -6, -14, -16, -16, -15, -12, -2, 6, 12, + 16, 18, 14, 16, 13, 12, 17, 16, 17, 17, 12, + 13, 10, 14, 14, 10, 2, -1, -3, -5, -10, -15, + -13, -20, -21, -21, -21, -19, -20, -18, -8, -4, -1, + -1, 4, 2, -3, 0, -5, -5, -3, -1, 0, 6, + 5, 6, 7, 7, 3, 2, 1, -5, -3, 0, 3, + 5, 7, 4, 10, 15, 15, 11, 6, 8, 9, 14, + 19, 18, 14, 12, 16, 15, 11, 9, 9, 5, 4, + 0, -7, -12, -18, -22, -29, -32, -36, -37, -38, -39, + -32, -24, -20, -14, -10, -2, 0, 1, 9, 13, 21, + 26, 31, 35, 40, 38, 32, 33, 25, 14, 11, 7, + 1, -1, -6, -5, -11, -20, -22, -19, -16, -9, 2, + 9, 14, 14, 13, 13, 12, 10, 3, 2, 1, 0, + 6, 5, -1, -4, -13, -17, -21, -25, -29, -30, -23, + -14, -4, 4, 11, 11, 12, 13, 13, 5, 6, 6, + 7, 5, 5, 9, -2, 3, 0, -2, -3, -5, -1, + 3, 9, 16, 18, 17, 17, 11, 5, 1, -4, -13, + -12, -7, -7, 1, 6, 4, 2, 3, 1, 1, 0, + -1, -5, -5, -3, -5, -1, 8, 9, 7, 12, 7, + 6, 4, 3, -1, -1, -4, -14, -16, -18, -24, -34, + -44, -37, -37, -36, -28, -19, -15, -6, -2, -3, 2, + 5, 6, 3, 6, 6, 9, 7, 3, -4, -15, -25, + -34, -37, -41, -41, -38, -33, -27, -22, -14, -15, -18, + -18, -15, -8, -7, -2, 2, 0, 4, 12, 13, 10, + 17, 20, 16, 17, 23, 24, 22, 24, 22, 28, 26, + 24, 22, 26, 28, 27, 23, 17, 10, 4, 4, 1, + -1, 0, 4, 9, 15, 14, 15, 14, 14, 13, 8, + 0, -1, -11, -13, -4, -3, -5, -3, -1, -6, -5, + -7, -4, -2, 2, 7, 15, 20, 14, 13, 8, 2, + -6, -15, -23, -25, -20, -22, -20, -14, -10, -4, -2, + 1, -10, -15, -12, -8, -8, -7, -5, -10, -12, -20, + -28, -26, -24, -16, -8, -5, 3, 8, 9, 12, 12, + 12, 14, 13, 12, 10, 13, 23, 29, 28, 33, 36, + 32, 28, 23, 25, 26, 30, 34, 27, 22, 16, 12, + 3, -6, -13, -13, -15, -14, -9, -11, -13, -13, -16, + -15, -20, -22, -20, -32, -30, -29, -24, -18, -18, -18, + -13, -15, -15, -16, -17, -10, -11, -12, -15, -17, -17, + -19, -21, -22, -26, -28, -21, -18, -14, -5, 2, 6, + 7, 5, 3, -2, 0, -4, -2, -3, -6, -9, -12, + -11, -11, -19, -23, -20, -21, -16, -19, -23, -22, -24, + -21, -22, -17, -15, -8, -1, 4, 14, 18, 23, 24, + 25, 25, 18, 15, 7, 2, 14, 19, 22, 20, 23, + 22, 20, 19, 20, 17, 16, 21, 22, 21, 18, 9, + 3, -6, -14, -19, -30, -36, -40, -32, -22, -21, -16, + -7, -1, 3, 2, 3, 6, 9, 16, 20, 22, 26, + 27, 29, 32, 30, 23, 19, 20, 21, 18, 22, 24, + 15, 14, 9, 9, 7, 6, 9, 9, 16, 22, 20, + 18, 18, 9, -1, -10, -16, -19, -22, -22, -20, -16, + -11, -5, 0, 1, 4, 2, 0, 3, 5, 10, 8, + 12, 10, 11, 9, 8, 7, -3, -4, -10, -11, -5, + 2, 8, 12, 12, 13, 14, 15, 14, 12, 10, 14, + 13, 8, 0, -2, -3, -9, -6, -13, -21, -12, -12, + -8, -9, -14, -16, -19, -23, -22, -23, -30, -26, -17, + -14, -9, -2, 3, 11, 16, 17, 17, 11, 12, 13, + 12, 9, 8, 7, 10, 17, 14, 13, 9, 7, 6, + 5, 10, 10, 6, 10, 9, 1, -5, -10, -12, -17, + -16, -14, -13, -10, -6, -2, 0, -1, 2, 2, -1, + 2, 6, 12, 18, 23, 22, 23, 24, 20, 16, 10, + 6, 9, 16, 15, 15, 16, 14, 8, 4, 0, -3, + -7, -4, -5, -5, 0, -4, 1, 1, 1, -4, -10, + -17, -25, -25, -28, -28, -27, -25, -20, -20, -20, -22, + -14, -11, -4, 4, 6, 11, 10, 12, 9, 6, 2, + -6, -10, -12, -7, -1, -6, 0, 1, 2, 5, 1, + -1, 1, -3, -6, -4, -5, -4, -6, -5, -7, -10, + -10, -8, -11, -9, -2, 9, 15, 14, 20, 19, 19, + 16, 16, 11, 3, 2, 2, 5, 4, 5, 3, -1, + -1, -6, -11, -16, -18, -18, -12, -17, -18, -13, -15, + -5, -4, -3, -1, 2, 6, 7, 11, 14, 17, 17, + 18, 21, 18, 19, 18, 23, 27, 36, 32, 35, 30, + 24, 25, 18, 10, 3, -1, -4, -11, -16, -21, -33, + -37, -35, -36, -35, -30, -26, -26, -21, -10, -7, -3, + -4, -3, -3, -9, -12, -16, -25, -22, -11, -6, 2, + 5, 7, 4, -2, -8, -16, -23, -30, -28, -23, -20, + -11, -11, -8, 5, 2, -3, -1, -11, -15, -10, -13, + -8, -8, -12, -9, -10, -15, -8, -4, -3, 7, 6, + 13, 20, 25, 24, 25, 27, 28, 25, 23, 22, 27, + 28, 27, 30, 28, 26, 20, 16, 13, 7, 2, 1, + 6, 3, -4, -6, -13, -18, -19, -21, -15, -3, -1, + 10, 16, 17, 20, 24, 28, 28, 26, 26, 28, 27, + 24, 23, 20, 20, 24, 20, 17, 14, 6, 0, 2, + 1, 0, -3, -7, -12, -18, -29, -28, -30, -32, -23, + -27, -25, -20, -17, -13, -11, -14, -17, -21, -22, -18, + -11, -12, -6, -8, -9, -5, -6, -10, -18, -19, -16, + -13, -9, -6, -7, -13, -10, -14, -22, -30, -37, -35, + -37, -35, -34, -36, -30, -23, -17, -16, -16, -11, -6, + -2, 3, 7, 7, 6, 7, 7, 13, 21, 20, 22, + 23, 22, 24, 17, 5, -1, -2, -8, -13, -14, -17, + -24, -28, -23, -22, -19, -12, -14, -10, -14, -21, -20, + -21, -22, -13, -6, -1, 6, 4, 10, 11, 8, 10, + 10, 17, 20, 27, 34, 32, 26, 26, 24, 17, 13, + 6, 9, 12, 15, 17, 12, 11, 9, 3, -3, -3, + -8, -9, -4, -2, -2, 2, 1, -1, -3, -7, -8, + -11, -15, -8, -5, 1, 9, 7, 10, 13, 17, 14, + 12, 8, 6, 3, 6, 9, 8, 5, 0, -2, 1, + 1, -3, -6, -12, -17, -17, -23, -28, -33, -31, -29, + -30, -35, -28, -25, -17, -5, 0, 6, 10, 14, 27, + 31, 26, 31, 30, 32, 41, 42, 42, 43, 34, 32, + 21, 12, 2, 1, -3, -1, 8, 13, 20, 19, 18, + 19, 13, 8, 5, 7, 6, 7, 6, 4, 3, -2, + 0, 2, -4, -1, -3, 2, 12, 22, 33, 32, 31, + 35, 35, 34, 32, 26, 27, 26, 21, 17, 10, 1, + -3, -14, -21, -19, -21, -19, -24, -24, -19, -16, -13, + -16, -13, -15, -17, -12, -9, -4, 7, 19, 27, 33, + 37, 34, 35, 30, 24, 23, 25, 21, 20, 18, 15, + 12, 13, 8, 2, -4, -12, -18, -17, -14, -10, -14, + -8, -14, -14, -12, -14, -19, -23, -31, -32, -28, -30, + -22, -20, -13, 1, 0, 6, 14, 15, 20, 22, 20, + 16, 9, 2, 1, 3, 6, 7, 9, 10, 14, 17, + 16, 14, 4, -7, -16, -31, -40, -41, -40, -38, -34, + -40, -37, -33, -28, -22, -17, -11, -10, -12, -5, -5, + -8, -4, 0, -1, 1, 1, 6, 11, 14, 22, 25, + 28, 31, 32, 32, 31, 31, 20, 13, 12, 5, 4, + 4, 2, 0, -3, -6, -8, -4, -4, -4, -1, 7, + 9, 10, 13, 13, 16, 10, 7, 3, 6, 8, 8, + 15, 20, 23, 18, 15, 12, 4, 1, 0, -4, -4, + -1, 8, 11, 13, 21, 24, 19, 12, 2, -5, -11, + -15, -17, -17, -19, -23, -28, -34, -33, -37, -29, -27, + -24, -17, -13, -8, -6, -2, 5, 3, 4, -2, -5, + -4, 0, 2, 3, 1, -5, -5, -6, -11, -11, -15, + -15, -19, -17, -17, -21, -23, -21, -22, -24, -28, -27, + -25, -15, -8, -1, 2, 2, 3, 3, 2, -2, 0, + 1, -1, 2, 5, 7, 2, 0, 2, -6, -9, -8, + -6, -3, -3, 3, 0, 5, 0, 0, -5, -12, -13, + -20, -14, -14, -6, -5, -2, 0, 6, 11, 9, 9, + 11, 10, 13, 19, 26, 29, 36, 37, 40, 35, 27, + 20, 13, 6, 3, -1, -1, -1, -3, -6, -8, -14, + -16, -25, -28, -23, -21, -24, -22, -22, -22, -24, -28, + -35, -43, -42, -37, -29, -20, -5, 2, 10, 23, 28, + 30, 31, 30, 39, 43, 40, 41, 43, 43, 38, 29, + 18, 14, 12, 3, 6, 3, 3, 0, -1, -3, -5, + -5, -8, -8, -10, -6, -1, 1, 5, 1, 2, 6, + 0, -3, -7, -13, -10, -7, -8, -7, -3, -5, -4, + -4, -4, -5, -2, 2, 3, 6, 4, 3, -1, -2, + -5, -16, -22, -31, -39, -38, -42, -47, -42, -42, -35, + -27, -30, -28, -25, -26, -24, -20, -19, -19, -19, -19, + -14, -16, -13, -9, -10, -1, 8, 17, 21, 28, 26, + 28, 24, 14, 8, 2, 0, -4, -4, -13, -16, -16, + -13, -12, -7, -5, 0, -4, -1, 2, 4, 8, 8, + 10, 10, 10, 14, 16, 17, 23, 20, 27, 27, 27, + 21, 14, 11, 0, -4, -8, -8, -1, -1, 1, 6, + 8, 23, 22, 23, 23, 25, 26, 26, 22, 21, 20, + 22, 17, 12, 8, 3, -2, -2, -4, -5, -3, 1, + 7, 6, 8, 9, 12, 6, 1, -4, -8, -6, -3, + -4, -5, -3, -7, -6, -6, -11, -11, -19, -23, -26, + -28, -34, -41, -41, -44, -45, -47, -40, -39, -33, -29, + -21, -14, -16, -6, -7, -3, 1, 6, 8, 11, 14, + 14, 15, 15, 18, 18, 16, 17, 12, 15, 20, 21, + 19, 21, 23, 22, 21, 16, 12, 8, 7, 7, 10, + 13, 13, 16, 16, 16, 16, 15, 15, 12, 14, 14, + 15, 12, 11, 17, 19, 19, 14, 13, 15, 17, 18, + 20, 24, 27, 24, 19, 11, 10, 1, 0, 0, -1, + 3, 8, 16, 18, 17, 22, 22, 21, 19, 7, 0, + 1, -1, -2, -1, -6, -8, -12, -14, -20, -21, -24, + -19, -9, -4, -3, 2, 2, 3, 0, -10, -19, -23, + -29, -31, -35, -29, -33, -28, -25, -25, -19, -22, -23, + -24, -21, -17, -15, -17, -13, -15, -12, -15, -14, -14, + -12, -9, -5, 1, 9, 13, 13, 17, 17, 15, 11, + 12, 8, 13, 20, 24, 30, 29, 33, 30, 26, 23, + 13, 9, 4, 3, 3, 5, 3, 2, 5, 3, 2, + 1, 3, 6, 10, 14, 19, 23, 21, 20, 21, 17, + 11, 5, -3, -7, -12, -15, -16, -13, -15, -13, -7, + -4, -5, -5, -1, 5, 11, 8, 7, -2, -2, -5, + -6, -1, -2, 0, 2, 8, 13, 15, 17, 15, 16, + 10, 13, 3, -1, -4, -4, -4, 0, 8, 13, 15, + 9, 11, 9, 12, 9, 10, 10, 5, 11, 16, 21, + 20, 15, 13, 5, 3, -3, 1, 1, 0, -4, -7, + -9, -7, -9, -10, -7, -6, -3, -2, -3, -3, -6, + -12, -16, -22, -21, -26, -28, -25, -24, -23, -23, -28, + -32, -29, -26, -26, -23, -29, -23, -16, -11, -7, -9, + -10, -12, -18, -20, -20, -26, -23, -16, -17, -10, -7, + 0, 3, -2, 0, -4, -7, -8, -6, -3, -7, -5, + -5, 1, 0, -3, -2, -3, 5, 7, 10, 19, 17, + 22, 21, 20, 16, 8, 9, 10, 12, 20, 28, 31, + 28, 28, 26, 21, 14, 8, 5, 4, 5, 8, 9, + 9, 13, 17, 16, 14, 20, 17, 13, 16, 17, 18, + 18, 15, 11, 5, -2, -8, -15, -17, -17, -24, -24, + -23, -18, -13, -13, -9, -7, -4, 0, 3, 6, 2, + 2, -4, -5, -5, -4, -4, -2, 2, 6, 10, 7, + 4, 2, -2, -3, -8, -10, -14, -27, -29, -37, -36, + -29, -27, -19, -7, -3, 0, -2, 2, 8, 13, 18, + 15, 10, 10, 6, 1, -5, -12, -17, -20, -23, -23, + -22, -19, -17, -10, -6, -3, 2, 0, 4, 11, 14, + 19, 16, 6, 7, 3, 3, 4, 1, 7, 8, 7, + 3, -2, 0, 0, 0, -1, -2, 0, 4, 3, 5, + 9, 9, 12, 7, 5, 0, 0, 1, 0, 2, -6, + -10, -9, -13, -15, -19, -15, -18, -16, -17, -9, -5, + -2, 2, 2, 3, 7, 2, -3, -8, -13, -8, 1, + 8, 12, 15, 17, 17, 11, 7, 0, -4, -8, -8, + -3, -1, -4, -6, -6, -13, -12, -12, -13, -12, -8, + -9, -5, -4, -2, 0, -1, -6, -7, -6, -10, -10, + -8, -6, 1, 5, 6, 15, 18, 16, 12, 12, 12, + 10, 13, 7, 0, -9, -10, -11, -6, -8, -8, -4, + 0, 6, 10, 11, 15, 15, 15, 12, 10, 6, 6, + 11, 12, 20, 25, 23, 25, 18, 12, 6, -1, -4, + -10, -12, -9, -13, -16, -15, -18, -18, -22, -22, -17, + -14, -12, -8, -3, 1, 4, 11, 13, 7, 0, -8, + -11, -11, -13, -14, -12, -11, -9, -6, -5, -2, 1, + 5, 6, 10, 18, 17, 15, 13, 11, 12, 13, 10, + 9, 13, 16, 16, 13, 11, 6, 5, 0, -5, -4, + -3, 2, 6, 5, 6, 11, 14, 20, 23, 28, 27, + 22, 24, 23, 22, 16, 17, 12, 7, -1, -9, -10, + -9, -9, -13, -11, -9, -2, -2, -7, -8, -6, -7, + -12, -12, -10, 0, 5, 11, 13, 11, 10, 7, 3, + 0, 0, 3, 10, 14, 16, 18, 19, 21, 14, 15, + 12, 7, 6, 7, 9, 7, 11, 6, 4, 4, -1, + -9, -12, -12, -14, -9, -9, -6, -5, -4, -6, -7, + -12, -15, -17, -27, -23, -20, -19, -19, -18, -24, -20, + -25, -28, -33, -31, -29, -27, -15, -12, -7, -3, 1, + -3, -3, -5, -8, -6, 0, 13, 17, 24, 25, 23, + 24, 18, 8, -3, -4, -4, -7, -3, 1, 4, 7, + 9, 10, 14, 14, 20, 28, 35, 38, 42, 43, 43, + 39, 30, 27, 19, 15, 8, 10, 12, 19, 25, 26, + 27, 23, 22, 15, 10, 6, 8, 4, 6, 6, 3, + 7, 7, 15, 11, 7, 6, 5, 9, 6, 0, -3, + -14, -21, -21, -30, -39, -42, -40, -37, -37, -36, -32, + -30, -24, -21, -22, -23, -24, -28, -31, -31, -29, -27, + -30, -31, -31, -31, -34, -33, -34, -26, -21, -15, -10, + -5, -3, -2, -3, -6, -5, -11, -14, -10, -5, 0, + 9, 10, 18, 21, 19, 21, 11, 7, 4, 6, 6, + 7, 3, -6, -9, -16, -23, -24, -23, -26, -18, -16, + -11, -8, 0, 6, 5, 6, 10, 8, 8, 16, 24, + 24, 23, 24, 24, 24, 18, 9, 4, -3, -11, -16, + -15, -18, -14, -12, -9, -3, -4, -1, 8, 11, 10, + 19, 21, 21, 23, 20, 22, 15, 9, 7, 5, 3, + 1, 12, 13, 10, 18, 23, 31, 37, 40, 36, 38, + 40, 40, 38, 27, 24, 21, 14, 12, 12, 7, 7, + 15, 18, 19, 18, 17, 18, 14, 12, 11, 7, 5, + 7, 9, 9, 15, 14, 15, 18, 16, 7, 0, -5, + -6, -6, -6, -1, 7, 9, 12, 6, 4, 4, 2, + -1, 2, 3, 3, 5, 4, -1, -13, -19, -29, -34, + -39, -43, -49, -54, -53, -55, -55, -56, -59, -58, -49, + -41, -32, -19, -10, -2, -4, -1, -6, -19, -27, -26, + -27, -27, -21, -22, -20, -26, -26, -20, -20, -20, -21, + -17, -18, -7, -6, -6, -5, -1, 7, 18, 10, 16, + 25, 24, 31, 30, 32, 30, 26, 24, 22, 23, 21, + 23, 21, 24, 19, 17, 13, 12, 15, 6, 2, -5, + -9, -13, -10, -5, 1, 10, 13, 17, 13, 8, 5, + 5, 6, 5, 13, 19, 16, 14, 12, 7, 15, 18, + 19, 16, 4, -1, 0, -1, -2, -9, -15, -19, -21, + -13, -13, -10, -7, -7, -7, -6, -11, -22, -18, -19, + -22, -22, -19, -18, -10, -7, -9, -7, -12, -16, -20, + -27, -35, -37, -37, -33, -24, -14, -4, 8, 14, 19, + 19, 16, 12, 6, 2, -5, -6, -11, -17, -16, -14, + -13, -12, -17, -21, -22, -24, -18, -14, -12, -1, 4, + 9, 17, 14, 9, 13, 14, 13, 14, 14, 12, 11, + 15, 11, 16, 21, 20, 20, 22, 31, 30, 26, 15, + 13, 6, 8, 5, 1, -5, -3, 2, 9, 14, 13, + 16, 17, 18, 13, 10, 8, 7, 9, 12, 21, 23, + 23, 21, 19, 16, 14, 5, -4, -12, -15, -16, -12, + -9, -12, -14, -17, -16, -15, -14, -15, -28, -27, -24, + -12, -8, -3, 3, 9, 15, 18, 25, 25, 31, 32, + 35, 36, 33, 36, 24, 13, 2, -11, -19, -18, -18, + -10, -6, -4, 0, -3, -3, -15, -18, -17, -9, -7, + 2, 5, 7, 6, 2, -2, -12, -16, -16, -9, -3, + 6, 8, 15, 17, 16, 18, 11, 5, -4, -8, -17, + -16, -22, -24, -25, -28, -23, -19, -11, -3, 5, 11, + 22, 26, 29, 24, 14, 12, 7, 6, -2, -1, 2, + 10, 23, 33, 36, 32, 31, 16, 3, -4, -3, -3, + 1, 8, 11, 13, 12, 8, 3, 5, 3, 1, -1, + 4, 2, 3, 8, 5, 5, 1, -2, -1, -3, -1, + 5, 8, 10, 17, 17, 15, 19, 27, 18, 21, 23, + 19, 20, 15, 1, -7, -18, -24, -24, -33, -28, -32, + -30, -30, -30, -30, -29, -30, -41, -43, -50, -51, -49, + -42, -32, -19, -10, 0, 4, -2, 5, 9, 8, 12, + 19, 17, 10, 9, 3, 1, -4, -8, -4, 0, 5, + 7, 10, 9, 12, 0, -6, -7, -13, -16, -10, -10, + -9, -1, -1, -2, -6, -11, -14, -17, -18, -10, -3, + -3, 0, 6, 1, 6, 4, 3, 3, 9, 16, 22, + 28, 27, 32, 18, 21, 25, 20, 21, 18, 18, 22, + 23, 15, 8, -3, -9, -10, -13, -8, 3, 7, 18, + 26, 23, 26, 30, 17, 11, 9, -1, 0, 2, 2, + 12, 15, 6, 1, 0, -5, 2, 1, -3, -1, -6, + -2, -4, -11, -18, -30, -38, -36, -33, -32, -27, -19, + -18, -14, -13, -16, -11, -12, -12, -4, 0, 7, 13, + 13, 10, 11, 6, 3, 3, 3, 4, 10, 4, -1, + -3, -11, -21, -27, -34, -33, -31, -33, -28, -22, -21, + -14, -8, -13, -10, -8, -12, -7, -11, -3, 3, 5, + 7, 7, -1, -12, -13, -17, -21, -8, -2, 4, 7, + 13, 18, 18, 16, 15, 13, 11, 15, 13, 12, 17, + 18, 15, 15, 11, -3, -1, 2, 11, 15, 10, 18, + 13, 10, 12, 9, 2, 2, 4, -1, 6, 9, 11, + 5, 7, 13, 8, 9, 10, 11, 9, 7, 11, 5, + 3, 1, -9, -19, -31, -40, -42, -33, -27, -24, -22, + -20, -25, -20, -12, -17, -23, -23, -25, -25, -20, -18, + -17, -19, -15, -22, -20, -19, -13, -8, -12, 0, 2, + -6, -1, -5, -15, -10, -12, -19, -8, -6, -3, 9, + 5, 12, 22, 10, 9, 12, 5, 8, 28, 13, 20, + 25, 11, 16, 19, 10, 15, 14, 6, 23, 19, 18, + 32, 17, 12, 19, -1, -8, 11, -4, -8, 9, -4, + -6, 0, -10, -7, -3, -8, -11, -11, -23, -7, -4, + -4, 14, 6, 4, 9, 3, -4, 4, 2, 9, 26, + 19, 26, 33, 22, 22, 24, 13, 20, 18, 18, 28, + 28, 19, 24, 16, -1, 1, -12, -34, -28, -25, -27, + -13, 6, 8, 21, 25, 22, 19, 3, 4, 0, -5, + 6, 8, 1, 6, 8, -4, -3, -10, -23, -17, -9, + -10, 3, 6, -1, 3, -10, -22, -28, -49, -49, -36, + -29, -10, 8, -1, 4, 14, -3, -14, -5, -16, -10, + 8, 7, 21, 24, 17, 25, 15, -4, 13, -7, -23, + 0, -7, -14, 12, 1, -18, -10, -27, -43, -31, -34, + -19, -3, -10, 15, 20, -7, 10, 9, -20, 7, 28, + 14, 42, 54, 32, 34, 24, 5, 10, -11, -13, 11, + -6, -4, 31, 7, 0, 34, 3, -9, 5, -24, -33, + -14, -11, -1, 8, 0, 10, 7, -7, 11, 10, -6, + 17, 16, 0, 10, 3, -26, -23, -33, -39, -26, -29, + -18, -6, -9, -1, 5, -11, -6, 7, -6, 1, 13, + 8, 1, 3, -13, -23, -25, -33, -28, -21, -9, 2, + 4, 1, 8, 4, -13, -5, -12, -14, 3, 14, 18, + 26, 30, 21, 20, 15, 15, 10, 5, 13, 11, 20, + 25, 29, 18, 19, 9, -10, -15, -13, -12, 1, 16, + 20, 30, 39, 37, 21, 15, 3, -7, -9, -1, 2, + -6, -7, -10, -20, -19, -19, -31, -25, -12, -15, -13, + -17, -18, -14, -24, -24, -18, -28, -24, -3, 1, 17, + 46, 48, 43, 46, 34, 12, 6, -14, -19, -10, -14, + 3, 15, 3, 7, 7, -13, 4, 9, -2, 3, 22, + 19, 25, 41, 48, 46, 36, 42, 40, 24, 33, 50, + 29, 30, 57, 35, 13, 29, 17, -9, 5, 15, 7, + 13, 38, 47, 40, 56, 72, 42, 29, 40, 18, 14, + 36, 52, 50, 58, 55, 42, 22, 20, 13, -8, 8, + 32, 26, 41, 70, 48, 51, 65, 36, 27, 23, 4, + 5, 1, -3, 2, -8, -23, -6, -30, -46, -24, -40, + -45, -22, -32, -35, -24, -50, -41, -35, -56, -38, -29, + -55, -25, -7, -40, -26, -25, -63, -51, -40, -61, -47, + -38, -38, -5, 2, 3, 26, -1, -7, 8, -20, -17, + 10, -14, -6, 41, 24, 27, 52, 26, 13, 25, 5, + -6, 2, -7, -2, 10, 4, 29, 36, 30, 74, 93, + 91, 131, 150, 132, 167, 177, 158, 189, 188, 178, 200, + 199, 187, 212, 202, 188, 210, 188, 173, 187, 175, 183, + 215, 218, 236, 264, 253, 279, 296, 275, 290, 288, 261, + 261, 261, 230, 216, 199, 157, 160, 147, 115, 108, 84, + 50, 32, 7, -30, -56, -96, -130, -146, -179, -199, -223, + -255, -280, -293, -326, -341, -352, -391, -410, -429, -464, -489, + -507, -538, -559, -577, -602, -634, -656, -679, -696, -702, -700, + -699, -700, -687, -666, -665, -656, -634, -626, -609, -572, -539, + -518, -484, -462, -444, -418, -390, -364, -336, -295, -245, -210, + -175, -127, -97, -63, -28, 10, 45, 83, 121, 167, 222, + 272, 324, 369, 396, 439, 485, 502, 536, 571, 585, 618, + 656, 676, 705, 729, 744, 767, 776, 786, 798, 796, 813, + 849, 855, 865, 883, 862, 843, 834, 794, 781, 778, 767, + 746, 744, 721, 702, 681, 638, 607, 562, 521, 490, 447, + 398, 361, 313, 255, 204, 123, 20, -59, -143, -217, -270, + -328, -400, -462, -529, -607, -666, -737, -797, -854, -906, -936, + -944, -955, -965, -976, -993, -1003, -1007, -1032, -1040, -1045, -1055, + -1039, -1016, -1003, -990, -995, -1026, -1046, -1070, -1079, -1058, -1060, + -1062, -1028, -1010, -1006, -991, -1000, -1004, -987, -981, -958, -921, + -890, -852, -798, -754, -713, -681, -682, -658, -617, -585, -524, + -452, -404, -332, -258, -224, -183, -144, -132, -94, -64, -31, + 37, 99, 147, 219, 280, 329, 389, 439, 483, 563, 632, + 702, 799, 884, 965, 1050, 1107, 1150, 1209, 1260, 1308, 1383, + 1446, 1514, 1582, 1632, 1679, 1727, 1770, 1804, 1837, 1872, 1916, + 1961, 1999, 2038, 2071, 2089, 2097, 2107, 2091, 2084, 2072, 2051, + 2021, 1998, 1940, 1868, 1814, 1734, 1641, 1559, 1480, 1395, 1305, + 1213, 1115, 1015, 901, 785, 667, 520, 381, 256, 110, -26, + -141, -284, -417, -528, -670, -805, -935, -1080, -1206, -1324, -1438, + -1527, -1622, -1725, -1798, -1879, -1956, -2006, -2063, -2128, -2166, -2201, + -2238, -2257, -2292, -2316, -2337, -2357, -2356, -2362, -2382, -2375, -2368, + -2367, -2358, -2337, -2329, -2318, -2296, -2273, -2240, -2195, -2140, -2095, + -2044, -1990, -1932, -1872, -1803, -1737, -1673, -1602, -1520, -1428, -1325, + -1219, -1112, -1006, -896, -780, -681, -591, -481, -388, -294, -189, + -85, 30, 148, 252, 348, 466, 579, 692, 811, 918, 1041, + 1162, 1271, 1389, 1507, 1611, 1735, 1864, 1965, 2085, 2203, 2312, + 2436, 2536, 2614, 2697, 2760, 2812, 2886, 2956, 3010, 3066, 3088, + 3098, 3120, 3110, 3101, 3106, 3108, 3130, 3149, 3139, 3122, 3085, + 3016, 2951, 2874, 2770, 2671, 2559, 2435, 2315, 2198, 2059, 1915, + 1761, 1570, 1387, 1185, 984, 787, 601, 413, 224, 40, -158, + -348, -560, -760, -960, -1147, -1312, -1471, -1621, -1779, -1925, -2069, + -2206, -2333, -2463, -2570, -2664, -2743, -2811, -2860, -2886, -2934, -2976, + -3015, -3057, -3074, -3076, -3079, -3060, -3032, -2998, -2950, -2920, -2893, + -2863, -2837, -2806, -2761, -2715, -2662, -2607, -2554, -2486, -2402, -2325, + -2264, -2190, -2127, -2063, -1989, -1932, -1862, -1788, -1724, -1640, -1545, + -1455, -1346, -1234, -1112, -984, -859, -735, -610, -494, -384, -280, + -176, -68, 40, 140, 244, 363, 478, 596, 739, 876, 1001, + 1128, 1240, 1352, 1474, 1595, 1717, 1853, 1972, 2093, 2215, 2328, + 2432, 2533, 2641, 2744, 2855, 2949, 3055, 3157, 3242, 3329, 3415, + 3479, 3528, 3569, 3588, 3617, 3649, 3676, 3708, 3747, 3751, 3753, + 3744, 3693, 3640, 3576, 3470, 3369, 3248, 3098, 2976, 2838, 2690, + 2557, 2395, 2222, 2055, 1872, 1675, 1488, 1279, 1057, 851, 623, + 393, 180, -74, -315, -537, -771, -979, -1161, -1373, -1558, -1729, + -1932, -2110, -2294, -2478, -2636, -2785, -2917, -3007, -3094, -3183, -3247, + -3319, -3402, -3450, -3510, -3564, -3595, -3622, -3635, -3627, -3635, -3639, + -3620, -3620, -3610, -3596, -3581, -3535, -3495, -3455, -3410, -3361, -3323, + -3265, -3202, -3141, -3078, -3001, -2919, -2830, -2739, -2640, -2540, -2430, + -2320, -2192, -2057, -1909, -1761, -1603, -1422, -1244, -1059, -887, -726, + -570, -425, -256, -92, 69, 238, 411, 557, 728, 910, 1066, + 1229, 1403, 1561, 1727, 1895, 2050, 2208, 2352, 2492, 2638, 2765, + 2893, 3025, 3145, 3263, 3387, 3496, 3595, 3707, 3804, 3884, 3975, + 4046, 4105, 4167, 4204, 4220, 4237, 4243, 4247, 4260, 4255, 4251, + 4246, 4201, 4143, 4092, 3996, 3885, 3772, 3604, 3435, 3283, 3086, + 2923, 2742, 2535, 2341, 2130, 1887, 1649, 1411, 1137, 915, 659, + 398, 163, -81, -351, -580, -814, -1069, -1262, -1476, -1689, -1850, + -2043, -2237, -2395, -2591, -2763, -2918, -3095, -3224, -3319, -3435, -3508, + -3582, -3698, -3772, -3858, -3950, -4008, -4047, -4088, -4093, -4085, -4098, + -4064, -4052, -4057, -4033, -4028, -4018, -3991, -3971, -3933, -3865, -3802, + -3727, -3633, -3562, -3477, -3392, -3300, -3210, -3115, -3018, -2924, -2819, + -2721, -2606, -2490, -2381, -2246, -2111, -1963, -1810, -1638, -1460, -1293, + -1132, -980, -828, -666, -496, -322, -125, 72, 264, 470, 676, + 879, 1087, 1280, 1457, 1633, 1799, 1970, 2152, 2327, 2501, 2678, + 2840, 3007, 3165, 3301, 3434, 3558, 3667, 3791, 3912, 4023, 4140, + 4257, 4359, 4475, 4554, 4614, 4656, 4682, 4697, 4726, 4749, 4775, + 4810, 4812, 4812, 4810, 4768, 4697, 4620, 4502, 4368, 4210, 4031, + 3860, 3663, 3472, 3291, 3076, 2849, 2642, 2392, 2140, 1890, 1610, + 1325, 1064, 782, 494, 231, -50, -329, -593, -861, -1112, -1345, + -1588, -1812, -2022, -2257, -2467, -2682, -2924, -3126, -3317, -3495, -3630, + -3737, -3855, -3941, -4031, -4128, -4200, -4281, -4348, -4388, -4427, -4449, + -4444, -4450, -4458, -4452, -4464, -4460, -4451, -4444, -4425, -4384, -4344, + -4289, -4234, -4160, -4076, -4000, -3917, -3837, -3753, -3669, -3558, -3460, + -3354, -3230, -3111, -2966, -2824, -2665, -2495, -2333, -2151, -1951, -1752, + -1554, -1367, -1222, -1053, -882, -716, -520, -331, -141, 62, 270, + 476, 707, 923, 1133, 1349, 1534, 1735, 1943, 2124, 2317, 2511, + 2668, 2839, 3002, 3140, 3317, 3481, 3615, 3771, 3920, 4050, 4196, + 4319, 4430, 4556, 4657, 4765, 4868, 4945, 4999, 5057, 5075, 5100, + 5123, 5133, 5134, 5127, 5104, 5084, 5058, 4968, 4896, 4750, 4575, + 4381, 4179, 3971, 3776, 3590, 3394, 3209, 2991, 2800, 2535, 2269, + 1972, 1654, 1319, 998, 697, 384, 105, -187, -476, -759, -1047, + -1316, -1579, -1841, -2085, -2317, -2550, -2745, -2938, -3145, -3326, -3523, + -3706, -3859, -3998, -4124, -4218, -4288, -4346, -4386, -4437, -4495, -4550, + -4619, -4680, -4732, -4779, -4813, -4820, -4842, -4825, -4791, -4773, -4742, + -4715, -4709, -4683, -4652, -4605, -4527, -4428, -4315, -4194, -4086, -3978, + -3872, -3779, -3685, -3569, -3458, -3313, -3121, -2921, -2693, -2454, -2230, + -1998, -1783, -1588, -1414, -1240, -1069, -886, -690, -473, -256, -36, + 170, 384, 594, 797, 1015, 1235, 1449, 1664, 1882, 2098, 2311, + 2504, 2681, 2843, 3019, 3171, 3337, 3534, 3709, 3885, 4072, 4235, + 4380, 4524, 4641, 4746, 4864, 4979, 5087, 5213, 5308, 5393, 5450, + 5468, 5475, 5472, 5452, 5462, 5467, 5453, 5451, 5425, 5342, 5255, + 5113, 4914, 4725, 4512, 4273, 4053, 3866, 3632, 3436, 3205, 2955, + 2705, 2420, 2095, 1794, 1503, 1195, 941, 639, 342, 56, -269, + -601, -894, -1208, -1499, -1736, -1994, -2239, -2426, -2652, -2891, -3099, + -3361, -3588, -3793, -4013, -4183, -4302, -4439, -4523, -4613, -4734, -4809, + -4891, -4999, -5056, -5090, -5131, -5092, -5061, -5044, -4987, -4954, -4955, + -4924, -4911, -4873, -4809, -4755, -4673, -4555, -4440, -4316, -4187, -4088, + -3986, -3881, -3802, -3717, -3605, -3495, -3359, -3207, -3063, -2889, -2698, + -2504, -2306, -2088, -1861, -1627, -1415, -1201, -1000, -799, -593, -410, + -220, -7, 203, 412, 634, 865, 1126, 1367, 1602, 1838, 2052, + 2257, 2474, 2659, 2863, 3076, 3255, 3429, 3617, 3773, 3939, 4102, + 4222, 4358, 4501, 4611, 4733, 4846, 4939, 5056, 5147, 5217, 5301, + 5357, 5388, 5428, 5417, 5400, 5430, 5422, 5406, 5442, 5446, 5431, + 5437, 5381, 5304, 5212, 5057, 4874, 4683, 4465, 4249, 4026, 3767, + 3545, 3304, 3021, 2741, 2450, 2113, 1807, 1490, 1151, 841, 544, + 212, -102, -439, -788, -1091, -1413, -1730, -2033, -2336, -2627, -2854, + -3118, -3350, -3560, -3781, -4008, -4194, -4376, -4524, -4640, -4757, -4865, + -4945, -5016, -5083, -5131, -5170, -5184, -5198, -5208, -5211, -5210, -5209, + -5192, -5174, -5154, -5108, -5052, -5002, -4932, -4854, -4780, -4704, -4604, + -4514, -4421, -4309, -4208, -4111, -4004, -3880, -3751, -3622, -3496, -3367, + -3210, -3047, -2867, -2654, -2430, -2177, -1897, -1651, -1417, -1182, -983, + -793, -593, -406, -211, 17, 232, 461, 716, 958, 1197, 1441, + 1674, 1899, 2130, 2355, 2573, 2788, 3004, 3220, 3419, 3612, 3809, + 3973, 4120, 4277, 4433, 4573, 4742, 4902, 5037, 5165, 5282, 5377, + 5460, 5539, 5596, 5654, 5716, 5741, 5759, 5770, 5776, 5762, 5751, + 5737, 5706, 5675, 5644, 5550, 5446, 5324, 5169, 4974, 4767, 4530, + 4289, 4067, 3823, 3621, 3391, 3145, 2878, 2575, 2228, 1890, 1525, + 1149, 807, 473, 145, -152, -454, -769, -1057, -1374, -1703, -2033, + -2372, -2701, -2977, -3258, -3495, -3694, -3897, -4089, -4270, -4483, -4668, + -4840, -5015, -5140, -5225, -5304, -5334, -5350, -5390, -5398, -5403, -5428, + -5438, -5449, -5472, -5463, -5441, -5401, -5333, -5252, -5151, -5051, -4974, + -4880, -4805, -4729, -4626, -4526, -4403, -4248, -4088, -3939, -3778, -3617, + -3464, -3308, -3173, -3027, -2852, -2669, -2461, -2233, -1979, -1713, -1455, + -1216, -996, -796, -610, -397, -198, 21, 272, 517, 775, 1037, + 1295, 1544, 1790, 2007, 2211, 2423, 2634, 2848, 3081, 3319, 3551, + 3792, 4000, 4171, 4303, 4418, 4518, 4596, 4679, 4807, 4913, 5044, + 5172, 5288, 5405, 5518, 5609, 5664, 5713, 5735, 5735, 5737, 5701, + 5691, 5656, 5633, 5611, 5552, 5475, 5394, 5293, 5177, 5064, 4924, + 4737, 4599, 4420, 4237, 4048, 3828, 3623, 3413, 3183, 2915, 2622, + 2308, 1980, 1657, 1261, 901, 549, 205, -85, -383, -688, -969, + -1246, -1530, -1850, -2206, -2561, -2915, -3224, -3482, -3713, -3921, -4107, + -4287, -4470, -4660, -4850, -5057, -5239, -5395, -5540, -5619, -5697, -5724, + -5697, -5675, -5633, -5590, -5579, -5530, -5486, -5442, -5426, -5391, -5348, + -5276, -5197, -5124, -5039, -4925, -4808, -4677, -4581, -4479, -4343, -4218, + -4087, -3970, -3858, -3729, -3570, -3384, -3206, -3020, -2839, -2636, -2453, + -2287, -2185, -2154, -1926, -1562, -1223, -758, -473, -64, 395, 599, + 880, 814, 938, 1172, 1498, 1928, 2127, 2422, 2608, 2841, 2937, + 2886, 2815, 2985, 3324, 3757, 4152, 4481, 4652, 4917, 4965, 4766, + 4583, 4328, 4503, 4815, 5118, 5408, 5682, 5956, 6082, 6055, 5744, + 5426, 5341, 5427, 5606, 5882, 6065, 6226, 6428, 6477, 6385, 6009, + 5728, 5552, 5439, 5339, 5200, 5008, 4947, 4835, 4614, 4330, 3887, + 3521, 3111, 2460, 1983, 1297, 650, 279, -353, -720, -1044, -1518, + -1668, -2117, -2496, -2743, -3266, -3607, -3790, -4149, -4075, -4042, -4096, + -3981, -4138, -4226, -4214, -4503, -4455, -4577, -4642, -4346, -4351, -4270, + -4263, -4522, -4521, -4673, -4814, -4731, -4950, -5011, -5004, -5288, -5341, + -5566, -5833, -5783, -5929, -5847, -5765, -5828, -5644, -5613, -5615, -5428, + -5291, -5014, -4554, -4277, -3964, -3854, -3829, -3612, -3603, -3438, -3137, + -2831, -2164, -1438, -939, -330, -156, 46, 242, 73, 242, 220, + 239, 542, 565, 739, 872, 801, 857, 676, 543, 586, 567, + 828, 1142, 1490, 1985, 2508, 2982, 3438, 3699, 3939, 4069, 4178, + 4420, 4622, 4917, 5338, 5801, 6285, 6658, 6963, 7213, 7233, 7328, + 7176, 7038, 7031, 6860, 6957, 6767, 6599, 6523, 6212, 6147, 6063, + 5860, 6020, 6015, 6033, 6184, 5722, 5607, 5016, 4337, 4063, 3229, + 3080, 3006, 2804, 3035, 2541, 2136, 1879, 1012, 401, -575, -1584, + -1930, -2278, -2485, -2477, -2712, -2747, -2766, -3320, -3592, -4188, -4669, + -4672, -4939, -4789, -4426, -4203, -3674, -3563, -3656, -3759, -4067, -4257, + -4522, -4970, -5204, -5237, -5139, -4907, -4911, -4917, -4921, -5007, -5230, + -5654, -6122, -6464, -6733, -6948, -7067, -6972, -6800, -6520, -6132, -5830, + -5382, -5091, -4797, -4546, -4472, -4362, -4350, -4235, -3851, -3454, -3144, + -2735, -2341, -1845, -1262, -958, -549, -166, 66, 382, 366, 352, + 341, 85, -13, -176, -303, -235, -341, -309, -227, -249, -50, + 143, 384, 874, 1149, 1552, 2155, 2767, 3499, 3994, 4460, 4920, + 5288, 5569, 5704, 5881, 6094, 6461, 6653, 6803, 7115, 7311, 7521, + 7612, 7443, 7380, 7124, 6742, 6495, 5964, 5656, 5415, 5167, 5656, + 5813, 6027, 6401, 6351, 6787, 7019, 6581, 6512, 5965, 5308, 5140, + 4336, 4147, 3899, 3398, 3360, 2830, 2624, 1968, 1026, 395, -699, + -1424, -2327, -3006, -3192, -3435, -3337, -3686, -3513, -3350, -3502, -3261, + -3878, -4005, -4063, -4187, -3767, -3598, -3384, -3300, -3094, -2857, -3023, + -3274, -3851, -4352, -4523, -4943, -5477, -5612, -5682, -5733, -5714, -5965, + -6110, -5950, -6158, -6548, -6897, -7165, -7281, -7352, -7258, -7185, -6659, + -5946, -5470, -4738, -4046, -3707, -3210, -3108, -3270, -3227, -3222, -3218, + -3017, -2943, -2668, -2296, -1593, -1061, -811, -403, -513, -361, -128, + -595, -633, -991, -1205, -1159, -1284, -1330, -1164, -999, -729, -538, + -336, 27, 350, 794, 1245, 1646, 2446, 3210, 4017, 4835, 5271, + 5739, 6028, 6140, 6212, 6161, 6066, 5984, 6081, 5995, 6152, 6301, + 6278, 6424, 6377, 6396, 6362, 6152, 5788, 5309, 5071, 4860, 4704, + 4804, 4919, 5258, 5869, 6121, 6365, 6694, 6692, 6694, 6532, 6187, + 5808, 5704, 5302, 4816, 4611, 4043, 3775, 3249, 2600, 1933, 982, + 336, -848, -1538, -2242, -3103, -3374, -3756, -3975, -4017, -4061, -3972, + -3749, -3609, -3853, -3850, -3714, -3760, -3736, -3914, -3923, -3830, -3541, + -3649, -3757, -3661, -3913, -4038, -4231, -4594, -4769, -5009, -5273, -5588, + -5676, -5937, -5997, -6060, -6164, -6414, -6623, -6765, -6857, -6771, -6921, + -6914, -6535, -6187, -5626, -5206, -4742, -4189, -3618, -3120, -2823, -2606, + -2550, -2703, -2736, -2626, -2498, -2406, -2133, -1852, -1348, -753, -318, + 162, 330, 524, 375, 9, -204, -866, -1249, -1532, -1669, -1455, + -1235, -723, -283, 262, 535, 862, 1340, 1712, 2316, 2625, 3171, + 4015, 4698, 5516, 6006, 6452, 6838, 6921, 7003, 6735, 6339, 6138, + 5768, 5575, 5593, 5568, 5728, 6041, 6233, 6260, 6175, 6048, 5728, + 5366, 4931, 4340, 4194, 4174, 4330, 4743, 5028, 5754, 6250, 6598, + 7120, 7114, 6962, 6675, 6157, 5373, 4797, 4081, 3237, 3153, 2588, + 2143, 1639, 1021, 681, -149, -816, -1987, -3003, -3493, -4138, -4420, + -4607, -4841, -4725, -4254, -4033, -3845, -3842, -4063, -4035, -4099, -4582, + -4718, -4779, -4689, -4437, -4327, -4352, -4119, -3881, -4061, -4345, -4768, + -5248, -5610, -5920, -6383, -6779, -6731, -6673, -6677, -6597, -6659, -6619, + -6417, -6516, -6862, -7017, -7069, -6944, -6715, -6376, -6000, -5162, -4333, + -3577, -2884, -2355, -1807, -1366, -1380, -1590, -1869, -1962, -1945, -2006, + -2141, -1960, -1516, -1025, -471, -135, 85, 348, 239, -8, -475, + -951, -1245, -1520, -1569, -1448, -1188, -517, 134, 827, 1585, 2114, + 2792, 3214, 3651, 4230, 4546, 4894, 5321, 5588, 6105, 6583, 6877, + 7014, 7087, 7068, 6876, 6695, 6280, 5684, 5385, 5205, 5064, 5033, + 5028, 5080, 5322, 5510, 5461, 5390, 5541, 5494, 5443, 5306, 5065, + 5193, 5338, 5513, 5818, 5911, 6345, 6506, 6514, 6543, 5981, 5703, + 5082, 4228, 3517, 2424, 1880, 1245, 562, -130, -864, -1156, -1561, + -1970, -2597, -3357, -3707, -4189, -4521, -4975, -5477, -5478, -5585, -5445, + -5353, -5327, -4971, -4580, -4431, -4469, -4432, -4422, -4275, -4227, -4507, + -4745, -4758, -4752, -4845, -4933, -5118, -5117, -5124, -5324, -5673, -5971, + -6152, -6366, -6702, -6970, -7159, -7136, -6929, -6917, -6703, -6520, -6302, + -5794, -5484, -5123, -4694, -4254, -3722, -3334, -2917, -2410, -1721, -1010, + -584, -312, 27, 321, 327, 214, -17, -363, -402, -550, -638, + -469, -315, -86, 142, 242, 387, 448, 458, 423, 321, 194, + 285, 417, 717, 1176, 1673, 2402, 3144, 3985, 4764, 5406, 6056, + 6507, 6783, 6891, 6868, 6850, 6717, 6532, 6359, 6248, 6303, 6279, + 6140, 6071, 5927, 5687, 5480, 5146, 4835, 4572, 4447, 4481, 4578, + 4840, 4936, 5246, 5659, 5732, 5856, 5658, 5403, 5282, 5004, 4949, + 4843, 4681, 4884, 4886, 4967, 5108, 4781, 4647, 4240, 3443, 2768, + 1830, 983, 309, -769, -1382, -1987, -2553, -2750, -3346, -3555, -4052, + -4400, -4599, -5196, -5437, -5945, -6340, -6343, -6554, -6611, -6381, -6184, + -5681, -5398, -5098, -4751, -4529, -4138, -4100, -4088, -4044, -4186, -4189, + -4263, -4453, -4465, -4598, -4651, -4726, -4919, -4926, -5142, -5286, -5490, + -5831, -6002, -6341, -6492, -6562, -6710, -6553, -6506, -6219, -5766, -5521, + -5008, -4556, -4002, -3293, -2769, -2069, -1467, -824, -34, 509, 1034, + 1385, 1560, 1650, 1664, 1419, 1016, 834, 511, 353, 381, 299, + 523, 833, 956, 1280, 1492, 1425, 1547, 1350, 1143, 1114, 931, + 1054, 1217, 1583, 2217, 2917, 4017, 4965, 5827, 6816, 7393, 7875, + 8197, 8175, 7924, 7578, 7040, 6566, 6242, 5746, 5530, 5334, 5222, + 5237, 5074, 5146, 5011, 4902, 4753, 4442, 4482, 4254, 4247, 4319, + 4187, 4516, 4690, 4935, 5193, 5229, 5350, 5332, 5486, 5386, 5143, + 4999, 4494, 4304, 3961, 3421, 2781, 2032, 1404, 614, -88, -956, + -1714, -2155, -2684, -3038, -3237, -3368, -3423, -3569, -3809, -4213, -4533, + -4973, -5514, -6011, -6663, -7084, -7258, -7158, -6947, -6639, -6111, -5548, + -4887, -4362, -4043, -3895, -3940, -4107, -4452, -4836, -5143, -5500, -5532, + -5510, -5485, -5096, -4739, -4375, -4065, -4063, -4094, -4252, -4576, -4904, + -5431, -5837, -6190, -6402, -6310, -6292, -5992, -5516, -5025, -4342, -3899, + -3386, -2697, -2077, -1493, -994, -392, 232, 931, 1608, 1988, 2360, + 2589, 2639, 2623, 2471, 2121, 1708, 1478, 1181, 1167, 1296, 1279, + 1648, 1859, 2107, 2368, 2359, 2390, 2122, 1904, 1629, 1418, 1502, + 1524, 1859, 2357, 3041, 3909, 4810, 5751, 6449, 7128, 7534, 7767, + 7908, 7699, 7460, 7032, 6647, 6301, 5876, 5556, 5190, 4948, 4762, + 4576, 4464, 4370, 4338, 4275, 4287, 4265, 4320, 4221, 4066, 3947, + 3514, 3379, 3003, 2635, 2534, 2078, 2040, 1950, 1958, 2152, 2085, + 2390, 2321, 2319, 2359, 1851, 1643, 877, 168, -527, -1245, -1704, + -2519, -2739, -3251, -3382, -3236, -3527, -3294, -3523, -3732, -3916, -4434, + -4888, -5615, -6161, -6729, -7283, -7543, -7920, -7865, -7660, -7430, -7034, + -6758, -6224, -5866, -5441, -5076, -4998, -4760, -4673, -4539, -4410, -4308, + -4131, -3992, -3791, -3611, -3448, -3213, -3070, -3046, -3048, -3168, -3244, + -3354, -3607, -3834, -4170, -4439, -4648, -4864, -4892, -4928, -4821, -4524, + -4211, -3576, -2819, -1968, -929, -19, 1029, 2064, 2949, 3716, 4159, + 4450, 4536, 4503, 4301, 3968, 3655, 3242, 2979, 2856, 2744, 2750, + 2771, 2749, 2859, 2850, 2793, 2702, 2402, 2179, 1877, 1672, 1581, + 1543, 1769, 1967, 2485, 3089, 3783, 4662, 5406, 6246, 6950, 7542, + 8016, 8200, 8245, 8027, 7584, 6958, 6241, 5494, 4710, 3974, 3255, + 2653, 2274, 2038, 1986, 1964, 2141, 2321, 2513, 2772, 2756, 2743, + 2636, 2406, 2125, 1836, 1456, 1247, 1145, 995, 1077, 1140, 1290, + 1561, 1685, 1762, 1609, 1391, 1147, 544, 84, -754, -1546, -2107, + -2806, -3137, -3522, -3732, -3826, -3834, -3609, -3493, -3340, -3254, -3499, + -3621, -3981, -4455, -4859, -5513, -6080, -6626, -7061, -7372, -7556, -7573, + -7515, -7366, -7091, -6799, -6366, -5887, -5484, -5098, -4746, -4334, -3941, + -3558, -3269, -3053, -2844, -2663, -2497, -2314, -2227, -2185, -2141, -2139, + -2070, -2037, -2031, -2062, -2205, -2348, -2544, -2774, -2979, -3298, -3520, + -3647, -3622, -3395, -3054, -2513, -1829, -948, 64, 1090, 2169, 3127, + 3987, 4712, 5229, 5560, 5754, 5741, 5619, 5401, 5005, 4666, 4287, + 3967, 3734, 3476, 3322, 3203, 3147, 3144, 3116, 3080, 3011, 2871, + 2735, 2544, 2363, 2245, 2075, 2032, 2118, 2263, 2688, 3066, 3605, + 4244, 4746, 5384, 5819, 6151, 6319, 6194, 5938, 5495, 4929, 4305, + 3581, 2924, 2279, 1713, 1372, 1086, 1006, 983, 1006, 1146, 1249, + 1349, 1360, 1231, 1084, 794, 502, 264, -85, -238, -411, -504, + -394, -322, -51, 188, 420, 589, 624, 666, 573, 338, -86, + -564, -1056, -1560, -1925, -2434, -2806, -3017, -3341, -3320, -3375, -3480, + -3410, -3567, -3553, -3595, -3805, -3919, -4284, -4482, -4754, -5190, -5354, + -5806, -6050, -6136, -6387, -6343, -6330, -6206, -5851, -5468, -4960, -4549, + -4080, -3542, -3150, -2698, -2440, -2318, -2132, -2067, -2081, -2017, -2099, + -2151, -2060, -2067, -1916, -1823, -1718, -1523, -1386, -1221, -1189, -1141, + -1014, -1008, -966, -996, -1015, -916, -809, -648, -467, -128, 237, + 735, 1358, 1969, 2697, 3399, 4060, 4732, 5295, 5720, 6077, 6169, + 6139, 5928, 5614, 5292, 4766, 4247, 3705, 3262, 3030, 2827, 2702, + 2684, 2728, 2887, 3092, 3216, 3310, 3313, 3214, 3098, 2873, 2620, + 2343, 2031, 1799, 1589, 1491, 1537, 1645, 1913, 2210, 2548, 2922, + 3295, 3650, 3951, 4100, 4099, 3972, 3740, 3421, 2948, 2427, 1762, + 1136, 574, 44, -330, -642, -846, -852, -751, -520, -229, 44, + 272, 446, 502, 443, 329, 66, -191, -492, -841, -1002, -1240, + -1237, -1199, -1177, -936, -867, -660, -456, -508, -464, -706, -997, + -1265, -1780, -2178, -2724, -3270, -3735, -4142, -4378, -4609, -4666, -4749, + -4575, -4355, -4137, -3767, -3563, -3218, -2970, -2834, -2630, -2716, -2776, + -2920, -3210, -3363, -3764, -4023, -4125, -4268, -4194, -4223, -4005, -3639, + -3258, -2891, -2644, -2297, -1987, -1751, -1587, -1570, -1485, -1415, -1342, + -1194, -1100, -889, -613, -267, 161, 482, 865, 1269, 1639, 2005, + 2202, 2381, 2549, 2628, 2700, 2625, 2559, 2481, 2357, 2319, 2192, + 2142, 2199, 2283, 2514, 2670, 2919, 3214, 3510, 3830, 3971, 4080, + 4073, 3911, 3700, 3359, 2954, 2549, 2094, 1766, 1556, 1442, 1462, + 1560, 1808, 2070, 2357, 2606, 2730, 2831, 2737, 2582, 2309, 1931, + 1585, 1178, 834, 529, 288, 214, 218, 302, 470, 679, 944, + 1211, 1420, 1562, 1674, 1631, 1548, 1355, 1072, 776, 375, 25, + -320, -614, -818, -992, -991, -906, -755, -525, -291, -17, 225, + 447, 528, 546, 466, 270, 96, -205, -536, -861, -1148, -1383, + -1586, -1688, -1814, -1783, -1772, -1745, -1630, -1611, -1505, -1488, -1462, + -1409, -1519, -1489, -1609, -1723, -1755, -1977, -2042, -2132, -2215, -2184, + -2268, -2205, -2170, -2107, -1978, -1990, -1909, -1886, -1943, -1997, -2152, + -2326, -2500, -2762, -2987, -3227, -3392, -3522, -3630, -3579, -3469, -3262, + -2916, -2555, -2103, -1581, -1090, -531, -20, 457, 873, 1228, 1561, + 1809, 1999, 2105, 2139, 2196, 2201, 2149, 2113, 2038, 1990, 1913, + 1787, 1705, 1595, 1490, 1372, 1201, 1113, 998, 917, 917, 894, + 961, 1007, 1098, 1321, 1470, 1681, 1882, 2067, 2317, 2465, 2626, + 2750, 2777, 2783, 2694, 2569, 2431, 2142, 1843, 1597, 1306, 1069, + 824, 622, 532, 430, 388, 357, 377, 438, 414, 481, 468, + 431, 454, 383, 374, 305, 207, 187, 133, 157, 115, 113, + 206, 244, 382, 475, 591, 753, 821, 916, 908, 855, 754, + 577, 399, 123, -159, -399, -647, -784, -923, -1010, -965, -918, + -806, -647, -504, -355, -253, -179, -130, -138, -156, -262, -339, + -401, -552, -600, -671, -697, -662, -673, -616, -597, -522, -495, + -513, -490, -624, -701, -804, -961, -1073, -1328, -1503, -1656, -1798, + -1801, -1913, -1863, -1785, -1720, -1453, -1309, -1051, -846, -715, -487, + -457, -357, -331, -400, -427, -627, -765, -873, -1021, -1105, -1255, + -1312, -1357, -1370, -1288, -1261, -1165, -1139, -1062, -917, -808, -680, + -597, -452, -277, -104, 122, 312, 558, 771, 919, 1110, 1205, + 1312, 1355, 1302, 1280, 1151, 1049, 946, 818, 733, 569, 451, + 429, 388, 408, 387, 376, 426, 463, 542, 576, 632, 666, + 673, 740, 766, 791, 845, 829, 857, 841, 822, 835, 796, + 773, 671, 600, 560, 484, 460, 371, 311, 284, 242, 277, + 261, 261, 277, 273, 358, 380, 410, 433, 435, 471, 432, + 414, 386, 330, 294, 194, 149, 108, 69, 84, 69, 92, + 83, 75, 88, 53, 12, -96, -194, -269, -369, -438, -523, + -553, -528, -500, -392, -277, -136, 53, 240, 466, 678, 870, + 1050, 1178, 1294, 1336, 1310, 1247, 1080, 916, 677, 387, 120, + -182, -471, -740, -972, -1148, -1273, -1343, -1402, -1363, -1263, -1129, + -922, -724, -518, -288, -79, 111, 250, 364, 405, 405, 395, + 284, 199, 83, -43, -126, -244, -313, -400, -451, -497, -610, + -672, -807, -951, -1087, -1325, -1517, -1736, -1929, -2086, -2260, -2318, + -2356, -2271, -2125, -1967, -1685, -1379, -1000, -598, -238, 149, 481, + 790, 1042, 1185, 1287, 1274, 1195, 1068, 868, 654, 386, 138, + -65, -273, -450, -598, -665, -670, -669, -620, -553, -425, -288, + -179, -72, 15, 122, 205, 263, 324, 357, 435, 518, 603, + 709, 779, 892, 1006, 1107, 1170, 1183, 1190, 1173, 1116, 1016, + 890, 750, 628, 488, 331, 197, 95, 43, 25, 1, 22, + 97, 209, 363, 495, 615, 724, 833, 937, 984, 990, 933, + 884, 851, 747, 678, 573, 497, 469, 401, 391, 352, 339, + 352, 337, 354, 361, 370, 402, 411, 418, 440, 468, 526, + 576, 619, 683, 766, 857, 965, 1038, 1114, 1159, 1172, 1167, + 1106, 1006, 840, 644, 426, 177, -110, -390, -665, -929, -1160, + -1375, -1497, -1550, -1592, -1553, -1507, -1394, -1201, -1084, -863, -685, + -540, -322, -234, -68, 29, 59, 160, 141, 170, 140, 79, + 77, -11, -53, -179, -274, -327, -480, -564, -736, -884, -995, + -1185, -1300, -1461, -1617, -1711, -1832, -1831, -1863, -1865, -1776, -1691, + -1516, -1353, -1168, -954, -729, -490, -305, -93, 81, 211, 322, + 364, 392, 384, 332, 264, 146, 29, -101, -230, -357, -486, + -616, -705, -752, -801, -809, -788, -750, -654, -546, -456, -328, + -200, -78, 45, 137, 232, 316, 388, 447, 485, 528, 578, + 630, 697, 760, 835, 910, 988, 1068, 1124, 1154, 1157, 1166, + 1163, 1116, 1070, 1024, 994, 986, 988, 1030, 1110, 1212, 1303, + 1411, 1498, 1551, 1599, 1587, 1565, 1481, 1336, 1212, 1028, 847, + 669, 466, 330, 187, 61, -9, -54, -55, -20, 11, 69, + 133, 195, 244, 253, 225, 182, 133, 62, -11, -96, -168, + -199, -214, -213, -197, -167, -127, -105, -86, -83, -109, -140, + -217, -323, -448, -588, -717, -854, -971, -1086, -1185, -1211, -1227, + -1180, -1135, -1099, -992, -918, -788, -704, -651, -562, -542, -470, + -421, -431, -391, -429, -386, -344, -336, -260, -257, -162, -61, + -6, 100, 120, 178, 215, 179, 132, 15, -106, -238, -416, + -595, -765, -929, -1066, -1170, -1252, -1278, -1290, -1258, -1173, -1114, + -1012, -945, -868, -741, -695, -612, -547, -494, -388, -332, -225, + -110, 22, 182, 318, 496, 677, 835, 992, 1104, 1162, 1166, + 1133, 1054, 916, 709, 430, 164, -90, -340, -600, -853, -1033, + -1135, -1177, -1146, -1079, -946, -746, -500, -208, 83, 377, 673, + 950, 1183, 1356, 1503, 1627, 1707, 1735, 1708, 1678, 1668, 1645, + 1588, 1494, 1419, 1354, 1291, 1194, 1052, 900, 718, 524, 325, + 110, -114, -330, -500, -630, -729, -803, -834, -795, -727, -627, + -492, -325, -125, 54, 238, 393, 528, 642, 691, 706, 661, + 585, 504, 380, 245, 87, -61, -195, -320, -435, -556, -663, + -742, -814, -883, -952, -1009, -1038, -1047, -1067, -1063, -1050, -1020, + -949, -888, -795, -698, -574, -405, -257, -70, 68, 203, 381, + 479, 580, 619, 623, 645, 565, 492, 364, 206, 106, -71, + -191, -331, -460, -469, -527, -471, -441, -386, -222, -123, 60, + 168, 245, 404, 470, 596, 605, 581, 633, 548, 562, 468, + 355, 334, 192, 161, 62, -36, -39, -146, -121, -167, -243, + -229, -302, -276, -327, -415, -419, -444, -396, -433, -455, -407, + -357, -244, -221, -158, -63, 36, 172, 210, 296, 326, 351, + 424, 367, 369, 300, 224, 235, 124, 54, -39, -122, -118, + -239, -304, -360, -403, -361, -418, -427, -394, -342, -259, -232, + -176, -110, -48, 27, 48, 78, 90, 86, 91, 76, 57, + -1, -34, -53, -103, -151, -209, -239, -261, -319, -354, -372, + -382, -385, -411, -432, -428, -431, -446, -471, -496, -512, -532, + -562, -570, -567, -543, -499, -457, -379, -290, -204, -94, -11, + 78, 155, 196, 234, 222, 198, 160, 113, 64, 5, -57, + -108, -136, -175, -186, -196, -184, -125, -90, -25, 58, 146, + 271, 372, 472, 562, 636, 709, 741, 760, 752, 730, 710, + 688, 655, 608, 595, 570, 556, 540, 517, 513, 511, 497, + 481, 449, 417, 401, 347, 325, 295, 248, 261, 238, 250, + 294, 295, 367, 380, 416, 454, 430, 479, 443, 431, 430, + 386, 397, 333, 292, 238, 176, 153, 54, 24, -37, -84, + -109, -172, -155, -199, -220, -219, -261, -227, -255, -280, -266, + -293, -277, -273, -243, -214, -221, -179, -153, -130, -109, -154, + -149, -151, -155, -186, -243, -253, -311, -326, -358, -434, -427, + -491, -533, -554, -598, -596, -655, -668, -679, -714, -671, -694, + -643, -607, -602, -532, -496, -409, -408, -377, -309, -289, -211, + -223, -196, -145, -147, -104, -157, -123, -125, -177, -152, -229, + -192, -204, -243, -213, -259, -194, -190, -172, -98, -123, -43, + -12, 41, 103, 87, 148, 150, 166, 154, 113, 118, 80, + 54, 8, 4, 25, 12, 59, 70, 162, 260, 305, 387, + 427, 501, 549, 564, 571, 517, 488, 423, 355, 294, 206, + 165, 113, 92, 77, 62, 115, 116, 154, 162, 171, 218, + 210, 221, 208, 192, 215, 176, 169, 114, 89, 89, 52, + 62, 29, 35, 73, 98, 167, 195, 261, 325, 349, 401, + 382, 393, 368, 302, 254, 174, 104, 6, -78, -136, -203, + -229, -291, -303, -284, -294, -241, -235, -222, -186, -187, -156, + -160, -149, -122, -114, -71, -44, -28, 6, 20, 47, 57, + 54, 52, 55, 53, 23, 9, -16, -59, -86, -158, -223, + -292, -372, -421, -498, -532, -561, -570, -531, -512, -456, -367, + -297, -206, -125, -37, 26, 88, 147, 157, 188, 169, 152, + 152, 131, 99, 62, 44, 46, 53, 61, 61, 79, 110, + 159, 175, 185, 237, 220, 278, 276, 239, 264, 203, 190, + 138, 70, 34, -9, 18, 1, 10, 71, 115, 191, 220, + 255, 265, 296, 319, 270, 266, 214, 189, 187, 155, 145, + 123, 149, 166, 172, 186, 179, 195, 213, 201, 182, 161, + 150, 116, 76, 41, -29, -58, -101, -183, -209, -269, -314, + -342, -385, -379, -380, -348, -304, -273, -197, -144, -88, -28, + -5, 11, 20, 27, -5, -24, -22, -61, -73, -87, -124, + -118, -133, -150, -160, -198, -196, -219, -228, -239, -281, -276, + -275, -288, -277, -305, -324, -302, -294, -292, -266, -261, -224, + -203, -210, -190, -198, -176, -180, -201, -196, -198, -175, -166, + -151, -127, -114, -59, -48, -8, 39, 75, 126, 131, 168, + 160, 152, 142, 82, 36, -13, -49, -81, -105, -105, -103, + -65, -38, -16, 19, 33, 67, 82, 95, 110, 98, 111, + 98, 87, 67, 54, 66, 52, 49, 53, 71, 106, 139, + 186, 224, 270, 320, 361, 413, 433, 462, 473, 478, 480, + 459, 441, 391, 339, 298, 239, 206, 159, 149, 120, 114, + 117, 95, 106, 81, 67, 61, 30, 11, -29, -42, -76, + -97, -98, -124, -107, -107, -103, -69, -71, -36, -12, 23, + 69, 86, 129, 152, 158, 162, 152, 127, 81, 48, -9, + -80, -120, -172, -201, -225, -276, -297, -311, -330, -339, -361, + -375, -389, -376, -365, -374, -378, -375, -370, -358, -347, -355, + -338, -314, -289, -244, -212, -168, -129, -80, -26, -12, 47, + 79, 92, 105, 105, 113, 99, 85, 29, -18, -53, -110, + -133, -167, -186, -196, -199, -176, -177, -150, -122, -106, -73, + -61, -30, -34, -29, -40, -68, -63, -85, -84, -71, -65, + -40, -16, 23, 56, 87, 144, 167, 196, 206, 221, 243, + 226, 233, 210, 192, 190, 150, 140, 110, 91, 77, 43, + 27, -10, -5, -5, -22, -9, -7, 27, 48, 59, 64, + 70, 87, 104, 139, 151, 188, 239, 270, 317, 311, 336, + 349, 341, 330, 274, 254, 223, 195, 163, 102, 81, 43, + 20, 8, -37, -28, -31, -29, -21, -39, -16, -22, -11, + -21, -41, -32, -47, -39, -60, -75, -71, -94, -98, -131, + -147, -139, -145, -146, -165, -150, -136, -112, -90, -106, -86, + -91, -87, -98, -136, -121, -135, -124, -132, -144, -114, -108, + -87, -74, -75, -50, -30, -5, -18, -24, -3, -3, -6, + -41, -76, -98, -127, -159, -215, -257, -263, -268, -266, -262, + -237, -194, -144, -113, -99, -61, -28, 12, 21, 46, 76, + 92, 130, 115, 123, 132, 135, 149, 134, 133, 132, 135, + 138, 94, 76, 51, 19, -15, -72, -98, -125, -135, -154, + -174, -171, -164, -139, -130, -99, -74, -40, 9, 34, 86, + 129, 176, 214, 226, 245, 250, 280, 271, 256, 250, 226, + 234, 212, 187, 178, 148, 144, 104, 79, 64, 37, 36, + 9, -10, -23, -38, -35, -62, -67, -67, -82, -70, -80, + -75, -59, -34, -3, 9, 48, 76, 101, 120, 120, 123, + 126, 131, 112, 92, 77, 61, 54, 32, 3, -18, -28, + -39, -56, -71, -91, -92, -100, -124, -134, -142, -144, -155, + -177, -178, -175, -171, -168, -160, -141, -123, -89, -73, -64, + -46, -39, -18, -19, -34, -32, -46, -51, -63, -74, -73, + -81, -70, -83, -71, -49, -39, -12, -1, 30, 48, 65, + 94, 100, 125, 136, 148, 156, 138, 140, 124, 115, 86, + 58, 57, 32, 43, 40, 44, 63, 60, 83, 90, 99, + 115, 113, 135, 140, 148, 164, 172, 187, 182, 190, 183, + 171, 171, 146, 139, 121, 105, 94, 61, 46, 17, -6, + -34, -70, -89, -121, -138, -158, -178, -190, -206, -206, -210, + -214, -204, -196, -173, -154, -128, -97, -81, -58, -51, -46, + -38, -47, -49, -57, -58, -57, -59, -49, -58, -58, -54, + -60, -48, -65, -72, -72, -78, -70, -77, -73, -76, -79, + -76, -90, -90, -91, -88, -76, -67, -43, -16, 6, 27, + 39, 55, 69, 71, 74, 65, 56, 60, 47, 37, 27, + 8, -5, -29, -50, -71, -89, -96, -114, -111, -113, -115, + -105, -112, -90, -78, -68, -49, -46, -26, -14, 5, 18, + 10, 14, 3, 5, -9, -20, -15, -30, -26, -33, -31, + -23, -23, -12, -21, -20, -16, -23, -20, -13, -7, 6, + 28, 47, 69, 96, 115, 134, 147, 154, 166, 174, 186, + 196, 202, 204, 198, 193, 181, 164, 144, 125, 113, 102, + 96, 90, 92, 91, 96, 99, 99, 100, 99, 99, 93, + 94, 86, 68, 55, 44, 36, 22, 13, 15, 13, 15, + 21, 16, 11, 3, -15, -31, -50, -75, -105, -125, -145, + -154, -155, -164, -178, -189, -186, -177, -174, -169, -152, -134, + -114, -93, -65, -42, -23, -4, -1, 6, 6, 2, -4, + -18, -26, -25, -25, -23, -32, -31, -33, -39, -50, -68, + -69, -74, -79, -78, -83, -85, -85, -77, -71, -61, -42, + -27, -3, 28, 59, 95, 123, 146, 155, 160, 162, 144, + 130, 112, 94, 82, 67, 60, 46, 35, 35, 22, 4, + -14, -27, -35, -45, -52, -61, -62, -65, -68, -55, -52, + -43, -38, -34, -20, -8, 8, 18, 24, 34, 36, 37, + 42, 46, 51, 50, 58, 76, 75, 70, 67, 58, 53, + 48, 36, 23, 18, 10, 3, 9, 14, 24, 39, 43, + 53, 62, 63, 66, 62, 66, 64, 59, 51, 25, 19, + 6, -10, -19, -26, -35, -43, -44, -37, -47, -43, -50, + -54, -60, -69, -75, -84, -91, -93, -98, -96, -99, -91, + -87, -91, -88, -84, -80, -75, -61, -48, -44, -40, -37, + -34, -45, -52, -58, -72, -82, -84, -78, -68, -65, -63, + -51, -42, -27, -22, -13, -3, 8, 20, 26, 31, 31, + 37, 33, 29, 33, 31, 32, 31, 34, 44, 55, 68, + 74, 69, 75, 73, 72, 65, 63, 67, 70, 83, 81, + 81, 85, 84, 80, 75, 69, 53, 44, 36, 27, 20, + 11, 1, -4, -19, -26, -27, -25, -21, -14, -12, -12, + -14, -9, -21, -29, -40, -50, -50, -54, -46, -35, -17, + -4, -1, 7, 20, 28, 26, 22, 23, 21, 23, 18, + 13, 12, 7, 6, 3, 2, -1, -1, 4, 6, 17, + 29, 35, 34, 34, 32, 28, 33, 26, 22, 16, 16, + 22, 20, 13, -1, -1, -7, -15, -20, -30, -32, -38, + -39, -45, -45, -53, -63, -70, -83, -96, -107, -113, -122, + -122, -118, -114, -114, -113, -112, -111, -110, -107, -103, -102, + -94, -80, -71, -58, -52, -47, -40, -43, -47, -48, -50, + -39, -46, -44, -44, -44, -43, -45, -41, -40, -34, -32, + -23, -12, -6, -1, -1, 6, 12, 18, 20, 22, 32, + 48, 65, 80, 93, 109, 122, 128, 131, 135, 135, 129, + 126, 130, 127, 124, 125, 121, 122, 115, 118, 122, 128, + 137, 143, 143, 141, 142, 134, 131, 121, 109, 105, 97, + 93, 99, 96, 96, 94, 83, 84, 80, 77, 66, 59, + 46, 42, 44, 32, 28, 20, 12, 8, 4, 4, 5, + 3, -4, -7, -6, -14, -19, -24, -34, -40, -45, -52, + -61, -62, -60, -57, -57, -61, -63, -61, -65, -73, -81, + -89, -94, -93, -89, -87, -82, -82, -84, -81, -86, -82, + -84, -86, -90, -86, -83, -82, -81, -80, -80, -76, -75, + -76, -70, -69, -68, -61, -53, -50, -43, -38, -42, -43, + -41, -41, -39, -34, -27, -21, -16, -20, -22, -27, -36, + -39, -38, -40, -37, -35, -28, -14, -6, -3, -2, 2, + 4, 5, 15, 18, 25, 35, 36, 41, 45, 48, 52, + 54, 52, 50, 60, 67, 76, 85, 85, 90, 86, 83, + 84, 77, 77, 72, 77, 81, 89, 91, 93, 99, 101, + 102, 98, 94, 87, 77, 70, 69, 63, 62, 55, 59, + 58, 54, 51, 53, 57, 62, 65, 60, 54, 48, 45, + 40, 29, 17, 8, -3, -14, -17, -18, -20, -25, -34, + -40, -44, -53, -56, -63, -71, -71, -69, -66, -62, -66, + -67, -68, -71, -75, -79, -79, -73, -67, -60, -49, -46, + -45, -45, -46, -55, -64, -67, -72, -74, -70, -68, -67, + -69, -70, -64, -56, -55, -54, -51, -41, -30, -26, -28, + -29, -30, -28, -25, -27, -20, -12, -5, -2, 2, 3, + -3, 0, -7, -8, -14, -15, -9, -7, 4, 12, 24, + 36, 41, 52, 58, 59, 51, 45, 48, 44, 46, 43, + 40, 42, 47, 53, 52, 52, 63, 69, 74, 75, 80, + 78, 69, 68, 59, 60, 54, 54, 54, 58, 66, 71, + 78, 78, 75, 78, 72, 71, 61, 55, 53, 42, 36, + 31, 28, 29, 23, 19, 25, 27, 27, 23, 29, 29, + 20, 11, 5, -4, -10, -31, -38, -39, -36, -33, -27, + -17, -15, -14, -17, -13, -14, -25, -33, -44, -51, -61, + -63, -63, -65, -67, -66, -63, -59, -52, -48, -45, -44, + -50, -62, -74, -84, -89, -100, -101, -102, -96, -95, -85, + -76, -78, -72, -71, -66, -61, -63, -60, -62, -72, -69, + -69, -58, -56, -50, -37, -28, -17, -17, -16, -17, -18, + -18, -13, -7, -4, 6, 17, 23, 25, 28, 24, 21, + 17, 21, 27, 30, 33, 35, 46, 49, 48, 54, 56, + 57, 58, 60, 64, 62, 64, 66, 67, 64, 70, 77, + 83, 82, 84, 88, 89, 95, 86, 75, 64, 51, 36, + 29, 26, 21, 26, 31, 38, 40, 55, 63, 65, 65, + 64, 60, 54, 54, 49, 41, 34, 26, 21, 9, 6, + 6, 5, -1, 3, 5, 3, 2, -4, -13, -13, -24, + -32, -33, -36, -33, -24, -18, -15, -9, -5, -5, -14, + -17, -24, -34, -36, -42, -43, -36, -42, -43, -43, -38, + -36, -27, -20, -23, -21, -28, -25, -22, -24, -25, -23, + -22, -30, -31, -26, -25, -20, -15, -8, -10, -11, -13, + -18, -22, -30, -36, -35, -39, -35, -34, -27, -24, -19, + -15, -7, -6, -7, -2, 0, 7, 12, 14, 19, 20, + 26, 26, 24, 16, 10, 4, 1, 3, 2, 9, 11, + 17, 19, 27, 31, 31, 32, 30, 27, 25, 28, 27, + 25, 22, 23, 23, 20, 21, 25, 36, 38, 40, 43, + 40, 32, 27, 20, 9, 4, 1, 12, 27, 37, 49, + 63, 73, 72, 73, 70, 67, 53, 39, 33, 26, 23, + 13, 9, 6, 0, -2, -3, 0, -1, 0, -1, -4, + -9, -16, -22, -21, -24, -21, -19, -12, -3, 0, 12, + 14, 13, 3, -6, -13, -27, -34, -42, -41, -44, -42, + -43, -46, -42, -40, -39, -36, -31, -29, -30, -22, -19, + -21, -20, -17, -17, -22, -31, -41, -45, -54, -65, -64, + -68, -70, -74, -70, -64, -62, -61, -60, -58, -52, -46, + -43, -37, -35, -40, -41, -47, -52, -58, -62, -61, -53, + -54, -46, -41, -40, -34, -29, -20, -15, -8, 2, 12, + 28, 35, 41, 42, 42, 43, 41, 43, 39, 45, 44, + 46, 55, 54, 55, 55, 51, 48, 42, 43, 39, 40, + 46, 54, 65, 70, 76, 81, 86, 89, 79, 73, 70, + 62, 56, 52, 39, 32, 28, 17, 18, 19, 18, 15, + 19, 20, 15, 13, 13, 10, 6, 5, 12, 10, 15, + 20, 24, 30, 31, 28, 22, 17, 2, -15, -24, -39, + -52, -53, -55, -46, -40, -34, -26, -21, -22, -31, -32, + -38, -36, -35, -32, -33, -34, -30, -28, -27, -35, -40, + -42, -45, -44, -45, -44, -52, -54, -57, -57, -53, -60, + -63, -63, -65, -51, -45, -40, -40, -39, -39, -43, -44, + -46, -52, -46, -51, -49, -45, -45, -47, -47, -45, -50, + -47, -40, -35, -32, -24, -17, -19, -14, -13, -9, -7, + -7, -7, -9, 0, 3, 7, 13, 12, 14, 15, 13, + 6, -1, -3, -9, -10, -5, -2, 6, 9, 11, 12, + 15, 19, 24, 37, 47, 47, 56, 53, 51, 52, 52, + 47, 39, 38, 40, 41, 43, 44, 42, 43, 42, 41, + 43, 40, 41, 35, 37, 39, 40, 41, 38, 30, 21, + 14, 5, 2, -1, -2, 1, -2, 6, 2, 4, 2, + -1, -11, -16, -23, -25, -20, -18, -25, -27, -32, -27, + -24, -16, -15, -11, -9, -3, -4, -2, -9, -10, -18, + -28, -33, -38, -37, -41, -41, -33, -24, -22, -25, -25, + -25, -24, -33, -38, -42, -52, -57, -55, -50, -51, -53, + -52, -48, -49, -49, -53, -55, -58, -51, -34, -19, -12, + -12, -5, 1, 1, 0, -6, -2, -10, -11, -11, -6, + 0, -6, 2, -2, -6, 2, 5, 16, 18, 18, 21, + 16, 18, 18, 20, 20, 13, 18, 9, 7, 12, 7, + 8, 10, 16, 17, 18, 23, 26, 36, 44, 51, 55, + 60, 64, 69, 68, 71, 70, 62, 58, 52, 44, 35, + 31, 34, 32, 33, 36, 37, 38, 41, 47, 55, 56, + 58, 60, 60, 57, 48, 41, 29, 19, 7, 4, 8, + 9, 10, 8, 13, 15, 13, 8, 8, 6, 4, 10, + 8, -4, -6, -9, -20, -28, -39, -38, -27, -24, -22, + -19, -23, -32, -35, -36, -41, -48, -51, -50, -52, -55, + -60, -67, -72, -76, -84, -82, -80, -81, -75, -64, -50, + -36, -28, -18, -14, -12, -15, -12, -18, -24, -21, -22, + -19, -21, -19, -22, -20, -18, -16, -17, -19, -15, -7, + 1, 0, 0, 9, 14, 20, 24, 20, 16, 17, 20, + 20, 25, 27, 26, 32, 33, 35, 38, 42, 38, 37, + 39, 46, 44, 43, 45, 45, 42, 37, 34, 25, 21, + 22, 33, 44, 49, 54, 53, 58, 54, 51, 46, 40, + 37, 37, 39, 34, 37, 39, 31, 39, 38, 36, 35, + 32, 33, 33, 32, 28, 23, 18, 22, 28, 31, 27, + 18, 3, 4, 0, -4, -7, -15, -18, -24, -32, -34, + -39, -42, -36, -31, -24, -12, -10, -10, -13, -20, -28, + -34, -44, -49, -50, -53, -56, -54, -52, -53, -47, -43, + -41, -45, -41, -38, -38, -33, -32, -34, -35, -33, -40, + -45, -53, -62, -61, -67, -72, -70, -67, -68, -59, -51, + -47, -38, -31, -20, -13, -13, -13, -14, -17, -21, -22, + -29, -31, -27, -23, -13, -6, 4, 12, 17, 25, 23, + 23, 25, 30, 30, 32, 31, 28, 27, 18, 14, 13, + 3, 5, 7, 19, 35, 47, 61, 70, 84, 90, 95, + 92, 94, 89, 77, 71, 66, 59, 50, 51, 50, 51, + 53, 56, 65, 67, 69, 75, 74, 69, 67, 56, 51, + 44, 34, 25, 17, 10, 6, 7, 7, 4, 6, -1, + -1, -2, -9, -9, -9, -7, -5, 1, -2, -5, -11, + -19, -27, -39, -38, -44, -45, -48, -48, -54, -59, -53, + -51, -49, -52, -50, -50, -47, -42, -32, -28, -28, -26, + -27, -34, -40, -40, -36, -37, -37, -34, -37, -36, -41, + -36, -40, -46, -48, -52, -47, -44, -40, -40, -38, -43, + -43, -47, -59, -62, -59, -59, -51, -41, -29, -19, -8, + -2, 1, 1, -4, -9, -19, -23, -29, -29, -25, -23, + -15, -7, -2, 6, 8, 15, 27, 35, 43, 40, 36, + 35, 32, 25, 22, 19, 17, 13, 13, 21, 25, 28, + 36, 44, 50, 57, 56, 58, 59, 62, 66, 70, 73, + 69, 66, 66, 66, 62, 53, 48, 44, 38, 39, 44, + 52, 51, 55, 57, 52, 49, 44, 36, 26, 16, 13, + 13, 14, 14, 17, 14, 10, 6, -5, -14, -23, -24, + -21, -28, -25, -27, -29, -29, -33, -33, -39, -42, -43, + -41, -40, -43, -46, -45, -43, -42, -41, -41, -46, -46, + -52, -52, -52, -59, -63, -70, -68, -73, -77, -73, -68, + -66, -62, -64, -66, -58, -54, -51, -52, -48, -47, -43, + -40, -39, -33, -26, -19, -17, -16, -17, -14, -9, -10, + -3, 5, 5, 9, 5, 9, 8, 4, 3, 0, -5, + -10, -3, 2, 8, 14, 16, 20, 27, 39, 40, 44, + 48, 43, 39, 34, 29, 22, 12, 8, 5, 0, -2, + -3, 5, 12, 16, 19, 22, 25, 28, 35, 28, 30, + 31, 30, 39, 43, 47, 43, 42, 41, 41, 41, 37, + 37, 39, 37, 38, 43, 44, 41, 43, 34, 28, 25, + 23, 30, 34, 32, 33, 29, 21, 18, 13, 14, 11, + 3, 2, 1, 3, 1, -1, 0, -3, -1, -3, -8, + -9, -7, -9, -2, 0, -3, 0, 1, 5, 0, -1, + -9, -13, -8, -11, -18, -23, -25, -29, -29, -26, -27, + -29, -25, -24, -23, -18, -19, -18, -17, -21, -22, -30, + -38, -42, -42, -42, -40, -41, -43, -39, -38, -37, -36, + -33, -31, -28, -27, -18, -15, -7, -8, -8, -1, 1, + 3, -5, 0, -4, -5, -4, -8, -10, -14, -21, -24, + -25, -20, -11, -4, 3, 6, 13, 15, 12, 17, 16, + 17, 17, 15, 21, 28, 33, 36, 35, 35, 29, 31, + 29, 28, 23, 21, 14, 15, 27, 36, 40, 40, 43, + 51, 56, 62, 69, 77, 80, 88, 88, 88, 82, 76, + 63, 52, 44, 36, 26, 23, 25, 24, 27, 26, 31, + 21, 13, 8, -8, -8, -11, -14, -18, -28, -28, -30, + -32, -29, -26, -26, -27, -24, -20, -14, -8, -6, -8, + -5, -10, -14, -18, -26, -34, -36, -38, -44, -51, -57, + -66, -64, -68, -72, -75, -75, -70, -68, -65, -64, -62, + -68, -63, -60, -65, -65, -69, -68, -67, -57, -46, -41, + -38, -34, -31, -39, -40, -45, -45, -48, -47, -40, -39, + -32, -26, -24, -14, -9, -7, -3, -2, 3, 4, 0, + -2, -2, -2, 1, 3, 2, 3, 8, 13, 20, 25, + 29, 31, 26, 17, 11, 3, -5, 2, 6, 9, 11, + 19, 26, 40, 51, 61, 60, 58, 61, 55, 55, 57, + 60, 54, 40, 42, 38, 34, 38, 37, 34, 32, 35, + 36, 35, 41, 36, 32, 29, 23, 22, 23, 22, 14, + 13, 19, 19, 20, 22, 22, 17, 13, 6, 9, 13, + 15, 17, 19, 11, 15, 8, 4, 6, -1, -3, 3, + 7, 11, 8, 10, 7, 6, 4, -4, -5, -11, -9, + -16, -14, -14, -16, -16, -22, -19, -19, -13, -9, -4, + 1, 1, 2, -6, -14, -25, -32, -41, -46, -50, -49, + -42, -39, -34, -24, -14, -18, -15, -17, -21, -23, -21, + -19, -21, -20, -19, -20, -19, -16, -17, -19, -20, -20, + -20, -20, -22, -22, -23, -22, -22, -14, -5, 5, 8, + 13, 16, 19, 23, 19, 21, 16, 16, 18, 13, 18, + 13, 15, 18, 12, 12, 6, 11, 8, 5, 5, 9, + 17, 14, 15, 14, 16, 14, 14, 12, 9, 7, 9, + 11, 13, 15, 15, 19, 17, 14, 8, 7, 4, 0, + 3, 8, 10, 7, 8, 19, 15, 19, 18, 19, 17, + 9, 14, 10, 4, -3, -11, -19, -25, -31, -35, -36, + -28, -21, -8, 5, 8, 11, 13, 7, 4, 1, -7, + -15, -17, -17, -21, -28, -33, -37, -40, -39, -41, -45, + -46, -44, -40, -41, -36, -31, -41, -40, -42, -44, -47, + -50, -49, -55, -52, -52, -52, -45, -50, -52, -56, -58, + -60, -69, -75, -82, -86, -91, -87, -80, -80, -72, -58, + -52, -45, -33, -21, -13, -12, -10, -6, -1, -2, -7, + -7, -5, -6, -3, 9, 15, 25, 36, 35, 39, 28, + 16, 11, 8, 11, 17, 27, 34, 36, 47, 49, 52, + 52, 42, 46, 49, 55, 65, 66, 67, 62, 56, 53, + 49, 50, 55, 53, 62, 69, 72, 73, 68, 61, 54, + 46, 43, 38, 34, 39, 43, 42, 39, 36, 31, 26, + 24, 17, 13, 14, 14, 21, 26, 29, 28, 26, 24, + 18, 19, 16, 11, 6, 2, -2, 1, 3, 2, -4, + -3, -1, -3, -2, -2, -5, -3, 0, 3, -3, -6, + -6, -15, -19, -25, -30, -35, -39, -34, -34, -34, -31, + -17, -17, -8, -2, -2, 8, 14, 25, 24, 26, 22, + 16, 10, 2, -3, -5, -12, -15, -11, -14, -16, -17, + -17, -16, -21, -18, -18, -21, -23, -21, -15, -11, -4, + -2, 3, 8, 10, 17, 18, 25, 24, 24, 24, 21, + 24, 23, 24, 22, 23, 31, 39, 49, 58, 64, 67, + 63, 57, 53, 52, 44, 45, 43, 40, 45, 42, 49, + 50, 49, 52, 51, 48, 46, 38, 37, 35, 36, 37, + 37, 37, 44, 45, 47, 42, 42, 36, 35, 44, 40, + 40, 28, 24, 23, 18, 12, 9, 8, 10, 17, 17, + 18, 12, 5, -2, -12, -16, -20, -27, -29, -29, -26, + -22, -17, -16, -15, -14, -15, -11, -11, -15, -19, -15, + -20, -22, -24, -37, -52, -62, -63, -68, -64, -59, -51, + -43, -42, -36, -32, -33, -33, -33, -41, -48, -51, -49, + -48, -47, -42, -45, -42, -41, -40, -39, -33, -29, -25, + -14, -1, -4, -6, -11, -16, -19, -26, -29, -28, -25, + -17, -10, -1, -1, 3, 7, -1, -3, -8, -18, -20, + -20, -16, -13, -11, -8, 0, 6, 8, 11, 14, 15, + 20, 26, 26, 26, 24, 23, 24, 30, 34, 41, 52, + 61, 70, 80, 85, 86, 89, 84, 87, 79, 67, 60, + 57, 59, 63, 68, 74, 78, 84, 89, 91, 87, 81, + 74, 69, 63, 59, 59, 56, 58, 60, 60, 59, 54, + 49, 41, 40, 34, 25, 19, 11, 1, 0, -1, -4, + -8, -12, -12, -17, -22, -31, -44, -54, -58, -68, -74, + -80, -80, -73, -65, -61, -61, -55, -50, -50, -59, -65, + -69, -73, -73, -78, -79, -83, -87, -87, -88, -94, -103, + -107, -107, -109, -106, -113, -115, -110, -105, -100, -100, -92, + -78, -62, -49, -39, -35, -27, -26, -25, -24, -22, -23, + -28, -26, -22, -15, -11, -4, 4, 13, 21, 32, 31, + 28, 30, 30, 28, 23, 25, 23, 21, 25, 21, 26, + 27, 32, 40, 48, 53, 55, 54, 55, 55, 54, 48, + 44, 47, 48, 54, 60, 71, 79, 79, 74, 72, 59, + 48, 42, 32, 26, 22, 21, 23, 22, 31, 42, 44, + 41, 36, 30, 30, 33, 38, 35, 30, 28, 20, 15, + 8, 4, 6, 9, 16, 26, 27, 23, 19, 16, 10, + 4, -4, -12, -12, -16, -16, -19, -24, -23, -23, -31, + -34, -38, -40, -41, -39, -39, -36, -36, -40, -45, -48, + -53, -66, -73, -76, -76, -78, -75, -71, -65, -59, -58, + -59, -56, -60, -62, -62, -62, -64, -68, -73, -79, -80, + -85, -87, -85, -78, -72, -66, -56, -48, -42, -37, -35, + -32, -33, -31, -25, -26, -27, -16, -18, -18, -13, -14, + -17, -22, -24, -25, -23, -19, -14, -12, -11, -7, -4, + -1, 2, 5, 8, 10, 10, 18, 28, 29, 25, 22, + 29, 21, 20, 21, 22, 30, 32, 41, 41, 45, 46, + 49, 52, 57, 59, 58, 52, 46, 47, 56, 58, 49, + 49, 46, 40, 33, 23, 14, 11, 16, 29, 34, 37, + 41, 42, 48, 54, 60, 61, 62, 62, 69, 79, 76, + 71, 72, 71, 64, 59, 54, 49, 40, 42, 34, 23, + 27, 18, 13, 9, 3, -4, -8, -16, -18, -20, -26, + -28, -30, -32, -29, -32, -35, -39, -41, -38, -34, -31, + -26, -18, -21, -20, -22, -28, -35, -34, -31, -33, -31, + -31, -40, -43, -45, -53, -64, -67, -74, -75, -74, -75, + -70, -61, -56, -45, -37, -30, -33, -35, -32, -31, -27, + -25, -19, -17, -14, -9, -4, -1, -3, -4, 1, 8, + 14, 20, 24, 25, 18, 11, 7, -3, -9, -3, 4, + 15, 30, 29, 33, 33, 36, 35, 31, 33, 34, 42, + 43, 42, 47, 49, 53, 61, 69, 73, 74, 79, 81, + 84, 76, 69, 62, 47, 39, 31, 19, 8, 2, -6, + -5, -3, -3, -1, 1, -2, -3, -3, -6, -12, -13, + -15, -11, -5, -4, -8, -14, -9, -3, 0, -3, -4, + 0, 3, 0, -6, -14, -23, -33, -38, -41, -38, -38, + -34, -30, -29, -29, -26, -31, -33, -41, -49, -50, -56, + -57, -58, -54, -46, -39, -39, -34, -31, -28, -30, -30, + -31, -29, -27, -16, -18, -17, -15, -13, -15, -12, -7, + -11, -9, -9, -4, -11, -7, -7, -8, -9, -10, -7, + -9, 1, 9, 15, 12, 19, 19, 18, 17, 13, 11, + 8, 6, 10, 17, 20, 26, 28, 33, 39, 30, 25, + 25, 18, 16, 21, 26, 30, 33, 32, 36, 42, 49, + 46, 39, 44, 44, 37, 35, 30, 24, 22, 23, 26, + 23, 25, 21, 24, 24, 22, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h new file mode 100644 index 0000000000000000000000000000000000000000..4cc8030cdac7c4e8364e0fcd7dcc5fff63617908 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_1000ms_sample_data.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was created from the PCM data in a WAV file held in v2 of the +// Speech Commands test dataset, at the path: +// speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav +// This should contain all 16,000 samples from the one-second file. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_1000MS_SAMPLE_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_1000MS_SAMPLE_DATA_H_ + +#include + +extern const int g_no_1000ms_sample_data_size; +extern const int16_t g_no_1000ms_sample_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_1000MS_SAMPLE_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/osx/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/osx/audio_provider.cc index 892757e799f3832db725424163e613bea35ab9e7..6468c1a95a9cd3f844595bf2c6e88c1e2833823b 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/osx/audio_provider.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/osx/audio_provider.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h" namespace { diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc index 9366dc71e0d76d087a3dad9b9c4c206a0749e235..8187962c3e780a76413134771dc63ba30910f3b6 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc @@ -28,8 +28,8 @@ RecognizeCommands::RecognizeCommands(tflite::ErrorReporter* error_reporter, suppression_ms_(suppression_ms), minimum_count_(minimum_count), previous_results_(error_reporter) { - previous_top_label_ = "_silence_"; - previous_top_label_time_ = 0; + previous_top_label_ = "silence"; + previous_top_label_time_ = std::numeric_limits::min(); } TfLiteStatus RecognizeCommands::ProcessLatestResults( diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h index adefffe850076821dd1e0bf683fdd2180d6999ea..292cd3e88dcd63f925cb16995b5e8a16554a8547 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" // Partial implementation of std::dequeue, just providing the functionality @@ -129,8 +129,8 @@ class RecognizeCommands { // help reduce spurious recognitions. explicit RecognizeCommands(tflite::ErrorReporter* error_reporter, int32_t average_window_duration_ms = 1000, - uint8_t detection_threshold = 51, - int32_t suppression_ms = 500, + uint8_t detection_threshold = 200, + int32_t suppression_ms = 1500, int32_t minimum_count = 3); // Call this with the results of running a model on sample data. @@ -149,8 +149,6 @@ class RecognizeCommands { // Working variables PreviousResultsQueue previous_results_; - int previous_results_head_; - int previous_results_tail_; const char* previous_top_label_; int32_t previous_top_label_time_; }; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc index f0cc73f10b3dadfdf06cb0f2935140b792635add..6582c948d16f9493a4b1e5bdf43bdc1f30e6dc31 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc @@ -118,7 +118,9 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) { } } TF_LITE_MICRO_EXPECT(has_found_new_command); - TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command)); + if (has_found_new_command) { + TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command)); + } TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor( {0, 0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 4}), @@ -141,8 +143,10 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) { } } TF_LITE_MICRO_EXPECT(has_found_new_command); - TF_LITE_MICRO_EXPECT_EQ(231, score); - TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("no", new_command)); + if (has_found_new_command) { + TF_LITE_MICRO_EXPECT_EQ(231, score); + TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("no", new_command)); + } } TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputLength) { diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/preprocessor.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/CMSIS/simple_features_generator.cc similarity index 85% rename from tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/preprocessor.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/CMSIS/simple_features_generator.cc index 5c6978b5edef635af58873bf537a251fa4510ef4..403976e222fe549f6f8c755bf7460d245d9370e8 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/preprocessor.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/CMSIS/simple_features_generator.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h" + extern "C" { -#define ARM_MATH_CM4 #define IFFT_FLAG_R 0 #define BIT_REVERSE_FLAG 1 #define FFT_SIZE 512 @@ -24,8 +25,6 @@ extern "C" { #include "tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/hanning.h" } -#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" - void quantize(q15_t* bufA, q15_t* bufB, uint8_t* output); q15_t bufA[FFT_SIZE]; @@ -42,9 +41,9 @@ constexpr int kOutputSize = ((kInputSize / 2) + (kAverageWindowSize - 1)) / kAverageWindowSize; } // namespace -TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, - const int16_t* input, int input_size, int output_size, - uint8_t* output) { +TfLiteStatus GenerateSimpleFeatures(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, + int output_size, uint8_t* output) { if (input_size > kInputSize) { error_reporter->Report("Input size %d larger than %d", input_size, kInputSize); @@ -94,12 +93,3 @@ void quantize(q15_t* bufA, q15_t* bufB, uint8_t* output) { output[i] = (uint8_t)(bufA[i] >> 5); } } - -TfLiteStatus Preprocess_1sec(tflite::ErrorReporter* error_reporter, - const int16_t* input, uint8_t* output) { - int i; - for (i = 0; i < 49; i++) { - Preprocess(error_reporter, input + i * 320, 480, 43, output + i * 43); - } - return kTfLiteOk; -} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/fixed_point/simple_features_generator.cc similarity index 96% rename from tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/fixed_point/simple_features_generator.cc index b623d8d11b75d59600cc6a029527d3957084a328..ad11684b0a94e630580aa9a95d4b1db92f914d6f 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/fixed_point/simple_features_generator.cc @@ -27,11 +27,11 @@ limitations under the License. // instead of floating point, to help show how this can work on platforms that // don't have good float support. -#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h" #include -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h" namespace { @@ -118,9 +118,9 @@ void CalculatePeriodicHann(int window_length, int16_t* window_function) { } // namespace -TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, - const int16_t* input, int input_size, int output_size, - uint8_t* output) { +TfLiteStatus GenerateSimpleFeatures(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, + int output_size, uint8_t* output) { // Ensure our input and output data arrays are valid. if (input_size > kMaxAudioSampleSize) { error_reporter->Report("Input size %d larger than %d", input_size, diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.cc similarity index 95% rename from tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.cc index c4fc5c33bb329cba4e1abcf6d36b01f14e9e2b27..0b20f2f86fb6455d4251cb81d3e70c3c15de7c6b 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.cc @@ -15,7 +15,7 @@ limitations under the License. // See the header for documentation on the meaning of this data. -#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.h" const uint8_t g_no_power_spectrum_data[g_no_power_spectrum_data_size] = { 255, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.h similarity index 90% rename from tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.h index fa39d3c70d78ce261db81cf8ad7c416efd2c468c..9693950fb5ee1d56242b83c6265e9e2315ec8971 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.h @@ -18,12 +18,12 @@ limitations under the License. // This is the expected result of running the sample data in // no_30ms_sample_data.cc through through the preprocessing pipeline. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_POWER_SPECTRUM_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_POWER_SPECTRUM_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_NO_POWER_SPECTRUM_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_NO_POWER_SPECTRUM_DATA_H_ #include constexpr int g_no_power_spectrum_data_size = 43; extern const uint8_t g_no_power_spectrum_data[]; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_POWER_SPECTRUM_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_NO_POWER_SPECTRUM_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.cc similarity index 97% rename from tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.cc index e98c84f7ed2e678eb91580a2b6fb69514cee4740..3d3a9538fb527888e3bdf0e1aa9ca00d4d5f1544 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.h" /* File automatically created by * tensorflow/examples/speech_commands/wav_to_features.py \ @@ -22,15 +22,15 @@ limitations under the License. * --window_size_ms=30 \ * --window_stride_ms=20 \ * --feature_bin_count=40 \ - * --quantize \ + * --quantize=1 \ * --preprocess="average" \ * --input_wav="speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav" \ - * --output_c_file="no_features_data.cc" \ + * --output_c_file="no_simple_features_data.cc" \ */ -const int g_no_f9643d42_nohash_4_width = 43; -const int g_no_f9643d42_nohash_4_height = 49; -const unsigned char g_no_f9643d42_nohash_4_data[] = { +const int g_no_simple_f9643d42_nohash_4_width = 43; +const int g_no_simple_f9643d42_nohash_4_height = 49; +const unsigned char g_no_simple_f9643d42_nohash_4_data[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.h similarity index 73% rename from tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.h index 39a3bb914cc1986aa851ace0e39ce63ed1a93282..30332b30c5c8325edb53713d572fcf987446844a 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_simple_features_data.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_NO_SIMPLE_FEATURES_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_NO_SIMPLE_FEATURES_DATA_H_ -extern const int g_yes_f2e59fea_nohash_1_width; -extern const int g_yes_f2e59fea_nohash_1_height; -extern const unsigned char g_yes_f2e59fea_nohash_1_data[]; +extern const int g_no_simple_f9643d42_nohash_4_width; +extern const int g_no_simple_f9643d42_nohash_4_height; +extern const unsigned char g_no_simple_f9643d42_nohash_4_data[]; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_NO_SIMPLE_FEATURES_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.cc similarity index 92% rename from tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.cc index f8858aad72f3c141d20077ffa927e30bd9492987..3aa05b7bf1d5d1762c9c6744ac8a5fe99f922332 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.cc @@ -24,11 +24,11 @@ limitations under the License. // functions used here, for example replacing the DFT with an FFT, so this // version shouldn't be used where performance is critical. -#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h" #include -#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h" namespace { @@ -72,9 +72,9 @@ void CalculatePeriodicHann(int window_length, float* window_function) { } // namespace -TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, - const int16_t* input, int input_size, int output_size, - uint8_t* output) { +TfLiteStatus GenerateSimpleFeatures(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, + int output_size, uint8_t* output) { // Ensure our input and output data arrays are valid. if (input_size > kMaxAudioSampleSize) { error_reporter->Report("Input size %d larger than %d", input_size, @@ -146,12 +146,3 @@ TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, } return kTfLiteOk; } - -TfLiteStatus Preprocess_1sec(tflite::ErrorReporter* error_reporter, - const int16_t* input, uint8_t* output) { - int i; - for (i = 0; i < 49; i++) { - Preprocess(error_reporter, input + i * 320, 480, 43, output + i * 43); - } - return kTfLiteOk; -} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h similarity index 77% rename from tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h index d710beeceea6a7b6fb7fca748e5795f602276e32..f4e86b18a4c3d1c0a5beb32eb6806faaf1c11c14 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_PREPROCESSOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_PREPROCESSOR_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_SIMPLE_FEATURES_GENERATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_SIMPLE_FEATURES_GENERATOR_H_ #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" @@ -24,11 +24,8 @@ limitations under the License. // both floating point and fixed point available, but because the calculations // involved can be time-consuming, it's recommended that you use or write // specialized versions for your platform. -TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, - const int16_t* input, int input_size, int output_size, - uint8_t* output); +TfLiteStatus GenerateSimpleFeatures(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, + int output_size, uint8_t* output); -TfLiteStatus Preprocess_1sec(tflite::ErrorReporter* error_reporter, - const int16_t* input, uint8_t* output); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_PREPROCESSOR_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_SIMPLE_FEATURES_GENERATOR_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator_test.cc similarity index 90% rename from tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator_test.cc index e8b49f67e3d72faa4700c4bdec7f94a5b79cd72e..65e526327c77c727ec88cee421a466f0df34ee76 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator_test.cc @@ -13,23 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_features_generator.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/no_power_spectrum_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" TF_LITE_MICRO_TESTS_BEGIN -TF_LITE_MICRO_TEST(TestPreprocessor) { +TF_LITE_MICRO_TEST(TestSimpleFeaturesGenerator) { tflite::MicroErrorReporter micro_error_reporter; tflite::ErrorReporter* error_reporter = µ_error_reporter; uint8_t yes_calculated_data[g_yes_power_spectrum_data_size]; - TfLiteStatus yes_status = Preprocess( + TfLiteStatus yes_status = GenerateSimpleFeatures( error_reporter, g_yes_30ms_sample_data, g_yes_30ms_sample_data_size, g_yes_power_spectrum_data_size, yes_calculated_data); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, yes_status); @@ -45,7 +45,7 @@ TF_LITE_MICRO_TEST(TestPreprocessor) { } uint8_t no_calculated_data[g_yes_power_spectrum_data_size]; - TfLiteStatus no_status = Preprocess( + TfLiteStatus no_status = GenerateSimpleFeatures( error_reporter, g_no_30ms_sample_data, g_no_30ms_sample_data_size, g_no_power_spectrum_data_size, no_calculated_data); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, no_status); diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.cc new file mode 100644 index 0000000000000000000000000000000000000000..4842f8dbd907dbbd73aab14c7767a8d64476b52d --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h" + +const char* kCategoryLabels[kCategoryCount] = { + "silence", + "unknown", + "yes", + "no", +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h similarity index 93% rename from tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h index f48252d14d251673f0070e63dfa4169ca3a89025..d31d6b33622b3a15c90fab4c52d7452960a54930 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/simple_model_settings.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MODEL_SETTINGS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MODEL_SETTINGS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_SIMPLE_MODEL_SETTINGS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_SIMPLE_MODEL_SETTINGS_H_ // Keeping these as constant expressions allow us to allocate fixed-sized arrays // on the stack for our working memory. @@ -40,4 +40,4 @@ constexpr int kSilenceIndex = 0; constexpr int kUnknownIndex = 1; extern const char* kCategoryLabels[kCategoryCount]; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MODEL_SETTINGS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_SIMPLE_MODEL_SETTINGS_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.cc similarity index 99% rename from tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.cc index 62e4359859a422c96ec368b6f91cba99e3c4a4eb..a14412edc941e8a7df0aef9dd66b79b1d9a1d7a6 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ // Automatically created from a TensorFlow Lite flatbuffer using the command: -// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc +// xxd -i tiny_conv.tflite > tiny_conv_simple_features_model_data.cc // See the README for a full description of the creation process. -#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.h" -const unsigned char g_tiny_conv_model_data[] = { +const unsigned char g_tiny_conv_simple_features_model_data[] = { 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x4d, 0x00, 0x00, @@ -1670,4 +1670,4 @@ const unsigned char g_tiny_conv_model_data[] = { 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, 0x00, 0x19, 0x06, 0x00, 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x09, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04}; -const int g_tiny_conv_model_data_len = 19800; +const int g_tiny_conv_simple_features_model_data_len = 19800; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.h similarity index 74% rename from tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.h index a465dbfabf7cbba44473ae7e2ff94b1de2092b20..cadf7d0de754e032ae9ff77cdd8deec43bc03847 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/tiny_conv_simple_features_model_data.h @@ -16,12 +16,12 @@ limitations under the License. // This is a standard TensorFlow Lite model file that has been converted into a // C data array, so it can be easily compiled into a binary for devices that // don't have a file system. It was created using the command: -// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc +// xxd -i tiny_conv.tflite > tiny_conv_simple_features_model_data.cc -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_TINY_CONV_SIMPLE_FEATURES_MODEL_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_TINY_CONV_SIMPLE_FEATURES_MODEL_DATA_H_ -extern const unsigned char g_tiny_conv_model_data[]; -extern const int g_tiny_conv_model_data_len; +extern const unsigned char g_tiny_conv_simple_features_model_data[]; +extern const int g_tiny_conv_simple_features_model_data_len; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_TINY_CONV_SIMPLE_FEATURES_MODEL_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.cc similarity index 95% rename from tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.cc index 9a34a2045a221e2eee8c51f23000e819b1638499..cd46408c0fb5c2c5dad12ae67c5456c8cb178b2d 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.cc @@ -15,7 +15,7 @@ limitations under the License. // See the header for documentation on the meaning of this data. -#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.h" const uint8_t g_yes_power_spectrum_data[g_yes_power_spectrum_data_size] = { 8, 89, 8, 0, 0, 0, 0, 0, 0, 0, 0, 4, 13, 1, 6, 23, 20, 6, 4, 0, 0, 0, diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.h similarity index 90% rename from tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.h index 5c8c00ac1116dcbd7ad4aeda1828603e962c2001..77e52d58b54763ec8df46729ab6f8dd84086d59b 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_power_spectrum_data.h @@ -18,12 +18,12 @@ limitations under the License. // This is the expected result of running the sample data in // yes_30ms_sample_data.cc through through the preprocessing pipeline. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_POWER_SPECTRUM_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_POWER_SPECTRUM_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_YES_POWER_SPECTRUM_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_YES_POWER_SPECTRUM_DATA_H_ #include constexpr int g_yes_power_spectrum_data_size = 43; extern const uint8_t g_yes_power_spectrum_data[]; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_POWER_SPECTRUM_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_YES_POWER_SPECTRUM_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.cc similarity index 97% rename from tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.cc index 2eb737fb8e1204a02f7ea4852016e85d03980bfd..2d660bb8b5c5b825eb48490699c89e5ba241369f 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.h" /* File automatically created by * tensorflow/examples/speech_commands/wav_to_features.py \ @@ -22,15 +22,15 @@ limitations under the License. * --window_size_ms=30 \ * --window_stride_ms=20 \ * --feature_bin_count=40 \ - * --quantize \ + * --quantize=1 \ * --preprocess="average" \ * --input_wav="speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav" \ - * --output_c_file="yes_features_data.cc" \ + * --output_c_file="yes_simple_features_data.cc" \ */ -const int g_yes_f2e59fea_nohash_1_width = 43; -const int g_yes_f2e59fea_nohash_1_height = 49; -const unsigned char g_yes_f2e59fea_nohash_1_data[] = { +const int g_yes_simple_f2e59fea_nohash_1_width = 43; +const int g_yes_simple_f2e59fea_nohash_1_height = 49; +const unsigned char g_yes_simple_f2e59fea_nohash_1_data[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.h similarity index 73% rename from tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.h index e2ee0c46cf13b00b310bd22b7ca1cb5a9751c6e6..87ea4a4aea89d02189bca9c37872e27b95672190 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/simple_features/yes_simple_features_data.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_YES_SIMPLE_FEATURES_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_YES_SIMPLE_FEATURES_DATA_H_ -extern const int g_no_f9643d42_nohash_4_width; -extern const int g_no_f9643d42_nohash_4_height; -extern const unsigned char g_no_f9643d42_nohash_4_data[]; +extern const int g_yes_simple_f2e59fea_nohash_1_width; +extern const int g_yes_simple_f2e59fea_nohash_1_height; +extern const unsigned char g_yes_simple_f2e59fea_nohash_1_data[]; -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_SIMPLE_FEATURES_YES_SIMPLE_FEATURES_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc new file mode 100644 index 0000000000000000000000000000000000000000..520a46ef59811263bcae4cca739ddc26c215b202 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/sparkfun_edge/audio_provider.cc @@ -0,0 +1,358 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" + +#include + +// These are headers from Ambiq's Apollo3 SDK. +#include "am_bsp.h" // NOLINT +#include "am_mcu_apollo.h" // NOLINT +#include "am_util.h" // NOLINT + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h" + +namespace { + +// These are the raw buffers that are filled by the ADC during DMA +constexpr int kAdcNumSlots = 2; +constexpr int kAdcSamplesPerSlot = 1024; +constexpr int kAdcSampleBufferSize = (kAdcNumSlots * kAdcSamplesPerSlot); +uint32_t g_ui32ADCSampleBuffer0[kAdcSampleBufferSize]; +uint32_t g_ui32ADCSampleBuffer1[kAdcSampleBufferSize]; +// Controls the double buffering between the two DMA buffers. +int g_dma_destination_index = 0; +// ADC Device Handle. +static void* g_adc_handle; +// ADC DMA error flag. +volatile bool g_adc_dma_error; +// So the interrupt can use the passed-in error handler to report issues. +tflite::ErrorReporter* g_adc_dma_error_reporter = nullptr; + +// Holds a longer history of audio samples in a ring buffer. +constexpr int kAudioCaptureBufferSize = 16000; +int16_t g_audio_capture_buffer[kAudioCaptureBufferSize] = {}; +int g_audio_capture_buffer_start = 0; +int64_t g_total_samples_captured = 0; +int32_t g_latest_audio_timestamp = 0; + +// Copy of audio samples returned to the caller. +int16_t g_audio_output_buffer[kMaxAudioSampleSize]; +bool g_is_audio_initialized = false; + +// Start the DMA fetch of ADC samples. +void adc_start_dma(tflite::ErrorReporter* error_reporter) { + am_hal_adc_dma_config_t ADCDMAConfig; + + // Configure the ADC to use DMA for the sample transfer. + ADCDMAConfig.bDynamicPriority = true; + ADCDMAConfig.ePriority = AM_HAL_ADC_PRIOR_SERVICE_IMMED; + ADCDMAConfig.bDMAEnable = true; + ADCDMAConfig.ui32SampleCount = kAdcSampleBufferSize; + if (g_dma_destination_index == 0) { + ADCDMAConfig.ui32TargetAddress = (uint32_t)g_ui32ADCSampleBuffer0; + } else { + ADCDMAConfig.ui32TargetAddress = (uint32_t)g_ui32ADCSampleBuffer1; + } + if (AM_HAL_STATUS_SUCCESS != + am_hal_adc_configure_dma(g_adc_handle, &ADCDMAConfig)) { + error_reporter->Report("Error - configuring ADC DMA failed."); + } + + // Reset the ADC DMA flags. + g_adc_dma_error = false; + g_adc_dma_error_reporter = error_reporter; +} + +// Configure the ADC. +void adc_config0(tflite::ErrorReporter* error_reporter) { + am_hal_adc_config_t ADCConfig; + am_hal_adc_slot_config_t ADCSlotConfig; + + // Initialize the ADC and get the handle. + if (AM_HAL_STATUS_SUCCESS != am_hal_adc_initialize(0, &g_adc_handle)) { + error_reporter->Report("Error - reservation of the ADC0 instance failed."); + } + + // Power on the ADC. + if (AM_HAL_STATUS_SUCCESS != + am_hal_adc_power_control(g_adc_handle, AM_HAL_SYSCTRL_WAKE, false)) { + error_reporter->Report("Error - ADC0 power on failed."); + } + + // Set up the ADC configuration parameters. These settings are reasonable + // for accurate measurements at a low sample rate. + ADCConfig.eClock = AM_HAL_ADC_CLKSEL_HFRC_DIV2; + ADCConfig.ePolarity = AM_HAL_ADC_TRIGPOL_RISING; + ADCConfig.eTrigger = AM_HAL_ADC_TRIGSEL_SOFTWARE; + ADCConfig.eReference = + AM_HAL_ADC_REFSEL_INT_2P0; // AM_HAL_ADC_REFSEL_INT_1P5; + ADCConfig.eClockMode = AM_HAL_ADC_CLKMODE_LOW_LATENCY; + ADCConfig.ePowerMode = AM_HAL_ADC_LPMODE0; + ADCConfig.eRepeat = AM_HAL_ADC_REPEATING_SCAN; + if (AM_HAL_STATUS_SUCCESS != am_hal_adc_configure(g_adc_handle, &ADCConfig)) { + error_reporter->Report("Error - configuring ADC0 failed."); + } + + // Set up an ADC slot (2) + ADCSlotConfig.eMeasToAvg = AM_HAL_ADC_SLOT_AVG_1; + ADCSlotConfig.ePrecisionMode = AM_HAL_ADC_SLOT_14BIT; + ADCSlotConfig.eChannel = AM_HAL_ADC_SLOT_CHSEL_SE2; + ADCSlotConfig.bWindowCompare = false; + ADCSlotConfig.bEnabled = true; + if (AM_HAL_STATUS_SUCCESS != + am_hal_adc_configure_slot(g_adc_handle, 2, &ADCSlotConfig)) { + error_reporter->Report("Error - configuring ADC Slot 2 failed."); + } + + // Set up an ADC slot (1) + ADCSlotConfig.eMeasToAvg = AM_HAL_ADC_SLOT_AVG_1; + ADCSlotConfig.ePrecisionMode = AM_HAL_ADC_SLOT_14BIT; + ADCSlotConfig.eChannel = AM_HAL_ADC_SLOT_CHSEL_SE1; + ADCSlotConfig.bWindowCompare = false; + ADCSlotConfig.bEnabled = true; + if (AM_HAL_STATUS_SUCCESS != + am_hal_adc_configure_slot(g_adc_handle, 1, &ADCSlotConfig)) { + error_reporter->Report("Error - configuring ADC Slot 1 failed."); + } + + // Configure the ADC to use DMA for the sample transfer. + adc_start_dma(error_reporter); + + // For this example, the samples will be coming in slowly. This means we + // can afford to wake up for every conversion. + am_hal_adc_interrupt_enable(g_adc_handle, + AM_HAL_ADC_INT_DERR | AM_HAL_ADC_INT_DCMP); + + // Enable the ADC. + if (AM_HAL_STATUS_SUCCESS != am_hal_adc_enable(g_adc_handle)) { + error_reporter->Report("Error - enabling ADC0 failed."); + } +} + +// Initialize the ADC repetitive sample timer A3. +void init_timerA3_for_ADC() { + // Start a timer to trigger the ADC periodically (1 second). + am_hal_ctimer_config_single(3, AM_HAL_CTIMER_TIMERA, + AM_HAL_CTIMER_HFRC_12MHZ | + AM_HAL_CTIMER_FN_REPEAT | + AM_HAL_CTIMER_INT_ENABLE); + + am_hal_ctimer_int_enable(AM_HAL_CTIMER_INT_TIMERA3); + + // 750 = 12,000,000 (clock rate) / 16,000 (desired sample rate). + am_hal_ctimer_period_set(3, AM_HAL_CTIMER_TIMERA, 750, 374); + + // Enable the timer A3 to trigger the ADC directly + am_hal_ctimer_adc_trigger_enable(); + + // Start the timer. + am_hal_ctimer_start(3, AM_HAL_CTIMER_TIMERA); +} + +// Make sure the CPU is running as fast as possible. +void enable_burst_mode(tflite::ErrorReporter* error_reporter) { + am_hal_burst_avail_e eBurstModeAvailable; + am_hal_burst_mode_e eBurstMode; + + // Check that the Burst Feature is available. + if (AM_HAL_STATUS_SUCCESS == + am_hal_burst_mode_initialize(&eBurstModeAvailable)) { + if (AM_HAL_BURST_AVAIL == eBurstModeAvailable) { + error_reporter->Report("Apollo3 Burst Mode is Available\n"); + } else { + error_reporter->Report("Apollo3 Burst Mode is Not Available\n"); + } + } else { + error_reporter->Report("Failed to Initialize for Burst Mode operation\n"); + } + + // Put the MCU into "Burst" mode. + if (AM_HAL_STATUS_SUCCESS == am_hal_burst_mode_enable(&eBurstMode)) { + if (AM_HAL_BURST_MODE == eBurstMode) { + error_reporter->Report("Apollo3 operating in Burst Mode (96MHz)\n"); + } + } else { + error_reporter->Report("Failed to Enable Burst Mode operation\n"); + } +} + +} // namespace + +// Interrupt handler for the ADC. +extern "C" void am_adc_isr(void) { + uint32_t ui32IntMask; + + // Read the interrupt status. + if (AM_HAL_STATUS_SUCCESS != + am_hal_adc_interrupt_status(g_adc_handle, &ui32IntMask, false)) { + g_adc_dma_error_reporter->Report("Error reading ADC0 interrupt status."); + } + + // Clear the ADC interrupt. + if (AM_HAL_STATUS_SUCCESS != + am_hal_adc_interrupt_clear(g_adc_handle, ui32IntMask)) { + g_adc_dma_error_reporter->Report("Error clearing ADC0 interrupt status."); + } + + // If we got a DMA complete, set the flag. + if (ui32IntMask & AM_HAL_ADC_INT_DCMP) { + uint32_t* source_buffer; + if (g_dma_destination_index == 0) { + source_buffer = g_ui32ADCSampleBuffer0; + g_dma_destination_index = 1; + } else { + source_buffer = g_ui32ADCSampleBuffer1; + g_dma_destination_index = 0; + } + adc_start_dma(g_adc_dma_error_reporter); + + // For slot 1: + uint32_t slotCount = 0; + for (uint32_t indi = 0; indi < kAdcSampleBufferSize; indi++) { + am_hal_adc_sample_t temp; + + temp.ui32Slot = AM_HAL_ADC_FIFO_SLOT(source_buffer[indi]); + temp.ui32Sample = AM_HAL_ADC_FIFO_SAMPLE(source_buffer[indi]); + + if (temp.ui32Slot == 1) { + g_audio_capture_buffer[g_audio_capture_buffer_start] = temp.ui32Sample; + g_audio_capture_buffer_start = + (g_audio_capture_buffer_start + 1) % kAudioCaptureBufferSize; + slotCount++; + } + } + + g_total_samples_captured += slotCount; + g_latest_audio_timestamp = + (g_total_samples_captured / (kAudioSampleFrequency / 1000)); + } + + // If we got a DMA error, set the flag. + if (ui32IntMask & AM_HAL_ADC_INT_DERR) { + g_adc_dma_error = true; + } +} + +TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter) { + // Set the clock frequency. + if (AM_HAL_STATUS_SUCCESS != + am_hal_clkgen_control(AM_HAL_CLKGEN_CONTROL_SYSCLK_MAX, 0)) { + error_reporter->Report("Error - configuring the system clock failed."); + return kTfLiteError; + } + + // Set the default cache configuration and enable it. + if (AM_HAL_STATUS_SUCCESS != + am_hal_cachectrl_config(&am_hal_cachectrl_defaults)) { + error_reporter->Report("Error - configuring the system cache failed."); + return kTfLiteError; + } + if (AM_HAL_STATUS_SUCCESS != am_hal_cachectrl_enable()) { + error_reporter->Report("Error - enabling the system cache failed."); + return kTfLiteError; + } + + // Ensure the CPU is running as fast as possible. + enable_burst_mode(error_reporter); + + // Start the CTIMER A3 for timer-based ADC measurements. + init_timerA3_for_ADC(); + + // Enable interrupts. + NVIC_EnableIRQ(ADC_IRQn); + am_hal_interrupt_master_enable(); + + // Edge Board Pin Definitions + constexpr int kSfEdgePinMic0 = 11; + const am_hal_gpio_pincfg_t g_sf_edge_pin_mic0 = { + .uFuncSel = AM_HAL_PIN_11_ADCSE2, + }; + constexpr int kSfEdgePinMic1 = 29; + const am_hal_gpio_pincfg_t g_sf_edge_pin_mic1 = { + .uFuncSel = AM_HAL_PIN_29_ADCSE1, + }; + + // Set pins to act as our ADC input + am_hal_gpio_pinconfig(kSfEdgePinMic0, g_sf_edge_pin_mic0); + am_hal_gpio_pinconfig(kSfEdgePinMic1, g_sf_edge_pin_mic1); + + // Configure the ADC + adc_config0(error_reporter); + + // Trigger the ADC sampling for the first time manually. + if (AM_HAL_STATUS_SUCCESS != am_hal_adc_sw_trigger(g_adc_handle)) { + error_reporter->Report("Error - triggering the ADC0 failed."); + return kTfLiteError; + } + + // Enable the LED outputs. + am_hal_gpio_pinconfig(AM_BSP_GPIO_LED_RED, g_AM_HAL_GPIO_OUTPUT_12); + am_hal_gpio_pinconfig(AM_BSP_GPIO_LED_YELLOW, g_AM_HAL_GPIO_OUTPUT_12); + + am_hal_gpio_output_set(AM_BSP_GPIO_LED_RED); + am_hal_gpio_output_set(AM_BSP_GPIO_LED_YELLOW); + + return kTfLiteOk; +} + +TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter, + int start_ms, int duration_ms, + int* audio_samples_size, int16_t** audio_samples) { + if (!g_is_audio_initialized) { + TfLiteStatus init_status = InitAudioRecording(error_reporter); + if (init_status != kTfLiteOk) { + return init_status; + } + g_is_audio_initialized = true; + } + + // This is the 'zero' level of the microphone when no audio is present, and + // should be recalibrated if the hardware configuration ever changes. It was + // generated experimentally by averaging some samples captured on a board. + const int16_t kAdcSampleDC = 6003; + + // Temporary gain emulation to deal with too-quiet audio on prototype boards. + const int16_t kAdcSampleGain = 10; + + // This should only be called when the main thread notices that the latest + // audio sample data timestamp has changed, so that there's new data in the + // capture ring buffer. The ring buffer will eventually wrap around and + // overwrite the data, but the assumption is that the main thread is checking + // often enough and the buffer is large enough that this call will be made + // before that happens. + const int start_offset = start_ms * (kAudioSampleFrequency / 1000); + const int duration_sample_count = + duration_ms * (kAudioSampleFrequency / 1000); + for (int i = 0; i < duration_sample_count; ++i) { + const int capture_index = (start_offset + i) % kAudioCaptureBufferSize; + const int32_t capture_value = g_audio_capture_buffer[capture_index]; + int32_t output_value = capture_value - kAdcSampleDC; + output_value *= kAdcSampleGain; + if (output_value < std::numeric_limits::min()) { + output_value = std::numeric_limits::min(); + } + if (output_value > std::numeric_limits::max()) { + output_value = std::numeric_limits::max(); + } + g_audio_output_buffer[i] = output_value; + } + + *audio_samples_size = kMaxAudioSampleSize; + *audio_samples = g_audio_output_buffer; + return kTfLiteOk; +} + +int32_t LatestAudioTimestamp() { return g_latest_audio_timestamp; } diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/sparkfun_edge/command_responder.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/sparkfun_edge/command_responder.cc new file mode 100644 index 0000000000000000000000000000000000000000..78469f2b7d78d70caaf0f890970d7ff666a3c452 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/sparkfun_edge/command_responder.cc @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/command_responder.h" + +#include "am_bsp.h" // NOLINT + +// This implementation will light up the LEDs on the board in response to +// different commands. +void RespondToCommand(tflite::ErrorReporter* error_reporter, + int32_t current_time, const char* found_command, + uint8_t score, bool is_new_command) { + static bool is_initialized = false; + if (!is_initialized) { + // Setup LED's as outputs + am_hal_gpio_pinconfig(AM_BSP_GPIO_LED_RED, g_AM_HAL_GPIO_OUTPUT_12); + am_hal_gpio_pinconfig(AM_BSP_GPIO_LED_BLUE, g_AM_HAL_GPIO_OUTPUT_12); + am_hal_gpio_pinconfig(AM_BSP_GPIO_LED_GREEN, g_AM_HAL_GPIO_OUTPUT_12); + am_hal_gpio_pinconfig(AM_BSP_GPIO_LED_YELLOW, g_AM_HAL_GPIO_OUTPUT_12); + is_initialized = true; + } + static int count = 0; + + // Toggle the blue LED every time an inference is performed. + ++count; + if (count & 1) { + am_hal_gpio_output_set(AM_BSP_GPIO_LED_BLUE); + } else { + am_hal_gpio_output_clear(AM_BSP_GPIO_LED_BLUE); + } + + // Turn on the yellow LED if 'yes' was heard. + am_hal_gpio_output_clear(AM_BSP_GPIO_LED_RED); + am_hal_gpio_output_clear(AM_BSP_GPIO_LED_YELLOW); + am_hal_gpio_output_clear(AM_BSP_GPIO_LED_GREEN); + if (is_new_command) { + error_reporter->Report("Heard %s (%d) @%dms", found_command, score, + current_time); + if (found_command[0] == 'y') { + am_hal_gpio_output_set(AM_BSP_GPIO_LED_YELLOW); + } + if (found_command[0] == 'n') { + am_hal_gpio_output_set(AM_BSP_GPIO_LED_RED); + } + if (found_command[0] == 'u') { + am_hal_gpio_output_set(AM_BSP_GPIO_LED_GREEN); + } + } +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5f6ceb3f0b3935d084fa9463c72e98d4e0cad83 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.cc @@ -0,0 +1,1800 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h" + +const int g_yes_1000ms_sample_data_size = 16000; +const int16_t g_yes_1000ms_sample_data[16000] = { + -7, -12, -18, -20, -20, -21, -21, -25, -29, + -31, -31, -30, -30, -29, -30, -30, -29, -28, + -24, -22, -17, -12, -8, -7, -6, -1, 2, + 5, 7, 8, 11, 15, 18, 19, 23, 24, + 24, 27, 27, 26, 25, 28, 30, 32, 33, + 31, 29, 27, 28, 30, 28, 26, 26, 24, + 22, 17, 16, 15, 13, 10, 5, 0, -4, + -4, -7, -9, -12, -14, -14, -13, -11, -10, + -8, -6, -3, 3, 7, 8, 12, 15, 18, + 21, 19, 19, 21, 23, 24, 23, 22, 19, + 17, 11, 5, -3, -12, -22, -28, -35, -45, + -54, -62, -69, -76, -84, -92, -100, -109, -116, + -117, -120, -120, -120, -122, -124, -126, -123, -121, + -116, -113, -107, -97, -88, -75, -61, -50, -41, + -27, -12, 4, 21, 37, 58, 76, 93, 108, + 121, 137, 156, 172, 184, 196, 205, 215, 224, + 235, 242, 245, 242, 240, 238, 231, 223, 214, + 205, 195, 178, 158, 135, 112, 90, 69, 46, + 19, -11, -45, -76, -105, -133, -159, -186, -211, + -236, -260, -280, -294, -308, -320, -331, -336, -338, + -335, -326, -316, -301, -286, -267, -246, -225, -203, + -180, -154, -124, -91, -59, -34, -8, 19, 42, + 64, 87, 103, 119, 134, 148, 162, 174, 182, + 188, 190, 189, 187, 184, 180, 177, 171, 162, + 154, 144, 137, 129, 118, 106, 95, 81, 69, + 58, 48, 37, 26, 14, 3, -7, -22, -31, + -42, -52, -62, -69, -75, -79, -82, -87, -88, + -92, -94, -91, -87, -85, -81, -74, -70, -64, + -55, -47, -40, -33, -25, -19, -12, -6, -4, + -1, 1, 1, -2, -9, -15, -17, -18, -20, + -22, -22, -26, -31, -33, -35, -31, -26, -17, + -4, 8, 19, 31, 44, 54, 64, 71, 79, + 86, 92, 102, 109, 111, 109, 104, 96, 84, + 70, 60, 51, 38, 27, 13, 4, -3, -9, + -13, -18, -26, -33, -32, -27, -20, -10, -4, + 2, 6, 10, 14, 16, 21, 25, 29, 31, + 33, 35, 37, 33, 22, 15, 13, 11, 12, + 9, 5, 2, 1, -3, -9, -17, -27, -32, + -35, -36, -36, -42, -50, -56, -66, -77, -85, + -96, -100, -106, -113, -118, -121, -119, -117, -119, + -122, -124, -123, -112, -94, -77, -64, -51, -37, + -22, -3, 17, 37, 54, 68, 86, 100, 114, + 134, 154, 167, 174, 178, 182, 189, 189, 187, + 185, 179, 177, 174, 171, 157, 138, 123, 108, + 94, 76, 50, 25, 6, -8, -20, -37, -59, + -86, -110, -132, -147, -159, -169, -178, -191, -203, + -213, -217, -215, -208, -199, -194, -195, -190, -178, + -165, -155, -144, -134, -123, -103, -80, -56, -35, + -18, -4, 11, 23, 36, 50, 65, 78, 93, + 111, 122, 129, 132, 131, 127, 125, 126, 126, + 128, 127, 125, 122, 118, 111, 108, 104, 99, + 93, 89, 90, 87, 82, 78, 75, 68, 65, + 67, 69, 66, 61, 54, 39, 28, 15, 3, + -7, -18, -25, -29, -35, -42, -52, -66, -78, + -83, -85, -86, -86, -82, -83, -84, -83, -81, + -75, -62, -57, -53, -49, -46, -41, -34, -26, + -16, -10, -7, -2, 2, 6, 12, 15, 19, + 18, 15, 17, 21, 24, 30, 33, 27, 22, + 21, 20, 23, 24, 21, 15, 13, 8, 3, + 1, -1, -3, -4, -6, -9, -11, -11, -8, + -10, -13, -15, -19, -17, -11, -2, 1, 2, + 6, 9, 10, 12, 13, 9, 8, 10, 13, + 20, 18, 13, 10, 4, 1, -2, -6, -11, + -13, -16, -18, -15, -18, -21, -21, -22, -23, + -25, -23, -22, -20, -19, -16, -12, -10, -9, + -11, -15, -19, -22, -19, -14, -11, -9, -11, + -17, -20, -18, -19, -15, -11, -8, -2, 8, + 19, 30, 36, 37, 36, 38, 45, 57, 69, + 77, 81, 79, 75, 76, 74, 69, 66, 60, + 53, 45, 36, 28, 22, 17, 10, 0, -5, + -11, -15, -18, -26, -31, -33, -34, -34, -35, + -37, -37, -35, -28, -24, -29, -37, -45, -46, + -41, -36, -31, -32, -33, -37, -37, -36, -36, + -34, -27, -19, -14, -11, -8, -1, 6, 14, + 19, 21, 25, 30, 34, 38, 38, 33, 26, + 22, 19, 20, 18, 17, 15, 10, 2, -3, + -5, -10, -13, -13, -13, -16, -16, -16, -15, + -13, -14, -13, -16, -19, -20, -18, -17, -18, + -16, -16, -24, -28, -28, -28, -23, -21, -21, + -20, -24, -27, -23, -18, -14, -7, 4, 11, + 15, 19, 21, 25, 33, 39, 41, 45, 47, + 50, 56, 58, 57, 59, 59, 55, 50, 47, + 39, 34, 30, 24, 18, 11, 8, 3, 0, + -3, -8, -14, -15, -13, -13, -12, -14, -17, + -17, -12, -10, -4, -7, -12, -10, -14, -17, + -17, -19, -25, -28, -27, -29, -30, -31, -35, + -38, -43, -47, -51, -52, -50, -49, -48, -47, + -45, -39, -32, -30, -31, -35, -35, -31, -24, + -17, -12, -11, -14, -15, -17, -16, -9, -5, + -3, -1, 0, 1, 0, 3, 12, 21, 26, + 33, 35, 38, 45, 50, 53, 53, 54, 58, + 61, 64, 69, 67, 66, 64, 58, 54, 51, + 46, 44, 45, 41, 35, 31, 27, 25, 27, + 25, 20, 13, 12, 16, 17, 17, 12, 7, + 3, 2, -2, -4, -8, -14, -19, -25, -29, + -38, -49, -60, -69, -73, -71, -74, -82, -89, + -98, -103, -104, -103, -99, -98, -98, -98, -99, + -97, -94, -91, -85, -82, -78, -74, -74, -71, + -68, -61, -54, -52, -47, -41, -36, -32, -21, + -12, -3, 11, 26, 36, 44, 48, 55, 64, + 77, 92, 100, 108, 117, 120, 122, 128, 130, + 129, 130, 127, 124, 122, 121, 118, 114, 110, + 102, 92, 85, 80, 77, 68, 55, 46, 39, + 36, 34, 31, 27, 15, 5, -1, -5, -11, + -20, -29, -37, -43, -46, -47, -54, -61, -65, + -74, -82, -84, -91, -94, -96, -104, -109, -111, + -111, -112, -113, -111, -112, -110, -104, -99, -96, + -93, -89, -87, -81, -71, -63, -54, -45, -43, + -37, -30, -24, -17, -12, -8, -2, 2, 15, + 23, 28, 35, 41, 42, 44, 52, 58, 66, + 74, 78, 80, 82, 85, 88, 90, 92, 92, + 88, 87, 87, 79, 73, 69, 64, 62, 55, + 50, 45, 41, 36, 29, 24, 20, 16, 12, + 8, 5, 2, 1, 1, 0, 1, -4, -4, + -4, -4, -1, 1, 2, 1, -3, -6, -1, + 5, 6, 7, 8, 4, 2, 0, -2, -3, + 0, -3, -4, -3, -4, -5, -8, -15, -20, + -25, -28, -32, -37, -38, -39, -43, -48, -55, + -62, -69, -75, -75, -78, -81, -83, -89, -89, + -92, -91, -91, -89, -83, -81, -74, -66, -63, + -54, -45, -39, -31, -23, -15, -4, 6, 14, + 23, 29, 35, 41, 45, 49, 55, 61, 69, + 75, 75, 76, 75, 74, 74, 73, 74, 72, + 69, 69, 65, 62, 57, 52, 44, 35, 33, + 29, 24, 14, 7, 3, -4, -12, -17, -20, + -22, -27, -32, -34, -39, -42, -43, -42, -43, + -40, -38, -36, -36, -37, -36, -33, -31, -27, + -24, -23, -22, -17, -11, -7, -7, -7, -3, + 5, 13, 19, 25, 27, 25, 27, 35, 40, + 40, 41, 45, 47, 50, 54, 52, 50, 45, + 43, 44, 40, 34, 28, 24, 18, 11, 6, + -2, -9, -14, -21, -27, -35, -39, -43, -50, + -57, -62, -66, -68, -71, -72, -73, -74, -76, + -76, -77, -75, -75, -74, -67, -61, -55, -49, + -45, -40, -30, -21, -11, -4, 4, 13, 23, + 34, 44, 52, 59, 65, 70, 77, 84, 87, + 88, 90, 91, 90, 89, 85, 80, 75, 72, + 71, 64, 56, 48, 41, 34, 27, 21, 12, + 1, -11, -19, -28, -33, -39, -46, -50, -53, + -58, -63, -66, -71, -73, -76, -76, -74, -73, + -71, -67, -65, -62, -60, -55, -51, -45, -39, + -35, -31, -27, -20, -13, -6, -3, 1, 8, + 12, 18, 24, 26, 30, 35, 38, 44, 47, + 47, 51, 53, 52, 53, 52, 50, 51, 49, + 50, 51, 50, 48, 48, 45, 43, 42, 37, + 34, 31, 31, 30, 26, 24, 21, 15, 12, + 11, 7, 4, 1, -3, -5, -7, -9, -15, + -21, -26, -28, -31, -35, -39, -46, -48, -49, + -53, -58, -63, -67, -69, -71, -72, -74, -75, + -77, -77, -73, -72, -69, -65, -60, -55, -50, + -47, -43, -38, -30, -25, -20, -12, -4, 4, + 9, 16, 20, 24, 28, 35, 43, 50, 58, + 61, 65, 72, 74, 74, 76, 79, 78, 76, + 78, 76, 76, 74, 70, 64, 59, 52, 46, + 41, 33, 26, 19, 12, 5, -2, -8, -15, + -20, -26, -31, -37, -39, -41, -44, -44, -47, + -51, -52, -52, -48, -45, -46, -48, -45, -42, + -40, -36, -32, -27, -24, -22, -18, -16, -11, + -10, -5, 0, 3, 8, 11, 16, 18, 21, + 23, 25, 26, 27, 28, 30, 31, 31, 30, + 29, 27, 26, 23, 19, 17, 13, 10, 6, + 0, -2, -5, -10, -12, -15, -19, -23, -26, + -29, -30, -30, -32, -33, -34, -35, -34, -31, + -29, -29, -28, -28, -23, -19, -17, -12, -12, + -10, -5, -2, 3, 7, 10, 13, 14, 19, + 22, 26, 31, 34, 34, 35, 36, 39, 43, + 45, 47, 47, 48, 49, 51, 48, 47, 50, + 45, 41, 41, 38, 34, 34, 30, 23, 17, + 11, 7, 4, -4, -9, -15, -23, -28, -32, + -35, -39, -45, -46, -49, -53, -52, -53, -55, + -56, -56, -55, -54, -53, -53, -51, -47, -44, + -42, -40, -37, -33, -28, -25, -23, -18, -15, + -8, -6, -2, 3, 8, 15, 18, 23, 26, + 27, 32, 36, 36, 36, 39, 38, 38, 40, + 39, 35, 31, 29, 25, 23, 19, 15, 11, + 7, 5, 3, 1, -1, -6, -8, -7, -10, + -9, -10, -11, -10, -7, -6, -8, -6, -5, + -4, 1, 2, 4, 7, 7, 9, 11, 11, + 9, 9, 10, 11, 13, 17, 15, 15, 15, + 17, 19, 17, 17, 17, 15, 15, 13, 11, + 12, 8, 7, 5, 3, 0, -4, -4, -6, + -9, -12, -14, -15, -15, -16, -20, -19, -20, + -20, -20, -18, -18, -21, -22, -21, -21, -23, + -20, -20, -23, -24, -23, -25, -25, -25, -25, + -26, -24, -23, -23, -23, -23, -22, -19, -18, + -15, -14, -10, -8, -4, -1, 1, 3, 6, + 8, 9, 14, 19, 22, 24, 26, 29, 32, + 31, 34, 39, 42, 42, 46, 49, 50, 50, + 52, 53, 52, 49, 49, 48, 48, 46, 45, + 40, 34, 30, 25, 21, 17, 13, 10, 6, + 2, -4, -9, -12, -15, -18, -21, -26, -28, + -31, -32, -33, -35, -35, -38, -37, -36, -34, + -35, -35, -33, -33, -34, -30, -26, -27, -25, + -23, -22, -18, -15, -16, -12, -9, -9, -6, + -1, 2, 3, 5, 8, 7, 9, 12, 15, + 17, 18, 18, 19, 18, 20, 19, 18, 21, + 20, 19, 18, 16, 15, 15, 15, 14, 12, + 9, 9, 10, 8, 6, 4, 2, 1, -1, + -3, -1, -3, -2, -4, -5, -5, -8, -8, + -10, -10, -8, -8, -8, -7, -8, -8, -8, + -9, -11, -12, -11, -9, -7, -8, -8, -8, + -10, -8, -7, -8, -7, -6, -7, -5, -3, + -3, -3, -3, -2, 0, 3, 3, 5, 7, + 10, 11, 10, 10, 12, 13, 16, 16, 16, + 17, 15, 16, 17, 16, 14, 16, 13, 11, + 11, 9, 9, 6, 4, 4, 3, 0, -2, + -4, -7, -7, -7, -13, -15, -13, -14, -16, + -15, -15, -17, -16, -16, -18, -19, -19, -20, + -19, -16, -15, -13, -12, -10, -7, -6, -4, + -4, -2, 0, 2, 6, 8, 10, 12, 14, + 15, 14, 13, 13, 13, 15, 15, 17, 17, + 17, 18, 17, 16, 15, 15, 14, 11, 9, + 8, 8, 9, 8, 5, 5, 3, -1, -1, + -4, -5, -7, -8, -8, -8, -9, -10, -8, + -11, -12, -12, -12, -12, -13, -11, -11, -9, + -8, -7, -8, -7, -6, -7, -6, -5, -4, + -4, -2, -2, -3, -2, -2, -3, 0, -1, + -3, 1, 1, 2, 4, 3, 5, 6, 3, + 3, 4, 3, 3, 4, 5, 4, 6, 7, + 7, 7, 6, 3, 3, 5, 3, 3, 6, + 6, 7, 6, 4, 5, 2, 1, 1, 0, + 0, 2, 1, 1, 1, -1, -2, -3, -5, + -4, -5, -4, -4, -6, -4, -4, -4, -5, + -6, -5, -6, -5, -4, -5, -4, -3, -4, + 0, 2, 2, 2, 2, 2, 2, 3, 3, + 5, 6, 6, 5, 6, 7, 6, 8, 6, + 5, 5, 5, 6, 6, 6, 5, 5, 2, + 2, 1, 2, 0, -1, -1, -1, -1, 0, + -1, -4, -6, -8, -8, -9, -8, -7, -6, + -5, -5, -6, -3, -4, -5, -4, -7, -6, + -4, -2, -1, -1, 1, 1, 1, 1, 1, + 2, 2, 1, 3, 4, 4, 6, 6, 6, + 6, 4, 4, 4, 4, 3, 2, 2, 2, + 2, 1, 1, 1, 0, 1, 1, 0, -2, + -2, -3, -3, -3, -3, -5, -4, -3, -5, + -5, -3, -5, -4, -4, -2, -2, -2, -1, + -3, -2, -2, -1, -3, -2, -1, -2, -2, + -2, 0, 0, 0, 0, 0, 1, 0, 0, + 1, 2, 3, 3, 3, 4, 5, 4, 3, + 4, 5, 5, 7, 7, 6, 9, 8, 6, + 7, 8, 6, 5, 7, 8, 8, 8, 7, + 6, 5, 4, 4, 4, 5, 4, 2, 1, + 2, 1, 0, -2, -3, -2, -4, -6, -6, + -7, -7, -8, -9, -9, -9, -9, -9, -9, + -9, -10, -10, -10, -8, -7, -8, -6, -5, + -4, -3, -5, -2, -2, -2, -1, -1, 0, + 1, 1, 2, 3, 2, 4, 3, 3, 5, + 3, 3, 5, 4, 5, 6, 5, 4, 5, + 3, 2, 2, 3, 4, 4, 4, 4, 4, + 3, 4, 4, 4, 3, 2, 2, 2, 2, + 2, 2, 2, 2, 1, 1, 1, 2, 1, + 1, 2, 1, 1, 2, 1, 1, 1, -1, + 0, 1, 0, -1, 1, -1, -1, -1, -2, + -1, -1, -1, -1, -1, -1, -1, -1, -2, + -1, 0, -1, -1, 1, 1, 2, 0, -1, + 0, -1, -1, 0, 0, 1, 2, 2, 2, + 1, 1, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, -1, -2, -1, -3, -4, + -4, -4, -4, -4, -4, -4, -3, -3, -5, + -6, -4, -2, -2, -1, -1, -1, -2, 1, + -1, 1, 0, 0, 1, 1, 1, 1, 2, + 1, 2, 2, 3, 3, 3, 3, 4, 5, + 5, 5, 5, 5, 5, 5, 5, 6, 6, + 5, 5, 5, 6, 6, 5, 3, 6, 5, + 4, 5, 3, 2, 2, 2, 2, 1, 1, + 2, 0, -1, 0, -1, -1, -1, -1, -1, + -1, -1, -3, -3, -3, -3, -4, -4, -5, + -6, -6, -6, -6, -6, -6, -5, -5, -6, + -5, -4, -4, -4, -4, -2, -2, -2, -1, + -2, 0, 1, 0, 1, 3, 4, 4, 4, + 4, 4, 4, 5, 4, 4, 4, 5, 7, + 5, 4, 4, 4, 4, 3, 2, 2, 2, + 2, 2, 0, 1, 1, 0, 1, 1, -1, + 0, -1, -2, -1, -3, -4, -4, -3, -5, + -5, -5, -5, -5, -5, -4, -3, -3, -2, + -3, -2, -2, -5, -3, -3, -3, -2, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 3, + 3, 4, 4, 4, 4, 5, 5, 2, 3, + 4, 3, 5, 4, 3, 4, 3, 3, 5, + 5, 3, 4, 2, 1, 1, 3, 4, 3, + 1, 3, 2, 1, 2, 1, 0, 1, 0, + 1, 0, 1, 1, 1, 1, 0, -1, 0, + 0, -1, -1, -2, -1, -1, -2, 0, -1, + -2, -1, -1, -2, -2, -1, -3, -3, -3, + -3, -3, -4, -3, -5, -6, -4, -4, -5, + -4, -3, -5, -6, -4, -5, -6, -4, -3, + -5, -4, -3, -4, -3, -2, -2, -2, 0, + 0, 1, 1, 0, 0, 0, 1, 1, 3, + 3, 3, 4, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, + 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, -2, -1, 1, 0, -1, -2, -2, + 0, 1, 0, 1, 1, 1, 1, 0, 0, + 1, 0, 0, 2, 1, 0, 1, 1, 1, + 1, 3, 3, 3, 4, 3, 3, 4, 2, + 2, 2, 2, 2, 2, 2, 1, 2, 2, + 2, 2, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -3, -3, -3, -5, -4, -5, -5, + -5, -5, -7, -7, -7, -8, -7, -8, -7, + -8, -8, -7, -8, -8, -8, -8, -7, -6, + -6, -6, -7, -6, -6, -5, -5, -3, -2, + -2, -1, 0, -1, 0, 1, 2, 2, 3, + 3, 3, 6, 7, 7, 7, 8, 9, 8, + 10, 10, 9, 10, 11, 9, 10, 12, 11, + 10, 9, 9, 9, 9, 10, 9, 6, 6, + 5, 5, 6, 3, 1, 1, 0, 1, 0, + 0, 1, -1, -2, -2, -1, -3, -3, -2, + -4, -4, -3, -2, -4, -4, -4, -5, -3, + -3, -5, -3, -3, -5, -4, -2, -2, -3, + -3, -1, 0, -1, 0, 0, 0, -2, -1, + 0, -1, -2, -2, -2, -2, -1, -3, -2, + -3, -4, -3, -3, -3, -3, -3, -3, -3, + -2, -4, -6, -5, -3, -2, -4, -3, -2, + -4, -4, -4, -3, -4, -5, -4, -5, -3, + -2, -5, -2, -4, -4, -3, -2, -1, -1, + -1, 0, 2, 2, 1, 1, 3, 3, 3, + 3, 4, 4, 5, 6, 5, 5, 6, 7, + 7, 7, 8, 8, 7, 9, 9, 9, 9, + 10, 9, 9, 9, 9, 9, 9, 8, 7, + 9, 9, 6, 7, 5, 2, 3, 2, 1, + 1, 0, -2, -2, -2, -3, -3, -2, -2, + -4, -5, -4, -4, -4, -4, -5, -4, -4, + -5, -4, -5, -4, -5, -6, -4, -4, -5, + -5, -5, -5, -6, -4, -4, -4, -3, -2, + -3, -3, -2, -2, -1, -2, -3, -1, 0, + -1, 0, 0, 0, 0, 1, 0, 0, 0, + 0, -1, 1, 1, 1, 0, -2, -2, -3, + -3, -4, -4, -6, -7, -5, -4, -5, -5, + -4, -6, -8, -7, -6, -5, -5, -5, -4, + -4, -5, -4, -3, -3, 0, 0, -2, -1, + 0, 0, 1, 1, 2, 2, 2, 2, 2, + 4, 5, 5, 5, 6, 7, 7, 9, 10, + 10, 10, 12, 12, 13, 14, 14, 14, 15, + 15, 15, 15, 15, 15, 14, 15, 15, 12, + 13, 13, 12, 10, 11, 11, 11, 10, 8, + 6, 5, 7, 6, 6, 4, 3, 4, 5, + 3, 2, 2, 1, 1, 2, 3, 1, 0, + 0, 1, 0, -2, -1, -2, -3, -3, -3, + -3, -4, -6, -8, -9, -9, -10, -12, -14, + -15, -18, -21, -21, -21, -21, -22, -24, -26, + -26, -27, -27, -28, -26, -25, -26, -28, -27, + -24, -23, -23, -24, -21, -17, -17, -15, -12, + -12, -12, -12, -9, -7, -6, -5, -3, -3, + -2, 0, 0, 1, 3, 7, 6, 4, 6, + 7, 8, 11, 10, 10, 13, 15, 14, 13, + 18, 20, 18, 19, 21, 23, 24, 23, 22, + 24, 26, 26, 26, 27, 25, 23, 25, 27, + 28, 28, 28, 23, 19, 23, 24, 20, 20, + 21, 15, 13, 15, 16, 14, 11, 8, 7, + 8, 11, 11, 6, 4, 8, 7, 6, 7, + 6, 4, 7, 13, 12, 7, 8, 8, 4, + 1, 1, 1, 2, -4, -12, -18, -24, -25, + -25, -32, -41, -55, -59, -61, -75, -87, -96, + -109, -122, -133, -141, -148, -157, -168, -180, -191, + -198, -202, -207, -206, -207, -211, -211, -208, -203, + -189, -171, -153, -132, -114, -96, -75, -54, -30, + -5, 19, 43, 61, 77, 93, 106, 123, 143, + 161, 182, 198, 202, 201, 209, 229, 242, 240, + 235, 239, 249, 258, 255, 242, 233, 245, 268, + 278, 256, 223, 223, 253, 263, 235, 198, 178, + 188, 215, 230, 200, 143, 113, 128, 158, 158, + 128, 99, 90, 82, 70, 56, 32, 7, 14, + 46, 36, -23, -71, -76, -54, -36, -39, -74, + -118, -134, -122, -101, -104, -129, -164, -174, -129, + -86, -109, -184, -219, -191, -147, -141, -183, -249, + -290, -269, -236, -266, -346, -394, -366, -325, -353, + -431, -472, -406, -313, -316, -398, -449, -401, -287, + -194, -164, -193, -245, -212, -55, 75, 67, 26, + 67, 165, 237, 269, 293, 319, 333, 368, 414, + 432, 463, 488, 448, 404, 391, 377, 361, 365, + 376, 308, 197, 150, 129, 73, 53, 91, 43, + -107, -165, -54, 1, -148, -312, -273, -125, -62, + -128, -258, -294, -141, 70, 57, -217, -378, -145, + 198, 289, 169, -47, -219, -101, 264, 458, 217, + -163, -199, 13, 121, 101, -51, -293, -319, -62, + 24, -274, -474, -296, -170, -336, -422, -285, -248, + -302, -130, 98, -11, -257, -146, 184, 278, 264, + 331, 192, -35, 235, 805, 830, 315, 82, 322, + 503, 522, 619, 557, 242, 163, 399, 507, 489, + 618, 602, 156, -164, 112, 476, 406, 94, -154, + -242, -132, 56, 5, -325, -566, -527, -478, -624, + -692, -561, -551, -744, -836, -671, -520, -626, -736, + -647, -581, -639, -687, -702, -739, -665, -383, -236, + -414, -513, -321, -114, -43, 32, 65, -98, -236, + 34, 608, 924, 680, 218, 56, 329, 847, 1214, + 1006, 341, 11, 340, 667, 553, 353, 355, 415, + 416, 364, 257, 108, 6, 113, 293, 233, 46, + 4, 25, -10, -12, 55, 40, -65, -56, -26, + -101, -61, 143, 229, 78, -161, -210, 103, 424, + 377, 86, -274, -491, -328, -37, 60, 128, 188, + -105, -625, -823, -464, 138, 389, 111, -343, -526, + -306, 13, 205, 250, -35, -554, -764, -498, -42, + 167, -210, -639, -448, -101, -110, -171, -74, -39, + 47, 424, 616, 324, 98, 367, 853, 942, 416, + -184, -130, 339, 472, 369, 239, -165, -418, 101, + 742, 659, 325, 365, 476, 233, -14, 270, 785, + 719, -29, -533, -220, 237, 305, 179, -190, -644, + -610, -380, -526, -601, -237, 48, -36, -124, -49, + -6, 23, 117, 55, -199, -428, -512, -338, -238, + -424, -323, -135, -464, -657, -189, 100, -379, -964, + -893, -346, -64, -322, -650, -480, 32, 238, 201, + 386, 616, 611, 400, 195, 357, 842, 1051, 832, + 712, 829, 1070, 1307, 1081, 551, 363, 544, 623, + 239, -374, -609, -230, 375, 486, -52, -446, -270, + 181, 645, 601, -135, -654, -256, 567, 840, 380, + -54, 18, 334, 386, 21, -214, 83, 243, -316, + -937, -1074, -1006, -896, -674, -424, -331, -354, -380, + -481, -392, 80, 358, 171, -170, -624, -796, -130, + 706, 803, 381, 152, 367, 620, 685, 655, 347, + 36, 180, 417, 412, 358, 288, 189, 150, 16, + -240, -428, -428, -266, -335, -819, -1150, -946, -587, + -437, -580, -961, -1218, -1065, -704, -431, -350, -315, + -214, -162, -81, 26, -8, -52, -117, -226, -40, + 285, 241, -2, -69, 57, 207, 81, -144, -69, + 65, 84, 49, -168, -248, 126, 502, 472, 192, + 120, 442, 667, 551, 512, 634, 814, 1014, 1098, + 1156, 1112, 974, 1144, 1330, 1099, 825, 847, 877, + 555, 2, -243, -102, -196, -471, -377, -235, -439, + -622, -547, -470, -495, -431, -197, -21, 21, -9, + -246, -438, -238, -31, 0, 96, 137, -25, -211, + -181, -149, -350, -368, -33, 21, -308, -323, 32, + 379, 605, 531, 85, -374, -367, 9, 277, 147, + -356, -698, -494, -140, -126, -354, -549, -673, -642, + -428, -269, -273, -246, -216, -349, -323, -16, 32, + -387, -742, -662, -434, -223, 41, 140, -58, -227, + -80, 93, 20, -166, -360, -536, -555, -305, -33, + -23, -86, -75, -9, 82, -1, -156, 24, 532, + 916, 956, 835, 901, 1127, 1279, 1417, 1435, 1144, + 822, 862, 1214, 1352, 1001, 611, 539, 532, 369, + 189, 170, 308, 465, 430, 232, 64, 14, 51, + -37, -244, -321, -276, -144, 57, 77, -215, -467, + -335, -186, -245, -133, -81, -588, -1130, -959, -520, + -631, -1122, -1270, -971, -873, -1118, -1157, -1078, -1296, + -1365, -1010, -873, -1138, -1061, -379, 89, 51, 177, + 372, 185, -14, 63, 197, 125, -123, -60, 243, + 195, 88, 201, 115, -63, -12, -79, -492, -751, + -489, 49, 163, -293, -424, -52, 229, 302, 212, + 217, 315, 70, -207, -210, -173, 129, 619, 556, + 213, 181, 170, 112, 167, 322, 451, 206, -136, + 58, 426, 526, 524, 394, 387, 568, 481, 297, + 164, 8, 263, 664, 777, 943, 989, 934, 1283, + 1495, 1153, 861, 738, 582, 614, 692, 655, 629, + 432, 127, -119, -338, -313, -138, -204, -561, -994, + -1168, -948, -700, -658, -788, -1053, -1027, -684, -566, + -528, -355, -335, -323, -28, 206, 87, 56, 387, + 585, 296, 24, 261, 492, 248, -132, -469, -674, + -502, -235, -255, -517, -847, -1038, -965, -707, -630, + -767, -639, -298, -193, -290, -310, -118, 74, -77, + -337, -324, -120, 187, 323, -72, -552, -454, -14, + 29, -427, -803, -735, -586, -762, -918, -783, -649, + -723, -857, -786, -626, -591, -417, -83, 167, 262, + 49, -161, 157, 842, 1298, 1356, 1206, 1041, 1194, + 1461, 1323, 1070, 1221, 1687, 2051, 2002, 1673, 1464, + 1550, 1851, 1907, 1531, 1327, 1399, 1342, 1287, 1264, + 1152, 1030, 878, 716, 601, 454, 264, 264, 352, + 151, -193, -296, -161, -93, -215, -423, -617, -668, + -547, -416, -464, -807, -1175, -1174, -1045, -1076, -1023, + -829, -710, -745, -1069, -1443, -1417, -1099, -939, -1165, + -1307, -1056, -843, -638, -304, -190, -334, -578, -770, + -705, -675, -947, -957, -565, -437, -617, -843, -1015, + -813, -489, -584, -904, -1054, -797, -229, -26, -208, + -66, 398, 710, 644, 390, 413, 726, 992, 1204, + 1337, 1234, 1104, 1038, 1001, 1043, 982, 847, 885, + 1024, 1098, 1138, 1108, 1038, 966, 885, 882, 878, + 929, 1005, 944, 1008, 1284, 1415, 1289, 1007, 760, + 812, 947, 806, 455, 111, -72, -290, -611, -626, + -559, -765, -1034, -1375, -1632, -1565, -1588, -1728, -1585, + -1477, -1547, -1533, -1371, -1103, -995, -1090, -1102, -947, + -686, -403, -295, -250, -107, -86, -171, -150, 12, + 234, 283, 185, 300, 461, 393, 382, 434, 378, + 306, 202, 195, 253, -8, -307, -105, 264, 342, + 212, 34, -57, 78, 435, 571, 180, -165, -51, + 339, 705, 683, 464, 658, 958, 825, 579, 465, + 390, 241, 61, 202, 429, 128, -122, 241, 406, + 39, -167, -60, 15, -31, -68, 146, 402, 344, + 227, 208, 87, -25, -31, -66, -169, -249, -87, + 75, -181, -438, -249, 49, 87, -40, -16, 53, + -86, -74, 98, 78, 110, 169, -84, -323, -251, + -102, -172, -513, -750, -675, -568, -587, -583, -523, + -450, -302, -245, -356, -480, -590, -495, -183, -105, + -191, -215, -308, -206, 39, 4, -77, -21, 74, + 186, 218, 356, 611, 489, 83, 13, 246, 371, + 348, 240, 61, -66, -107, -170, -205, -74, 200, + 277, 45, -11, 180, 263, 100, -74, 102, 246, + 6, -154, -162, -197, -128, -189, -227, -49, -238, + -490, -333, -188, 1, 215, 150, 144, 128, -33, + 187, 532, 676, 911, 773, 283, 351, 673, 620, + 349, 105, 205, 425, 325, 295, 372, 340, 511, + 628, 394, 224, 187, 91, -174, -556, -482, -37, + -9, -226, -382, -568, -466, -208, -241, -426, -656, + -814, -788, -902, -1065, -946, -860, -896, -831, -744, + -672, -685, -743, -723, -783, -813, -570, -341, -239, + -57, 137, 348, 576, 593, 454, 429, 503, 449, + 238, 173, 350, 423, 419, 530, 501, 272, 156, + 207, 295, 404, 568, 676, 419, 30, 113, 463, + 550, 473, 349, 126, 33, 144, 207, 193, 267, + 304, 81, -252, -401, -368, -347, -404, -452, -408, + -272, -40, 234, 281, 48, -72, -18, 54, 208, + 309, 285, 245, 164, 38, -20, 148, 430, 563, + 655, 679, 453, 300, 319, 219, 25, -15, 54, + -117, -444, -431, -135, -147, -468, -667, -722, -593, + -301, -217, -428, -642, -598, -400, -422, -602, -628, + -554, -509, -501, -541, -488, -250, -129, -284, -441, + -358, -161, -82, 4, 134, 157, 290, 516, 582, + 702, 859, 871, 858, 759, 615, 616, 754, 839, + 725, 464, 259, 187, 127, 150, 280, 238, 92, + 78, 5, -86, 6, 67, -14, -92, -143, -211, + -89, 213, 300, 107, -91, -154, -153, -238, -355, + -314, -227, -168, -92, -142, -219, -156, -47, 53, + -15, -195, -161, -186, -382, -395, -297, -238, -240, + -390, -502, -336, -97, -29, -116, -290, -289, -67, + 74, 112, 119, 182, 358, 382, 315, 341, 290, + 218, 190, 101, -51, -168, -132, -41, -39, -15, + 104, 186, 151, 68, 89, 154, 67, 10, 143, + 120, -185, -382, -365, -263, -145, -111, -159, -190, + -53, 151, 177, 179, 384, 553, 502, 490, 572, + 600, 573, 442, 119, -212, -260, -166, -318, -506, + -413, -279, -285, -354, -390, -278, -142, -85, -18, + -19, -121, -143, -32, 88, 118, 42, -96, -187, + -167, -113, -172, -270, -256, -178, -192, -249, -128, + 103, 132, -47, -147, -104, -56, -9, 45, 35, + 109, 315, 381, 326, 336, 457, 667, 786, 675, + 489, 460, 569, 595, 470, 303, 272, 448, 620, + 545, 226, -92, -128, 91, 172, -98, -385, -378, + -264, -284, -362, -314, -148, -72, -198, -350, -353, + -344, -389, -353, -292, -327, -413, -473, -519, -588, + -577, -546, -737, -989, -1030, -997, -1010, -861, -683, + -731, -690, -419, -197, -47, 112, 167, 74, 41, + 176, 309, 438, 671, 781, 793, 868, 904, 991, + 1099, 987, 812, 816, 869, 766, 605, 633, 728, + 592, 424, 460, 405, 170, 75, 30, -105, -58, + 63, -58, -242, -359, -415, -255, -44, -127, -266, + -191, -187, -296, -273, -260, -341, -345, -324, -384, + -467, -421, -233, -125, -227, -341, -256, -168, -217, + -249, -302, -447, -425, -274, -289, -299, -229, -275, + -272, -103, -57, -117, -106, -162, -256, -184, -31, + 51, 69, 31, -19, 72, 256, 318, 331, 254, + 28, -7, 121, 48, -64, 58, 183, 152, 161, + 201, 167, 190, 287, 278, 157, 56, 103, 332, + 460, 299, 166, 238, 308, 374, 508, 509, 373, + 275, 270, 298, 229, 185, 192, 23, -160, -80, + 67, 31, -170, -378, -384, -330, -500, -648, -615, + -686, -716, -510, -510, -771, -752, -475, -434, -556, + -480, -403, -515, -464, -255, -177, -105, 29, 95, + 152, 210, 190, 180, 279, 408, 325, 225, 462, + 607, 537, 759, 1022, 973, 945, 964, 846, 818, + 952, 907, 584, 313, 302, 428, 533, 479, 260, + 178, 262, 185, 18, -77, -263, -370, -208, -240, + -589, -739, -572, -444, -405, -357, -475, -738, -771, + -542, -441, -529, -651, -803, -823, -556, -285, -227, + -233, -202, -168, -110, -78, -220, -302, -56, 129, + -60, -149, 54, 130, 169, 324, 231, 24, 89, + 269, 320, 262, 231, 225, 138, 67, 153, 310, + 399, 269, -21, -197, -183, -59, 144, 234, -13, + -274, -168, 32, -37, -277, -417, -441, -416, -324, + -312, -467, -540, -373, -166, -161, -297, -365, -341, + -246, -69, 81, 99, -3, 11, 305, 540, 449, + 394, 586, 667, 606, 685, 665, 425, 410, 585, + 509, 360, 424, 538, 583, 482, 250, 159, 310, + 423, 217, -131, -280, -204, -51, -12, -204, -338, + -232, -143, -201, -306, -374, -336, -229, -257, -453, + -576, -497, -379, -326, -302, -372, -504, -453, -229, + -133, -226, -328, -326, -261, -151, -6, 97, 143, + 164, 143, 138, 267, 433, 500, 470, 297, 143, + 279, 504, 556, 475, 333, 233, 225, 228, 198, + 128, 24, -17, 4, -55, -187, -251, -213, -119, + -94, -214, -357, -349, -246, -195, -183, -261, -440, + -533, -476, -341, -213, -170, -220, -299, -220, -8, + 51, -11, 19, 172, 292, 189, 9, -6, 102, + 238, 384, 477, 448, 353, 304, 354, 473, 543, + 400, 229, 275, 380, 425, 415, 371, 398, 460, + 377, 202, 154, 199, 110, -123, -365, -524, -524, + -360, -134, -47, -182, -348, -453, -542, -503, -376, + -398, -521, -595, -621, -560, -439, -284, -115, -80, + -123, -57, 28, -15, -60, -9, 47, 119, 203, + 288, 435, 571, 635, 706, 750, 627, 436, 345, + 330, 398, 460, 368, 213, 127, 140, 215, 202, + 58, -99, -244, -387, -470, -527, -637, -754, -791, + -768, -742, -739, -735, -704, -649, -552, -479, -491, + -494, -454, -433, -422, -398, -315, -115, 75, 175, + 244, 307, 360, 398, 460, 532, 529, 446, 422, + 497, 541, 504, 541, 702, 803, 744, 645, 621, + 727, 877, 873, 734, 593, 513, 523, 516, 412, + 336, 334, 274, 199, 163, 123, 125, 117, 107, + 140, 72, -73, -114, -68, -15, 13, -122, -338, + -367, -325, -386, -497, -608, -634, -546, -477, -427, + -377, -412, -464, -436, -343, -276, -327, -390, -313, + -149, -17, 2, -93, -146, -104, -76, -87, -131, + -224, -280, -194, -46, 12, -76, -189, -151, 18, + 160, 200, 99, -81, -149, -95, -31, -6, -45, + -93, -97, -71, 0, 73, 34, -82, -129, -102, + -84, -96, -107, -69, -5, 6, 18, 48, 35, + 27, 32, -4, -71, -30, 119, 205, 266, 352, + 325, 237, 282, 352, 358, 342, 265, 203, 200, + 159, 120, 159, 195, 185, 133, 37, 20, 152, + 312, 363, 316, 255, 251, 259, 211, 160, 86, + -4, -30, -79, -154, -213, -271, -243, -146, -147, + -211, -283, -319, -219, -157, -207, -237, -252, -245, + -136, 0, 42, -22, -108, -82, 34, 130, 179, + 152, 98, 105, 110, 116, 180, 175, 66, -9, + -9, 36, 82, 75, 12, -39, -14, 23, 1, + 12, 31, -61, -155, -184, -158, -86, -60, -67, + -63, -84, -100, -81, -115, -171, -157, -150, -179, + -191, -209, -245, -217, -128, -54, -42, -73, -100, + -88, -10, 104, 199, 249, 227, 201, 204, 151, + 83, 75, 87, 84, 67, 34, 18, 44, 110, + 218, 275, 232, 190, 209, 263, 294, 256, 174, + 108, 37, -54, -110, -129, -179, -293, -360, -339, + -282, -190, -135, -188, -239, -234, -227, -182, -127, + -89, -51, -73, -136, -151, -85, 0, 72, 129, + 122, 65, 44, 103, 202, 272, 252, 170, 148, + 167, 152, 130, 127, 79, 14, 70, 157, 142, + 109, 70, -25, -57, -6, 46, 98, 135, 135, + 82, 16, 10, 68, 87, -20, -120, -116, -98, + -102, -129, -204, -271, -282, -252, -216, -215, -221, + -156, -70, -66, -120, -156, -146, -126, -84, -15, + -21, -76, -8, 131, 146, 86, 42, 12, 44, + 110, 169, 171, 91, 68, 173, 262, 248, 160, + 36, -90, -109, -24, -12, -57, -64, -78, -89, + -75, -87, -101, -82, -72, -76, -81, -63, -34, + -4, 61, 87, 46, 23, -1, -8, 40, 63, + 46, 45, 39, 14, -11, -25, -16, 36, 78, + 85, 110, 120, 132, 189, 228, 217, 154, 89, + 57, 14, -14, -6, 0, 13, 8, -50, -68, + -60, -107, -140, -126, -122, -151, -147, -118, -105, + -85, -83, -100, -139, -195, -194, -168, -183, -173, + -148, -166, -168, -123, -59, -11, 20, 64, 98, + 80, 58, 83, 111, 143, 176, 171, 152, 146, + 165, 174, 143, 93, 30, 5, 21, 42, 35, + -37, -94, -61, -12, -5, -27, -58, -85, -81, + -11, 79, 65, -14, -17, 15, -4, -2, 39, + 20, -29, -19, 3, -11, -39, -62, -43, -34, + -60, -77, -119, -163, -128, -5, 87, 73, 51, + 116, 189, 217, 240, 234, 177, 192, 295, 344, + 313, 263, 236, 240, 230, 179, 99, 19, -25, + -16, -9, -35, -66, -53, -16, -40, -70, -81, + -102, -86, -87, -156, -225, -228, -145, -52, -22, + -57, -171, -255, -247, -208, -165, -187, -242, -275, + -261, -168, -75, -13, 8, -62, -125, -136, -133, + -81, -11, -17, -80, -115, -103, -27, 71, 134, + 137, 44, -48, -24, 69, 156, 194, 175, 112, + 55, 54, 101, 148, 157, 142, 100, 44, 27, + 63, 106, 107, 89, 67, 37, 17, 30, 63, + 69, 61, 21, -37, -55, -72, -53, -26, -53, + -77, -87, -109, -119, -80, -36, -29, -38, -48, + -57, -65, -16, 52, 83, 83, 24, -27, -14, + 9, 27, 52, 50, 45, 90, 132, 117, 75, + 16, -1, 60, 95, 55, 25, 26, 20, 61, + 119, 89, 1, -61, -68, -46, -36, -40, -39, + -49, -58, -16, 30, 13, -12, 18, 35, 6, + 3, 30, 22, 25, 52, 32, 12, 9, -5, + -16, -25, -33, -38, -44, -76, -118, -118, -96, + -54, -3, 9, -31, -82, -84, -35, 18, 25, + -26, -72, -48, 8, 25, 8, -20, -66, -105, + -102, -80, -73, -79, -80, -70, -59, -55, -82, + -113, -85, -51, -59, -57, -38, -13, -7, -18, + -6, 20, 51, 55, 18, -8, -7, 24, 78, + 119, 137, 135, 139, 153, 144, 155, 179, 166, + 128, 56, 8, 38, 85, 94, 72, 20, -32, + -9, 25, 17, -15, -84, -123, -106, -82, -62, + -60, -43, -4, -12, -45, -68, -108, -100, -47, + -49, -64, -50, -9, 37, 59, 68, 62, 53, + 49, 25, 13, 32, 40, 60, 109, 82, 18, + 10, -1, 21, 102, 111, 40, -10, -9, 20, + 31, 0, -51, -108, -135, -89, -21, 1, -54, + -125, -129, -113, -144, -205, -227, -167, -118, -114, + -100, -71, 5, 34, -51, -119, -120, -72, 10, + 56, 51, 58, 65, 98, 135, 84, 20, -3, + -1, 57, 135, 137, 90, 88, 107, 102, 45, + -4, 9, 48, 95, 99, 65, 42, 44, 78, + 80, 29, 11, 39, 27, 0, 7, 19, 10, + -45, -99, -86, -77, -74, -57, -74, -84, -92, + -134, -114, -65, -73, -76, -96, -105, -50, -31, + -17, 17, 9, 18, 62, 75, 55, 63, 76, + 61, 61, 80, 103, 107, 110, 131, 134, 120, + 94, 66, 70, 78, 59, 52, 57, 53, 72, + 76, 31, -18, -53, -57, -35, -17, -9, -27, + -34, -7, -17, -26, -13, -60, -86, -53, -42, + -36, -36, -46, -13, 19, -16, -47, -15, 11, + -9, -18, -26, -24, 14, 8, -53, -54, 15, + 43, 15, -9, -5, 5, -12, -40, -57, -74, + -94, -105, -91, -20, 30, -10, -50, -58, -52, + -42, -47, -54, -61, -83, -64, -30, -3, 31, + 9, -35, -43, -31, 6, 50, 54, 55, 67, + 53, 43, 30, 27, 62, 37, -26, -52, -54, + -29, 3, -12, -23, 11, 26, 23, 31, 57, + 66, 46, 32, 35, 83, 124, 111, 124, 157, + 143, 101, 80, 60, 27, 11, 21, 22, 9, + -4, -26, -41, -35, -50, -103, -138, -116, -90, + -89, -90, -79, -74, -58, -18, -12, -29, -36, + -17, 22, 30, -1, -8, 8, 10, 19, 31, + 36, 38, 41, 28, -7, -14, -6, -20, -30, + -11, -2, -9, 0, 25, 56, 78, 68, 40, + 34, 47, 50, 40, 37, 26, 28, 53, 61, + 57, 25, -35, -75, -65, -48, -65, -81, -67, + -53, -41, 3, 19, -3, -9, -2, -1, -24, + -36, -23, -26, -29, -9, 0, -15, -17, -9, + 12, 50, 45, 14, 19, 37, 24, 9, 16, + 13, -16, -19, 3, -3, -12, -10, -23, -43, + -47, -38, -46, -44, -7, 3, -19, -13, -26, + -52, -29, -19, -32, 0, 11, -26, -24, -20, + -41, -30, -24, -53, -67, -26, 23, 20, 9, + 6, -8, 3, 16, 7, 3, -5, 2, 33, + 53, 72, 94, 86, 69, 96, 118, 95, 91, + 78, 32, 26, 48, 48, 37, 21, 7, -6, + -8, 8, 1, -17, -2, 18, 1, -28, -51, + -84, -93, -74, -46, -18, -19, -31, -10, 10, + 10, 7, -5, -30, -39, -28, -9, 10, 17, + 11, 14, 20, -1, 2, 18, 7, 15, 40, + 40, 32, 27, 23, 31, 43, 33, 7, -3, + 18, 51, 53, 31, 21, 14, 16, 14, 4, + 11, 16, 1, -24, -38, -33, -27, -50, -74, + -70, -60, -54, -44, -22, -22, -43, -33, -16, + -35, -36, -18, -27, -42, -46, -36, -17, -15, + -22, -21, -20, -2, 15, 12, 22, 27, 22, + 41, 57, 60, 63, 54, 56, 65, 62, 68, + 58, 34, 53, 70, 58, 60, 51, 33, 41, + 39, 16, -3, -16, -18, -15, -18, -32, -76, + -85, -62, -82, -87, -68, -84, -75, -40, -48, + -55, -45, -42, -24, -14, -1, 27, 23, -1, + -2, 12, 15, 32, 55, 52, 55, 82, 81, + 58, 62, 59, 37, 24, 20, 17, 18, 19, + 15, 14, 5, -18, -27, -20, -19, -34, -39, + -29, -30, -27, -27, -48, -52, -54, -77, -48, + -18, -36, -34, -13, -21, -38, -28, -15, -7, + -6, -20, -18, 2, 4, -11, -5, 7, 1, + 1, 12, -2, -17, 7, 15, 2, 15, 34, + 48, 78, 94, 82, 66, 66, 64, 47, 44, + 57, 64, 74, 65, 34, 26, 31, 32, 33, + 18, 5, -1, -18, -22, -31, -54, -37, -32, + -74, -89, -77, -73, -65, -72, -75, -39, -21, + -31, -31, -24, -19, -8, -4, 7, 26, 22, + 15, 13, 11, 28, 47, 42, 35, 28, 5, + 18, 55, 55, 45, 44, 18, 9, 18, -2, + -5, 6, -15, -16, -12, -20, -4, 4, -15, + -18, -10, -5, -2, -16, -24, -14, -7, -14, + -33, -33, -20, -17, -17, -18, -30, -37, -35, + -34, -13, -3, -28, -28, -10, -21, -17, -4, + -12, -16, -20, -27, -16, -8, -4, 14, 24, + 11, 17, 30, 27, 14, 7, 28, 30, 22, + 45, 47, 23, 31, 23, -5, 10, 17, -5, + 2, 15, 9, 20, 29, 11, -9, -8, 8, + 10, -1, -14, -30, -30, -8, -9, -20, -17, + -17, -12, 1, 6, -7, -18, -6, 10, -6, + -7, 29, 35, 21, 16, 9, 25, 44, 26, + 21, 34, 28, 40, 41, 9, -2, 1, 12, + 34, 18, -12, -10, -16, -29, -24, -25, -20, + -17, -35, -29, -12, -29, -39, -32, -30, -17, + -12, -28, -20, -5, -4, 7, 14, 10, 3, + -3, 0, 19, 27, 4, -21, -18, -7, -4, + 0, 1, -6, -17, -30, -24, -11, -9, 0, + -1, 0, -3, -12, 1, 15, -2, 3, 16, + -3, -8, 7, 3, 13, 32, 23, 10, -6, + -11, 8, 4, -12, -9, 3, 12, -2, -31, + -36, -33, -37, -17, -5, -20, -14, 4, 5, + 4, 6, 17, 31, 27, 23, 16, -1, -4, + 15, 24, 21, 18, 7, -7, -14, 18, 41, + 25, 14, 13, 2, 5, 12, 8, 15, 10, + 2, 13, 10, 3, 5, -1, 0, 11, 10, + 6, 2, 7, 10, -4, -3, 2, -13, -4, + 14, -4, -17, -11, -4, 8, 3, -8, -1, + -7, -20, -4, 23, 23, 8, 5, 24, 21, + -5, -2, 7, -9, -15, -8, -6, 6, 2, + -26, -19, 1, -19, -31, -27, -34, -41, -47, + -39, -12, -12, -29, -32, -41, -36, -26, -36, + -35, -33, -29, -1, 5, -13, -21, -21, -3, + 12, 1, -7, -1, 2, 12, 9, -1, 15, + 21, 18, 25, 4, -13, 5, 12, 16, 33, + 33, 19, 21, 26, 30, 30, 24, 23, 19, + 22, 34, 39, 28, 15, 14, 24, 24, 18, + 12, 10, 4, 8, 28, 29, 2, -7, 6, + 8, 10, 2, -13, -8, -2, 0, 12, 13, + -1, 3, 21, 26, 24, 17, 11, 15, 19, + 19, 19, 11, 1, 3, 3, 0, -5, -11, + -16, -26, -18, 3, -5, -17, 2, 10, 6, + 6, -8, -11, 4, -3, -17, -10, -17, -37, + -31, -17, -26, -37, -42, -53, -49, -34, -40, + -39, -21, -17, -23, -23, -25, -30, -24, -13, + -10, -10, 1, 1, -7, 7, 19, 11, 4, + -3, -8, 1, 6, 7, 25, 22, -5, 3, + 20, 7, -1, 14, 17, 18, 20, 12, 25, + 41, 23, 19, 37, 39, 21, 17, 23, 17, + 6, 9, 15, 4, -15, -8, 8, 7, 1, + -12, -18, -14, -15, -10, 0, -3, 3, 13, + -8, -21, -8, -26, -29, -1, -9, -24, -19, + -22, -24, -18, -25, -27, -28, -34, -26, -9, + -14, -14, -8, -8, -5, 4, 4, -10, -12, + -7, -8, -10, -15, -19, -10, -5, -9, -9, + -19, -33, -27, -14, -15, -14, -16, -25, -10, + 5, -7, -11, 2, 3, 7, 17, 28, 33, + 32, 33, 39, 49, 57, 63, 62, 64, 67, + 59, 55, 67, 71, 58, 53, 53, 44, 38, + 44, 51, 51, 45, 35, 34, 46, 55, 48, + 36, 21, 3, -5, 2, 7, 0, -17, -30, + -34, -48, -62, -64, -66, -66, -62, -79, -90, + -85, -88, -88, -85, -88, -103, -112, -112, -102, + -99, -102, -103, -110, -100, -80, -60, -57, -68, + -59, -45, -35, -6, 9, -3, 2, 32, 45, + 48, 51, 40, 51, 78, 85, 83, 87, 94, + 101, 104, 105, 100, 86, 82, 96, 102, 96, + 85, 68, 63, 65, 55, 50, 46, 28, 32, + 43, 33, 30, 27, 8, 18, 36, 27, 20, + 13, -14, -19, 8, 12, 0, -1, -12, -24, + -20, -27, -39, -39, -39, -44, -38, -32, -42, + -38, -33, -43, -55, -57, -60, -61, -56, -57, + -55, -43, -46, -58, -55, -50, -50, -51, -48, + -46, -44, -36, -26, -20, -13, -11, -8, 1, + 5, 0, 8, 21, 31, 42, 39, 43, 56, + 48, 37, 45, 45, 47, 52, 46, 40, 26, + 18, 28, 30, 22, 14, 0, -3, 8, 0, + -7, 0, -10, -13, -9, -13, -13, -18, -33, + -32, -26, -37, -41, -32, -26, -30, -34, -31, + -38, -40, -24, -25, -29, -15, -18, -23, -4, + 2, -7, 0, 5, 10, 22, 23, 25, 31, + 33, 37, 38, 39, 43, 46, 41, 44, 46, + 37, 35, 46, 63, 67, 52, 38, 30, 35, + 41, 41, 41, 29, 15, 16, 4, -4, 3, + -12, -18, -13, -27, -39, -47, -55, -44, -43, + -53, -45, -36, -37, -37, -38, -40, -49, -57, + -41, -24, -28, -31, -26, -20, -15, -21, -23, + -18, -19, -14, -10, -11, 1, -6, -26, -14, + -1, -7, -10, -11, -9, 0, -4, -9, 3, + 8, 0, -2, 1, 16, 20, 7, 9, 10, + 8, 18, 12, 11, 17, -6, -19, 0, 0, + -10, -6, -12, -14, -11, -9, -2, -10, -19, + -9, -11, -4, 18, 7, -3, 9, 17, 23, + 28, 25, 19, 19, 24, 33, 37, 30, 28, + 35, 44, 43, 33, 31, 30, 26, 33, 39, + 35, 31, 27, 19, 23, 24, 19, 13, 0, + 0, 2, -7, -9, -10, -13, -6, -6, -23, + -28, -15, -9, -20, -34, -30, -15, -12, -11, + -3, -4, -4, 6, 15, 9, -11, -20, 3, + 26, 23, 1, -16, -3, 12, 2, -22, -36, + -35, -28, -20, -13, -19, -38, -43, -29, -11, + -5, -15, -37, -40, -9, 12, -1, -23, -30, + -16, 12, 21, -1, -25, -21, 4, 34, 55, + 34, -12, -11, 47, 99, 107, 58, 0, 8, + 78, 148, 151, 56, -40, -2, 142, 215, 99, + -67, -64, 76, 153, 99, -21, -107, -92, -1, + 106, 107, -123, -395, -334, 60, 274, -69, -597, + -626, -126, 238, 18, -447, -577, -312, -34, 20, + -89, -242, -332, -222, 74, 262, 64, -285, -232, + 259, 563, 294, -138, -130, 312, 642, 515, 189, + 57, 187, 415, 538, 467, 277, 109, 134, 334, + 441, 299, 59, -7, 128, 228, 146, -20, -99, + -34, 60, 24, -108, -188, -147, -57, -48, -142, + -224, -210, -144, -122, -175, -212, -176, -150, -199, + -256, -210, -100, -79, -195, -298, -248, -107, -48, + -110, -192, -224, -189, -112, -40, -31, -124, -238, + -193, -3, 87, -53, -221, -165, 48, 132, -2, + -150, -109, 61, 147, 83, -20, -60, -13, 85, + 157, 130, 17, -68, -10, 147, 217, 116, -20, + -21, 103, 200, 158, 52, 35, 105, 155, 132, + 81, 74, 110, 114, 74, 48, 68, 100, 77, + 27, 30, 48, 19, -15, 7, 63, 53, -56, + -123, -41, 81, 75, -61, -154, -84, 45, 68, + -24, -105, -76, 22, 53, -13, -63, -21, 54, + 59, -1, -34, 16, 80, 81, 48, 37, 61, + 89, 88, 101, 134, 132, 100, 83, 125, 188, + 173, 101, 95, 172, 214, 149, 68, 94, 181, + 177, 103, 83, 132, 165, 122, 83, 140, 191, + 153, 92, 106, 198, 226, 138, 85, 146, 215, + 187, 110, 77, 115, 146, 115, 91, 96, 78, + 27, -3, 42, 102, 71, -23, -46, 30, 95, + 63, -18, -25, 77, 174, 138, 13, -25, 96, + 218, 181, 34, -70, -45, 17, 2, -67, -174, + -346, -516, -553, -446, -455, -789, -1213, -1308, -1046, + -878, -1179, -1691, -1839, -1528, -1219, -1292, -1623, -1772, + -1538, -1147, -921, -951, -1038, -929, -549, -95, 155, + 127, 97, 387, 931, 1339, 1380, 1234, 1276, 1661, + 2102, 2223, 2027, 1848, 1942, 2198, 2295, 2119, 1856, + 1725, 1745, 1752, 1601, 1335, 1102, 993, 952, 830, + 570, 286, 139, 133, 85, -135, -436, -638, -645, + -571, -620, -835, -1064, -1151, -1069, -951, -964, -1109, + -1209, -1162, -1044, -961, -944, -977, -1001, -912, -687, + -517, -623, -887, -897, -469, 10, -35, -590, -934, + -545, 184, 427, -53, -619, -563, 40, 489, 339, + -128, -306, -6, 403, 497, 232, -55, 0, 388, + 704, 584, 145, -76, 260, 816, 942, 485, 2, + 65, 575, 923, 744, 290, 76, 276, 596, 662, + 419, 134, 92, 280, 434, 344, 88, -66, 8, + 151, 126, -81, -239, -176, -29, -74, -351, -574, + -487, -208, -132, -426, -780, -797, -577, -595, -978, + -1169, -667, -36, -548, -2285, -3281, -1756, 927, 1236, + -1911, -5006, -4073, -66, 2017, -295, -3701, -3797, -892, + 975, -165, -1978, -1636, 374, 1482, 679, -567, -591, + 706, 2337, 3224, 2743, 1269, 287, 1221, 3597, 5083, + 4106, 1858, 972, 2334, 4096, 4167, 2806, 1916, 2383, + 3045, 2508, 1220, 820, 1784, 2669, 1981, 204, -876, + -470, 510, 803, 170, -787, -1568, -1893, -1598, -1027, + -992, -1803, -2610, -2484, -1905, -2113, -3113, -3399, -2267, + -1261, -2007, -3637, -3909, -2340, -893, -1158, -2272, -2486, + -1639, -915, -777, -596, -91, 196, 85, 210, 875, + 1373, 1247, 1219, 1958, 2718, 2328, 1196, 1008, 2350, + 3677, 3269, 1503, 366, 922, 2264, 2810, 1996, 608, + -168, 75, 680, 811, 395, -56, -318, -607, -966, + -1108, -925, -613, -368, -369, -919, -1926, -2460, -1685, + -300, 155, -611, -1524, -2204, -3227, -3859, -2037, 1622, + 2382, -2583, -8448, -7544, -84, 4814, 915, -6423, -7558, + -1746, 2515, -59, -4587, -3858, 1260, 3625, 187, -4148, + -3500, 1542, 5467, 4780, 1256, -1127, -403, 2481, 5332, + 6346, 5014, 2536, 1216, 2467, 5039, 6238, 5070, 3381, + 3269, 4173, 3905, 2248, 1586, 3299, 5240, 4362, 1004, + -1382, -489, 2113, 3168, 1620, -742, -1824, -1435, -897, + -1058, -1500, -1545, -1398, -1965, -3266, -4136, -3756, -2609, + -1804, -1986, -3087, -4599, -5296, -4051, -1731, -781, -2228, + -4092, -3977, -2325, -1353, -1568, -1490, -428, 178, -672, + -1650, -1058, 749, 2039, 2079, 1540, 897, 310, 572, + 2266, 4265, 4265, 1869, -231, 559, 3332, 4752, 3229, + 768, 101, 1364, 2463, 1984, 819, 411, 723, 675, + -162, -923, -743, -32, 185, -516, -1653, -2359, -2103, + -986, 42, -205, -1702, -2870, -2337, -809, -221, -982, + -1544, -946, -598, -2117, -4291, -4100, -857, 1948, 338, + -4799, -7972, -5403, 173, 2371, -1063, -5533, -5578, -1777, + 605, -985, -3249, -2213, 1184, 2691, 560, -2356, -2288, + 1233, 5244, 6441, 4004, 370, -663, 2555, 7404, 9282, + 6573, 2612, 1836, 4662, 7467, 7393, 5421, 4262, 4741, + 5362, 4705, 3163, 2397, 3337, 4887, 4810, 2254, -749, + -1316, 772, 2706, 2016, -573, -2552, -2746, -2012, -1647, + -1978, -2579, -3105, -3473, -3911, -4484, -4891, -4795, -4163, + -3543, -3538, -4275, -5356, -5743, -4637, -2614, -1301, -1825, + -3341, -4011, -2937, -751, 1007, 1245, 235, -639, -61, + 1626, 2864, 2967, 2734, 3013, 3329, 2914, 2312, 2666, + 3839, 4308, 3162, 1453, 768, 1255, 1887, 2006, 1715, + 1031, -297, -1660, -1690, -277, 813, -30, -2137, -3370, + -2854, -1553, -593, -413, -1146, -2567, -3440, -2369, -205, + 379, -1258, -2315, -812, 262, -3205, -8576, -7894, 738, + 7492, 1951, -11595, -17098, -6934, 7139, 8065, -4575, -14199, + -8946, 3606, 7504, -547, -8242, -5113, 4406, 8113, 2134, + -5040, -4089, 4157, 10934, 10158, 4167, -565, -192, 4428, + 9765, 12201, 9861, 4512, 1225, 3451, 8483, 10133, 6497, + 2574, 3333, 6806, 6986, 2487, -1214, 623, 5416, 6647, + 2204, -3289, -4556, -1565, 1544, 1525, -1236, -4293, -5695, + -5174, -3995, -3403, -3449, -3750, -4505, -6014, -7296, -6523, + -3849, -2096, -3288, -5722, -6004, -3581, -1497, -1960, -3330, + -2800, -434, 964, -111, -1739, -1136, 1736, 4151, 3736, + 1274, -451, 469, 3386, 5833, 5898, 3646, 1085, 272, + 1743, 4061, 5108, 3837, 1490, 246, 967, 1866, 859, + -1069, -974, 1542, 2835, 47, -4285, -5068, -1567, 1781, + 1223, -1997, -4227, -3747, -1720, 41, 245, -1228, -2972, + -2673, 22, 1980, -930, -7721, -11271, -5725, 4974, 8484, + -2007, -16979, -19255, -4670, 11057, 9690, -6417, -17537, -10841, + 4262, 9292, 206, -9128, -6224, 4828, 10018, 3699, -5183, + -5121, 4702, 14279, 14466, 5778, -2633, -2185, 7036, 16118, + 16305, 8081, 390, 499, 6580, 11150, 10036, 5704, 2902, + 3378, 4664, 3786, 863, -796, 1216, 4609, 4493, -338, + -5670, -6486, -2751, 884, 571, -3095, -6446, -6997, -5770, + -5041, -5016, -4216, -2579, -2468, -5088, -8129, -7964, -4228, + -323, 497, -1556, -3653, -3615, -1718, 464, 1808, 2386, + 2832, 3085, 2905, 2676, 3473, 5501, 7094, 6442, 3929, + 1663, 1436, 3254, 5807, 7100, 5044, -34, -4091, -2992, + 2149, 5333, 2562, -3067, -5877, -4480, -2080, -1793, -3026, + -3838, -3735, -3663, -4472, -5756, -5753, -3576, -640, -274, + -3965, -7787, -6757, -717, 4380, 3595, -1553, -5936, -8603, + -10223, -8952, -922, 9700, 9355, -7788, -25795, -22413, 2268, + 20887, 12133, -11291, -20129, -5899, 10236, 8585, -3645, -6300, + 4667, 14216, 9346, -3593, -8558, 715, 15085, 21179, 14887, + 3733, -2703, -675, 7170, 15131, 18360, 13959, 4205, -2825, + -656, 7594, 11845, 7182, 319, -439, 3255, 3213, -3299, + -8972, -6318, 2300, 7190, 2254, -9247, -17334, -15064, -4452, + 5160, 5127, -4268, -14501, -17256, -11145, -1830, 3786, 2984, + -2498, -8101, -9587, -5703, 622, 4570, 4035, 1442, 729, + 2493, 3534, 2433, 2239, 5944, 11438, 12371, 6496, -211, + -156, 7092, 13566, 11979, 3928, -2545, -2226, 2713, 6150, + 5117, 1270, -1851, -2859, -2376, -1909, -2364, -3401, -4183, + -3897, -2875, -3205, -5503, -7822, -7501, -3934, -942, -1572, + -4262, -5939, -4671, -2353, -1387, -1159, -1270, -1328, -606, + 474, 1044, -2647, -11603, -17081, -10374, 5922, 14849, 2056, + -22033, -31238, -14612, 11094, 17910, 1778, -15538, -15417, -2045, + 6690, 2855, -2559, 473, 8823, 11423, 3782, -4649, -2775, + 9111, 20847, 21610, 11572, 962, -1465, 5731, 15559, 20008, + 16950, 9230, 2204, 114, 3088, 8130, 10523, 7643, 2045, + -2107, -2945, -2538, -3593, -5210, -4403, -857, 1328, -2497, + -11667, -18881, -16866, -6286, 3400, 2835, -7811, -18322, -19279, + -10025, 1525, 6930, 3766, -4647, -11401, -9904, -322, 10100, + 12428, 5874, -274, 926, 6762, 9360, 6778, 5904, 10509, + 15077, 12681, 3846, -1653, 2460, 11036, 14737, 8967, -1021, + -6168, -3899, 2328, 6041, 3404, -2878, -7672, -6869, -1918, + 801, -2188, -7419, -8083, -2687, 1898, -692, -8121, -11198, + -5642, 2830, 5915, 1120, -5666, -8314, -5770, 118, 4614, + 4713, 1482, -2544, -3331, -3779, -8931, -13840, -10273, 3355, + 13432, 2906, -20058, -30890, -17080, 7759, 16047, 2886, -12525, + -15117, -5998, 1614, 2294, 2684, 4610, 6236, 5486, 2514, + 1346, 1962, 4564, 11022, 17438, 18182, 10179, -796, -3019, + 5456, 15942, 18468, 11176, 2796, -143, 1670, 3922, 3836, + 3337, 3330, 1623, -2609, -7177, -7654, -4250, -2210, -3491, + -5312, -4380, -3103, -6738, -13209, -14278, -6529, 3346, 4931, + -2861, -11176, -12097, -5552, 2679, 7102, 6050, 1301, -3350, + -3378, 1785, 7413, 9059, 7013, 5043, 5331, 5197, 3143, + 1862, 3790, 8037, 10159, 7236, 1450, -3393, -3980, 598, + 6251, 7410, 1502, -7144, -10260, -5116, 2386, 4197, -894, + -6255, -6026, -1493, 873, -1639, -4426, -2720, 2252, 4206, + 158, -4631, -4466, 537, 4709, 4528, 1691, -828, -1394, + -455, 756, 2662, 3101, 1730, -3579, -12987, -18531, -12998, + 1944, 11963, 1503, -19826, -29919, -18138, 2254, 7644, -1829, + -9260, -6516, 134, -793, -5234, -2336, 6264, 12828, 11829, + 6589, 3429, 2592, 4795, 11433, 19490, 21681, 13136, 379, + -4138, 3585, 14812, 17633, 10124, 623, -2287, 696, 2273, + -926, -5000, -4391, -386, 139, -4657, -11003, -13946, -11930, + -7460, -1932, 1277, -2311, -10543, -16920, -14512, -4039, 4987, + 7518, 3175, -4213, -7535, -4747, 3590, 12231, 13419, 8429, + 2377, 1080, 5563, 8497, 7304, 5331, 5656, 8235, 6997, + 998, -3131, -1857, 3017, 5883, 3744, -408, -4503, -6489, + -4796, -374, 3254, 1651, -2830, -5206, -3690, -681, -969, + -2819, -2616, 19, 3379, 2359, -2476, -6413, -6111, -463, + 4664, 4106, -565, -4801, -4960, -1242, 2479, 3706, 2168, + -1104, -3048, -1563, 1217, 2013, -5714, -17921, -21743, -10839, + 7751, 13091, -4648, -26509, -29653, -9872, 10100, 9523, -4335, + -12121, -5509, 4923, 6380, 1839, -508, 3312, 10704, 14545, + 12317, 5508, -243, 2421, 11485, 19096, 18306, 8626, -1357, + -5542, -1695, 7815, 13549, 10229, -23, -8373, -7496, -2775, + -1016, -2900, -4868, -4103, -4535, -6851, -8099, -8137, -6414, + -4023, -1790, -45, -1513, -4791, -6160, -4105, 1060, 5970, + 7099, 3934, -996, -2213, 1973, 6975, 7927, 4726, 2474, + 3951, 5221, 2642, -2359, -3579, 1362, 6614, 6282, 116, + -5643, -5733, -1884, 2107, 3418, 2566, 684, -2319, -3803, + -2133, 1512, 2943, 475, -1004, 753, 3095, 1652, -3074, + -4562, -932, 3815, 4486, -22, -4199, -4666, -2201, 284, + 316, -914, -2297, -2441, -1538, -435, 909, 626, -1222, + -1534, -429, 1711, 2386, -1786, -10676, -18200, -16272, -3805, + 9505, 8238, -9397, -24577, -22256, -4907, 8659, 5940, -3701, + -6764, 40, 6190, 4239, 208, 238, 7081, 14458, 15143, + 10726, 3479, -706, 1700, 9131, 17577, 17708, 7959, -5009, + -11508, -5347, 5635, 10789, 6499, -3121, -9303, -9814, -6625, + -3333, -3193, -4349, -5615, -6188, -5123, -4441, -4550, -4074, + -2769, -61, 2441, 2881, 1395, -578, -341, 2509, 6034, + 8202, 6377, 2696, 1272, 2589, 4787, 4611, 2378, 2124, + 3911, 4872, 2049, -3374, -5770, -2705, 3179, 5905, 2589, + -2792, -5419, -3176, 1056, 2875, 2483, 1205, 605, 856, + 1012, 892, 105, -411, 707, 2924, 4184, 1755, -2553, + -4857, -3556, 401, 2466, 945, -2315, -5556, -5549, -2241, + 534, 601, -1774, -3034, -1962, -886, -448, -720, -467, + 864, 760, -22, -2546, -10211, -17121, -15877, -4803, 7993, + 7254, -6563, -18374, -17755, -6143, 3291, 4322, 1822, 416, + 2788, 5190, 4256, 2627, 2590, 6398, 12709, 15757, 12829, + 5542, -667, 167, 7241, 14346, 14826, 6392, -3516, -7434, + -4607, 1054, 2988, 847, -1549, -2641, -3046, -5363, -8256, + -9130, -6906, -1460, 2260, 1568, -2911, -8580, -9418, -3675, + 5021, 10127, 7909, 1478, -4015, -3331, 2450, 7291, 7632, + 2567, -2022, -899, 3418, 5544, 1349, -4117, -3409, 1758, + 6000, 3526, -3975, -7331, -3931, 2747, 7037, 4962, -21, + -2902, -2008, 1306, 4461, 6364, 5956, 3623, 1734, 793, + 44, -893, -1041, 1633, 5264, 4870, -943, -7404, -8611, + -4974, -1192, 185, -1334, -3672, -4910, -5132, -4387, -3532, + -3233, -2430, -469, 1245, 892, -969, -2441, -2140, 320, + 4999, 5954, -4638, -20056, -24424, -8954, 13558, 16089, -3145, + -20665, -19447, -4802, 4488, 3733, 943, 683, 3109, 6219, + 9247, 7736, 782, -1410, 8024, 20877, 20174, 4723, -7148, + -2758, 11240, 17896, 11462, 414, -6134, -4913, 113, 2818, + 98, -5900, -8369, -4446, 924, 1657, -3389, -10569, -13223, + -7690, 2339, 7741, 1634, -9014, -10982, -1172, 9642, 9098, + 1310, -2795, -1040, 2790, 3808, 3559, 3064, -527, -3160, + -1391, 3120, 5224, -144, -6714, -6416, -719, 5630, 7253, + 2735, -2973, -4325, 679, 7146, 8220, 4055, -42, 814, + 5288, 7658, 6592, 3051, -746, -541, 3401, 6030, 1953, + -6340, -8619, -2689, 4076, 3217, -4875, -9612, -7826, -4293, + -2441, -4080, -5740, -5529, -3656, -506, -1035, -5787, -9518, + -7034, 2323, 9287, 6495, -1853, -6110, -3281, -1708, -8958, + -19544, -18870, -2771, 13029, 10762, -7491, -21837, -18923, -4183, + 8733, 12580, 9779, 4597, 738, 1460, 6302, 9711, 8375, + 8143, 12512, 15808, 11272, 389, -5554, 161, 11080, 15851, + 10426, 692, -6372, -6808, -2525, 652, 827, -219, -349, + -622, -3328, -7883, -11020, -8961, -3240, 1884, 4155, 1995, + -3530, -7816, -6444, -218, 6086, 9279, 7901, 3113, -2352, + -5757, -3836, 2022, 4572, 894, -3519, -3311, -534, -618, + -3716, -5515, -3290, 1495, 4374, 4455, 2961, -645, -3247, + -656, 5273, 9838, 9751, 5755, 1863, 158, 1457, 4585, + 6390, 5379, 2894, 2284, 1867, -2279, -7051, -6578, 70, + 4745, 1660, -4524, -8007, -7088, -5690, -5467, -4178, -2679, + -2218, -3422, -4167, -4313, -6105, -6633, -4202, 864, 5119, + 4084, -163, -5331, -8699, -8710, -7313, -4649, -2471, -1419, + -1136, -3199, -6428, -8048, -4902, 1089, 4681, 5723, 5535, + 5146, 4006, 2052, 2314, 5274, 8680, 9907, 8776, 6722, + 2548, -2403, -3303, 1224, 7406, 9468, 5089, -1197, -4384, + -3570, -298, 1776, 2005, 2041, 1326, 971, -180, -2334, + -1170, 1913, 4281, 4732, 2874, 1174, -1341, -3384, -2503, + 368, 4031, 3270, -986, -3519, -5360, -6004, -5576, -3603, + 208, 708, -2137, -4940, -5349, -3588, -2796, -1399, 1017, + 3144, 4196, 2483, 828, 338, 919, 3842, 6202, 7189, + 7499, 6330, 4847, 3252, 2136, 3698, 5845, 5566, 3019, + 267, -55, -1091, -4220, -5041, -3430, -280, 171, -4649, + -8723, -9280, -5975, -3192, -3974, -3912, -4053, -3748, -3570, + -5871, -5499, -3552, -1691, 320, 341, 748, -313, -3436, + -4687, -3681, 21, 2550, 643, -2123, -3254, -2226, -1044, + -1617, -1510, 183, 1250, 726, -1662, -3388, -1759, 933, + 3817, 5242, 3025, 248, -1339, -514, 2022, 3410, 3970, + 3324, 2632, 2603, 2240, 2166, 1271, 487, 1076, 2039, + 3296, 3836, 3610, 2913, 2718, 4213, 5555, 6023, 4769, + 2442, 2067, 2173, 1623, 1201, 348, 52, -124, -1528, + -2834, -3604, -3463, -2357, -2564, -3775, -3801, -1929, -465, + -2109, -3743, -2657, 200, 2580, 954, -1304, -95, 1549, + 2303, 1795, 1633, 3356, 3699, 2361, 792, 1148, 4045, + 4820, 3851, 3197, 2449, 2704, 1722, -652, -1154, -393, + 113, -1010, -3328, -4342, -3939, -3345, -3697, -5115, -5610, + -4202, -3639, -5088, -5351, -3216, -862, -414, -1839, -3996, + -4831, -2467, 147, 1055, 1288, -247, -2225, -2233, -1562, + -1278, -936, -961, -935, -367, -323, -459, -1940, -3974, + -2262, -13, 2, -401, -1825, -2308, -1124, 448, 2154, + 2434, 1300, -812, -1337, 1325, 3374, 3466, 2500, 2156, + 3439, 3549, 2068, 1392, 1986, 3025, 3944, 3898, 3259, + 4467, 6347, 5356, 2893, 1690, 2072, 4136, 5313, 2776, + -236, -1063, -794, 524, 802, -1377, -2879, -2167, -1439, + -1595, -1539, -1666, -2495, -2375, -1253, -515, -187, -1409, + -2847, -511, 2411, 1761, 492, -18, 607, 2350, 3288, + 3505, 2741, 1099, 699, 2017, 3214, 3333, 1567, 33, + 1260, 1925, 808, -377, -2558, -3781, -1677, 164, -580, + -1727, -2619, -3421, -3586, -3957, -4562, -3646, -2285, -3437, + -5293, -4792, -4128, -4012, -2920, -2249, -2439, -3737, -5607, + -4427, -1259, 71, 609, 555, -1039, -3354, -5388, -3760, + 415, 2513, 2513, 819, -1436, -2780, -2740, -501, 2727, + 3936, 1491, -965, -766, -484, -223, 361, 695, 1771, + 1130, -1839, -1764, 797, -31, -2549, -1790, 2108, 4043, + 887, -154, 2411, 2605, 2012, 1977, 3923, 6630, 4176, + 107, -311, 1731, 1910, 1011, 3119, 3219, 998, -1282, + -2832, -1645, -685, 945, 2574, 2543, -267, -5015, -3819, + -342, 1228, 2055, -619, -1233, 2069, 2896, 1095, 62, + 1365, 3366, 4584, 4956, 3323, -19, -50, 4024, 5222, + 3695, 3118, 1933, 1256, 1443, 128, -119, 2043, 2477, + 1823, 1324, 30, -1363, -3023, -3074, -188, 621, -1775, + -2806, -2961, -2753, -4359, -5350, -1220, -116, -4157, -4811, + -2793, -1040, -1957, -2862, -1901, -3192, -3720, -2357, -1727, + -387, -2131, -5011, -3650, -454, 596, -1298, -3716, -3122, + 496, 136, -2415, -1675, -811, -837, 140, -1243, -187, + -1431, -5320, -2121, 100, -467, 2465, 681, -2093, 1224, + 1632, 1428, 1776, 648, 2480, 3622, 876, 259, 1403, + 2139, 3117, 497, -763, -170, 279, 1769, 342, -871, + -25, -1549, -2290, 290, 1042, -796, -4291, -3895, 159, + 1264, -540, -2328, -702, 1972, 852, -2274, -798, 1126, + -579, -480, 3481, 3833, 1004, 901, 1536, 1809, 3103, + 2521, 3183, 5220, 1800, -266, 4663, 4230, -790, 159, + 2274, 5114, 4304, -1998, 344, 4921, -343, -2048, 1180, + 2112, 3109, -10, -1818, 552, -1360, -2889, -1302, -1918, + -37, 1406, -1762, -3054, -1446, -2073, -4292, -3214, 1163, + 2333, -712, -2583, -2058, -1034, -600, -3796, -2395, 2137, + -1122, -1927, 702, -2196, -4374, -3257, -1558, -256, -728, + -395, -176, -1529, -2772, -1121, -340, -1147, -250, -4079, + -473, 4241, -2818, -3523, 3255, 2355, -2550, -1082, 1197, + 2213, -94, -237, 3123, 1314, -1075, 977, 1081, 2045, + 2966, -1328, -1069, -741, -524, -380, -2766, -986, 926, + -3281, -1554, 2554, -3620, -6394, -1680, -321, 2889, 243, + -1567, 2276, -1294, -525, 2010, -4883, -1495, 6778, 2085, + -873, 2496, 418, -1156, -1179, 1604, 6173, 1190, -2381, + 5788, 2431, -4941, -242, 1248, 1023, 4426, 3399, 2726, + 1388, -922, 595, 392, 1414, 6260, 2673, -973, 2237, + 1776, -2393, -757, 4158, 2842, -2327, 505, 1230, -3623, + -917, 336, -1400, -1018, 1771, 2696, -570, -2435, 886, + 2309, -2865, -1328, 2077, -1967, -3486, -411, 961, -1661, + -1979, 1179, -493, -2597, 1995, 284, -3300, -2213, 184, + 312, -1665, -641, -1325, -1276, 90, 69, 476, -778, + -1099, 853, 1515, 1630, 1188, -877, -1751, 702, 2983, + -201, 664, 4018, -352, -1864, 875, 2367, 813, -2463, + -702, 886, -2204, -2216, 399, -1729, -2408, 1412, -2757, + -3530, 449, -2554, -3910, 906, 697, -1696, 566, -1360, + -1991, 81, -1756, -159, 1180, -667, -584, -359, 183, + 1943, -412, -1747, 1659, 1961, 280, 294, 222, 2000, + 2076, 829, -43, -880, 3353, 3615, 1279, 1746, -1031, + 1301, 3477, -777, 2567, 1215, -2344, 3556, 561, -2166, + 1119, 2377, -391, -1825, -2359, 49, 1764, 391, -291, + 325, 1223, 1443, -624, -2828, 1381, 2438, 28, -652, + -166, 581, -2039, -374, -20, -2459, -1149, 1505, 2008, + -1798, -3848, -1796, -2208, -2224, -878, 728, -154, -534, + 1061, 538, -1465, 73, 1147, 82, -119, 3800, 4797, + -873, 784, 1458, -148, 3180, 1319, 908, 4951, 584, + -57, 2394, -967, 586, 405, -1601, 3566, -285, -3949, + -1301, -1953, -1223, -1831, -3477, -779, -389, -3169, -1828, + -1496, -1451, -556, -3327, -209, 534, -4908, 131, -386, + -5232, 1373, 2129, -1740, -1957, -1102, 76, 396, -1426, + -179, 1357, -3276, -1420, 3819, -44, 56, 2777, -1202, + 1908, 1410, 2031, 3495, -2197, -163, 1565, 239, 2803, + 480, -1636, 1180, 616, 1206, 1166, -1579, 1572, 814, + -774, 2310, 740, -2606, 1234, -603, -362, 1562, -2134, + 652, -777, -2353, 5464, 377, -2490, 1012, 157, 680, + -1389, -1898, 1135, -1, -1730, 1800, -1466, -1687, -1469, + -3250, -1081, 1381, -81, -204, -26, 353, 1941, 174, + 104, 2009, 1032, -871, 3280, 3398, -651, -154, 3309, + 1964, 448, 812, -17, 887, 2405, 3295, -54, -2396, + 1410, 1380, -1156, 296, -1706, -1729, 401, -970, -878, + -723, -2285, 1259, 1320, -1960, -1039, -211, -661, -763, + -1599, -43, 308, -1841, 72, -2075, -3010, -497, 506, + -377, 247, 1932, -1788, -2419, 257, 208, -2176, 488, + 2827, -1720, -1649, -619, 520, 1103, -1231, -1327, 2162, + 1535, -383, 315, -1488, -235, 1761, -27, -232, 515, + 127, -2239, 654, 2871, -379, -1274, 2445, 874, -2444, + 514, -206, -1289, 1314, 1869, 1316, 1878, -1454, -982, + 476, 359, 2084, -708, 405, -246, -1071, 1757, -866, + -2331, 783, 501, -853, 896, 36, -2468, -1138, 1445, + -613, -687, 1999, -449, -731, 1478, 384, -45, 96, + 1530, 1919, 186, -94, 1347, -329, -348, 1631, 574, + 1062, 735, -1652, 675, 244, 1241, 1137, -2469, 621, + 45, -612, 1308, -2015, -208, 2392, -1646, -67, 77, + -1558, 113, 1263, -236, -971, -333, -733, -555, 2024, + -135, -3817, -398, 1696, -1179, -1473, 1175, -166, 618, + 1132, -2504, -575, 146, -688, 1323, 150, -2021, 15, + 1673, 347, -1535, -106, 235, -32, 1167, -471, -503, + -1260, 416, -13, -1082, 1036, -790, -1676, 487, 985, + 77, 57, -1175, 1146, 2023, -1706, -404, 3249, -739, + -979, 3044, -514, -168, 2201, -2863, 1009, 1833, -2309, + 1565, 476, -1698, 1667, -496, -2193, 1686, 532, 336, + -1095, -1655, 578, -909, -1263, 2569, -2833, -1808, 2860, + -822, 27, 1098, -1371, 1585, -284, -1074, 2944, -764, + -2871, 2484, 1179, -1213, -670, -1226, 1112, 1837, -299, + -388, -51, 1, 992, -723, -361, 1723, -1115, -2012, + 1261, -9, -127, -510, -1550, 1448, 957, -1930, 171, + 776, -2104, 14, 764, -599, -745, -438, -371, -659, + 1075, 282, -3116, 684, 3747, 22, -2139, 816, 1413, + -333, 458, 906, 483, -1084, 797, 1039, -467, -377, + 1386, -1182, 610, 1787, -1354, -2800, 2638, 424, -2372, + 1153, -51, -689, 290, -2199, 818, 3755, -2674, -1689, + 3497, -507, -1978, 1729, 1413, 215, -76, 53, 759, + 371, -1529, 1005, -770, -685, 1754, -908, -653, 1047, + -1066, -784, -199, -526, 86, -1750, -916, 1839, 580, + -1884, 319, 226, -977, 212, 202, -741, -1013, 2057, + 69, -2961, 974, 1964, -512, -224, 1554, -79, -1142, + 1853, -71, 1009, 1174, -718, 2040, -158, -1508, 1042, + 0, -1219, 1212, 448, -208, -47, -779, -867, 1924, + -254, -1085, -221, -1283, 1543, -584, -951, 225, -1089, + -464, -853, -615, 1576, -2313, -1214, 950, -2548, -314, + 1201, -1527, 952, 764, -1915, 528, 169, -1676, 1742, + 425, -2346, 932, 290, 109, 492, -379, 932, 70, + 582, 135, 769, 1665, -1751, 576, 1013, 366, 2339, + 71, 637, 1500, 576, 111, 494, 765, 1170, 1421, + -5, -892, 2054, -640, 160, 1426, -651, 348, -841, + -558, 1563, 277, -408, -1468, 482, -1538, -2255, 968, + -1307, -454, 1306, -3085, -1680, 2624, -2191, -1719, 1891, + -3826, -1441, 2736, -3694, -266, 1897, -4468, 841, 2828, + -4060, -318, 2305, -1662, 528, 3056, -2429, -156, 2045, + -753, 475, 419, -597, 1100, 1845, 504, 1067, -402, + -824, 1807, 1192, 459, 200, 1728, 50, -497, 678, + -355, 938, 1239, -1223, 360, 1251, -95, 981, 1029, + -1940, 260, 1627, -2387, 3426, 519, -3141, 1822, -506, + -1471, 1101, -2137, 1069, 885, -2618, 1673, -463, -1558, + 1439, -386, -1923, 1538, -1313, -1735, 540, -1433, -915, + 494, -839, -1527, -1143, 480, -1081, 27, 1732, -1285, + -1833, 1952, -667, -1626, 1819, -1293, -1323, 2139, -376, + -1392, 1277, -1172, -240, 2907, -1875, -238, 2573, -1068, + -471, 2065, -686, -1315, 2575, 233, -1005, 1135, 706, + 534, 278, -182, 1091, -21, -222, 1413, -371, -54, + 1108, -103, 382, -70, 787, 894, -108, 1308, 1113, + -1412, 574, 1140, -2032, 500, 569, -1251, 951, -50, + -1398, 772, -474, -1536, 1297, 251, -2321, 109, -703, + -425, 40, -1354, -773, -225, -1743, -1839, 1244, 261, + -3082, -424, 1162, -937, 123, -322, -407, -561, -331, + 1369, -1142, -1050, 1024, 1116, -213, -752, 1521, -383, + -415, 1011, 947, -713, 743, 1945, -237, 881, 600, + -757, 885, -835, 756, 2454, -1985, 699, 1572, -1652, + 673, 232, -42, 1975, -736, -270, 1660, -704, -96, + 1264, -428, 278, 774, -954, -1325, 756, 1275, -594, + -353, 204, -1130, -782, -432, -979, 268, 378, 20, + -870, 405, -357, -1661, 637, 473, 293, -314, -895, + 3, -175, -1016, -643, 204, -588, -1007, -131, 401, + -849, -476, 271, 320, -198, 533, -25, -1994, 1421, + 525, -1611, 1261, 507, -488, 1093, 361, -1814, 2230, + 312, -196, 3242, -803, -962, 1714, -1479, 1426, 1612, + -1953, 1376, -581, -669, 1370, -1251, 426, 1274, -470, + 1757, 807, -589, 1275, 126, -871, 1025, -1331, 287, + 1258, -1813, 146, -839, -1471, 828, -402, -281, 1704, + -1341, -231, 939, -1035, -472, -197, -764, -380, -816, + -266, 382, -497, -1708, -591, 1119, -1941, 178, 969, + -1656, 685, 1004, -1114, -127, -1473, -678, 1610, -1253, + 277, 1807, -1642, -461, 2033, -1449, 392, 98, -157, + 1525, -860, 2455, 413, -2159, 2457, 475, -374, 1532, + -981, 843, 973, 324, 1168, 225, -407, 1487, 681, + -680, 1098, 117, 245, 1238, -223, 1076, -428, -466, + 2593, -663, -1225, 1303, -933, -561, 1190, -1071, -1229, + 406, -284, -13, 198, -1494, -637, 352, -1960, 420, + 49, -1472, -761, -234, -2213, -1750, -521, -1554, -813, + 662, -633, -1388, -15, -947, -391, -152, -894, 631, + -461, -885, 633, -51, -1063, 218, 1149, -61, -274, + 988, -140, 7, 1774, 1558, -623, 755, 1352, -511, + 1106, 744, 17, 2640, -91, 697, 1547, -1757, 1832, + 1859, -206, 1505, 575, -444, 556, 250, 1786, 792, + -125, -266, 407, 501, 798, -536, -1214, 58, 6, + 354, -685, 613, 99, -2022, -116, -236, -182, 263, + -824, -1187, -142, -138, -1228, -1008, 786, -1421, -1127, + -269, -2278, 841, 222, -2423, 678, -1153, -2082, 574, + -570, -729, 180, -777, 212, 270, -274, 1077, -493, + 118, 804, -1260, 349, 799, 545, 481, 971, 1099, + 1146, -273, 34, 1728, 1128, 411, 758, 308, -808, + 950, 1490, 209, -265, 1154, -11, -460, 2644, -122, + -728, 2033, -1100, -305, 1774, -208, -1567, -57, -140, + -670, -454, -1390, -80, 978, -438, -731, -684, 344, + -458, -199, -126, -1663, -883, 642, -1517, -1144, -375, + -422, -452, -1815, -791, 763, -1502, -205, 684, -1641, + 448, 1399, -2160, 804, 1088, -2214, 1030, 1585, -1093, + -11, 1718, -360, -81, 1294, 398, 218, 1225, 644, + 505, 2090, -385, 526, 2111, -303, -316, 1550, 1323, + -459, 881, 1874, -1256, 1429, 2485, -1003, -552, 14, + 432, 952, 471, -633, 408, -358, 140, 554, -1260, + -404, 245, -2572, 954, 1005, -1621, -82, -175, -957, + 112, 106, -1117, -819, -62, -785, 71, 93, -1296, + -1680, 242, -956, -2696, 302, -204, -1404, 254, -558, + -201, -630, 16, -436, -1647, 1649, -1096, -1267, 2273, + -1270, 20, 1749, -2509, 780, 942, -1859, 2762, 304, + -300, 2617, -947, 861, 2601, -1153, 754, 1629, -681, + 686, 1443, -235, 1900, 5, -565, 1559, 285, -170, + 757, 480, 547, 752, -427, 50, 839, -95, -791, + -1698, -291, -62, -1730, 524, 1008, -2176, -369, 165, + -749, -972, -287, 889, -1218, -1712, 833, -855, -995, + -14, -793, -1815, 605, -607, -1890, 769, -781, 230, + 1155, -2000, 876, 1835, -1617, 9, 1058, -1232, 859, + 1486, -1301, 1595, 501, -951, 2935, -921, -634, 2826, + -793, 655, 2660, -232, 235, 1879, 481, -51, 804, + 987, -360, -331, 2099, -302, -149, 1966, -1233, -12, + 1330, -2265, 1256, -116, -1394, 2937, -995, -1572, 2964, + -2257, -2587, 1820, -2132, -1609, 778, -1596, -486, 560, + -1749, 274, -706, -1714, 1304, -360, -2657, 1833, -750, + -1729, 433, -1461, -794, -1545, -892, 385, -891, -374, + 1261, -589, 235, 815, -773, -669, 636, -471, 136, + 871, -392, 782, 677, -472, 1130, 1029, -1262, 1070, + 2171, 575, 675, 600, 2104, 1077, -182, 2621, -604, + -30, 3302, -1331, 599, 742, 291, 1329, -551, 1043, + 1729, -1754, 1220, 1113, -2174, 1281, 743, -2027, 851, + -205, -1576, 214, -1629, -605, -394, -1508, -254, -63, + -489, -847, -26, -997, -1065, -120, -376, -1283, -1393, + 83, -212, -1610, 419, -1120, -590, 395, -1210, -21, + -273, -622, 899, -196, -1059, 1130, 616, -529, -166, + 794, 22, -216, 862, 664, -390, 980, 228, 789, + 182, 402, 2149, -1133, 799, 2637, -799, 176, 1306, + 905, -93, 677, 338, 121, 483, 297, 339, 347, + 249, 731, 40, 66, 112, -889, -128, 582, -1191, + -67, -1364, -233, 488, -1734, -634, 1517, -1657, -1015, + 594, -1422, 1396, -1357, -1617, 1254, -1596, -941, 789, + -1860, -77, 245, -327, 569, -723, 104, 905, -543, + -918, 1387, -42, -440, 619, 68, 45, 1364, -880, + 19, 1491, -561, 1174, 1403, -1411, 1351, 1222, -612, + 864, 877, -658, 382, 864, -552, 1286, 309, -105, + 1083, -170, -289, 1049, -248, -537, 625, -48, 337, + -385, 532, -315, -1398, 588, -628, -1192, 649, -806, + -170, 541, -2267, 1052, 274, -1970, 833, 253, -1345, + -290, -120, -959, -94, -189, -1397, -136, -155, -654, + 207, -706, 617, 415, -1962, 1169, 670, -1132, 319, + 297, -589, 100, 510, -620, 610, -153, -15, 1327, + -99, 229, 281, 169, 1015, -106, 1197, 577, -698, + 577, 931, -964, 1605, 505, -1713, 2369, 115, -1585, + 1839, 664, -1411, 867, 620, 329, 491, -1119, 420, + 266, -1708, 499, -69, -1037, 795, -321, -959, 32, + 235, -1748, 295, -249, -230, 485, -1185, -97, 489, + -2036, 711, 405, -2800, 593, 434, -1038, 536, 347, + -570, 705, -806, -290, 818, -999, 53, 1585, -756, + -657, 1180, 115, -364, 217, -226, 1033, 347, -20, + 611, 658, 590, -128, -451, 1676, -660, -21, 805, + -880, 1481, 412, -1534, 1522, 221, -132, 662, -407, + 613, 1132, -551, -187, 1184, -577, -444, 953, -1034, + -472, 461, -865, -99, 637, -572, 300, 450, -591, + 137, 404, -972, 306, -524, -1167, 433, 124, -1326, + -368, -305, -917, 452, -626, -695, 656, 258, -1401, + 270, 446, -1045, 636, -357, -1072, 913, 512, -1732, + 489, 952, -747, 58, 673, -453, 1125, -488, 46, + 1723, -1244, 417, 1803, -1215, 623, 659, -560, 676, + -9, 92, 701, 1100, -623, 142, 283, -512, 547, + 576, -525, -155, 1143, -1286, -329, 1959, -1302, -459, + 1188, -1199, 1020, -118, -1303, 956, -905, -647, 595, + -356, -1354, -74, 750, -791, -335, 56, -862, -36, + 276, -279, 46, -485, -181, 196, -584, -238, 259, + -314, -77, 383, 509, -386, -180, 859, -542, 955, + 372, -362, 1458, 113, -106, 1495, -534, 63, 1295, + -505, 846, 983, -1097, 1764, 320, -185, 1061, -525, + 115, 217, -328, 326, 312, 374, 179, -683, 485, + -1286, 147, -583, -979, 888, -504, -1235, 715, -1050, + -1111, 848, -828, -1043, -115, -327, 22, -451, -1008, + 98, -262, -545, -363, -48, -257, -731, 878, 96, + -1186, 426, 359, -1101, 1074, -267, 521, -375, -166, + 1398, -994, 780, 550, 124, -298, 581, 236, 305, + -111, 396, 741, -10, 662, 155, 271, 563, 65, + -318, 812, -483, 843, 75, -714, 1152, -26, -190, + -97, 533, -111, -564, 724, -24, -820, 835, -473, + -632, 154, -104, -932, 919, -606, -619, 496, -310, + -271, -360, 120, -630, 126, 65, -931, 548, -207, + -455, 410, -282, -931, 944, -354, 69, 412, -661, + 1068, -969, -443, 1894, -1281, -442, 2003, -1640, 713, + 852, -1344, 1338, -457, 243, 498, -697, -129, 993, + -388, -76, 1039, -768, 492, -104, -58, 951, -854, + 181, 1093, -1111, 491, 544, -1061, 118, 586, -477, + -411, 392, 233, 91, -908, 532, 218, -1176, 670, + -74, -674, 696, -801, 194, 592, -1790, 762, -564, + -791, 595, -145, -727, 228, 434, -246, -232, -169, + 281, -324, 289, -120, -270, -49, 282, 250, -56, + -405, 507, 27, -1060, 1329, -203, -204, 1677, -767, + -313, 1272, -968, 717, 183, -1652, 2157, -75, -1906, + 2590, -428, -1614, 2564, -1511, -240, 1421, -1911, 1420, + 396, -1397, 1691, -694, -1500, 1942, -823, -784, 841, + -635, 759, -447, 351, 44, -946, 227, 441, -564, + 155, -719, 182, 509, -320, -300, 205, -662, 726, + 469, -1240, 191, 664, -269, -152, -18, 214, -149, + -257, 347, 76, -79, -384, 874, -387, -269, 892, + -783, 537, 46, 27, 251, -332, 133, 377, -522, + 232, 626, -362, -499, 1112, -342, -522, 362, -187, + 547, -384, -155, 517, -551, 227, 651, -825, -88, + 579, -758, -40, 456, -774, 542, -164, -482, 968, + -1000, -394, 1094, -885, 431, 74, -348, 403, -959, + 831, -465, -330, 762, -717, -645, 1342, -499, -416, + 944, -417, -438, 737, -368, -42, 740, -1234, 689, + 29, -106, 619, -824, -10, 1047, -824, 146, -59, + 210, 163, -43, 522, -352, 213, 460, -1049, 599, + 308, -843, 632, 223, -504, 296, 530, -931, 751, + -176, -524, 379, 236, -626, 66, 662, -575, 191, + -175, -619, 660, -424, -217, 704, -498, 200, 62, + -543, 280, 91, -378, 54, 168, -554, 670, -215, + -1097, 1805, -1015, -617, 1642, -1560, 727, 61, 7, + -48, -659, 1308, -752, -613, 914, 160, -469, 164, + -167, 274, 326, -667, 497, 333, -757, 1252, -481, + -1257, 2019, -949, -719, 1676, -1078, 250, 323, -1100, + 1550, 145, -1697, 972, 522, -966, 374, -365, 846, + -276, -756, 629, -278, 302, -151, -243, -363, 841, + -7, -1092, 476, 45, 201, -378, -456, 1113, -926, + 97, 178, -240, 326, -597, 472, -10, -190, 394, + -501, -259, 307, 133, 240, -433, -192, 472, -190, + 12, 398, -191, -605, 1295, -576, -154, 474, -661, + 866, -968, 172, 887, -736, 36, 259, -201, 265, + 460, -859, 622, 102, -690, 776, -80, -745, 919, + 140, -750, 224, 134, -236, -196, 456, 409, -1069, + 600, 239, -306, -383, 541, -213, -323, -121, 700, + -735, 179, 222, -613, 653, -711, -81, 592, -694, + 117, 703, -772, -264, 644, -117, -422, 276, 64, + -355, -430, 800, -74, -619, 1207, -1057, 4, 960, + -1219, 977, -78, -1186, 1536, 267, -1388, 1144, -90, + -1052, 1889, -1255, -387, 1815, -1763, 1037, 421, -1003, + 767, -24, -277, -54, 759, -285, -1015, 1422, -581, + -121, 547, -687, 288, 440, -626, -623, 1261, -248, + -1133, 1204, -714, 382, 219, -851, 240, -161, 672, + -261, -855, 1043, -599, 111, -362, 225, 641, -913, + -122, 1075, -1165, 432, 131, -803, 978, 33, -1291, + 992, 224, -1054, 789, -121, -215, 262, -11, 89, + -174, 365, -240, 114, 406, -813, 291, 233, 158, + -377, 194, 216, -477, 635, -228, -512, 599, 23, + -273, 71, 258, 10, -155, -198, 354, 61, -749, + 768, -19, -709, 596, 97, -276, 164, 69, -144, + -20, 529, -897, 188, 480, -703, 836, -874, 259, + 917, -1044, -7, 566, -97, -439, 256, -466, 998, + -360, -1134, 1619, -762, -752, 1446, -707, -177, 652, + -899, 579, 253, -410, 146, -262, 275, 353, -610, + 52, 671, -862, 419, -140, 273, 247, -1062, 1005, + -175, -497, 772, -431, -101, 450, -598, 266, 428, + -842, 477, -11, -554, 642, 17, -787, 544, 445, + -625, -205, 796, -222, -733, 764, -572, 423, 166, + -994, 931, -228, -303, 362, -214, 104, 448, -1091, + 722, 570, -1311, 773, 259, -648, 477, 193, -682, + 302, 459, -464, -383, 1120, -561, -564, 1083, -372, + -354, 864, -586, -200, 502, -331, 27, 446, -657, + 281, 571, -888, 502, 251, -423, 116, 277, -263, + 118, -170, 168, 367, -723, 202, 438, -793, 451, + -30, -292, 202, 38, -188, -66, 221, -90, -105, + 7, 346, -578, 337, 247, -371, -14, 22, 36, + 151, -322, -244, 692, -556, -5, 550, -560, 200, + 161, -347, 191, 258, -520, 441, -212, -215, 584, + -428, -251, 213, 90, -187, 109, 138, -211, -17, + 191, 111, -259, 161, -141, 232, -175, 0, 154, + -369, 539, -171, -438, 484, 43, -375, -37, 249, + 196, -328, -106, 541, -531, 103, 240, -191, 186, + -363, 40, 585, -573, 258, 170, -593, 515, -261, + -86, 407, -339, 164, -214, -34, 464, -377, -206, + 336, -230, 239, -85, -69, 322, -503, 322, 142, + -748, 867, -160, -753, 836, -249, -362, 750, -374, + -222, 448, -82, -246, 399, 13, -429, 441, -47, + -127, -29, 337, -502, 318, 132, -457, 498, -145, + -91, 98, 208, -179, 54, 62, -260, 237, 96, + -161, 32, -150, 93, 21, -31, 74, 75, -322, + 164, 168, -191, 119, -121, -66, -195, 296, -128, + -251, 381, -56, -338, 281, -29, -472, 664, -301, + -275, 423, -285, -77, 258, -82, -139, 160, -54, + -26, 27, 75, -49, -196, 305, -131, -187, 262, + -37, -206, 65, 269, -240, -144, 261, 54, -338, + 355, 3, -503, 535, -253, -210, 433, -290, -33, + 381, -546, 173, 252, -364, 271, -329, 166, 266, + -564, 507, -32, -648, 861, -400, -357, 819, -519, + -74, 392, -423, 426, -306, -93, 691, -991, 537, + 467, -992, 614, 426, -823, 491, 182, -371, 174, + 84, -64, 98, -96, 23, 182, -69, -211, 226, + 18, -134, 334, -514, 352, 378, -623, 363, 266, + -592, 493, -46, -369, 594, -440, -10, 295, -368, + 326, -192, -140, 306, -305, 140, 198, -396, 202, + 154, -341, 208, -8, -169, -76, 106, 20, -347, + 233, 30, -193, 117, -9, -165, 182, -4, -195, + 96, 131, -188, -106, 166, -71, -99, 57, 4, + -31, -131, 101, 63, -199, 225, -25, -281, 342, + -247, -170, 516, -289, -263, 422, -158, -148, 363, + -192, -138, 122, 62, -105, 7, 194, -53, -224, + 83, 173, -182, 20, 178, -274, 182, 74, -109, + -5, 319, -303, -72, 428, -371, 50, 271, -204, + 17, 161, -256, 169, 93, -169, 94, -89, 139, + 80, -199, 325, -67, -83, 202, -154, 16, 202, + -325, 162, 61, -93, 201, -278, 236, 108, -477, + 594, -145, -370, 647, -261, -356, 669, -369, -181, + 420, -266, -154, 159, -25, 53, -40, -22, 68, + -203, 144, -2, -173, 88, -3, -62, 2, 75, + 55, -95, -130, 219, -142, -191, 164, -170, 44, + 0, -246, 249, -27, -413, 461, 27, -490, 292, + 19, -145, 13, 99, 91, -466, 209, 295, -773, + 465, 210, -680, 410, 163, -358, 399, -201, 87, + 23, -212, 270, -230, 86, 159, -353, 381, -73, + -456, 726, -353, -357, 754, -367, -344, 657, -59, + -417, 432, 35, -309, 153, 97, -69, 89, -101, + 63, 107, -127, 106, 112, -26, -236, 376, 43, + -479, 544, -57, -407, 447, -148, -103, 195, -198, + 80, 156, -228, 35, 145, -77, -55, 130, -33, + -190, 123, 41, -170, 74, 114, -241, 67, 192, + -195, -76, 186, -136, -133, 213, -105, -110, 144, + -51, -126, 154, -59, -124, 147, -49, -132, 82, + 26, -130, 63, 68, -211, 97, 131, -224, 59, + 184, -250, 59, 205, -225, -67, 163, -135, -24, + 74, -22, -4, -81, 21, 71, -137, 71, 47, + -120, 71, 34, -65, 138, -6, -116, 112, -47, + -39, 20, -75, 64, -7, 2, 35, 52, -61, + -29, 81, -61, -30, 195, -91, -136, 261, -11, + -186, 162, -86, -35, 152, -106, -32, 126, -4, + 49, 33, -9, -11, 46, 111, -132, -3, 204, + -175, -10, 281, -146, -94, 226, -126, -36, 58, + -14, 61, -172, 48, 193, -221, 83, 149, -279, + 195, 130, -357, 226, 102, -260, 191, 16, -223, + 124, 14, -144, 90, -31, -81, -66, 54, 103, + -181, 29, 174, -281, 92, 81, -226, 139, -133, + -41, 167, -147, 44, 27, -132, 107, -34, -122, + 105, -54, 17, 52, -131, 138, 33, -206, 158, + 43, -80, 24, 10, -27, 33, 43, -71, 15, + 71, -42, 14, 18, 0, -3, -14, -14, 58, + 46, -99, 122, 105, -202, 125, 119, -238, 112, + 133, -242, 113, 129, -301, 52, 161, -177, 82, + 73, -139, 46, 122, -119, 22, 155, -230, 23, + 242, -211, -12, 182, -184, -57, 190, -34, -101, + 58, -20, 6, 103, -61, -78, 12, 18, 12, + 86, -71, -27, 43, -24, 8, 39, -109, 21, + -4, -44, 66, 13, -59, 61, -39, 35, 113, + -179, 19, 171, -158, 14, 112, -133, 26, 9, + -43, -9, 6, 41, -77, 22, 80, -61, -63, + 65, -32, -32, 125, -105, -11, 114, -120, 42, + 42, -92, 45, -56, -25, 131, -83, -24, 97, + -51, -5, 67, -69, 7, 41, -27, 8, 3, + -10, 8, -3, -87, -28, 122, -33, -58, 124, + -53, -50, 67, -115, -17, 111, -112, -30, 101, + -24, -13, 41, 3, 45, -13, -34, 23, 23, + -19, 13, -49, -49, 68, -68, -32, 91, -58, + -18, 73, -19, -27, 17, -33, -35, 99, -38, + -99, 78, -31, -62, 95, -71, -124, 184, -15, + -146, 160, -27, -109, 140, -25, -63, 84, -34, + -18, 58, -68, -16, 22, -87, 86, 23, -130, + 61, 62, -132, 51, 168, -139, 35, 133, -121, + 50, 102, -120, 40, 126, -87, -40, 119, -14, + -59, 78, 11, -68, 41, 24, -25, 55, -2, + 15, 21, -73, 56, 88, -74, -41, 4, -10, + -4, 5, 7, -39, -3, -4, -39, 94, 52, + -135, 42, 90, -86, 12, 21, -55, -70, -37, + 55, -63, -35, 50, -100, 21, 84, -151, 24, + 87, -94, 51, 2, -58, 104, -61, -70, 60, + -25, -42, -31, 55, 35, -129, 47, 69, -65, + 77, 2, -60, 110, -32, -69, 84, -54, -26, + 98, -28, -7, 49, -49, -19, 119, -11, -157, + 20, 106, 29, -8, -38, -30, 72, 30, -3, + 1, -32, -11, -9, 52, 46, -144, -38, 86, + -31, -9, -42, -75, 142, 34, -64, 79, -109, + -55, 195, -69, -80, 48, -49, 62, 25, -111, + -42, 52, 19, -41, 1, -16, -33, 44, 30, + -21, 17, -2, -30, 111, 34, -111, 83, 55, + -119, 66, 62, -89, 63, -39, -143, 168, 21, + -158, 158, 32, -132, 134, -3, -77, 88, -45, + -18, 117, -51, -71, 10, 30, 35, -27, -63, + 13, 34, 23, -23, 19, -4, -92, 34, 74, + -69, -15, 20, -36, 56, -36, -96, 69, -34, + -122, 32, 31, -51, -3, -21, 4, 43, -44, + 6, 81, -39, -35, 26, -38, -24, 29, -16, + -47, -6, 19, -7, -9, 41, 32, 13, -2, + -21, 3, 24, 49, -3, -66, 14, 95, -7, + -52, 80, 68, -72, -14, 39, 2, 24, -6, + -53, 86, 21, -78, 67, 28, -34, 16, -23, + -1, 70, -3, -58, 45, 33, -94, -34, 62, + 41, -11, -27, 27, 46, 14, -33, -12, 44, + -16, -59, 6, 45, -3, -42, 2, 13, 19, + -1, -71, 3, 42, -36, 6, 17, 26, 5, + -46, 6, -68, -75, 86, -20, -90, 80, 4, + -86, 5, 2, -33, -15, -2, -8, -18, 15, + -7, -25, 27, -28, -88, 39, -2, -85, 58, + 40, -45, 3, 17, 0, 11, -4, -3, 84, + 22, -113, 8, 94, 10, 9, 28, 6, -3, + 5, -2, 23, 23, -1, -40, 20, 48, -40, + -21, 72, 7, -40, -1, 27, 16, 30, 31, + -16, 11, 9, -71, -7, 62, 21, -61, -19, + 78, -2, -22, 67, -42, -12, 75, -79, 47, + 86, -124, -42, 21, 4, 23, -32, -7, 19, + 1, -13, -46, 2, 32, -43, -7, 86, -16, + -22, 46, -61, -35, 11, -64, -38, 17, -12, + -27, 20, 41, 6, -58, -61, 58, -51, -77, + 36, -25, 19, 93, -76, 1, 72, -92, 15, + 40, -56, 65, 13, -29, 82, -9, -21, 24, + -83, -5, 4, -63, 77, 80, -58, -6, -19, + -43, 100, 5, -36, 63, 33, -26, -48, 26, + -18, -75, 34, 24, -45, -1, 6, -35, -24, + -23, -22, 47, -15, -46, 31, -40, -41, 74, + -32, -73, 59, -51, -26, 143, -29, -42, 93, + -44, -21, 56, -7, 55, 51, -61, 74, 111, + -71, 35, 124, -123, -3, 62, -79, 100, 49, + -122, 143, 79, -137, 72, 30, -82, 75, -10, + -48, 35, -23, -25, 34, 0, -54, -6, 34, + -46, -59, -7, -72, -6, 70, -41, -39, 23, + -33, 11, 104, -44, -30, 54, -69, -20, 62, + -75, 1, 45, -69, 1, 40, -59, -15, 18, + -16, 38, -1, -52, 8, 14, -32, 11, -15, + -58, 18, -22, -44, 69, 40, -50, -21, 1, + -35, -3, -5, -20, 40, 36, -41, -36, -43, + -11, 48, -34, -40, 51, -10, -9, 30, 10, + 12, 51, 51, -8, -16, 32, -6, 31, 24, + -38, 43, 18, -15, 53, -10, -55, 9, 8, + -28, 21, 10, -26, 21, 10, -9, 5, -29, + -13, 38, -1, -11, 49, 0, -41, 10, 23, + -25, -35, -2, -32, -10, 58, -6, -18, 16, + -9, 4, 11, 17, 21, 21, 12, -2, 49, + -16, -128, 21, 75, -32, 22, 34, -59, 48, + 75, -69, -11, -2, -65, 39, 57, -54, -79, + -11, -20, -13, 38, 4, -9, -22, -22, 33, + -7, -52, 10, -10, -19, 54, 47, -21, -35, + -6, -4, 11, 8, -28, 1, 8, -4, 30, + 1, -22, 26, -7, -24, 56, 25, -45, 13, + 24, -32, 13, 22, -46, -2, 15, -39, 28, + 32, -69, 0, 27, -69, 0, 39, -40, 28, + 55, -27, -13, 0, -14, 37, 25, -25, 34, + -3, -69, 26, 39, -41, -6, 29, -7, 5, + 66, 41, -27, -17, 6, -14, -21, 0, 29, + -9, -26, 32, -5, -34, 60, 15, -60, 20, + 13, 11, 43, -48, -15, 88, -13, -55, 26, + -32, -46, 35, 14, -37, -11, 12, -20, 11, + 9, -64, -16, 17, 5, 38, 7, -30, -9, + -49, -11, 52, -15, -38, -27, -12, 36, 53, + 1, -37, -17, -12, 0, 31, 1, 13, 40, + -15, 2, 47, -15, -17, 28, -2, -4, 25, + -6, -12, 2, -17, -9, 5, -15, 17, 21, + -28, 0, 15, -43, -63, -6, -14, -8, 37, + -34, -40, 30, -12, -14, 37, -13, -16, 26, + -15, -2, 13, -37, -13, 32, 13, -8, -2, + -12, -8, 9, 9, -3, 4, 13, 34, -2, + -22, 40, 19, 29, 25, -48, -17, 23, 17, + 7, 3, 0, 12, 37, -1, -25, 30, 41, + -7, 7, 29, -31, -31, -23, -27, 5, 2, + -18, -2, 22, 9, -6, 5, -7, -24, 9, + 0, -28, 19, 61, -11, -45, 21, -28, -65, + 28, 33, -44, -27, -6, -26, -8, 4, 5, + 9, -10, -46, -20, 20, -7, -7, -33, -26, + 50, 9, -65, -22, -3, -20, 15, 21, 20, + 24, -16, -27, -13, 14, 21, -38, -48, 9, + 35, 28, 21, 3, -31, -8, 57, 32, -35, + -22, 20, 14, 12, 28, 39, 0, -18, 44, + -2, -17, 53, 0, -27, 33, 43, 5, -10, + 25, 47, -3, -4, 36, 15, -12, -3, 29, + 41, 23, 23, -8, -32, 15, 37, 0, 3, + 22, 31, 1, -20, 27, 2, -50, 0, 33, + 16, -16, -17, 18, -26, -34, 31, -27, -84, + -33, 4, -5, -22, -17, -28, -66, -24, 8, + -16, -25, -51, -13, 45, -11, -49, -26, -49, + -38, 21, 10, -52, -58, -19, -4, 9, -31, + -29, 55, 2, -45, 29, 10, -22, 49, 33, + -27, -19, -5, 30, 47, 11, -11, -2, 8, + 5, 17, 8, 3, 57, 63, 28, 24, 11, + 2, 14, 22, 7, 7, 2, 23, 33, -2, + -8, 14, 7, 20, 57, 32, -5, 12, 23, + 10, 17, 26, -18, -72, -6, 74, 61, 13, + -17, -21, -7, 29, 45, 5, -52, -49, 1, + 10, 35, 40, -46, -66, 7, 31, -27, -44, + -12, -41, -22, 32, -12, -32, -3, -17, -22, + -22, -31, -30, -23, -13, 3, 0, -21, -19, + -7, -17, -9, 18, -40, -64, 1, 4, -4, + 8, -17, -28, -1, 9, -7, -9, 27, 6, + -63, -32, 52, 25, -46, -23, -6, -11, 35, + 29, -50, -44, 17, -6, -12, 53, 28, -17, + -9, 28, 34, -20, -18, 22, 43, 28, -6, + 8, 14, 19, 28, 14, 27, 26, 12, 76, + 66, -18, -2, 18, -12, -1, -2, -1, 51, + 30, -18, 5, 14, -12, 2, 13, -25, -9, + 32, 7, -5, 15, -12, -33, -18, -13, 6, + 0, -25, -12, 1, -17, 0, 13, -24, -27, + 4, 35, 14, -22, 5, 13, -18, -30, -10, + -7, -7, 31, 23, -27, -26, 9, 47, 6, + -50, -11, 19, 1, 11, 12, -19, -43, -18, + 10, -6, -3, 12, 2, -12, -16, 10, 9, + -25, -21, -10, -13, 0, 8, -1, -9, 10, + 4, -34, 14, 46, 5, 18, 24, -15, -7, + 20, -1, -13, 7, 11, 14, 11, -2, 8, + 27, 10, -1, 13, -2, -7, 48, 44, -15, + -16, -6, 3, 7, -35, -25, 8, -31, -16, + 30, 36, 22, -13, -21, -10, 8, 2, -58, + -37, 32, 25, -1, -25, -21, 3, 3, -6, + -11, -3, 2, 4, 34, 22, -25, -19, 0, + -6, -10, -8, -35, -32, 8, -3, -20, -11, + -6, 3, 8, -8, 3, 25, 23, -7, -35, + -15, 8, -20, -6, 15, -44, -29, 19, -5, + -1, 18, 28, 6, -21, 9, 11, -20, -10, + 18, 22, 6, -2, 12, 6, 23, 34, -20, + -19, 1, -10, 34, 41, 13, 6, 3, 22, + 11, -4, 4, -12, -8, 17, 18, 12, -1, + 5, 9, -6, -2, 4, 1, 3, 2, -6, + -32, -25, 9, 18, 27, -4, -54, -29, 2, + -3, -18, -38, -28, -10, 9, 20, 5, -9, + -15, -3, 2, -14, -15, -6, 5, 10, 6, + 3, -11, -9, -5, -20, -13, 8, 3, -14, + 6, 20, -15, -21, 9, 19, 21, 12, -4, + -21, -17, 16, 27, -4, -28, -2, 26, 9, + -12, -16, -28, -28, -4, 4, -15, -9, 3, + -10, -16, 2, 17, -10, -26, 3, 16, 26, + 17, -12, -9, 2, -2, -5, -11, 5, 28, + 1, -14, 13, 14, 5, 18, 6, -17, -5, + 7, 2, -3, 11, 10, -1, 50, 36, -28, + 21, 39, -9, -6, 2, 10, 36, 20, -2, + -3, -11, -10, -6, -5, -4, -8, 2, 17, + 1, -13, 11, -13, -36, 11, 14, -19, -6, + 3, 0, 20, -5, -24, 12, 7, -11, 2, + -15, -28, -1, 6, -14, -31, -39, -19, 19, + 37, 3, -32, -27, -6, 13, 31, 15, -41, + -41, 25, 35, -3, -16, -25, -19, -10, -3, + 19, 10, -4, 7, -4, -19, -12, -13, -9, + 6, 2, -12, -6, 12, 6, -1, -5, -19, + -7, 7, 40, 56, -3, -13, 21, 24, 7, + -11, -9, -3, 24, 28, -10, 1, 12, 21, + 24, -16, -15, 4, -7, -2, 19, 13, -11, + -7, -8, 15, 41, 5, -16, -18, -11, 26, + 26, -5, -12, -14, -6, 10, 8, -8, -16, + -16, -3, 10, 1, -3, -3, -2, -15, -18, + 6, -4, -4, 21, 4, -2, 15, 13, 0, + -2, 12, 7, -15, -9, 1, -2, 2, -1, + -9, -15, -17, -14, -10, 1, -4, -16, -17, + -1, 18, 8, 1, 22, 11, -19, -10, 4, + -23, -29, 0, -2, -14, -6, 13, 7, -23, + -13, 10, 9, 11, 10, 4, -4, -4, 1, + 6, 14, 9, 2, 0, 2, 6, 4, -9, + -18, -8, 8, 18, 8, 13, 9, -27, -22, + -10, -24, -9, 17, 11, 2, 9, 3, -13, + -10, -1, -7, -1, 10, -4, 1, 16, 12, + -6, -14, -2, -5, -1, 0, -1, 6, -9, + -3, 12, 4, 1, -2, 2, 17, 24, 22, + 9, 8, 21, 14, -2, -2, 4, -1, -7, + -7, -6, -1, -6, 17, 30, -7, -10, -3, + -19, -18, 2, 21, 4, -20, -6, -1, -18, + -14, -6, -7, -1, 6, 10, 8, -5, 0, + 10, -22, -40, -22, 4, 34, 16, -19, -16, + -12, -17, -16, -17, -29, -28, -4, 10, 16, + 22, 13, 4, -1, -5, 16, 15, -11, -6, + 9, 3, -14, -22, -19, -12, 5, -5, -15, + 3, 9, 27, 17, -4, 8, -2, 1, 16, + 11, 9, 9, 8, -14, -16, 7, -5, -15, + -11, -5, 19, 25, 25, 43, 21, -9, -9, + -19, -10, 14, -11, -19, 8, 3, 1, 11, + -1, -24, -20, -1, 2, 7, 24, 22, 11, + 8, 6, -2, -11, -3, -2, -4, 0, -7, + 0, 6, -1, -16, -35, -8, 8, -11, -6, + 6, 18, 16, 7, 12, 5, -2, -3, -10, + -21, -27, -10, -3, -3, 8, 0, -9, -10, + -3, 0, -5, 6, 9, 19, 23, 8, -5, + -19, -16, -5, -6, -27, -22, 1, 6, 8, + 2, -9, -13, -15, -18, -13, 4, 25, 29, + 26, -2, -22, 1, 8, 1, -6, -6, -7, + -20, 0, 13, -14, -24, -24, -21, 2, 14, + 16, 23, 15, 10, 10, 5, 0, -26, -32, + 3, 19, 5, -8, -7, -8, -3, 17, 27, + -7, -28, 10, 32, 10, 1, 10, 3, -4, + 22, 24, -31, -40, 0, 6, 5, 17, 17, + 1, 10, 30, 8, -12, -6, 9, 6, -12, + -5, 1, -4, 6, 11, 0, -9, -4, -3, + -4, -3, 2, 0, -2, -9, -27, -23, 2, + 13, -6, -9, -3, -12, -2, 10, 6, -7, + -19, -31, -13, 16, 11, -3, -13, -15, 0, + 7, -3, -7, -1, -4, 7, 15, 0, -12, + -8, -1, -7, -12, -21, -17, 5, 30, 25, + -6, -6, 0, -12, -8, 2, 13, 11, 1, + 5, 4, 4, 10, -1, -20, -12, -4, 3, + 15, 11, -7, -24, -4, 8, -2, -14, -25, + -17, 7, 21, 14, 1, 0, 12, 17, 13, + 6, 1, 6, 14, 11, -10, -21, -12, -4, + 3, -2, -21, -24, -2, 12, 14, 17, 4, + -2, 11, 11, 11, 1, -34, -32, -5, 10, + 7, -11, -12, 6, 7, -4, -10, -15, -5, + 17, 21, 0, -15, -15, -1, 5, -18, -18, + -10, -9, 24, 27, -9, -14, 0, 9, 25, + 22, 1, -7, -2, 16, 13, -14, -10, 7, + 0, 2, 15, 2, -9, 5, 10, -5, -3, + 10, 3, 0, 15, 15, -1, -3, 8, 6, + -7, -7, 2, 0, -4, 5, -8, -37, -28, + -1, 8, 6, 10, -1, -12, 12, 28, 8, + -17, -16, -15, -17, 1, 6, -4, -8, -4, + -15, -15, 6, -9, -15, 10, 9, -13, -8, + 5, -2, -10, 5, 12, -27, -33, 9, 8, + -16, -3, 16, -3, -7, 22, 22, 10, 5, + -11, -16, -4, 9, 12, 6, -3, 2, 2, + -1, 4, -7, -8, 1, 8, 19, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h new file mode 100644 index 0000000000000000000000000000000000000000..33aeea516fb8c7fcb080b3b971bf5d69b81b9c4c --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_1000ms_sample_data.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was created from the PCM data in a WAV file held in v2 of the +// Speech Commands test dataset, at the path: +// speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav +// This should contain all 16,000 samples from the one-second file. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_1000MS_SAMPLE_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_1000MS_SAMPLE_DATA_H_ + +#include + +extern const int g_yes_1000ms_sample_data_size; +extern const int16_t g_yes_1000ms_sample_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_1000MS_SAMPLE_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD index 47ac85c605488bdaa30515325122019a2d88678f..451eed28528fa58c56af879bf556c395aaf7b6cf 100644 --- a/tensorflow/lite/experimental/micro/kernels/BUILD +++ b/tensorflow/lite/experimental/micro/kernels/BUILD @@ -22,7 +22,6 @@ cc_library( copts = tflite_copts(), deps = [ "//tensorflow/lite/c:c_api_internal", - "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:padding", @@ -43,7 +42,42 @@ cc_library( copts = tflite_copts(), deps = [ ":micro_ops", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "portable_optimized_micro_ops", + srcs = [ + "fully_connected.cc", + "portable_optimized/depthwise_conv.cc", + "softmax.cc", + ], + hdrs = [ + ], + copts = tflite_copts(), + deps = [ "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels:padding", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "portable_optimized_ops_resolver", + srcs = [ + "all_ops_resolver.cc", + ], + hdrs = [ + "all_ops_resolver.h", + ], + copts = tflite_copts(), + deps = [ + ":portable_optimized_micro_ops", "//tensorflow/lite/experimental/micro:micro_framework", ], ) @@ -61,6 +95,19 @@ tflite_micro_cc_test( ], ) +tflite_micro_cc_test( + name = "portable_optimized_depthwise_conv_test", + srcs = [ + "depthwise_conv_test.cc", + ], + deps = [ + ":portable_optimized_ops_resolver", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + tflite_micro_cc_test( name = "fully_connected_test", srcs = [ diff --git a/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc index 05ba8798c0dc34eab5c563489cf9fc928325d00f..ff952b39c001be0d0f757d2de130f9c0df27c543 100644 --- a/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc @@ -50,6 +50,11 @@ void TestDepthwiseConvFloat(std::initializer_list input_dims_data, CreateFloatTensor(output_data, output_dims, "output_tensor"), }; + // Place a unique value in the uninitialized output buffer. + for (int i = 0; i < output_dims_count; ++i) { + output_data[i] = 23; + } + TfLiteContext context; PopulateContext(tensors, tensors_size, &context); @@ -403,4 +408,84 @@ TF_LITE_MICRO_TEST(SimpleTestReluQuantized) { kTfLiteActRelu, output_data); } +TF_LITE_MICRO_TEST(SimpleTestOptimizedFilterWidth) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = 0; + const float input_max = 255.0f; + const float filter_min = -63.5f; + const float filter_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 128.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 9; + uint8_t output_data[output_dims_count]; + + tflite::testing::TestDepthwiseConvQuantized( // + {4, 1, 1, 9, 1}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), + F2Q(2, input_min, input_max), + F2Q(7, input_min, input_max), + F2Q(8, input_min, input_max), + F2Q(3, input_min, input_max), + F2Q(4, input_min, input_max), + F2Q(9, input_min, input_max), + F2Q(10, input_min, input_max), + F2Q(5, input_min, input_max), + F2Q(6, input_min, input_max), + F2Q(11, input_min, input_max), + F2Q(12, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {4, 2, 1, 8, 1}, // Filter shape. + { + // Filter values. + F2Q(1, filter_min, filter_max), + F2Q(2, filter_min, filter_max), + F2Q(3, filter_min, filter_max), + F2Q(4, filter_min, filter_max), + F2Q(-9, filter_min, filter_max), + F2Q(10, filter_min, filter_max), + F2Q(-11, filter_min, filter_max), + F2Q(12, filter_min, filter_max), + F2Q(5, filter_min, filter_max), + F2Q(6, filter_min, filter_max), + F2Q(7, filter_min, filter_max), + F2Q(8, filter_min, filter_max), + F2Q(13, filter_min, filter_max), + F2Q(-14, filter_min, filter_max), + F2Q(15, filter_min, filter_max), + F2Q(-16, filter_min, filter_max), + }, + filter_min, filter_max, // Filter quantization range. + {1, 1}, // Bias shape. + { + // Bias values. + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + F2Q32(4, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + 220, + 184, + 140, + 150, + 161, + 200, + 172, + 148, + 133, + }, + {4, 1, 1, 9, 1}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/kernels/portable_optimized/depthwise_conv.cc b/tensorflow/lite/experimental/micro/kernels/portable_optimized/depthwise_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1ddf6b0733b07fa57c505e95e79d03c9ca0f1ce --- /dev/null +++ b/tensorflow/lite/experimental/micro/kernels/portable_optimized/depthwise_conv.cc @@ -0,0 +1,439 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" + +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace depthwise_conv { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// Size of the cached buffer we'll be using to hold reordered weights. +constexpr int kReshapedFilterDataSize = 1 * 1024; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + int out_width, int out_height, + const TfLiteType data_type, OpData* data) { + data->padding.height = ComputePadding(params->stride_height, 1, height, + filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, 1, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return kTfLiteOk; +} + +// Specialized implementation of the depthwise convolution operation designed to +// work with the particular filter width of eight used by the default micro +// speech sample code. It uses 1KB of RAM to hold reordered weight parameters, +// converted from TFLite's NHWC format to NCHW format, and expressed as signed +// eight bit integers, rather than unsigned. Care must be taken when calling +// this not to use it for more than one node since there's only a single static +// buffer holding the weights. You should use this implementation if depthwise +// convolutions are a performance bottleneck, you have a layer that meets the +// parameter requirements, and the extra RAM usage and additional code size are +// not an issue. +static inline void DepthwiseConvOptimizedForFilterWidthEight( + TfLiteContext* context, const DepthwiseParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& filter_shape, const uint8* filter_data, + const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int depth_multiplier = params.depth_multiplier; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + static int8_t reshaped_filter_data[kReshapedFilterDataSize]; + const int needed_size = + output_depth * filter_width * filter_height * input_depth; + if (needed_size > kReshapedFilterDataSize) { + context->ReportError( + context, + "Size too large for reshaped weight buffer (%d needed, %d available)", + needed_size, kReshapedFilterDataSize); + return; + } + + RuntimeShape reshaped_filter_shape; + reshaped_filter_shape.BuildFrom( + {1, output_depth, filter_height, filter_width}); + + // If this is the first time through, repack the weights into a cached buffer + // so that they can be accessed sequentially. + static bool is_reshaped_filter_initialized = false; + if (!is_reshaped_filter_initialized) { + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int oc = 0; oc < output_depth; ++oc) { + const uint8* current_filter = + filter_data + Offset(filter_shape, 0, filter_y, filter_x, oc); + int8* reshaped_filter = + reshaped_filter_data + + Offset(reshaped_filter_shape, 0, oc, filter_y, filter_x); + *reshaped_filter = (int32_t)(*current_filter) + filter_offset; + } + } + } + is_reshaped_filter_initialized = true; + } + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + int in_y_start = in_y_origin; + int filter_y_start = 0; + if (in_y_origin < 0) { + in_y_start = 0; + filter_y_start = 0 - in_y_origin; + } + int filter_y_end = filter_height; + if ((in_y_origin + filter_height) >= input_height) { + filter_y_end -= (in_y_origin + filter_height) - input_height; + } + int in_y = in_y_start; + int in_x_start = in_x_origin; + int filter_x_start = 0; + bool is_out_of_x_bounds = false; + if (in_x_origin < 0) { + in_x_start = 0; + filter_x_start = 0 - in_x_origin; + is_out_of_x_bounds = true; + } + int filter_x_end = filter_width; + if ((in_x_origin + filter_width) >= input_width) { + filter_x_end -= (in_x_origin + filter_width) - input_width; + is_out_of_x_bounds = true; + } + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y, ++in_y) { + const uint8* current_input = + input_data + Offset(input_shape, b, in_y, in_x_start, ic); + if ((filter_width == 8) && !is_out_of_x_bounds) { + int8* current_filter = + reshaped_filter_data + Offset(reshaped_filter_shape, 0, oc, + filter_y, filter_x_start); + const uint32_t input_vals0 = + *reinterpret_cast(current_input); + current_input += 4; + const int32_t filter_vals0 = + *reinterpret_cast(current_filter); + current_filter += 4; + const uint8 input_val0 = input_vals0 & 0xff; + const int8 filter_val0 = filter_vals0 & 0xff; + acc += filter_val0 * input_val0; + const uint8 input_val1 = (input_vals0 >> 8) & 0xff; + const int8 filter_val1 = (filter_vals0 >> 8) & 0xff; + acc += filter_val1 * input_val1; + const uint8 input_val2 = (input_vals0 >> 16) & 0xff; + const int8 filter_val2 = (filter_vals0 >> 16) & 0xff; + acc += filter_val2 * input_val2; + const uint8 input_val3 = (input_vals0 >> 24) & 0xff; + const int8 filter_val3 = (filter_vals0 >> 24) & 0xff; + acc += filter_val3 * input_val3; + + const uint32_t input_vals1 = + *reinterpret_cast(current_input); + const int32_t filter_vals1 = + *reinterpret_cast(current_filter); + const uint8 input_val4 = input_vals1 & 0xff; + const int8 filter_val4 = filter_vals1 & 0xff; + acc += filter_val4 * input_val4; + const uint8 input_val5 = (input_vals1 >> 8) & 0xff; + const int8 filter_val5 = (filter_vals1 >> 8) & 0xff; + acc += filter_val5 * input_val5; + const uint8 input_val6 = (input_vals1 >> 16) & 0xff; + const int8 filter_val6 = (filter_vals1 >> 16) & 0xff; + acc += filter_val6 * input_val6; + const uint8 input_val7 = (input_vals1 >> 24) & 0xff; + const int8 filter_val7 = (filter_vals1 >> 24) & 0xff; + acc += filter_val7 * input_val7; + } else { + const uint8* current_filter = + filter_data + + Offset(filter_shape, 0, filter_y, filter_x_start, oc); + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + int32 input_val = *current_input; + current_input += input_depth; + int32 filter_val = *current_filter; + current_filter += output_depth; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + if (bias_data) { + acc += bias_data[oc]; + } + acc = reference_ops::depthwise_conv::DepthwiseConvRound< + DepthwiseConvOutputRounding::kAwayFromZero>( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_shape, b, out_y, out_x, oc)] = + static_cast(acc); + } + } + } + } + } +} // namespace + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + op_params.depth_multiplier = params->depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + op_params.depth_multiplier = params->depth_multiplier; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + + // Figure out if we can use the optimized path for this set of parameters. + const int filter_width = GetTensorShape(filter).Dims(2); + const int input_depth = GetTensorShape(input).Dims(3); + const int output_depth = GetTensorShape(filter).Dims(3); + const int filter_height = GetTensorShape(filter).Dims(1); + const int needed_size = + output_depth * filter_width * filter_height * input_depth; + bool use_optimized_path = false; + if ((filter_width == 8) && (input_offset == 0) && (filter_offset == -127) && + (input_depth == 1) && (needed_size <= kReshapedFilterDataSize)) { + // FIXME(petewarden) - We need a more robust way of handling this, ideally + // with an allocation mechanism available through the context API. + // Use the address of the node as a proxy for its identity, since we need + // to ensure the weight values are consistent between calls, and there's + // no easy way to do that quickly other than relying on the identity of + // the owning node. + static TfLiteNode* initialized_node_address = node; + if (initialized_node_address == node) { + use_optimized_path = true; + } else { + static bool has_warned = false; + if (!has_warned) { + context->ReportError( + context, + "Multiple depthwise conv ops match optimization parameters, but " + "only the first will use the fast path, because there's only one " + "RAM cache available"); + has_warned = true; + } + } + } + if (use_optimized_path) { + DepthwiseConvOptimizedForFilterWidthEight( + context, op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); + } else { + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int out_width = ComputeOutSize(params->padding, width, filter_width, + params->stride_width); + int out_height = ComputeOutSize(params->padding, height, filter_height, + params->stride_height); + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, + filter_width, filter_height, out_width, + out_height, data_type, data)); + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, bias, output); + break; + default: + context->ReportError(context, "Type %d not currently supported.", + input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { + static TfLiteRegistration r = {depthwise_conv::Init, depthwise_conv::Free, + depthwise_conv::Prepare, depthwise_conv::Eval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/experimental/micro/testing/micro_test.h b/tensorflow/lite/experimental/micro/testing/micro_test.h index 2f20dd5ac77dfd3f304c7cc93be0b865a0c2f0cb..32e9a57f76ecc055c67c0ede8d1c83550c602aab 100644 --- a/tensorflow/lite/experimental/micro/testing/micro_test.h +++ b/tensorflow/lite/experimental/micro/testing/micro_test.h @@ -107,13 +107,13 @@ extern tflite::ErrorReporter* reporter; } \ } while (false) -#define TF_LITE_MICRO_EXPECT_EQ(x, y) \ - do { \ - if ((x) != (y)) { \ - micro_test::reporter->Report(#x " == " #y " failed at %s:%d", __FILE__, \ - __LINE__); \ - micro_test::did_test_fail = true; \ - } \ +#define TF_LITE_MICRO_EXPECT_EQ(x, y) \ + do { \ + if ((x) != (y)) { \ + micro_test::reporter->Report(#x " == " #y " failed at %s:%d (%d vs %d)", \ + __FILE__, __LINE__, (x), (y)); \ + micro_test::did_test_fail = true; \ + } \ } while (false) #define TF_LITE_MICRO_EXPECT_NE(x, y) \ diff --git a/tensorflow/lite/experimental/micro/testing/test_ecm3531_binary.sh b/tensorflow/lite/experimental/micro/testing/test_ecm3531_binary.sh new file mode 100755 index 0000000000000000000000000000000000000000..1647cf82a276d7c1725c7c3334693e0e1b7e057c --- /dev/null +++ b/tensorflow/lite/experimental/micro/testing/test_ecm3531_binary.sh @@ -0,0 +1,16 @@ +#!/bin/bash -e +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + diff --git a/tensorflow/lite/experimental/micro/tools/ci_build/ci_build_micro_projects.sh b/tensorflow/lite/experimental/micro/tools/ci_build/ci_build_micro_projects.sh new file mode 100755 index 0000000000000000000000000000000000000000..dcec7269bdc95ab57204f3b4cbc17f9d3cacadc0 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/ci_build/ci_build_micro_projects.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Creates the project file distributions for the TensorFlow Lite Micro test and +# example targets aimed at embedded platforms. +# +# Usage: ci_build_micro_projects.sh +# +# For example: +# ci_build_micro_projects.sh mbed "CMSIS disco_f746ng" + +set -e +set -x + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR=${SCRIPT_DIR}/../../../../../.. +cd ${ROOT_DIR} +pwd + +tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh + +make -f tensorflow/lite/experimental/micro/tools/make/Makefile \ + TARGET=${1} \ + TAGS="${2}" \ + generate_projects diff --git a/tensorflow/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile index fde195118b18ca308940292c7bd5706ecace8563..e11e8a8cf092e0244e2d5b04dd48ae51560f4ca8 100644 --- a/tensorflow/lite/experimental/micro/tools/make/Makefile +++ b/tensorflow/lite/experimental/micro/tools/make/Makefile @@ -30,18 +30,26 @@ TARGET_ARCH := $(HOST_ARCH) # STM32F746NG board, using the CMSIS library's implementations where possible. ALL_TAGS := $(TAGS) $(TARGET) +# This is obviously horrible. We need to generate these 3 versions of the +# include directories from one source. INCLUDES := \ -I. \ --I$(MAKEFILE_DIR)/../../../../../ \ --I$(MAKEFILE_DIR)/../../../../../../ \ --I$(MAKEFILE_DIR)/../../../../../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ --I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ --I$(OBJDIR) -# This is at the end so any globally-installed frameworks like protobuf don't -# override local versions in the source tree. -INCLUDES += -I/usr/local/include +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include + +# Same list of paths, but now relative to the generated project files. +GENERATED_PROJECT_INCLUDES := \ +-I. \ +-I./third_party/gemmlowp \ +-I./third_party/flatbuffers/include + +# Same list of paths, but now in the format the generate_keil_project.py +# script expects them. +PROJECT_INCLUDES := \ +. \ +third_party/gemmlowp \ +third_party/flatbuffers/include TEST_SCRIPT := tensorflow/lite/experimental/micro/testing/test_linux_binary.sh @@ -79,7 +87,6 @@ tensorflow/lite/core/api/op_resolver.cc \ tensorflow/lite/kernels/kernel_util.cc \ tensorflow/lite/kernels/internal/quantization_util.cc MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) -MICROLITE_CC_SRCS := $(call specialize,$(MICROLITE_CC_SRCS)) MICROLITE_CC_HDRS := \ $(wildcard tensorflow/lite/experimental/micro/*.h) \ @@ -118,12 +125,24 @@ third_party/flatbuffers/LICENSE.txt MAKE_PROJECT_FILES := \ README_MAKE.md \ - Makefile + Makefile \ + .vscode/tasks.json MBED_PROJECT_FILES := \ README_MBED.md \ mbed-os.lib \ - mbed_app.json + mbed_app.json \ + .vscode/tasks.json + +KEIL_PROJECT_FILES := \ + README_KEIL.md \ + keil_project.uvprojx + +ALL_PROJECT_TARGETS := + +KEIL_PROJECT_FILES := \ + README_KEIL.md \ + keil_project.uvprojx # These target-specific makefiles should modify or replace options like # CXXFLAGS or LIBS to work for a specific targetted architecture. All logic @@ -131,6 +150,9 @@ MBED_PROJECT_FILES := \ # keep this main makefile focused on the sources and dependencies. include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) +# Call specialize here so that platform-specific tags can be taken into account. +MICROLITE_CC_SRCS := $(call specialize,$(MICROLITE_CC_SRCS)) + ALL_TAGS += $(TARGET_ARCH) ALL_SRCS := \ @@ -156,8 +178,8 @@ include $(wildcard tensorflow/lite/experimental/micro/examples/*/Makefile.inc) MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS)))) -MICROLITE_TEST_TARGETS := $(addprefix $(BINDIR), \ -$(patsubst %_test.cc,%.test_target,$(MICROLITE_TEST_SRCS))) +MICROLITE_LIB_OBJS += $(addprefix $(OBJDIR), \ +$(patsubst %.S,%.o,$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(THIRD_PARTY_CC_SRCS))))) # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.cc @@ -197,11 +219,21 @@ $(BINDIR)%_test : $(OBJDIR)%_test.o $(MICROLITE_LIB_PATH) $(BINDIR)%.test_target: $(BINDIR)%_test $(TEST_SCRIPT) $< '~~~ALL TESTS PASSED~~~' +# snease: Add %.bin rule here since BINDIR is now defined +# These are microcontroller-specific rules for converting the ELF output +# of the linker into a binary image that can be loaded directly. +OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy +$(BINDIR)%.bin: $(BINDIR)% + @mkdir -p $(dir $@) + $(OBJCOPY) $< $@ -O binary + # Generate standalone makefile projects for all of the test targets. $(foreach TEST_TARGET,$(MICROLITE_TEST_SRCS),\ $(eval $(call microlite_test,$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET)))) -test: test_micro_speech $(MICROLITE_TEST_TARGETS) +test: $(MICROLITE_TEST_TARGETS) + +generate_projects: $(ALL_PROJECT_TARGETS) # Gets rid of all generated files. clean: diff --git a/tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh b/tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh index 82c15e32f6572f36588945431918cf75299d3a64..7a434e5ca81596b651992219d433cd856c94e7b1 100755 --- a/tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh +++ b/tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh @@ -34,10 +34,19 @@ FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff68 CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip" STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/c07d611fb0af58450c5a3e0ab4d52b47f99bc82d.zip" SIFIVE_FE310_LIB_URL="https://github.com/sifive/freedom-e-sdk/archive/baeeb8fd497a99b3c141d7494309ec2e64f19bdf.zip" -RISCV_TOOLCHAIN_URL="https://static.dev.sifive.com/dev-tools/riscv64-unknown-elf-gcc-20181030-x86_64-linux-ubuntu14.tar.gz" +AM_SDK_URL="http://s3.asia.ambiqmicro.com/downloads/AmbiqSuite-Rel2.0.0.zip" AP3_URL="https://github.com/AmbiqMicro/TFLiteMicro_Apollo3/archive/dfbcef9a57276c087d95aab7cb234f1d4c9eaaba.zip" CUST_CMSIS_URL="https://github.com/AmbiqMicro/TFLiteMicro_CustCMSIS/archive/8f63966c5692e6a3a83956efd2e4aed77c4c9949.zip" -GCC_EMBEDDED_URL="https://developer.arm.com/-/media/Files/downloads/gnu-rm/7-2018q2/gcc-arm-none-eabi-7-2018-q2-update-linux.tar.bz2" +KISSFFT_URL="https://github.com/mborgerding/kissfft/archive/v130.zip" +SPARKFUN_EDGE_BSP_URL="https://github.com/sparkfun/SparkFun_Edge_BSP/archive/620f5f7a69fc69e38cda8132b69302d9c28ba0dd.zip" + +if [[ "$OSTYPE" == "darwin"* ]]; then + RISCV_TOOLCHAIN_URL="https://static.dev.sifive.com/dev-tools/riscv64-unknown-elf-gcc-8.1.0-2019.01.0-x86_64-apple-darwin.tar.gz" + GCC_EMBEDDED_URL="https://developer.arm.com/-/media/Files/downloads/gnu-rm/7-2018q2/gcc-arm-none-eabi-7-2018-q2-update-mac.tar.bz2" +else + RISCV_TOOLCHAIN_URL="https://static.dev.sifive.com/dev-tools/riscv64-unknown-elf-gcc-20181030-x86_64-linux-ubuntu14.tar.gz" + GCC_EMBEDDED_URL="https://developer.arm.com/-/media/Files/downloads/gnu-rm/7-2018q2/gcc-arm-none-eabi-7-2018-q2-update-linux.tar.bz2" +fi download_and_extract() { local usage="Usage: download_and_extract URL DIR" @@ -72,37 +81,58 @@ download_and_extract() { find "${dir}" -type f -name '*BUILD' -delete } -patch_apollo3_sdk() { - local ap3_dir="${1}" - if [ ! -f ${ap3_dir}/VERSION.txt ]; then - echo "Could not find ${ap3_dir}, skipping Apollo3 SDK"; +patch_am_sdk() { + local am_dir="${1}" + if [ ! -f ${am_dir}/VERSION.txt ]; then + echo "Could not find ${am_dir}, skipping AmbiqMicro SDK patch"; return; fi - local src_dir=${ap3_dir}/boards/apollo3_evb/examples/hello_world/gcc - local dest_dir=${ap3_dir}/boards/apollo3_evb/examples/hello_world/gcc_patched + + local src_dir=${am_dir}/boards/apollo3_evb/examples/hello_world/gcc + local dest_dir=${am_dir}/boards/apollo3_evb/examples/hello_world/gcc_patched + rm -rf ${dest_dir} mkdir ${dest_dir} + cp "${src_dir}/startup_gcc.c" "${dest_dir}/startup_gcc.c" cp "${src_dir}/hello_world.ld" "${dest_dir}/apollo3evb.ld" - sed -i -e '131s/1024/1024\*20/g' "${dest_dir}/startup_gcc.c" - sed -i -e 's/main/_main/g' "${dest_dir}/startup_gcc.c" + + sed -i -e '114s/1024/1024\*20/g' "${dest_dir}/startup_gcc.c" + #sed -i -e 's/main/_main/g' "${dest_dir}/startup_gcc.c" + sed -i -e '3s/hello_world.ld/apollo3evb.ld/g' "${dest_dir}/apollo3evb.ld" sed -i -e '3s/startup_gnu/startup_gcc/g' "${dest_dir}/apollo3evb.ld" - sed -i -e '6s/am_reset_isr/Reset_Handler/g' "${dest_dir}/apollo3evb.ld" - sed -i -e '22s/\*(.text\*)/\*(.text\*)\n\n\t\/\* These are the C++ global constructors. Stick them all here and\n\t \* then walk through the array in main() calling them all.\n\t \*\/\n\t_init_array_start = .;\n\tKEEP (\*(SORT(.init_array\*)))\n\t_init_array_end = .;\n\n\t\/\* XXX Currently not doing anything for global destructors. \*\/\n/g' "${dest_dir}/apollo3evb.ld" - sed -i -e "70s/} > SRAM/} > SRAM\n \/\* Add this to satisfy reference to symbol 'end' from libnosys.a(sbrk.o)\n \* to denote the HEAP start.\n \*\/\n end = .;/g" "${dest_dir}/apollo3evb.ld" + sed -i -e $'22s/\*(.text\*)/\*(.text\*)\\\n\\\n\\\t\/\* These are the C++ global constructors. Stick them all here and\\\n\\\t \* then walk through the array in main() calling them all.\\\n\\\t \*\/\\\n\\\t_init_array_start = .;\\\n\\\tKEEP (\*(SORT(.init_array\*)))\\\n\\\t_init_array_end = .;\\\n\\\n\\\t\/\* XXX Currently not doing anything for global destructors. \*\/\\\n/g' "${dest_dir}/apollo3evb.ld" + sed -i -e $'70s/} > SRAM/} > SRAM\\\n \/\* Add this to satisfy reference to symbol "end" from libnosys.a(sbrk.o)\\\n \* to denote the HEAP start.\\\n \*\/\\\n end = .;/g' "${dest_dir}/apollo3evb.ld" + + # Workaround for bug in 2.0.0 SDK, remove once that's fixed. + sed -i -e $'s/#ifndef AM_HAL_GPIO_H/#ifdef __cplusplus\\\nextern "C" {\\\n#endif\\\n#ifndef AM_HAL_GPIO_H/g' ${am_dir}/mcu/apollo3/hal/am_hal_gpio.h + echo "Finished preparing Apollo3 files" } +patch_kissfft() { + sed -i -E $'s@#ifdef FIXED_POINT@// Patched automatically by download_dependencies.sh so default is 16 bit.\\\n#ifndef FIXED_POINT\\\n#define FIXED_POINT (16)\\\n#endif\\\n// End patch.\\\n\\\n#ifdef FIXED_POINT@g' tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h + sed -i -E "s@#define KISS_FFT_MALLOC malloc@#define KISS_FFT_MALLOC(X) (void*)(0) /* Patched. */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h + sed -i -E "s@#define KISS_FFT_FREE free@#define KISS_FFT_FREE(X) /* Patched. */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h + sed -ir -E "s@(fprintf.*\);)@/* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c + sed -ir -E "s@(exit.*\);)@return; /* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c + echo "Finished patching kissfft" +} + download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" download_and_extract "${CMSIS_URL}" "${DOWNLOADS_DIR}/cmsis" download_and_extract "${STM32_BARE_LIB_URL}" "${DOWNLOADS_DIR}/stm32_bare_lib" download_and_extract "${SIFIVE_FE310_LIB_URL}" "${DOWNLOADS_DIR}/sifive_fe310_lib" download_and_extract "${RISCV_TOOLCHAIN_URL}" "${DOWNLOADS_DIR}/riscv_toolchain" +download_and_extract "${AM_SDK_URL}" "${DOWNLOADS_DIR}/AmbiqSuite-Rel2.0.0" +patch_am_sdk "${DOWNLOADS_DIR}/AmbiqSuite-Rel2.0.0" download_and_extract "${AP3_URL}" "${DOWNLOADS_DIR}/apollo3_ext" -patch_apollo3_sdk "${DOWNLOADS_DIR}/Apollo3-SDK-2018.08.13" download_and_extract "${CUST_CMSIS_URL}" "${DOWNLOADS_DIR}/CMSIS_ext" download_and_extract "${GCC_EMBEDDED_URL}" "${DOWNLOADS_DIR}/gcc_embedded" +download_and_extract "${KISSFFT_URL}" "${DOWNLOADS_DIR}/kissfft" +patch_kissfft "${DOWNLOADS_DIR}/kissfft" +download_and_extract "${SPARKFUN_EDGE_BSP_URL}" "${DOWNLOADS_DIR}/AmbiqSuite-Rel2.0.0/boards/SparkFun_TensorFlow_Apollo3_BSP" echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/lite/experimental/micro/tools/make/generate_keil_project.py b/tensorflow/lite/experimental/micro/tools/make/generate_keil_project.py new file mode 100644 index 0000000000000000000000000000000000000000..fb0c632e24b139a5a3e27fadbfb850a53fff531d --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/generate_keil_project.py @@ -0,0 +1,117 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generates a Keil uVision project file from a template.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os.path +import re + + +def sanitize_xml(unsanitized): + """Uses a whitelist to avoid generating bad XML.""" + return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', unsanitized) + + +def main(unused_args, flags): + """Generates a Keil project file from a template source.""" + with open(flags.input_template, 'r') as input_template_file: + template_file_text = input_template_file.read() + + template_file_text = re.sub(r'%{EXECUTABLE}%', flags.executable, + template_file_text) + + srcs_list = flags.srcs.split(' ') + hdrs_list = flags.hdrs.split(' ') + all_srcs_list = srcs_list + hdrs_list + all_srcs_list.sort() + + replace_srcs = '' + for src in all_srcs_list: + if not src: + continue + ext = os.path.splitext(src)[1] + # These extension indexes are used by uVision to keep track of the type + # of files. I determined them by experimentation, since the file format + # isn't documented. + if ext == '.h': + ext_index = '5' + elif ext == '.c': + ext_index = '1' + elif ext == '.cc' or ext == '.cpp': + ext_index = '8' + else: + ext_index = '5' + basename = sanitize_xml(os.path.basename(src)) + clean_src = sanitize_xml(src) + replace_srcs += ' \n' + replace_srcs += ' ' + basename + '\n' + replace_srcs += ' ' + ext_index + '\n' + replace_srcs += ' ' + clean_src + '\n' + replace_srcs += ' \n' + template_file_text = re.sub(r'%{SRCS}%', replace_srcs, template_file_text) + + include_paths = re.sub(' ', ';', flags.include_paths) + template_file_text = re.sub(r'%{INCLUDE_PATHS}%', include_paths, + template_file_text) + + with open(flags.output_file, 'w') as output_file: + output_file.write(template_file_text) + + +def parse_args(): + """Converts the raw arguments into accessible flags.""" + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--input_template', + type=str, + default='', + help='Path to template project file to build from.') + parser.add_argument( + '--output_file', + type=str, + default='', + help='Path to write the completed project file to.') + parser.add_argument( + '--executable', + type=str, + default='', + help='Name of the executable the project will build.') + parser.add_argument( + '--hdrs', + type=str, + default='', + help='Space-separated list of C or C++ source files to compile.') + parser.add_argument( + '--srcs', + type=str, + default='', + help='Space-separated list of C or C++ header files to include.') + parser.add_argument( + '--include_paths', + type=str, + default='', + help='Space-separated list of paths to look for header files on.') + flags, unparsed = parser.parse_known_args() + + main(unparsed, flags) + + +if __name__ == '__main__': + parse_args() diff --git a/tensorflow/lite/experimental/micro/tools/make/generate_keil_project_test.sh b/tensorflow/lite/experimental/micro/tools/make/generate_keil_project_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..22b68e4f68360cc28d6dd7e751381709a72892e7 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/generate_keil_project_test.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Bash unit tests for the TensorFlow Lite Micro project generator. + +set -e + +INPUT_TEMPLATE=${TEST_SRCDIR}/tensorflow/lite/experimental/micro/tools/make/templates/keil_project.uvprojx.tpl +OUTPUT_FILE=${TEST_TMPDIR}/keil_project.uvprojx +EXECUTABLE=test_executable + +${TEST_SRCDIR}/tensorflow/lite/experimental/micro/tools/make/generate_keil_project \ + --input_template=${INPUT_TEMPLATE} \ + --output_file=${OUTPUT_FILE} \ + --executable=${EXECUTABLE} \ + --hdrs="foo.h bar.h" \ + --srcs="foo.c bar.cc some/bad_makefile target that you # can invoke to create the standalone project. define generate_project @@ -68,9 +69,22 @@ $(PRJDIR)$(3)/$(1)/%: tensorflow/lite/experimental/micro/tools/make/templates/%. sed -E 's#\%\{SRCS\}\%#$(4)#g' $$< | \ sed -E 's#\%\{EXECUTABLE\}\%#$(3)#g' | \ sed -E 's#\%\{LINKER_FLAGS\}\%#$(6)#g' | \ - sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' > $$@ + sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' | \ + sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' > $$@ + +$(PRJDIR)$(3)/$(1)/keil_project.uvprojx: tensorflow/lite/experimental/micro/tools/make/templates/keil_project.uvprojx.tpl + @mkdir -p $$(dir $$@) + python tensorflow/lite/experimental/micro/tools/make/generate_keil_project.py \ + --input_template=$$< --output_file=$$@ --executable=$(3) \ + --srcs="$(4)" --hdrs="$(5)" --include_paths="$$(PROJECT_INCLUDES)" + +$(PRJDIR)$(3)/$(1)/.vscode/tasks.json : tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.$(1).tpl + @mkdir -p $$(dir $$@) + cp $$< $$@ generate_$(3)_$(1)_project: $(addprefix $(PRJDIR)$(3)/$(1)/, $(4) $(5) $(2)) + +ALL_PROJECT_TARGETS += generate_$(3)_$(1)_project endef # Specialized version of generate_project for TF Lite Micro test targets that @@ -83,8 +97,9 @@ endef # Calling eval on the output will create targets that you can invoke to # generate the standalone project. define generate_microlite_projects -$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS)) -$(call generate_project,mbed,$(MBED_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS)) +$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(MICROLITE_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES)) +$(call generate_project,mbed,$(MBED_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS)) +$(call generate_project,keil,$(KEIL_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS)) endef @@ -103,7 +118,6 @@ $(1)_LOCAL_HDRS := $(3) $(1)_LOCAL_OBJS := $$(addprefix $$(OBJDIR), \ $$(patsubst %.cc,%.o,$$(patsubst %.c,%.o,$$($(1)_LOCAL_SRCS)))) $(1)_BINARY := $$(BINDIR)$(1) -ALL_BINARIES += $$($(1)_BINARY) $$($(1)_BINARY): $$($(1)_LOCAL_OBJS) $$(MICROLITE_LIB_PATH) @mkdir -p $$(dir $$@) $$(CXX) $$(CXXFLAGS) $$(INCLUDES) \ @@ -113,5 +127,8 @@ $(1): $$($(1)_BINARY) $(1)_bin: $$($(1)_BINARY).bin test_$(1): $$($(1)_BINARY) $$(TEST_SCRIPT) $$($(1)_BINARY) '~~~ALL TESTS PASSED~~~' +ifneq (,$(findstring _test,$(1))) + MICROLITE_TEST_TARGETS += test_$(1) +endif $(eval $(call generate_microlite_projects,$(1),$(call specialize,$(2)),$(3))) endef diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc index 6ed402a623188a7c39a007a1cfd7dbc67b775103..4df26a7bf704eeee0acb8fa3386ab42589204595 100644 --- a/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc +++ b/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc @@ -1,15 +1,18 @@ -# Settings for apollo3 evb platforms. -ifeq ($(TARGET), apollo3evb) +# Settings for apollo3 evb and SparkFun Edge platforms. +ifeq ($(TARGET),$(filter $(TARGET),apollo3evb sparkfun_edge)) export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) TARGET_ARCH := cortex-m4 TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- # Download the Ambiq Apollo3 SDK and set this variable to find the header # files: - APOLLO3_SDK := $(MAKEFILE_DIR)/downloads/Apollo3-SDK-2018.08.13 + APOLLO3_SDK := $(MAKEFILE_DIR)/downloads/AmbiqSuite-Rel2.0.0 # Need a pointer to the GNU ARM toolchain for crtbegin.o for the fp functions - # with the softfp interfaces. + # with the hard interfaces. GCC_ARM := $(MAKEFILE_DIR)/downloads/gcc_embedded/ + # Use the faster depthwise conv implementation. + ALL_TAGS += portable_optimized + PLATFORM_FLAGS = \ -DPART_apollo3 \ -DAM_PACKAGE_BGA \ @@ -31,7 +34,7 @@ ifeq ($(TARGET), apollo3evb) -mcpu=cortex-m4 \ -mthumb \ -mfpu=fpv4-sp-d16 \ - -mfloat-abi=softfp \ + -mfloat-abi=hard \ -std=gnu++11 \ -Wvla \ -Wall \ @@ -49,7 +52,7 @@ ifeq ($(TARGET), apollo3evb) CXXFLAGS += $(PLATFORM_FLAGS) CCFLAGS += $(PLATFORM_FLAGS) LDFLAGS += \ - -mthumb -mcpu=cortex-m4 -mfpu=fpv4-sp-d16 -mfloat-abi=softfp \ + -mthumb -mcpu=cortex-m4 -mfpu=fpv4-sp-d16 -mfloat-abi=hard \ -nostartfiles -static \ -Wl,--gc-sections -Wl,--entry,Reset_Handler \ -Wl,--start-group -lm -lc -lgcc -Wl,--end-group \ @@ -58,10 +61,16 @@ ifeq ($(TARGET), apollo3evb) -Wl,-T,$(APOLLO3_SDK)/boards/apollo3_evb/examples/hello_world/gcc_patched/apollo3evb.ld \ -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref BUILD_TYPE := micro + ifeq ($(TARGET), apollo3evb) + BOARD_BSP_PATH := $(APOLLO3_SDK)/boards/apollo3_evb/bsp + endif + ifeq ($(TARGET), sparkfun_edge) + BOARD_BSP_PATH := $(APOLLO3_SDK)/boards/SparkFun_TensorFlow_Apollo3_BSP/bsp + endif MICROLITE_LIBS := \ - $(APOLLO3_SDK)/boards/apollo3_evb/bsp/gcc/bin/libam_bsp.a \ + $(BOARD_BSP_PATH)/gcc/bin/libam_bsp.a \ $(APOLLO3_SDK)/mcu/apollo3/hal/gcc/bin/libam_hal.a \ - $(GCC_ARM)/lib/gcc/arm-none-eabi/7.3.1/thumb/v7e-m/fpv4-sp/softfp/crtbegin.o \ + $(GCC_ARM)/lib/gcc/arm-none-eabi/7.3.1/thumb/v7e-m/fpv4-sp/hard/crtbegin.o \ -lm INCLUDES += \ -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ @@ -70,7 +79,7 @@ ifeq ($(TARGET), apollo3evb) -I$(GCC_ARM)/arm-none-eabi/ \ -I$(APOLLO3_SDK)/mcu/apollo3/ \ -I$(APOLLO3_SDK)/CMSIS/AmbiqMicro/Include/ \ - -I$(APOLLO3_SDK)/boards/apollo3_evb/bsp/ \ + -I$(BOARD_BSP_PATH) \ -I$(APOLLO3_SDK)/devices/ \ -I$(APOLLO3_SDK)/utils/ diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/README.md b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3e339fe635d4af2e9e884d0c3bdb56d9d210e9ad --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/README.md @@ -0,0 +1,9 @@ +Compiling instructions here +https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro + +CONTACT INFORMATION: + +Contact info@etacompute.com for more information on obtaining the Eta Compute +SDK and evalution board. + +www.etacompute.com diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/README.txt b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/README.txt deleted file mode 100644 index 3cb74a72437be8017527c0ea05a1b82eb1a4ac9e..0000000000000000000000000000000000000000 --- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/README.txt +++ /dev/null @@ -1,32 +0,0 @@ -Running The Micro Speech Example On Eta Compute's ECM3531EVB - -This code will enable you to compile and execute the Tensorflow Lite Micro Speech Example on Eta Computes's low power ECM3531 chip. - - -GETTING STARTED: - -1. Download the Tensorflow code from Github and follow instructions there to download other dependencies. - -2. Download the Eta Compute SDK, version 0.0.17. - -3. Install the Arm compiler arm-none-eabi-gcc, version = arm-none-eabi-gcc (GNU Tools for Arm Embedded Processors 7-2018-q2-update) 7.3.1 20180622 (release) [ARM/embedded-7-branch revision 261907] - -4. Edit the file tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc so that the variable ETA_SDK points to the location where the Eta Compute SDK is installed, and the variable GCC_ARM points to the Arm compiler. - -5. Compile the code with the command "make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=ecm3531 test". This will create the executable tensorflow/lite/experimental/micro/tools/make/gen/ecm3531_cortex-m3/bin/micro_speech_test. - -6. Connect the board to the host computer, start PuTTY (Connection type = Serial, Speed = 11520, Data bits = 8, Stop bits = 1, Parity = None), and load the executable with ocd. A sample script for loading the image is provided in tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/load_program. - -The following will be printed on the Uart: - -Testing TestInvoke -Ran successfully - -/ tests passed -~~~ALL TESTS PASSED~~~ - - - -CONTACT INFORMATION: - -Contact info@etacompute.com for more information on obtaining the Eta Compute SDK and evalution board. diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c index 2764f3ba50de699fa72717585114369cf833d76e..25d3e7c169d5f7419a892d35bd30aa2d9a128160 100644 --- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c @@ -51,12 +51,6 @@ void EtaPrintExecutionTime(uint64_t); //***************************************************************************** extern int main(int argc, char** argv); -void DebugLog(const char* s) { EtaCspIoPrintf("%s", s); } -void DebugLogInt32(int32_t i) { EtaCspIoPrintf("%d", i); } -void DebugLogUInt32(uint32_t i) { EtaCspIoPrintf("%d", i); } -void DebugLogHex(uint32_t i) { EtaCspIoPrintf("0x%8x", i); } -void DebugLogFloat(float i) { EtaCspIoPrintf("%f", i); } - int _main(void) { uint64_t time_ms; diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531.lds b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531.lds index af34f988f2d04a0c1c87f20d6058df560db7e2c5..383b7f924408b484c8ee2ada5c4d6ec66edb059a 100644 --- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531.lds +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531.lds @@ -55,7 +55,7 @@ SECTIONS { _datax = .; KEEP(*(.mainStack)) - . += 12288; + . += 16384; _edatax = .; _stack_top = .; . += 4; diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531_flash.lds b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531_flash.lds new file mode 100644 index 0000000000000000000000000000000000000000..9cbbea3569ba05b8fc9269ff6c5500fb386c03a3 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/ecm3531_flash.lds @@ -0,0 +1,85 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +/* + * linker script for use with ECM3531 chip. + * .text and .ro map to FLASH all else to SRAM. + * + */ + + /* + * Indicate to the linker the entry point. + */ +ENTRY(ResetISR) + +/* + * FLASH is at 0x01000000 of length 0x00080000 512KB + * SRAM is at 0x10000000 of length 0x00020000 128KB + */ +MEMORY +{ + FLASH (RX) : ORIGIN = 0x01000000, LENGTH = 0x00080000 + SRAM (RWX) : ORIGIN = 0x10000000, LENGTH = 0x00020000 +} + +SECTIONS +{ + .text : + { + _text = .; + KEEP(*(.vectors)) + . = ALIGN(0x4); + *(.text*) + . = ALIGN(0x4); + *(.rodata*) + . = ALIGN(0x4); + _etext = .; + } > FLASH= 0 + .dummy : + { + . = ALIGN(0x4); + _eftext = .; + } > FLASH +/* put the stack at the bottom of SRAM*/ + .datax (NOLOAD) : + { + _datax = .; + KEEP(*(.mainStack)) + . = ALIGN(0x4); + . += 16384; + _edatax = .; + _stack_top = .; + } > SRAM + .data : + { + _data = .; + *(.data*) + KEEP(*(.mainHeap)) + _edata = .; + } > SRAM AT > FLASH + + .bss (NOLOAD) : + { + _bss = .; + *(.bss*) + *(COMMON) + _ebss = .; + } > SRAM + + + +} + diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/flash_erase b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/flash_erase new file mode 100755 index 0000000000000000000000000000000000000000..5395b3d9965e98572fb12d61d7b862f4ce926a0f --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/flash_erase @@ -0,0 +1,47 @@ +#!/usr/bin/python3 +#Usage: cd to the directory tensorflow/lite/experimental/micro/tools/make/targets/ecm3531 and type ./flash_erase to erase the flash. +# +# +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +#============================================================================== + + +import os +import telnetlib + +def send_ocd_cmd(line): + ocd_sock.write(bytes(line,encoding = 'utf-8')) + print(ocd_sock.read_until(b'> ').decode('utf-8'), end='') + +def get_ocd_response(): + print(ocd_sock.read_until(b'> ').decode('utf-8'), end='') + +#get hooked up to openocd daemon +ocd_sock = telnetlib.Telnet(host='localhost', port=4444) +get_ocd_response() # clean it out + +#ocd comand +ocd_commands = ["halt\n", + "flash erase_sector 0 0 127\n", + "mww 0x1001fff8 0\n", + "mdw 0x01000000 16\n", + "reset\n"] + +# OK now do what we came here for!!! +for x in ocd_commands: + print(x) + send_ocd_cmd(x) + + diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/flash_program b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/flash_program new file mode 100755 index 0000000000000000000000000000000000000000..bc3fe5cb21aa2a89d8dda41a68185f03e43c674e --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/flash_program @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +#Usage: cd to the directory tensorflow/lite/experimental/micro/tools/make/targets/ecm3531 and type ./flash_program executable_name to load an executable from the directory tensorflow/lite/experimental/micro/tools/make/gen/ecm3531_cortex-m3/bin/ into flash +# +# +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +#============================================================================== + + +import sys, getopt +import os +import telnetlib + +def send_ocd_cmd(line): + ocd_sock.write(bytes(line,encoding = 'utf-8')) + print(ocd_sock.read_until(b'> ').decode('utf-8'), end='') + +def get_ocd_response(): + print(ocd_sock.read_until(b'> ').decode('utf-8'), end='') + +#get hooked up to openocd daemon +ocd_sock = telnetlib.Telnet(host='localhost', port=4444) +get_ocd_response() # clean it out + +# git path to project elf file +cur_dir = os.getcwd() +#elf_file = cur_dir + '/../../gen/ecm3531_cortex-m3/bin/' + 'micro_speech' +elf_file = cur_dir + '/../../gen/ecm3531_cortex-m3/bin/' + sys.argv[1] +print("elf_file = ",elf_file) + + +# use these to download and run the elf fle +ocd_commands = ["halt\n", + "flash erase_sector 0 0 127\n", + "flash write_image {}\n".format(elf_file), + "mww 0x1001fff8 0\n", + "reset\n"] + +# OK now do what we came here for!!! +for x in ocd_commands: + print(x) + send_ocd_cmd(x) diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/load_program b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/load_program index ac1f49962a61756ccbde02300c612bd7b4f48e84..781231480aa2f1dec18cc468e1ea0129604c71e7 100755 --- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/load_program +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/load_program @@ -1,5 +1,5 @@ #!/usr/bin/python3 -#Usage: cd to the directory tensorflow/lite/experimental/micro/tools/make/targets/ecm3531 and type ./load_prgram to load the executable tensorflow/lite/experimental/micro/tools/make/gen/ecm3531_cortex-m3/bin/micro_speech_test into SRAM +#Usage: cd to the directory tensorflow/lite/experimental/micro/tools/make/targets/ecm3531 and type ./load_prgram executable_name to load an executable from the directory tensorflow/lite/experimental/micro/tools/make/gen/ecm3531_cortex-m3/bin/ # # # Copyright 2015 The TensorFlow Authors. All Rights Reserved. @@ -18,6 +18,7 @@ #============================================================================== +import sys, getopt import os import telnetlib @@ -34,7 +35,8 @@ get_ocd_response() # clean it out # git path to project elf file cur_dir = os.getcwd() -elf_file = cur_dir + '/../../gen/ecm3531_cortex-m3/bin/' + 'micro_speech_test' +#elf_file = cur_dir + '/../../gen/ecm3531_cortex-m3/bin/' + 'preprocessor_test' +elf_file = cur_dir + '/../../gen/ecm3531_cortex-m3/bin/' + sys.argv[1] print("elf_file = ",elf_file) diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc index baae58f87e1761c978a87256fda8b7e90edb79e5..897a2b66d21668c4a28573e9d068b865c8f008a8 100644 --- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc +++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc @@ -5,6 +5,10 @@ ifeq ($(TARGET), ecm3531) ETA_SDK := /home/hari/TensaiSDK-v0.0.17/soc/ GCC_ARM := /home/hari/Downloads/gcc-arm-none-eabi-7-2018-q2-update/ +#Pick the appropriate lds file depending whether you are running frof SRAM of flash + ETA_LDS_FILE := ecm3531.lds +# ETA_LDS_FILE := ecm3531_flash.lds + ifeq ($(wildcard $(ETA_SDK)),) $(error Path to ETA SDK is not set (ETA_SDK)) endif @@ -14,6 +18,7 @@ ifeq ($(TARGET), ecm3531) endif PLATFORM_FLAGS = \ + -DARM_MATH_CM3 \ -DFIRMWARE_BUILD \ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ -DTF_LITE_STATIC_MEMORY \ @@ -58,14 +63,13 @@ ifeq ($(TARGET), ecm3531) -Wl,--start-group -lm -lc -lgcc -Wl,--end-group \ -fno-exceptions \ -nostdlib --specs=nano.specs -t -lstdc++ -lc -lnosys -lm \ - -Wl,-T,$(MAKEFILE_DIR)/targets/ecm3531/ecm3531.lds \ - -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref + -Wl,-T,$(MAKEFILE_DIR)/targets/ecm3531/$(ETA_LDS_FILE) \ + -Wl,-Map=$(MAKEFILE_DIR)/targets/ecm3531/ecm3531.map,--cref BUILD_TYPE := micro MICROLITE_LIBS := \ $(GCC_ARM)/lib/gcc/arm-none-eabi/7.3.1/thumb/v7e-m/fpv4-sp/softfp/crtbegin.o \ -lm - INCLUDES += \ - -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ + ECM3531_INCLUDES := \ -I$(GCC_ARM)/arm-none-eabi/include/ \ -I$(ETA_SDK)/ecm3531/boards/eta_evb/projects/m3/common/inc/ \ -I$(ETA_SDK)/ecm3531/m3/reg/inc/ \ @@ -75,6 +79,9 @@ ifeq ($(TARGET), ecm3531) -I$(ETA_SDK)/../utils/inc/ \ -I$(ETA_SDK)/ecm3531/boards/eta_evb/eta_bsp/inc + INCLUDES += $(ECM3531_INCLUDES) + GENERATED_PROJECT_INCLUDES += $(ECM3531_INCLUDES) + # _main.c contains application and target specific initialization, like # setting clock speed, default uart setups, etc. and an implementation # of the DebugLog interfaces. @@ -83,7 +90,13 @@ ifeq ($(TARGET), ecm3531) $(MAKEFILE_DIR)/targets/ecm3531/_main.c \ $(wildcard $(ETA_SDK)/ecm3531/boards/eta_evb/projects/m3/common/src/*.c) \ $(wildcard $(ETA_SDK)/ecm3531/m3/csp/src/*.c) \ - $(wildcard $(ETA_SDK)/ecm3531/m3/csp/src/*.s) \ + $(wildcard $(ETA_SDK)/ecm3531/m3/csp/src/*.s) + + # The linker script isn't a header, but it needs to get copied to the gen/ + # directory for generated projects. This is similar to the behavior needed + # for headers. + MICROLITE_CC_HDRS += \ + $(MAKEFILE_DIR)/targets/ecm3531/$(ETA_LDS_FILE) TEST_SCRIPT := tensorflow/lite/experimental/micro/testing/test_ecm3531_binary.sh # These are tests that don't currently work on the blue pill. diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/osx_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/osx_makefile.inc index 3b91eeff9fd5f2df06caa9a5f73b221815f9bbdf..090b4fa101d765adb6a7eed181752021f55ebd1b 100644 --- a/tensorflow/lite/experimental/micro/tools/make/targets/osx_makefile.inc +++ b/tensorflow/lite/experimental/micro/tools/make/targets/osx_makefile.inc @@ -1,6 +1,9 @@ # Settings for Mac OS platforms. ifeq ($(TARGET), osx) + # Make sure we can find the embedded GCC compiler. + export PATH := ${PATH}:tensorflow/lite/experimental/micro/tools/make/downloads/gcc_embedded/bin/ + PLATFORM_FLAGS = \ -DTF_LITE_DISABLE_X86_NEON diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/osx_x86_64_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/osx_x86_64_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..78febaf5dddda19f082a21b7a4c5b9409677f769 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/targets/osx_x86_64_makefile.inc @@ -0,0 +1,10 @@ +# Settings for x86 on Mac +ifeq ($(TARGET), osx) + ifeq ($(TARGET_ARCH), x86_64) + PLATFORM_FLAGS = \ + -DTF_LITE_DISABLE_X86_NEON + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) + endif +endif + diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl b/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl index 74d54f1ebee12d7773edfd1b073ddf17dd3791d6..ca6519c1390b5b783e4b6f26cac40a6b7ef32f46 100644 --- a/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl +++ b/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl @@ -4,12 +4,8 @@ SRCS := \ OBJS := \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(SRCS))) -INCLUDES := \ --I. \ --I./third_party/gemmlowp \ --I./third_party/flatbuffers/include - CXXFLAGS += %{CXX_FLAGS}% +CCFLAGS += %{CC_FLAGS}% LDFLAGS += %{LINKER_FLAGS}% @@ -20,7 +16,6 @@ LDFLAGS += %{LINKER_FLAGS}% $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ %{EXECUTABLE}% : $(OBJS) - $(CXX) $(LDFLAGS) $(OBJS) \ - -o $@ + $(CXX) $(CXXFLAGS) -o $@ $(OBJS) $(LDFLAGS) all: %{EXECUTABLE}% diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/README_KEIL.md.tpl b/tensorflow/lite/experimental/micro/tools/make/templates/README_KEIL.md.tpl new file mode 100644 index 0000000000000000000000000000000000000000..945b9f9c1ae4c5761afb80febe57803d1e7fcab2 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/templates/README_KEIL.md.tpl @@ -0,0 +1,5 @@ +# TensorFlow Lite Micro Mbed Project + +This folder has been autogenerated by TensorFlow, and contains source, header, +and project files needed to build a single TensorFlow Lite Micro target using +the Keil uVision IDE. diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/keil_project.uvprojx.tpl b/tensorflow/lite/experimental/micro/tools/make/templates/keil_project.uvprojx.tpl new file mode 100644 index 0000000000000000000000000000000000000000..440d4b6b3e8a7894bc2b0c6afbd5ff78b54f198b --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/templates/keil_project.uvprojx.tpl @@ -0,0 +1,418 @@ + + + + 2.1 + + + + %{EXECUTABLE}% + 0x4 + ARM-ADS + 6100001::V6.10.1::.\ARMCLANG + 1 + + + STM32F746NGHx + STMicroelectronics + Keil.STM32F7xx_DFP.2.11.0 + http://www.keil.com/pack + IRAM(0x20010000,0x40000) IRAM2(0x20000000,0x10000) IROM(0x08000000,0x100000) IROM2(0x00200000,0x100000) CPUTYPE("Cortex-M7") FPU3(SFPU) CLOCK(12000000) ELITTLE + + + UL2CM3(-S0 -C0 -P0 -FD20010000 -FC1000 -FN2 -FF0STM32F7x_1024 -FS08000000 -FL0100000 -FF1STM32F7xTCM_1024 -FS1200000 -FL1100000 -FP0($$Device:STM32F746NGHx$CMSIS\Flash\STM32F7x_1024.FLM) -FP1($$Device:STM32F746NGHx$CMSIS\Flash\STM32F7xTCM_1024.FLM)) + 0 + $$Device:STM32F746NGHx$Drivers\CMSIS\Device\ST\STM32F7xx\Include\stm32f7xx.h + + + + + + + + + + $$Device:STM32F746NGHx$CMSIS\SVD\STM32F7x6_v1r1.svd + 0 + 0 + + + + + + + 0 + 0 + 0 + 0 + 1 + + .\Objects\ + %{EXECUTABLE}% + 1 + 0 + 0 + 1 + 1 + .\Listings\ + 1 + 0 + 0 + + 0 + 0 + + + 0 + 0 + 0 + 0 + + + 0 + 0 + + + 0 + 0 + 0 + 0 + + + 0 + 0 + + + 0 + 0 + 0 + 0 + + 0 + + + + 0 + 0 + 0 + 0 + 0 + 1 + 0 + 0 + 0 + 0 + 3 + + + 1 + + + SARMCM3.DLL + -REMAP -MPU + DCM.DLL + -pCM7 + SARMCM3.DLL + -MPU + TCM.DLL + -pCM7 + + + + 1 + 0 + 0 + 0 + 16 + + + + + 1 + 0 + 0 + 1 + 1 + -1 + + 1 + BIN\UL2CM3.DLL + + + + + + 0 + + + + 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 0 + 1 + 1 + 0 + 1 + 1 + 0 + 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 0 + 0 + "Cortex-M7" + + 0 + 0 + 0 + 1 + 1 + 0 + 0 + 2 + 0 + 1 + 1 + 8 + 0 + 0 + 0 + 0 + 4 + 4 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 1 + 0 + 0 + 0 + 0 + 1 + 1 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x20010000 + 0x40000 + + + 1 + 0x8000000 + 0x100000 + + + 0 + 0x0 + 0x0 + + + 1 + 0x0 + 0x0 + + + 1 + 0x0 + 0x0 + + + 1 + 0x0 + 0x0 + + + 1 + 0x8000000 + 0x100000 + + + 1 + 0x200000 + 0x100000 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x0 + 0x0 + + + 0 + 0x20010000 + 0x40000 + + + 0 + 0x20000000 + 0x10000 + + + + + + 1 + 7 + 0 + 0 + 1 + 0 + 0 + 0 + 0 + 0 + 3 + 0 + 0 + 0 + 1 + 0 + 3 + 3 + 1 + 1 + 0 + 0 + 0 + + + + + %{INCLUDE_PATHS}% + + + + 1 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + + + + + + + + + 0 + 0 + 0 + 0 + 1 + 0 + 0x08000000 + 0x20010000 + + + + + + + + + + + + + Source + +%{SRCS}% + + + + ::Compiler + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.make.tpl b/tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.make.tpl new file mode 100644 index 0000000000000000000000000000000000000000..141994d854565dc2ad2152e440c1d29526acb3dc --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.make.tpl @@ -0,0 +1,16 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=733558 + // for the documentation about the tasks.json format + "version": "2.0.0", + "tasks": [ + { + "label": "Make Build", + "type": "shell", + "command": "make", + "group": { + "kind": "build", + "isDefault": true + } + } + ] +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.mbed.tpl b/tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.mbed.tpl new file mode 100644 index 0000000000000000000000000000000000000000..616f3b23188df4af934433772c86c5c1a9452539 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/templates/tasks.json.mbed.tpl @@ -0,0 +1,39 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=733558 + // for the documentation about the tasks.json format + "version": "2.0.0", + "tasks": [ + { + "label": "Mbed Config Root", + "type": "shell", + "command": "mbed config root .", + }, + { + "label": "Mbed Deploy", + "type": "shell", + "command": "mbed deploy", + }, + { + "label": "Mbed Patch C++11", + "type": "shell", + "command": "python", + "args": [ + "-c", + "import fileinput, glob;\nfor filename in glob.glob(\"mbed-os/tools/profiles/*.json\"):\n for line in fileinput.input(filename, inplace=True):\n print line.replace(\"\\\"-std=gnu++98\\\"\",\"\\\"-std=c++11\\\", \\\"-fpermissive\\\"\")" + ] + }, + { + "label": "Mbed Init", + "dependsOn": ["Mbed Config Root", "Mbed Deploy", "Mbed Patch C++11"] + }, + { + "label": "Mbed build", + "type": "shell", + "command": "mbed compile -m auto -t GCC_ARM", + "group": { + "kind": "build", + "isDefault": true + } + } + ] +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/microfrontend/lib/BUILD b/tensorflow/lite/experimental/microfrontend/lib/BUILD index a055e52f71001295cf95dfcbe790bc4118140fed..8dd42fc38290dbf5be8f9f1a850ad88cbf326ace 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/BUILD +++ b/tensorflow/lite/experimental/microfrontend/lib/BUILD @@ -6,6 +6,11 @@ package( licenses(["notice"]) # Apache 2.0 +load( + "//tensorflow/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + cc_library( name = "bits", hdrs = ["bits.h"], @@ -117,72 +122,65 @@ cc_library( ], ) -cc_test( +tflite_micro_cc_test( name = "fft_test", - size = "small", srcs = ["fft_test.cc"], deps = [ ":fft", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) -cc_test( +tflite_micro_cc_test( name = "filterbank_test", - size = "small", srcs = ["filterbank_test.cc"], deps = [ ":filterbank", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) -cc_test( +tflite_micro_cc_test( name = "frontend_test", - size = "small", srcs = ["frontend_test.cc"], deps = [ ":frontend", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) -cc_test( +tflite_micro_cc_test( name = "log_scale_test", - size = "small", srcs = ["log_scale_test.cc"], deps = [ ":log_scale", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) -cc_test( +tflite_micro_cc_test( name = "noise_reduction_test", - size = "small", srcs = ["noise_reduction_test.cc"], deps = [ ":noise_reduction", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) -cc_test( +tflite_micro_cc_test( name = "pcan_gain_control_test", - size = "small", srcs = ["pcan_gain_control_test.cc"], deps = [ ":pcan_gain_control", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) -cc_test( +tflite_micro_cc_test( name = "window_test", - size = "small", srcs = ["window_test.cc"], deps = [ ":window", - "@com_google_googletest//:gtest_main", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/experimental/microfrontend/lib/bits.h b/tensorflow/lite/experimental/microfrontend/lib/bits.h index bf15466a3d6484c3059a1ded1bb51e4d4287b1bf..04b3ba6f055f956720b58720c78083b1529fb065 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/bits.h +++ b/tensorflow/lite/experimental/microfrontend/lib/bits.h @@ -63,14 +63,14 @@ static inline int CountLeadingZeros64Slow(uint64_t n) { static inline int CountLeadingZeros64(uint64_t n) { #if defined(_MSC_VER) && defined(_M_X64) - // MSVC does not have __buitin_clzll. Use _BitScanReverse64. + // MSVC does not have __builtin_clzll. Use _BitScanReverse64. unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse64(&result, n)) { return 63 - result; } return 64; #elif defined(_MSC_VER) - // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse + // MSVC does not have __builtin_clzll. Compose two calls to _BitScanReverse unsigned long result = 0; // NOLINT(runtime/int) if ((n >> 32) && _BitScanReverse(&result, n >> 32)) { return 31 - result; diff --git a/tensorflow/lite/experimental/microfrontend/lib/fft_test.cc b/tensorflow/lite/experimental/microfrontend/lib/fft_test.cc index 1b754c1b4c244edf1b091a581e5ae9399c2ac2e3..ec1f247ba24ad27917330708d6f9c754515a686b 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/fft_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_test.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/microfrontend/lib/fft.h" #include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h" -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { @@ -25,9 +24,13 @@ const int16_t kFakeWindow[] = { 0, -28328, 0, 21447, 0, -13312, 0, 5943, 0, -1152, 0}; const int kScaleShift = 0; -TEST(FftTest, CheckOutputValues) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FftTest_CheckOutputValues) { struct FftState state; - ASSERT_TRUE( + TF_LITE_MICRO_EXPECT( FftPopulateState(&state, sizeof(kFakeWindow) / sizeof(kFakeWindow[0]))); FftInit(&state); @@ -37,14 +40,15 @@ TEST(FftTest, CheckOutputValues) { {0, 0}, {-10, 9}, {-20, 0}, {-9, -10}, {0, 25}, {-119, 119}, {-887, 0}, {3000, 3000}, {0, -6401}, {-3000, 3000}, {886, 0}, {118, 119}, {0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}}; - ASSERT_EQ(state.fft_size / 2 + 1, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.fft_size / 2 + 1, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i <= state.fft_size / 2; ++i) { - EXPECT_EQ(state.output[i].real, expected[i].real); - EXPECT_EQ(state.output[i].imag, expected[i].imag); + TF_LITE_MICRO_EXPECT_EQ(state.output[i].real, expected[i].real); + TF_LITE_MICRO_EXPECT_EQ(state.output[i].imag, expected[i].imag); } FftFreeStateContents(&state); } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc b/tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc index 41f0064d4f1674471fa731e72464b1d40fce4216..16257aa11a5ca3e82aeff60f8cc0176de3c519ab 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc @@ -17,8 +17,7 @@ limitations under the License. #include -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { @@ -33,9 +32,9 @@ const uint64_t kWork[] = {1835887, 61162970173, 258694800000}; const int kScaleShift = 0; // Test filterbank generation using scaled-down defaults. -class FilterbankTest : public ::testing::Test { - protected: - FilterbankTest() { +class FilterbankTestConfig { + public: + FilterbankTestConfig() { config_.num_channels = 2; config_.lower_band_limit = 8.0; config_.upper_band_limit = 450.0; @@ -44,110 +43,124 @@ class FilterbankTest : public ::testing::Test { struct FilterbankConfig config_; }; -TEST_F(FilterbankTest, CheckStartIndex) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FilterbankTest_CheckStartIndex) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); - EXPECT_EQ(state.start_index, kStartIndex); + TF_LITE_MICRO_EXPECT_EQ(state.start_index, kStartIndex); FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckEndIndex) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckEndIndex) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); - EXPECT_EQ(state.end_index, kEndIndex); + TF_LITE_MICRO_EXPECT_EQ(state.end_index, kEndIndex); FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckChannelFrequencyStarts) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckChannelFrequencyStarts) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); const int16_t expected[] = {0, 4, 8}; - ASSERT_EQ(state.num_channels + 1, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i <= state.num_channels; ++i) { - EXPECT_EQ(state.channel_frequency_starts[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.channel_frequency_starts[i], expected[i]); } FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckChannelWeightStarts) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckChannelWeightStarts) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); const int16_t expected[] = {0, 8, 16}; - ASSERT_EQ(state.num_channels + 1, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i <= state.num_channels; ++i) { - EXPECT_EQ(state.channel_weight_starts[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.channel_weight_starts[i], expected[i]); } FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckChannelWidths) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckChannelWidths) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); const int16_t expected[] = {8, 8, 8}; - ASSERT_EQ(state.num_channels + 1, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i <= state.num_channels; ++i) { - EXPECT_EQ(state.channel_widths[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.channel_widths[i], expected[i]); } FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckWeights) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckWeights) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); const int16_t expected[] = {0, 3277, 2217, 1200, 222, 0, 0, 0, 0, 3376, 2468, 1591, 744, 0, 0, 0, 0, 4020, 3226, 2456, 1708, 983, 277, 0}; - ASSERT_EQ(state.channel_weight_starts[state.num_channels] + - state.channel_widths[state.num_channels], - sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.channel_weight_starts[state.num_channels] + + state.channel_widths[state.num_channels], + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { - EXPECT_EQ(state.weights[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.weights[i], expected[i]); } FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckUnweights) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckUnweights) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); const int16_t expected[] = {0, 819, 1879, 2896, 3874, 0, 0, 0, 0, 720, 1628, 2505, 3352, 0, 0, 0, 0, 76, 870, 1640, 2388, 3113, 3819, 0}; - ASSERT_EQ(state.channel_weight_starts[state.num_channels] + - state.channel_widths[state.num_channels], - sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.channel_weight_starts[state.num_channels] + + state.channel_widths[state.num_channels], + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { - EXPECT_EQ(state.unweights[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.unweights[i], expected[i]); } FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckConvertFftComplexToEnergy) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckConvertFftComplexToEnergy) { struct FilterbankState state; state.start_index = kStartIndex; state.end_index = kEndIndex; @@ -161,42 +174,46 @@ TEST_F(FilterbankTest, CheckConvertFftComplexToEnergy) { int i; for (i = state.start_index; i < state.end_index; ++i) { - EXPECT_EQ(energy[i], kEnergy[i]); + TF_LITE_MICRO_EXPECT_EQ(energy[i], kEnergy[i]); } } -TEST_F(FilterbankTest, CheckAccumulateChannels) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckAccumulateChannels) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); FilterbankAccumulateChannels(&state, kEnergy); - ASSERT_EQ(state.num_channels + 1, sizeof(kWork) / sizeof(kWork[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels + 1, + sizeof(kWork) / sizeof(kWork[0])); int i; for (i = 0; i <= state.num_channels; ++i) { - EXPECT_EQ(state.work[i], kWork[i]); + TF_LITE_MICRO_EXPECT_EQ(state.work[i], kWork[i]); } FilterbankFreeStateContents(&state); } -TEST_F(FilterbankTest, CheckSqrt) { +TF_LITE_MICRO_TEST(FilterbankTest_CheckSqrt) { + FilterbankTestConfig config; struct FilterbankState state; - ASSERT_TRUE( - FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + TF_LITE_MICRO_EXPECT(FilterbankPopulateState(&config.config_, &state, + kSampleRate, kSpectrumSize)); std::memcpy(state.work, kWork, sizeof(kWork)); uint32_t* scaled_filterbank = FilterbankSqrt(&state, kScaleShift); const uint32_t expected[] = {247311, 508620}; - ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.num_channels; ++i) { - EXPECT_EQ(scaled_filterbank[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(scaled_filterbank[i], expected[i]); } FilterbankFreeStateContents(&state); } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c b/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c index ce8b4acc0f696f9c4123bab9daeb1e8802c3e828..5a668fa92f17f97cb97ed8c3f640d65b1821e3de 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c @@ -28,9 +28,7 @@ void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config) { config->output_scale_shift = 7; } -static float FreqToMel(float freq) { - return 1127.0 * log(1.0 + (freq / 700.0)); -} +static float FreqToMel(float freq) { return 1127.0 * log1p(freq / 700.0); } static void CalculateCenterFrequencies(const int num_channels, const float lower_frequency_limit, diff --git a/tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc b/tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc index a6faa1fc1f51360e295253fb2b3cfdf01ada74ad..568484f14dde6b958d5c9e144ab8dfd7a68a0fb0 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { @@ -29,9 +28,9 @@ const int16_t kFakeAudioData[] = { 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768}; // Test end-to-end frontend behaviors. -class FrontendTest : public ::testing::Test { - protected: - FrontendTest() { +class FrontendTestConfig { + public: + FrontendTestConfig() { config_.window.size_ms = 25; config_.window.step_size_ms = 10; config_.noise_reduction.smoothing_bits = 10; @@ -53,9 +52,15 @@ class FrontendTest : public ::testing::Test { struct FrontendConfig config_; }; -TEST_F(FrontendTest, CheckOutputValues) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FrontendTest_CheckOutputValues) { + FrontendTestConfig config; struct FrontendState state; - ASSERT_TRUE(FrontendPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + FrontendPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; struct FrontendOutput output = FrontendProcessSamples( @@ -63,18 +68,20 @@ TEST_F(FrontendTest, CheckOutputValues) { sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read); const uint16_t expected[] = {479, 425}; - ASSERT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < output.size; ++i) { - EXPECT_EQ(output.values[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(output.values[i], expected[i]); } FrontendFreeStateContents(&state); } -TEST_F(FrontendTest, CheckConsecutiveWindow) { +TF_LITE_MICRO_TEST(FrontendTest_CheckConsecutiveWindow) { + FrontendTestConfig config; struct FrontendState state; - ASSERT_TRUE(FrontendPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + FrontendPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; FrontendProcessSamples(&state, kFakeAudioData, @@ -86,18 +93,20 @@ TEST_F(FrontendTest, CheckConsecutiveWindow) { &num_samples_read); const int16_t expected[] = {436, 378}; - ASSERT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < output.size; ++i) { - EXPECT_EQ(output.values[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(output.values[i], expected[i]); } FrontendFreeStateContents(&state); } -TEST_F(FrontendTest, CheckNotEnoughSamples) { +TF_LITE_MICRO_TEST(FrontendTest_CheckNotEnoughSamples) { + FrontendTestConfig config; struct FrontendState state; - ASSERT_TRUE(FrontendPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + FrontendPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; FrontendProcessSamples(&state, kFakeAudioData, @@ -113,10 +122,10 @@ TEST_F(FrontendTest, CheckNotEnoughSamples) { kStepSamples, &num_samples_read); - EXPECT_EQ(output.size, 0); - EXPECT_EQ(output.values, nullptr); + TF_LITE_MICRO_EXPECT_EQ(output.size, 0); + TF_LITE_MICRO_EXPECT_EQ(output.values, nullptr); FrontendFreeStateContents(&state); } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc b/tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc index 1ea0842ec2ad1065782198b635bf8b4858d6bf3a..be52fd426a23a389aac84e4b2dac832924716f83 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc @@ -15,15 +15,18 @@ limitations under the License. #include "tensorflow/lite/experimental/microfrontend/lib/log_scale.h" #include "tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h" -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { const int kScaleShift = 6; const int kCorrectionBits = -1; -TEST(LogScaleTest, CheckOutputValues) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(LogScaleTest_CheckOutputValues) { struct LogScaleState state; state.enable_log = true; state.scale_shift = kScaleShift; @@ -36,11 +39,11 @@ TEST(LogScaleTest, CheckOutputValues) { const uint16_t expected[] = {479, 425}; int i; for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { - EXPECT_EQ(output[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(output[i], expected[i]); } } -TEST(LogScaleTest, CheckOutputValuesNoLog) { +TF_LITE_MICRO_TEST(LogScaleTest_CheckOutputValuesNoLog) { struct LogScaleState state; state.enable_log = false; state.scale_shift = kScaleShift; @@ -53,8 +56,8 @@ TEST(LogScaleTest, CheckOutputValuesNoLog) { const uint16_t expected[] = {65535, 45998}; int i; for (i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { - EXPECT_EQ(output[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(output[i], expected[i]); } } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc index 13d58b2476762d89ee79be554be12a9b7a897ad5..ba864c427ced36748167c9412fe2966d72d3cb0e 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc @@ -15,17 +15,16 @@ limitations under the License. #include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h" #include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h" -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { const int kNumChannels = 2; // Test noise reduction using default config values. -class NoiseReductionTest : public ::testing::Test { - protected: - NoiseReductionTest() { +class NoiseReductionTestConfig { + public: + NoiseReductionTestConfig() { config_.smoothing_bits = 10; config_.even_smoothing = 0.025; config_.odd_smoothing = 0.06; @@ -35,38 +34,48 @@ class NoiseReductionTest : public ::testing::Test { struct NoiseReductionConfig config_; }; -TEST_F(NoiseReductionTest, TestNoiseReductionEstimate) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(NoiseReductionTest_TestNoiseReductionEstimate) { + NoiseReductionTestConfig config; struct NoiseReductionState state; - ASSERT_TRUE(NoiseReductionPopulateState(&config_, &state, kNumChannels)); + TF_LITE_MICRO_EXPECT( + NoiseReductionPopulateState(&config.config_, &state, kNumChannels)); uint32_t signal[] = {247311, 508620}; NoiseReductionApply(&state, signal); const uint32_t expected[] = {6321887, 31248341}; - ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.num_channels; ++i) { - EXPECT_EQ(state.estimate[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.estimate[i], expected[i]); } NoiseReductionFreeStateContents(&state); } -TEST_F(NoiseReductionTest, TestNoiseReduction) { +TF_LITE_MICRO_TEST(NoiseReductionTest_TestNoiseReduction) { + NoiseReductionTestConfig config; struct NoiseReductionState state; - ASSERT_TRUE(NoiseReductionPopulateState(&config_, &state, kNumChannels)); + TF_LITE_MICRO_EXPECT( + NoiseReductionPopulateState(&config.config_, &state, kNumChannels)); uint32_t signal[] = {247311, 508620}; NoiseReductionApply(&state, signal); const uint32_t expected[] = {241137, 478104}; - ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.num_channels; ++i) { - EXPECT_EQ(signal[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(signal[i], expected[i]); } NoiseReductionFreeStateContents(&state); } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc index 7c92d2d29d0e41d5e378a596c5a06e8418edfa8d..93d7a8bcb94d5e0145b9ee701b413194f2946a7b 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h" #include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { @@ -25,9 +24,9 @@ const int kSmoothingBits = 10; const int kCorrectionBits = -1; // Test pcan auto gain control using default config values. -class PcanGainControlTest : public ::testing::Test { - protected: - PcanGainControlTest() { +class PcanGainControlTestConfig { + public: + PcanGainControlTestConfig() { config_.enable_pcan = 1; config_.strength = 0.95; config_.offset = 80.0; @@ -37,24 +36,30 @@ class PcanGainControlTest : public ::testing::Test { struct PcanGainControlConfig config_; }; -TEST_F(PcanGainControlTest, TestPcanGainControl) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(PcanGainControlTest_TestPcanGainControl) { uint32_t estimate[] = {6321887, 31248341}; + PcanGainControlTestConfig config; struct PcanGainControlState state; - ASSERT_TRUE(PcanGainControlPopulateState(&config_, &state, estimate, - kNumChannels, kSmoothingBits, - kCorrectionBits)); + TF_LITE_MICRO_EXPECT(PcanGainControlPopulateState( + &config.config_, &state, estimate, kNumChannels, kSmoothingBits, + kCorrectionBits)); uint32_t signal[] = {241137, 478104}; PcanGainControlApply(&state, signal); const uint32_t expected[] = {3578, 1533}; - ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.num_channels, + sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.num_channels; ++i) { - EXPECT_EQ(signal[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(signal[i], expected[i]); } PcanGainControlFreeStateContents(&state); } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/lib/window_test.cc b/tensorflow/lite/experimental/microfrontend/lib/window_test.cc index 60f11440f56ea39a25a5aa2beb23eb25a83048b3..cf9df523b8f5e540d47c8e6d3d42e37540f62ce0 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/window_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/window_test.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/lite/experimental/microfrontend/lib/window.h" #include "tensorflow/lite/experimental/microfrontend/lib/window_util.h" -#include -#include +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace { @@ -29,9 +28,9 @@ const int16_t kFakeAudioData[] = { 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768}; // Test window function behaviors using default config values. -class WindowTest : public ::testing::Test { - protected: - WindowTest() { +class WindowTestConfig { + public: + WindowTestConfig() { config_.size_ms = 25; config_.step_size_ms = 10; } @@ -39,84 +38,98 @@ class WindowTest : public ::testing::Test { struct WindowConfig config_; }; -TEST_F(WindowTest, CheckCoefficients) { +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(WindowState_CheckCoefficients) { + WindowTestConfig config; struct WindowState state; - ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + WindowPopulateState(&config.config_, &state, kSampleRate)); const int16_t expected[] = {16, 144, 391, 743, 1176, 1664, 2177, 2681, 3145, 3541, 3843, 4032, 4096, 4032, 3843, 3541, 3145, 2681, 2177, 1664, 1176, 743, 391, 144, 16}; - ASSERT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.size; ++i) { - EXPECT_EQ(state.coefficients[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.coefficients[i], expected[i]); } WindowFreeStateContents(&state); } -TEST_F(WindowTest, CheckResidualInput) { +TF_LITE_MICRO_TEST(WindowState_CheckResidualInput) { + WindowTestConfig config; struct WindowState state; - ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + WindowPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); int i; for (i = kStepSamples; i < kWindowSamples; ++i) { - EXPECT_EQ(state.input[i - kStepSamples], kFakeAudioData[i]); + TF_LITE_MICRO_EXPECT_EQ(state.input[i - kStepSamples], kFakeAudioData[i]); } WindowFreeStateContents(&state); } -TEST_F(WindowTest, CheckOutputValues) { +TF_LITE_MICRO_TEST(WindowState_CheckOutputValues) { + WindowTestConfig config; struct WindowState state; - ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + WindowPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); const int16_t expected[] = { 0, 1151, 0, -5944, 0, 13311, 0, -21448, 0, 28327, 0, -32256, 0, 32255, 0, -28328, 0, 21447, 0, -13312, 0, 5943, 0, -1152, 0}; - ASSERT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.size; ++i) { - EXPECT_EQ(state.output[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.output[i], expected[i]); } WindowFreeStateContents(&state); } -TEST_F(WindowTest, CheckMaxAbsValue) { +TF_LITE_MICRO_TEST(WindowState_CheckMaxAbsValue) { + WindowTestConfig config; struct WindowState state; - ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + WindowPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); - EXPECT_EQ(state.max_abs_output_value, 32256); + TF_LITE_MICRO_EXPECT_EQ(state.max_abs_output_value, 32256); WindowFreeStateContents(&state); } -TEST_F(WindowTest, CheckConsecutiveWindow) { +TF_LITE_MICRO_TEST(WindowState_CheckConsecutiveWindow) { + WindowTestConfig config; struct WindowState state; - ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + WindowPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData + kWindowSamples, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, &num_samples_read)); @@ -124,38 +137,41 @@ TEST_F(WindowTest, CheckConsecutiveWindow) { const int16_t expected[] = { 0, -1152, 0, 5943, 0, -13312, 0, 21447, 0, -28328, 0, 32255, 0, -32256, 0, 28327, 0, -21448, 0, 13311, 0, -5944, 0, 1151, 0}; - ASSERT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + TF_LITE_MICRO_EXPECT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); int i; for (i = 0; i < state.size; ++i) { - EXPECT_EQ(state.output[i], expected[i]); + TF_LITE_MICRO_EXPECT_EQ(state.output[i], expected[i]); } WindowFreeStateContents(&state); } -TEST_F(WindowTest, CheckNotEnoughSamples) { +TF_LITE_MICRO_TEST(WindowState_CheckNotEnoughSamples) { + WindowTestConfig config; struct WindowState state; - ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + TF_LITE_MICRO_EXPECT( + WindowPopulateState(&config.config_, &state, kSampleRate)); size_t num_samples_read; - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); - ASSERT_TRUE(WindowProcessSamples( + TF_LITE_MICRO_EXPECT(WindowProcessSamples( &state, kFakeAudioData + kWindowSamples, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, &num_samples_read)); - ASSERT_FALSE(WindowProcessSamples( - &state, kFakeAudioData + kWindowSamples + kStepSamples, - sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples - - kStepSamples, - &num_samples_read)); - - EXPECT_EQ( + TF_LITE_MICRO_EXPECT_EQ( + false, WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples + kStepSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - + kWindowSamples - kStepSamples, + &num_samples_read)); + + TF_LITE_MICRO_EXPECT_EQ( state.input_used, sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - 2 * kStepSamples); WindowFreeStateContents(&state); } -} // namespace +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py b/tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py index 3d49482f4ecd34ec47df1d3baa3e6dccf8ae4bef..52bad311436db5374b2c85da8cda30e962cf0dba 100644 --- a/tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py +++ b/tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py @@ -96,7 +96,7 @@ def audio_microfrontend(audio, Raises: ValueError: If the audio tensor is not explicitly a vector. """ - audio_shape = audio.get_shape() + audio_shape = audio.shape if audio_shape.ndims is None: raise ValueError("Input to `AudioMicrofrontend` should have known rank.") if len(audio_shape) > 1: diff --git a/tensorflow/lite/experimental/objc/BUILD.apple b/tensorflow/lite/experimental/objc/BUILD.apple new file mode 100644 index 0000000000000000000000000000000000000000..eef9e322311dd96738580a4af9ec61f6e2e745db --- /dev/null +++ b/tensorflow/lite/experimental/objc/BUILD.apple @@ -0,0 +1,110 @@ +# TensorFlow Lite for Objective-C + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") + +SOURCES = glob([ + "sources/*.h", + "sources/*.m", + "sources/*.mm", +]) + +API_HEADERS = glob([ + "apis/*.h", +]) + +MINIMUM_OS_VERSION = "9.0" + +# Compiler flags for building regular non-test libraries. +RELEASE_COPTS = [ + # Enables language-specific warnings for Objective-C, Objective-C++, C, and C++. + "-Wall", + # Warns if functions, variables, and types marked with the deprecated attribute are being used. + "-Wdeprecated-declarations", + # Warns for errors in documentation. + "-Wdocumentation", + # Turns all warnings into errors. + "-Werror", + # Enables extra warning flags that are not enabled by -Wall. + "-Wextra", + # Warns if a global function is defined without a previous prototype declaration. + "-Wmissing-prototypes", + # From -Wextra. Disables warning when signed value is converted to unsigned value during comparison. + "-Wno-sign-compare", + # From -Wextra. Disables warning for unused parameters, which are common in delegate methods and block callbacks. + "-Wno-unused-parameter", + # Warns if a global or local variable or type declaration shadows another variable, parameter, type, class member, or instance variable. + "-Wshadow", + # Warns if a function is declared or defined without specifying the argument types. For a block with no args, use (void) instead of (). + "-Wstrict-prototypes", + # Warns if an @selector() expression is encountered with a method name that hasn't been defined yet. + "-Wundeclared-selector", + # Turn off warnings for headers not part of TensorFlow Lite Objective-C API. + "--system-header-prefix=tensorflow/lite/experimental/c/", +] + +# Compiler flags for building test libraries. +TEST_COPTS = RELEASE_COPTS + [ + # From -Wall. Disables warning when passing nil to a callee that requires a non-null argument. + "-Wno-nonnull", + # Disables warning when a global or local variable or type declaration shadows another. + "-Wno-shadow", +] + +# Default tags for filtering targets. Targets in this file are restricted to Apple platforms. +DEFAULT_TAGS = [ + "apple", +] + +objc_library( + name = "TensorFlowLite", + srcs = SOURCES, + hdrs = API_HEADERS, + copts = RELEASE_COPTS, + tags = DEFAULT_TAGS, + deps = [ + "//tensorflow/lite/experimental/c:c_api", + ], + alwayslink = 1, +) + +ios_unit_test( + name = "TensorFlowLiteTests", + size = "small", + minimum_os_version = MINIMUM_OS_VERSION, + tags = DEFAULT_TAGS + [ + # These sanitizer tests are not supported by iOS build toolchain (b/74292221). + # Disabled these for iOS test targets. + "noasan", + "notsan", + "nomsan", + ], + deps = [ + ":TestsLib", + ], +) + +objc_library( + name = "TestsLib", + testonly = 1, + srcs = glob([ + "tests/*.m", + ]), + hdrs = glob([ + "apis/*.h", + "sources/*.h", + "tests/*.h", + ]), + copts = TEST_COPTS, + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + ], + tags = DEFAULT_TAGS, + deps = [ + ":TensorFlowLite", + ], +) diff --git a/tensorflow/lite/experimental/objc/README.md b/tensorflow/lite/experimental/objc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9dd5aa697392d2867fb9d57159a5f9b79ba3847b --- /dev/null +++ b/tensorflow/lite/experimental/objc/README.md @@ -0,0 +1,54 @@ +# TensorFlow Lite for Objective-C + +[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight +solution for Objective-C developers. It enables low-latency inference of +on-device machine learning models with a small binary size and fast performance +supporting hardware acceleration. + +## Getting Started + +### Bazel + +In your `BUILD` file, add the `TensorFlowLite` dependency: + +```python +objc_library( + deps = [ + "//tensorflow/lite/experimental/objc:TensorFlowLite", + ], +) +``` + +If you would like to build the Objective-C TensorFlow Lite library using Bazel on Apple +platforms, clone or download the [TensorFlow GitHub repo](https://github.com/tensorflow/tensorflow), +then navigate to the root `tensorflow` directory and execute the `configure.py` script: + +```shell +python configure.py +``` + +Follow the prompts and when asked to configure the Bazel rules for Apple +platforms, enter `y`. + +Build the `TensorFlowLite` Objective-C library target: + +```shell +bazel build tensorflow/lite/experimental/objc:TensorFlowLite +``` + +Build the `TensorFlowLiteTests` target: + +```shell +bazel test tensorflow/lite/experimental/objc:TensorFlowLiteTests +``` + +### Tulsi + +Open the `TensorFlowLite.tulsiproj` using the +[TulsiApp](https://github.com/bazelbuild/tulsi) or by running the +[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh) +script from the root `tensorflow` directory: + +```shell +generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj +``` diff --git a/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen new file mode 100644 index 0000000000000000000000000000000000000000..04a2a2c19cdb0bd259705e98b43ea6980305524e --- /dev/null +++ b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen @@ -0,0 +1,60 @@ +{ + "sourceFilters" : [ + "tensorflow/lite", + "tensorflow/lite/experimental/c", + "tensorflow/lite/experimental/objc", + "tensorflow/lite/experimental/objc/apis", + "tensorflow/lite/experimental/objc/sources", + "tensorflow/lite/experimental/objc/tests", + "tensorflow/lite/kernels", + "tensorflow/lite/kernels/internal", + "tensorflow/lite/nnapi", + "tensorflow/lite/schema", + ], + "buildTargets" : [ + "//tensorflow/lite/experimental/objc:TensorFlowLite", + "//tensorflow/lite/experimental/objc:TensorFlowLiteTests", + ], + "projectName" : "TensorFlowLite", + "optionSet" : { + "LaunchActionPreActionScript" : { + "p" : "$(inherited)" + }, + "BazelBuildStartupOptionsRelease" : { + "p" : "$(inherited)" + }, + "BazelBuildOptionsRelease" : { + "p" : "$(inherited)" + }, + "BazelBuildOptionsDebug" : { + "p" : "$(inherited)" + }, + "EnvironmentVariables" : { + "p" : "$(inherited)" + }, + "BuildActionPreActionScript" : { + "p" : "$(inherited)" + }, + "CommandlineArguments" : { + "p" : "$(inherited)" + }, + "TestActionPreActionScript" : { + "p" : "$(inherited)" + }, + "BazelBuildStartupOptionsDebug" : { + "p" : "$(inherited)" + }, + "BuildActionPostActionScript" : { + "p" : "$(inherited)" + }, + "TestActionPostActionScript" : { + "p" : "$(inherited)" + }, + "LaunchActionPostActionScript" : { + "p" : "$(inherited)" + } + }, + "additionalFilePaths" : [ + "tensorflow/lite/experimental/objc/BUILD", + ] +} diff --git a/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/project.tulsiconf b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/project.tulsiconf new file mode 100644 index 0000000000000000000000000000000000000000..e92cb481386d88a8210b14f19ed92e7e47d9033f --- /dev/null +++ b/tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj/project.tulsiconf @@ -0,0 +1,17 @@ +{ + "configDefaults" : { + "optionSet" : { + "BazelBuildOptionsDebug" : { + + }, + "BazelBuildOptionsRelease" : { + + }, + } + }, + "projectName" : "TensorFlowLite", + "packages" : [ + "tensorflow/lite/experimental/objc" + ], + "workspaceRoot" : "../../../../.." +} diff --git a/tensorflow/lite/experimental/objc/apis/TFLInterpreter.h b/tensorflow/lite/experimental/objc/apis/TFLInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..3c06a4bc82f752baabdb6db100ee96e9ce29d29f --- /dev/null +++ b/tensorflow/lite/experimental/objc/apis/TFLInterpreter.h @@ -0,0 +1,179 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@class TFLInterpreterOptions; +@class TFLTensor; + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum TFLInterpreterErrorCode + * This enum specifies various error codes related to `TFLInterpreter`. + */ +typedef NS_ENUM(NSUInteger, TFLInterpreterErrorCode) { + /** Provided tensor index is invalid. */ + TFLInterpreterErrorCodeInvalidTensorIndex, + + /** Input data has invalid byte size. */ + TFLInterpreterErrorCodeInvalidInputByteSize, + + /** Provided shape is invalid. It must be a non-empty array of positive unsigned integers. */ + TFLInterpreterErrorCodeInvalidShape, + + /** Provided model cannot be loaded. */ + TFLInterpreterErrorCodeFailedToLoadModel, + + /** Failed to create `TFLInterpreter`. */ + TFLInterpreterErrorCodeFailedToCreateInterpreter, + + /** Failed to invoke `TFLInterpreter`. */ + TFLInterpreterErrorCodeFailedToInvoke, + + /** Failed to retrieve a tensor. */ + TFLInterpreterErrorCodeFailedToGetTensor, + + /** Invalid tensor. */ + TFLInterpreterErrorCodeInvalidTensor, + + /** Failed to resize an input tensor. */ + TFLInterpreterErrorCodeFailedToResizeInputTensor, + + /** Failed to copy data into an input tensor. */ + TFLInterpreterErrorCodeFailedToCopyDataToInputTensor, + + /** Copying data into an output tensor not allowed. */ + TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed, + + /** Failed to get data from a tensor. */ + TFLInterpreterErrorCodeFailedToGetDataFromTensor, + + /** Failed to allocate memory for tensors. */ + TFLInterpreterErrorCodeFailedToAllocateTensors, + + /** Operaton not allowed without allocating memory for tensors first. */ + TFLInterpreterErrorCodeAllocateTensorsRequired, + + /** Operaton not allowed without invoking the interpreter first. */ + TFLInterpreterErrorCodeInvokeInterpreterRequired, +}; + +/** + * A TensorFlow Lite model interpreter. + */ +@interface TFLInterpreter : NSObject + +/** The total number of input tensors. 0 if the interpreter creation failed. */ +@property(nonatomic, readonly) NSUInteger inputTensorCount; + +/** The total number of output tensors. 0 if the interpreter creation failed. */ +@property(nonatomic, readonly) NSUInteger outputTensorCount; + +/** Unavailable. */ +- (instancetype)init NS_UNAVAILABLE; + +/** + * Initializes a new TensorFlow Lite interpreter instance with the given model file path and the + * default interpreter options. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * interpreter. + * + * @return A new instance of `TFLInterpreter` with the given model and the default interpreter + * options. `nil` if there is an error in initializing the interpreter. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Initializes a new TensorFlow Lite interpreter instance with the given model file path and + * options. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param options Options to use for configuring the TensorFlow Lite interpreter. + * @param error An optional error parameter populated when there is an error in initializing the + * interpreter. + * + * @return A new instance of `TFLInterpreter` with the given model and options. `nil` if there is an + * error in initializing the interpreter. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath + options:(TFLInterpreterOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Invokes the interpreter to run inference. + * + * @param error An optional error parameter populated when there is an error in invoking the + * interpreter. + * + * @return Whether the invocation is successful. Returns NO if an error occurred. + */ +- (BOOL)invokeWithError:(NSError **)error; + +/** + * Returns the input tensor at the given index. + * + * @param index The index of an input tensor. + * @param error An optional error parameter populated when there is an error in looking up the input + * tensor. + * + * @return The input tensor at the given index. `nil` if there is an error. See the `TFLTensor` + * class documentation for more details on the life expectancy between the returned tensor and + * this interpreter. + */ +- (nullable TFLTensor *)inputTensorAtIndex:(NSUInteger)index error:(NSError **)error; + +/** + * Returns the output tensor at the given index. + * + * @param index The index of an output tensor. + * @param error An optional error parameter populated when there is an error in looking up the + * output tensor. + * + * @return The output tensor at the given index. `nil` if there is an error. See the `TFLTensor` + * class documentation for more details on the life expectancy between the returned tensor and + * this interpreter. + */ +- (nullable TFLTensor *)outputTensorAtIndex:(NSUInteger)index error:(NSError **)error; + +/** + * Resizes the input tensor at the given index to the specified shape (an array of positive unsigned + * integers). + * + * @param index The index of an input tensor. + * @param shape Shape that the given input tensor should be resized to. It should be an array of + * positive unsigned integer(s) containing the size of each dimension. + * @param error An optional error parameter populated when there is an error in resizing the input + * tensor. + * + * @return Whether the input tensor was resized successfully. Returns NO if an error occurred. + */ +- (BOOL)resizeInputTensorAtIndex:(NSUInteger)index + toShape:(NSArray *)shape + error:(NSError **)error; + +/** + * Allocates memory for tensors. + * + * @param error An optional error parameter populated when there is an error in allocating memory. + * + * @return Whether memory allocation is successful. Returns NO if an error occurred. + */ +- (BOOL)allocateTensorsWithError:(NSError **)error; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h b/tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..6461fbf0178b1e72afb81e91d58109a2d7b0226b --- /dev/null +++ b/tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h @@ -0,0 +1,37 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** Custom configuration options for a TensorFlow Lite interpreter. */ +@interface TFLInterpreterOptions : NSObject + +/** + * Maximum number of threads that the interpreter should run on. Defaults to 0 (unspecified, letting + * TensorFlow Lite to optimize the threading decision). + */ +@property(nonatomic) NSUInteger numberOfThreads; + +/** + * Initializes a new instance of `TFLInterpreterOptions`. + * + * @return A new instance of `TFLInterpreterOptions`. + */ +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h b/tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h new file mode 100644 index 0000000000000000000000000000000000000000..3d5cf793c5bed984debe3a36fdec4f0945cd7c64 --- /dev/null +++ b/tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h @@ -0,0 +1,36 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Parameters for asymmetric quantization. Quantized values can be converted to float values using: + * `realValue = scale * (quantizedValue - zeroPoint)`. + */ +@interface TFLQuantizationParameters : NSObject + +/** Scale of asymmetric quantization. */ +@property(nonatomic, readonly) float scale; + +/** Zero point of asymmetric quantization. */ +@property(nonatomic, readonly) int32_t zeroPoint; + +/** Unavailable. */ +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/apis/TFLTensor.h b/tensorflow/lite/experimental/objc/apis/TFLTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..dc710abf4e2ea99126be2fb359412287f3c37a33 --- /dev/null +++ b/tensorflow/lite/experimental/objc/apis/TFLTensor.h @@ -0,0 +1,111 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@class TFLQuantizationParameters; + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum TFLTensorDataType + * This enum specifies supported TensorFlow Lite tensor data types. + */ +typedef NS_ENUM(NSUInteger, TFLTensorDataType) { + /** Tensor data type not available. This indicates an error with the model. */ + TFLTensorDataTypeNoType, + + /** 32-bit single precision floating point. */ + TFLTensorDataTypeFloat32, + + /** 32-bit signed integer. */ + TFLTensorDataTypeInt32, + + /** 8-bit unsigned integer. */ + TFLTensorDataTypeUInt8, + + /** 64-bit signed integer. */ + TFLTensorDataTypeInt64, + + /** Boolean. */ + TFLTensorDataTypeBool, + + /** 16-bit signed integer. */ + TFLTensorDataTypeInt16, + + /** 8-bit signed integer. */ + TFLTensorDataTypeInt8, +}; + +/** + * An input or output tensor in a TensorFlow Lite model. + * + * @warning Each `TFLTensor` instance is associated with a `TFLInterpreter` instance. Multiple + * `TFLTensor` instances of the same TensorFlow Lite model are associated with the same + * `TFLInterpreter` instance. As long as a `TFLTensor` instance is still in use, its associated + * `TFLInterpreter` instance will not be deallocated. + */ +@interface TFLTensor : NSObject + +/** Name of the tensor. */ +@property(nonatomic, readonly, copy) NSString *name; + +/** Data type of the tensor. */ +@property(nonatomic, readonly) TFLTensorDataType dataType; + +/** Parameters for asymmetric quantization. `nil` if the tensor does not use quantization. */ +@property(nonatomic, readonly, nullable) TFLQuantizationParameters *quantizationParameters; + +/** Unavailable. */ +- (instancetype)init NS_UNAVAILABLE; + +/** + * Copies the given data into an input tensor. This is allowed only for an input tensor and only + * before the interpreter is invoked; otherwise an error will be returned. + * + * @param data The data to set. The byte size of the data must match what's required by the input + * tensor. + * @param error An optional error parameter populated when there is an error in copying the data. + * + * @return Whether the data was copied into the input tensor successfully. Returns NO if an error + * occurred. + */ +- (BOOL)copyData:(NSData *)data error:(NSError **)error; + +/** + * Retrieves a copy of data in the tensor. For an output tensor, the data is only available after + * the interpreter invocation has successfully completed; otherwise an error will be returned. + * + * @param error An optional error parameter populated when there is an error in retrieving the data. + * + * @return A copy of data in the tensor. `nil` if there is an error in retrieving the data or the + * data is not available. + */ +- (nullable NSData *)dataWithError:(NSError **)error; + +/** + * Retrieves the shape of the tensor, an array of positive unsigned integers containing the size + * of each dimension. For example: the shape of [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] is + * [2, 2, 3] (i.e. an array of 2 arrays of 2 arrays of 3 numbers). + * + * @param error An optional error parameter populated when there is an error in retrieving the + * shape. + * + * @return The shape of the tensor. `nil` if there is an error in retrieving the shape. + */ +- (nullable NSArray *)shapeWithError:(NSError **)error; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLErrorUtil.h b/tensorflow/lite/experimental/objc/sources/TFLErrorUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..ce8d50c896e6d5716308c7dc8818258fa38dbd72 --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLErrorUtil.h @@ -0,0 +1,40 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Helper utility for error reporting. */ +@interface TFLErrorUtil : NSObject + +/** + * Creates and saves an interpreter error with the given error code and description. + * + * @param code Error code. + * @param description Error description. + * @param error Pointer to where to save the created error. If `nil`, no error will be saved. + */ ++ (void)saveInterpreterErrorWithCode:(TFLInterpreterErrorCode)code + description:(NSString *)description + error:(NSError **)error; + +/** Unavailable. */ +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLErrorUtil.m b/tensorflow/lite/experimental/objc/sources/TFLErrorUtil.m new file mode 100644 index 0000000000000000000000000000000000000000..aa973c780060f4fa67573ff1e224ab0aed2bc92b --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLErrorUtil.m @@ -0,0 +1,38 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "TFLErrorUtil.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Error domain of TensorFlow Lite interpreter related errors. */ +static NSString *const TFLInterpreterErrorDomain = @"org.tensorflow.lite.interpreter"; + +@implementation TFLErrorUtil + +#pragma mark - Public + ++ (void)saveInterpreterErrorWithCode:(TFLInterpreterErrorCode)code + description:(NSString *)description + error:(NSError **)error { + if (error) { + *error = [NSError errorWithDomain:TFLInterpreterErrorDomain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; + } +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter+Internal.h b/tensorflow/lite/experimental/objc/sources/TFLInterpreter+Internal.h new file mode 100644 index 0000000000000000000000000000000000000000..9b900c4f050451061a5d1a02b8be4dc51cade175 --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter+Internal.h @@ -0,0 +1,63 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h" + +@class TFLTensor; + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLInterpreter (Internal) + +/** + * Copies the given data into the input tensor at the given index. This is allowed only before the + * interpreter is invoked. + * + * @param data The data to set. The byte size of the data must match what's required by the input + * tensor at the given index. + * @param index An input tensor index. + * @param error An optional error parameter populated when there is an error in setting the data. + * + * @return Whether the data was copied into the input tensor at the given index successfully. + * Returns NO if an error occurred. + */ +- (BOOL)copyData:(NSData *)data toInputTensorAtIndex:(NSUInteger)index error:(NSError **)error; + +/** + * Retrieves a copy of the data from the given tensor. For an output tensor, the interpreter + * invocation has to complete before the data can be retrieved. + * + * @param tensor A tensor. + * @param error An optional error parameter populated when there is an error in getting the data. + * + * @return The data of the given tensor. `nil` if there is an error or data is not available. + */ +- (nullable NSData *)dataFromTensor:(TFLTensor *)tensor error:(NSError **)error; + +/** + * Retrieves the shape of the given tensor, an array of positive unsigned integer(s) containing the + * size of each dimension. For example: shape of [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] is + * [2, 2, 3]. + * + * @param tensor An input or output tensor. + * @param error An optional error parameter populated when there is an error in retrieving the + * shape. + * + * @return The shape of the tensor. `nil` if there is an error in retrieving the shape. + */ +- (nullable NSArray *)shapeOfTensor:(TFLTensor *)tensor error:(NSError **)error; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm new file mode 100644 index 0000000000000000000000000000000000000000..a8ca982f6dd619f9a01bd67cc028ee6fb583a75d --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm @@ -0,0 +1,407 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h" + +#import "TFLErrorUtil.h" +#import "TFLQuantizationParameters+Internal.h" +#import "TFLTensor+Internal.h" +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h" +#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h" + +#include "tensorflow/lite/experimental/c/c_api.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Error reporter for TFLInterpreter. + * + * @param user_data User data. Not used. + * @param format Error message which may contain argument formatting specifiers. + * @param args Values of the arguments in the error message. + */ +static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_list args) { + NSLog(@"%@", [[NSString alloc] initWithFormat:@(format) arguments:args]); +} + +@interface TFLInterpreter () + +/** TFL_Interpreter backed by C API. */ +@property(nonatomic, nullable) TFL_Interpreter *interpreter; + +@end + +@implementation TFLInterpreter + +#pragma mark - NSObject + +- (void)dealloc { + TFL_DeleteInterpreter(_interpreter); +} + +#pragma mark - Public + +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + return [self initWithModelPath:modelPath + options:[[TFLInterpreterOptions alloc] init] + error:error]; +} + +- (nullable instancetype)initWithModelPath:(NSString *)modelPath + options:(TFLInterpreterOptions *)options + error:(NSError **)error { + self = [super init]; + + if (self != nil) { + TFL_Model *model = nullptr; + TFL_InterpreterOptions *cOptions = nullptr; + + @try { + const char *modelPathCString = modelPath.UTF8String; + NSString *pathErrorString = + [NSString stringWithFormat:@"Cannot load model from path (%@).", modelPath]; + if (modelPathCString == nullptr) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel + description:pathErrorString + error:error]; + return nil; + } + + model = TFL_NewModelFromFile(modelPathCString); + if (model == nullptr) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToLoadModel + description:pathErrorString + error:error]; + return nil; + } + + cOptions = TFL_NewInterpreterOptions(); + if (cOptions == nullptr) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter + description:@"Failed to create the interpreter." + error:error]; + return nil; + } + + if (options.numberOfThreads > 0) { + TFL_InterpreterOptionsSetNumThreads(cOptions, (int32_t)options.numberOfThreads); + } + TFL_InterpreterOptionsSetErrorReporter(cOptions, TFLInterpreterErrorReporter, nullptr); + + _interpreter = TFL_NewInterpreter(model, cOptions); + if (_interpreter == nullptr) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter + description:@"Failed to create the interpreter." + error:error]; + return nil; + } + + _inputTensorCount = (NSUInteger)TFL_InterpreterGetInputTensorCount(_interpreter); + _outputTensorCount = (NSUInteger)TFL_InterpreterGetOutputTensorCount(_interpreter); + if (_inputTensorCount <= 0 || _outputTensorCount <= 0) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter + description:@"Failed to create the interpreter." + error:error]; + return nil; + } + } @finally { + TFL_DeleteInterpreterOptions(cOptions); + TFL_DeleteModel(model); + } + } + + return self; +} + +- (BOOL)invokeWithError:(NSError **)error { + if (TFL_InterpreterInvoke(self.interpreter) != kTfLiteOk) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToInvoke + description:@"Failed to invoke the interpreter." + error:error]; + return NO; + } + + return YES; +} + +- (nullable TFLTensor *)inputTensorAtIndex:(NSUInteger)index error:(NSError **)error { + if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) { + return nil; + } + + return [self tensorOfType:TFLTensorTypeInput atIndex:index error:error]; +} + +- (nullable TFLTensor *)outputTensorAtIndex:(NSUInteger)index error:(NSError **)error { + if (![self isValidTensorIndex:index belowLimit:self.outputTensorCount error:error]) { + return nil; + } + + return [self tensorOfType:TFLTensorTypeOutput atIndex:index error:error]; +} + +- (BOOL)resizeInputTensorAtIndex:(NSUInteger)index + toShape:(NSArray *)shape + error:(NSError **)error { + if (![self isValidTensorIndex:index belowLimit:self.inputTensorCount error:error]) { + return NO; + } + + if (shape.count == 0) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape + description:@"Invalid shape. Must not be empty." + error:error]; + return NO; + } + + int cDimensions[self.inputTensorCount]; + for (int dimIndex = 0; dimIndex < shape.count; ++dimIndex) { + int dimension = shape[dimIndex].intValue; + if (dimension <= 0) { + NSString *errorDescription = @"Invalid shape. Dimensions must be positive integers."; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidShape + description:errorDescription + error:error]; + return NO; + } + cDimensions[dimIndex] = dimension; + } + + if (TFL_InterpreterResizeInputTensor(self.interpreter, (int32_t)index, cDimensions, + (int32_t)shape.count) != kTfLiteOk) { + NSString *errorDescription = [NSString + stringWithFormat:@"Failed to resize input tensor at index (%lu).", (unsigned long)index]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToResizeInputTensor + description:errorDescription + error:error]; + return NO; + } + + return YES; +} + +- (BOOL)allocateTensorsWithError:(NSError **)error { + if (TFL_InterpreterAllocateTensors(self.interpreter) != kTfLiteOk) { + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToAllocateTensors + description:@"Failed to allocate memory for tensors." + error:error]; + return NO; + } + return YES; +} + +#pragma mark - TFLInterpreter (Internal) + +- (BOOL)copyData:(NSData *)data toInputTensorAtIndex:(NSUInteger)index error:(NSError **)error { + const TFL_Tensor *cTensor = [self cTensorOfType:TFLTensorTypeInput atIndex:index error:error]; + if (cTensor == nullptr) { + return NO; + } + + NSUInteger byteSize = (NSUInteger)TFL_TensorByteSize(cTensor); + if (data.length != byteSize) { + NSString *errorDescription = [NSString + stringWithFormat:@"Input tensor at index (%lu) expects data size (%lu), but got (%lu).", + (unsigned long)index, byteSize, (unsigned long)data.length]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidInputByteSize + description:errorDescription + error:error]; + return NO; + } + + if (TFL_TensorCopyFromBuffer((TFL_Tensor *)cTensor, data.bytes, data.length) != kTfLiteOk) { + NSString *errorDescription = + [NSString stringWithFormat:@"Failed to copy data into input tensor at index (%lu).", + (unsigned long)index]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCopyDataToInputTensor + description:errorDescription + error:error]; + return NO; + } + + return YES; +} + +- (nullable NSData *)dataFromTensor:(TFLTensor *)tensor error:(NSError **)error { + const TFL_Tensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error]; + if (cTensor == nullptr) { + return nil; + } + + void *bytes = TFL_TensorData(cTensor); + NSUInteger byteSize = (NSUInteger)TFL_TensorByteSize(cTensor); + if (bytes == nullptr || byteSize == 0) { + NSString *tensorType = [TFLTensor stringForTensorType:tensor.type]; + NSString *errorDescription = + [NSString stringWithFormat:@"Failed to get data from %@ tensor at index (%lu).", tensorType, + (unsigned long)index]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetDataFromTensor + description:errorDescription + error:error]; + return nil; + } + + return [NSData dataWithBytes:bytes length:byteSize]; +} + +- (nullable NSArray *)shapeOfTensor:(TFLTensor *)tensor error:(NSError **)error { + const TFL_Tensor *cTensor = [self cTensorOfType:tensor.type atIndex:tensor.index error:error]; + if (cTensor == nullptr) { + return nil; + } + + NSString *tensorType = [TFLTensor stringForTensorType:tensor.type]; + int32_t rank = TFL_TensorNumDims(cTensor); + if (rank <= 0) { + NSString *errorDescription = + [NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid rank (%d).", tensorType, + (unsigned long)index, rank]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor + description:errorDescription + error:error]; + return nil; + } + + NSMutableArray *shape = [NSMutableArray arrayWithCapacity:rank]; + for (int32_t dimIndex = 0; dimIndex < rank; dimIndex++) { + int32_t dimension = TFL_TensorDim(cTensor, dimIndex); + if (dimension <= 0) { + NSString *errorDescription = + [NSString stringWithFormat:@"%@ tensor at index (%lu) has invalid %d-th dimension (%d).", + tensorType, (unsigned long)index, dimIndex, dimension]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor + description:errorDescription + error:error]; + return nil; + } + shape[dimIndex] = @((NSUInteger)dimension); + } + + return shape; +} + +#pragma mark - Private + +- (const TFL_Tensor *)cTensorOfType:(TFLTensorType)type + atIndex:(NSUInteger)index + error:(NSError **)error { + const TFL_Tensor *tensor = nullptr; + + switch (type) { + case TFLTensorTypeInput: + tensor = TFL_InterpreterGetInputTensor(self.interpreter, (int32_t)index); + break; + case TFLTensorTypeOutput: + tensor = TFL_InterpreterGetOutputTensor(self.interpreter, (int32_t)index); + break; + } + + if (tensor == nullptr) { + NSString *tensorType = [TFLTensor stringForTensorType:type]; + NSString *errorDescription = + [NSString stringWithFormat:@"Failed to get %@ tensor at index (%lu).", tensorType, + (unsigned long)index]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToGetTensor + description:errorDescription + error:error]; + } + + return tensor; +} + +- (nullable TFLTensor *)tensorOfType:(TFLTensorType)type + atIndex:(NSUInteger)index + error:(NSError **)error { + const TFL_Tensor *tensor = [self cTensorOfType:type atIndex:index error:error]; + + if (tensor == nullptr) { + return nil; + } + + NSString *tensorType = [TFLTensor stringForTensorType:type]; + const char *cName = TFL_TensorName(tensor); + if (cName == nullptr) { + NSString *errorDescription = + [NSString stringWithFormat:@"Failed to get name of %@ tensor at index (%lu).", tensorType, + (unsigned long)index]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensor + description:errorDescription + error:error]; + return nil; + } + NSString *name = [NSString stringWithUTF8String:cName]; + + TFLTensorDataType dataType = [self tensorDataTypeFromCTensorType:TFL_TensorType(tensor)]; + + TFL_QuantizationParams cParams = TFL_TensorQuantizationParams(tensor); + TFLQuantizationParameters *quantizationParams; + + // TODO(b/119735362): Update this check once the TFL_QuantizationParams struct has a mode. + if (cParams.scale != 0.0) { + quantizationParams = [[TFLQuantizationParameters alloc] initWithScale:cParams.scale + zeroPoint:cParams.zero_point]; + } + + // TODO: Set quantization parameters when C API supports it. + return [[TFLTensor alloc] initWithInterpreter:self + type:type + index:index + name:name + dataType:dataType + quantizationParameters:quantizationParams]; +} + +- (TFLTensorDataType)tensorDataTypeFromCTensorType:(TFL_Type)cTensorType { + switch (cTensorType) { + case kTfLiteFloat32: + return TFLTensorDataTypeFloat32; + case kTfLiteInt32: + return TFLTensorDataTypeInt32; + case kTfLiteUInt8: + return TFLTensorDataTypeUInt8; + case kTfLiteInt8: + return TFLTensorDataTypeInt8; + case kTfLiteInt64: + return TFLTensorDataTypeInt64; + case kTfLiteBool: + return TFLTensorDataTypeBool; + case kTfLiteInt16: + return TFLTensorDataTypeInt16; + case kTfLiteNoType: + case kTfLiteString: + case kTfLiteComplex64: + // kTfLiteString and kTfLiteComplex64 are not supported in TensorFlow Lite Objc API. + return TFLTensorDataTypeNoType; + } +} + +- (BOOL)isValidTensorIndex:(NSUInteger)index + belowLimit:(NSUInteger)totalTensorCount + error:(NSError **)error { + if (index >= totalTensorCount) { + NSString *errorDescription = + [NSString stringWithFormat:@"Invalid tensor index (%lu) exceeds max (%lu).", + (unsigned long)index, (unsigned long)(totalTensorCount - 1)]; + [TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeInvalidTensorIndex + description:errorDescription + error:error]; + return NO; + } + + return YES; +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreterOptions.m b/tensorflow/lite/experimental/objc/sources/TFLInterpreterOptions.m new file mode 100644 index 0000000000000000000000000000000000000000..d129befecabc5af752ccff70e84a4a66c7ee4bca --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreterOptions.m @@ -0,0 +1,30 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation TFLInterpreterOptions + +#pragma mark - Public + +- (instancetype)init { + self = [super init]; + return self; +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters+Internal.h b/tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters+Internal.h new file mode 100644 index 0000000000000000000000000000000000000000..37d9ef0bb4761c9ff93111ba3158d4c4d68a9ec2 --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters+Internal.h @@ -0,0 +1,33 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLQuantizationParameters (Internal) + +/** + * Initializes a `TFLQuantizationParameters` instance with the given scale and zero point. + * + * @param scale Scale of asymmetric quantization. + * @param zeroPoint Zero point of asymmetric quantization. + * + * @return A new instance of `TFLQuantizationParameters` with the given scale and zero point. + */ +- (instancetype)initWithScale:(float)scale zeroPoint:(int32_t)zeroPoint; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters.m b/tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters.m new file mode 100644 index 0000000000000000000000000000000000000000..44cb90d3323a73c1f79a27f319ac263c84e94408 --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters.m @@ -0,0 +1,36 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h" + +#import "TFLQuantizationParameters+Internal.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation TFLQuantizationParameters + +#pragma mark - TFLTensor (Internal) + +- (instancetype)initWithScale:(float)scale zeroPoint:(int32_t)zeroPoint { + self = [super init]; + if (self != nil) { + _scale = scale; + _zeroPoint = zeroPoint; + } + return self; +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLTensor+Internal.h b/tensorflow/lite/experimental/objc/sources/TFLTensor+Internal.h new file mode 100644 index 0000000000000000000000000000000000000000..3d5c51caabd8e44ab2b30a7b44259f6878865586 --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLTensor+Internal.h @@ -0,0 +1,74 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h" + +@class TFLInterpreter; + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum TFLTensorType + * This enum specifies input or output tensor types. + */ +typedef NS_ENUM(NSUInteger, TFLTensorType) { + /** Input tensor type. */ + TFLTensorTypeInput, + + /** Output tensor type. */ + TFLTensorTypeOutput, +}; + +@interface TFLTensor (Internal) + +/** Input or output tensor type. */ +@property(nonatomic, readonly) TFLTensorType type; + +/** Index of the tensor. */ +@property(nonatomic, readonly) NSUInteger index; + +/** + * Initializes a `TFLTensor` with the given interpreter, name, data type, and quantization + * parameters. + * + * @param interpreter Interpreter backing the tensor. + * @param type Input or output tensor type. + * @param index Index of the tensor. + * @param name Name of the tensor. + * @param dataType Data type of the tensor. + * @param quantizationParameters Quantization parameters of the tensor. `nil` if the tensor does not + * use quantization. + * + * @return A new instance of `TFLTensor` with the given name, data type, shape, and quantization + * parameters. + */ +- (instancetype)initWithInterpreter:(TFLInterpreter *)interpreter + type:(TFLTensorType)type + index:(NSUInteger)index + name:(NSString *)name + dataType:(TFLTensorDataType)dataType + quantizationParameters:(nullable TFLQuantizationParameters *)quantizationParameters; + +/** + * Returns the string name of the given input or output tensor type. + * + * @param type Input or output tensor type. + * + * @return The string name of the given input or output tensor type. + */ ++ (NSString *)stringForTensorType:(TFLTensorType)type; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/sources/TFLTensor.m b/tensorflow/lite/experimental/objc/sources/TFLTensor.m new file mode 100644 index 0000000000000000000000000000000000000000..2eaebfd6bec0483817bd4c1c3e540113cca75f5e --- /dev/null +++ b/tensorflow/lite/experimental/objc/sources/TFLTensor.m @@ -0,0 +1,103 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h" + +#import "TFLErrorUtil.h" +#import "TFLInterpreter+Internal.h" +#import "TFLTensor+Internal.h" + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h" + +NS_ASSUME_NONNULL_BEGIN + +// String names of input or output tensor types. +static NSString *const kTFLInputTensorTypeString = @"input"; +static NSString *const kTFLOutputTensorTypeString = @"output"; + +@interface TFLTensor () + +// Redefines readonly properties. +@property(nonatomic) TFLTensorType type; +@property(nonatomic) NSUInteger index; +@property(nonatomic, copy) NSString *name; +@property(nonatomic) TFLTensorDataType dataType; +@property(nonatomic, nullable) TFLQuantizationParameters *quantizationParameters; + +/** + * The backing interpreter. It's a strong reference to ensure that the interpreter is never released + * before this tensor is released. + * + * @warning Never let the interpreter hold a strong reference to the tensor to avoid retain cycles. + */ +@property(nonatomic) TFLInterpreter *interpreter; + +@end + +@implementation TFLTensor + +#pragma mark - Public + +- (BOOL)copyData:(NSData *)data error:(NSError **)error { + if (self.type == TFLTensorTypeOutput) { + [TFLErrorUtil + saveInterpreterErrorWithCode:TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed + description:@"Cannot copy data into an output tensor." + error:error]; + return NO; + } + + return [self.interpreter copyData:data toInputTensorAtIndex:self.index error:error]; +} + +- (nullable NSData *)dataWithError:(NSError **)error { + return [self.interpreter dataFromTensor:self error:error]; +} + +- (nullable NSArray *)shapeWithError:(NSError **)error { + return [self.interpreter shapeOfTensor:self error:error]; +} + +#pragma mark - TFLTensor (Internal) + +- (instancetype)initWithInterpreter:(TFLInterpreter *)interpreter + type:(TFLTensorType)type + index:(NSUInteger)index + name:(NSString *)name + dataType:(TFLTensorDataType)dataType + quantizationParameters:(nullable TFLQuantizationParameters *)quantizationParameters { + self = [super init]; + if (self != nil) { + _interpreter = interpreter; + _type = type; + _index = index; + _name = [name copy]; + _dataType = dataType; + _quantizationParameters = quantizationParameters; + } + return self; +} + ++ (NSString *)stringForTensorType:(TFLTensorType)type { + switch (type) { + case TFLTensorTypeInput: + return kTFLInputTensorTypeString; + case TFLTensorTypeOutput: + return kTFLOutputTensorTypeString; + } +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/tests/TFLInterpreterOptionsTests.m b/tensorflow/lite/experimental/objc/tests/TFLInterpreterOptionsTests.m new file mode 100644 index 0000000000000000000000000000000000000000..00b800d6af96636054f2a79f3d4c8d007dd89ea3 --- /dev/null +++ b/tensorflow/lite/experimental/objc/tests/TFLInterpreterOptionsTests.m @@ -0,0 +1,49 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h" + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Unit tests for TFLInterpreterOptions. + */ +@interface TFLInterpreterOptionsTests : XCTestCase +@end + +@implementation TFLInterpreterOptionsTests + +#pragma mark - Tests + +- (void)testInit { + TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init]; + XCTAssertNotNil(options); + XCTAssertEqual(options.numberOfThreads, 0); +} + +- (void)testSetNumberOfThread { + TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init]; + options.numberOfThreads = 2; + XCTAssertEqual(options.numberOfThreads, 2); + options.numberOfThreads = 0; + XCTAssertEqual(options.numberOfThreads, 0); + options.numberOfThreads = 3; + XCTAssertEqual(options.numberOfThreads, 3); +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/tests/TFLInterpreterTests.m b/tensorflow/lite/experimental/objc/tests/TFLInterpreterTests.m new file mode 100644 index 0000000000000000000000000000000000000000..eefa9b9f05826a0782c0b236a2d7e145428b1ca1 --- /dev/null +++ b/tensorflow/lite/experimental/objc/tests/TFLInterpreterTests.m @@ -0,0 +1,358 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreter.h" + +#import + +#import "tensorflow/lite/experimental/objc/apis/TFLInterpreterOptions.h" +#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h" +#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Float model resource name. */ +static NSString *const kAddFloatModelResourceName = @"add"; + +/** Quantized model resource name. */ +static NSString *const kAddQuantizedModelResourceName = @"add_quantized"; + +/** Model resource type. */ +static NSString *const kAddModelResourceType = @"bin"; + +/** Rank of the input and output tensor in the Add model. */ +static const NSUInteger kAddModelTensorRank = 1U; + +/** Size of the first (and only) dimension of the input and output tensor in the Add model. */ +static const NSUInteger kAddModelTensorFirstDimensionSize = 2U; + +/** Quantization scale of the quantized model. */ +static const float kAddQuantizedModelScale = 0.003922F; + +/** Quantization zero point of the quantized model. */ +static const int32_t kAddQuantizedModelZeroPoint = 0; + +/** Invalid input tensor index. */ +static const NSUInteger kInvalidInputTensorIndex = 1U; + +/** Invalid output tensor index. */ +static const NSUInteger kInvalidOutputTensorIndex = 1U; + +/** Accurary used in comparing floating numbers. */ +static const float kTestAccuracy = 1E-5F; + +/** + * Unit tests for TFLInterpreter. + */ +@interface TFLInterpreterTests : XCTestCase + +/** Absolute path of the Add float model resource. */ +@property(nonatomic, nullable) NSString *floatModelPath; + +/** Default interpreter using the Add model. */ +@property(nonatomic, nullable) TFLInterpreter *interpreter; + +@end + +@implementation TFLInterpreterTests + +#pragma mark - XCTestCase + +- (void)setUp { + [super setUp]; + + NSBundle *bundle = [NSBundle bundleForClass:[self class]]; + self.floatModelPath = [bundle pathForResource:kAddFloatModelResourceName + ofType:kAddModelResourceType]; + NSError *error; + self.interpreter = [[TFLInterpreter alloc] initWithModelPath:self.floatModelPath error:&error]; + XCTAssertNil(error); + XCTAssertNotNil(self.interpreter); + XCTAssertTrue([self.interpreter allocateTensorsWithError:nil]); +} + +- (void)tearDown { + self.floatModelPath = nil; + self.interpreter = nil; + + [super tearDown]; +} + +#pragma mark - Tests + +- (void)testSuccessfulFullRunAddFloatModel { + // Shape for both input and output tensor. + NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank]; + shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize]; + + // Creates the interpreter options. + TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init]; + XCTAssertNotNil(options); + options.numberOfThreads = 2; + + // Creates the interpreter. + NSError *error; + TFLInterpreter *customInterpreter = [[TFLInterpreter alloc] initWithModelPath:self.floatModelPath + options:options + error:&error]; + XCTAssertNil(error); + XCTAssertNotNil(customInterpreter); + + // Allocates memory for tensors. + XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]); + XCTAssertNil(error); + + // Verifies input and output tensor counts. + XCTAssertEqual(customInterpreter.inputTensorCount, 1); + XCTAssertEqual(customInterpreter.outputTensorCount, 1); + + // Resizes the intput tensor. + XCTAssertTrue([customInterpreter resizeInputTensorAtIndex:0 toShape:shape error:&error]); + XCTAssertNil(error); + + // Re-allocates memory for tensors. + XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]); + XCTAssertNil(error); + + // Verifies the input tensor. + TFLTensor *inputTensor = [customInterpreter inputTensorAtIndex:0 error:&error]; + XCTAssertNotNil(inputTensor); + XCTAssertNil(error); + XCTAssertTrue([inputTensor.name isEqualToString:@"input"]); + XCTAssertEqual(inputTensor.dataType, TFLTensorDataTypeFloat32); + NSArray *inputTensorShape = [inputTensor shapeWithError:&error]; + XCTAssertNil(error); + XCTAssertTrue([shape isEqualToArray:inputTensorShape]); + + // Copies the input data. + NSMutableData *inputData = [NSMutableData dataWithCapacity:0]; + float one = 1.f; + float three = 3.f; + [inputData appendBytes:&one length:sizeof(float)]; + [inputData appendBytes:&three length:sizeof(float)]; + XCTAssertTrue([inputTensor copyData:inputData error:&error]); + XCTAssertNil(error); + + // Invokes the interpreter. + XCTAssertTrue([customInterpreter invokeWithError:&error]); + XCTAssertNil(error); + + // Verifies the output tensor. + TFLTensor *outputTensor = [customInterpreter outputTensorAtIndex:0 error:&error]; + XCTAssertNotNil(outputTensor); + XCTAssertNil(error); + XCTAssertTrue([outputTensor.name isEqualToString:@"output"]); + XCTAssertEqual(outputTensor.dataType, TFLTensorDataTypeFloat32); + NSArray *outputTensorShape = [outputTensor shapeWithError:&error]; + XCTAssertNil(error); + XCTAssertTrue([shape isEqualToArray:outputTensorShape]); + + // Tries to query an invalid output tensor index. + TFLTensor *invalidOutputTensor = [customInterpreter outputTensorAtIndex:kInvalidOutputTensorIndex + error:&error]; + XCTAssertNil(invalidOutputTensor); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex); + + // Gets the output tensor data. + error = nil; + NSData *outputData = [outputTensor dataWithError:&error]; + XCTAssertNotNil(outputData); + XCTAssertNil(error); + float output[kAddModelTensorFirstDimensionSize]; + [outputData getBytes:output length:(sizeof(float) * kAddModelTensorFirstDimensionSize)]; + XCTAssertEqualWithAccuracy(output[0], 3.f, kTestAccuracy); + XCTAssertEqualWithAccuracy(output[1], 9.f, kTestAccuracy); +} + +- (void)testSuccessfulFullRunQuantizedModel { + // Shape for both input and output tensor. + NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank]; + shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize]; + + // Creates the interpreter options. + TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init]; + XCTAssertNotNil(options); + options.numberOfThreads = 2; + + NSBundle *bundle = [NSBundle bundleForClass:[self class]]; + NSString *quantizedModelPath = [bundle pathForResource:kAddQuantizedModelResourceName + ofType:kAddModelResourceType]; + + // Creates the interpreter. + NSError *error; + TFLInterpreter *customInterpreter = + [[TFLInterpreter alloc] initWithModelPath:quantizedModelPath options:options error:&error]; + XCTAssertNil(error); + XCTAssertNotNil(customInterpreter); + + // Allocates memory for tensors. + XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]); + XCTAssertNil(error); + + // Verifies input and output tensor counts. + XCTAssertEqual(customInterpreter.inputTensorCount, 1); + XCTAssertEqual(customInterpreter.outputTensorCount, 1); + + // Resizes the intput tensor. + XCTAssertTrue([customInterpreter resizeInputTensorAtIndex:0 toShape:shape error:&error]); + XCTAssertNil(error); + + // Re-allocates memory for tensors. + XCTAssertTrue([customInterpreter allocateTensorsWithError:&error]); + XCTAssertNil(error); + + // Verifies the input tensor. + TFLTensor *inputTensor = [customInterpreter inputTensorAtIndex:0 error:&error]; + XCTAssertNotNil(inputTensor); + XCTAssertNil(error); + XCTAssertTrue([inputTensor.name isEqualToString:@"input"]); + XCTAssertEqual(inputTensor.dataType, TFLTensorDataTypeUInt8); + XCTAssertEqualWithAccuracy(inputTensor.quantizationParameters.scale, kAddQuantizedModelScale, + kTestAccuracy); + XCTAssertEqual(inputTensor.quantizationParameters.zeroPoint, kAddQuantizedModelZeroPoint); + NSArray *inputTensorShape = [inputTensor shapeWithError:&error]; + XCTAssertNil(error); + XCTAssertTrue([shape isEqualToArray:inputTensorShape]); + + // Copies the input data. + NSMutableData *inputData = [NSMutableData dataWithCapacity:0]; + uint8_t one = 1; + uint8_t three = 3; + [inputData appendBytes:&one length:sizeof(uint8_t)]; + [inputData appendBytes:&three length:sizeof(uint8_t)]; + XCTAssertTrue([inputTensor copyData:inputData error:&error]); + XCTAssertNil(error); + + // Invokes the interpreter. + XCTAssertTrue([customInterpreter invokeWithError:&error]); + XCTAssertNil(error); + + // Verifies the output tensor. + TFLTensor *outputTensor = [customInterpreter outputTensorAtIndex:0 error:&error]; + XCTAssertNotNil(outputTensor); + XCTAssertNil(error); + XCTAssertTrue([outputTensor.name isEqualToString:@"output"]); + XCTAssertEqual(outputTensor.dataType, TFLTensorDataTypeUInt8); + XCTAssertEqualWithAccuracy(outputTensor.quantizationParameters.scale, kAddQuantizedModelScale, + kTestAccuracy); + XCTAssertEqual(outputTensor.quantizationParameters.zeroPoint, kAddQuantizedModelZeroPoint); + NSArray *outputTensorShape = [outputTensor shapeWithError:&error]; + XCTAssertNil(error); + XCTAssertTrue([shape isEqualToArray:outputTensorShape]); + + // Tries to query an invalid output tensor index. + TFLTensor *invalidOutputTensor = [customInterpreter outputTensorAtIndex:kInvalidOutputTensorIndex + error:&error]; + XCTAssertNil(invalidOutputTensor); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex); + + // Gets the output tensor data. + error = nil; + NSData *outputData = [outputTensor dataWithError:&error]; + XCTAssertNotNil(outputData); + XCTAssertNil(error); + uint8_t output[kAddModelTensorFirstDimensionSize]; + [outputData getBytes:output length:(sizeof(uint8_t) * kAddModelTensorFirstDimensionSize)]; + XCTAssertEqual(output[0], 3); + XCTAssertEqual(output[1], 9); +} + +- (void)testInitWithModelPath_invalidPath { + // Shape for both input and output tensor. + NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank]; + shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize]; + + // Creates the interpreter. + NSError *error; + TFLInterpreter *brokenInterpreter = [[TFLInterpreter alloc] initWithModelPath:@"InvalidPath" + error:&error]; + XCTAssertNil(brokenInterpreter); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeFailedToLoadModel); +} + +- (void)testInvoke_beforeAllocation { + NSError *error; + TFLInterpreter *interpreterWithoutAllocation = + [[TFLInterpreter alloc] initWithModelPath:self.floatModelPath error:&error]; + XCTAssertNotNil(interpreterWithoutAllocation); + XCTAssertNil(error); + + XCTAssertFalse([interpreterWithoutAllocation invokeWithError:&error]); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeFailedToInvoke); +} + +- (void)testInputTensorAtIndex_invalidIndex { + NSError *error; + TFLTensor *inputTensor = [self.interpreter inputTensorAtIndex:kInvalidInputTensorIndex + error:&error]; + XCTAssertNil(inputTensor); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex); +} + +- (void)testResizeInputTensorAtIndex_invalidIndex { + NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank]; + shape[0] = [NSNumber numberWithUnsignedInteger:kAddModelTensorFirstDimensionSize]; + NSError *error; + XCTAssertFalse([self.interpreter resizeInputTensorAtIndex:kInvalidInputTensorIndex + toShape:shape + error:&error]); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidTensorIndex); +} + +- (void)testResizeInputTensorAtIndex_emptyShape { + NSMutableArray *emptyShape = [NSMutableArray arrayWithCapacity:0]; + NSError *error; + XCTAssertFalse([self.interpreter resizeInputTensorAtIndex:0 toShape:emptyShape error:&error]); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidShape); +} + +- (void)testResizeInputTensorAtIndex_zeroDimensionSize { + NSMutableArray *shape = [NSMutableArray arrayWithCapacity:kAddModelTensorRank]; + shape[0] = [NSNumber numberWithUnsignedInteger:0]; + NSError *error; + XCTAssertFalse([self.interpreter resizeInputTensorAtIndex:0 toShape:shape error:&error]); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidShape); +} + +- (void)testCopyDataToInputTensorAtIndex_invalidInputDataByteSize { + NSMutableData *inputData = [NSMutableData dataWithCapacity:0]; + float one = 1.f; + float three = 3.f; + [inputData appendBytes:&one length:sizeof(float)]; + [inputData appendBytes:&three length:(sizeof(float) - 1)]; + NSError *error; + TFLTensor *inputTensor = [self.interpreter inputTensorAtIndex:0 error:&error]; + XCTAssertNotNil(inputTensor); + XCTAssertNil(error); + XCTAssertFalse([inputTensor copyData:inputData error:&error]); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeInvalidInputByteSize); +} + +- (void)testCopyDataToOutputTensorAtIndex_notAllowed { + NSMutableData *data = [NSMutableData dataWithCapacity:0]; + float one = 1.f; + float three = 3.f; + [data appendBytes:&one length:sizeof(float)]; + [data appendBytes:&three length:(sizeof(float) - 1)]; + NSError *error; + TFLTensor *outputTensor = [self.interpreter outputTensorAtIndex:0 error:&error]; + XCTAssertNotNil(outputTensor); + XCTAssertNil(error); + XCTAssertFalse([outputTensor copyData:data error:&error]); + XCTAssertEqual(error.code, TFLInterpreterErrorCodeCopyDataToOutputTensorNotAllowed); +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/objc/tests/TFLQuantizationParametersTests.m b/tensorflow/lite/experimental/objc/tests/TFLQuantizationParametersTests.m new file mode 100644 index 0000000000000000000000000000000000000000..239e0bcb0dee8b6d2258be6f7e1ae2591611f501 --- /dev/null +++ b/tensorflow/lite/experimental/objc/tests/TFLQuantizationParametersTests.m @@ -0,0 +1,48 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow/lite/experimental/objc/apis/TFLQuantizationParameters.h" + +#import + +#import "tensorflow/lite/experimental/objc/sources/TFLQuantizationParameters+Internal.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Test scale of quantization parameters. */ +static const float kTestScale = 2.0; + +/** Test zero point of quantization parameters. */ +static const int32_t kTestZeroPoint = 128; + +/** + * Unit tests for TFLQuantizationParameters. + */ +@interface TFLQuantizationParametersTests : XCTestCase +@end + +@implementation TFLQuantizationParametersTests + +#pragma mark - Tests + +- (void)testInitWithScaleAndZeroPoint { + TFLQuantizationParameters *params = + [[TFLQuantizationParameters alloc] initWithScale:kTestScale zeroPoint:kTestZeroPoint]; + XCTAssertEqual(params.scale, kTestScale); + XCTAssertEqual(params.zeroPoint, kTestZeroPoint); +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow/lite/experimental/swift/BUILD.apple b/tensorflow/lite/experimental/swift/BUILD.apple new file mode 100644 index 0000000000000000000000000000000000000000..8f44546a4b8d98a023cda9eae33fba2ce87258b9 --- /dev/null +++ b/tensorflow/lite/experimental/swift/BUILD.apple @@ -0,0 +1,105 @@ +# TensorFlow Lite for Swift + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_application", "ios_unit_test") +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") + +MINIMUM_OS_VERSION = "9.0" + +# Default tags for filtering targets. Targets in this file are restricted to Apple platforms. +DEFAULT_TAGS = [ + "apple", +] + +swift_library( + name = "TensorFlowLite", + srcs = glob(["Sources/*.swift"]), + module_name = "TensorFlowLite", + tags = DEFAULT_TAGS, + deps = [ + "//tensorflow/lite/experimental/c:c_api", + ], +) + +ios_unit_test( + name = "TensorFlowLiteTests", + size = "small", + minimum_os_version = MINIMUM_OS_VERSION, + tags = DEFAULT_TAGS + [ + # DISABLED: Following sanitizer tests are not supported by iOS test targets. + "noasan", + "nomsan", + "notsan", + ], + deps = [ + ":TestsLib", + ], +) + +ios_application( + name = "TensorFlowLiteApp", + app_icons = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/**"]), + bundle_id = "com.tensorflow.lite.swift.TensorFlowLite", + families = [ + "ipad", + "iphone", + ], + infoplists = ["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist"], + launch_storyboard = "TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard", + minimum_os_version = MINIMUM_OS_VERSION, + sdk_frameworks = [ + "CoreGraphics", + ], + tags = DEFAULT_TAGS + ["manual"], + deps = [ + ":AppLib", + ], +) + +swift_library( + name = "TestsLib", + testonly = 1, + srcs = glob(["Tests/*.swift"]), + tags = DEFAULT_TAGS, + deps = [ + ":Resources", + ":TensorFlowLite", + ], +) + +swift_library( + name = "AppLib", + srcs = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/*.swift"]), + module_name = "AppLib", + tags = DEFAULT_TAGS + ["manual"], + deps = [ + ":AppResources", + ":TensorFlowLite", + ], +) + +objc_library( + name = "Resources", + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + "//tensorflow/lite:testdata/multi_add.bin", + ], + tags = DEFAULT_TAGS, +) + +objc_library( + name = "AppResources", + data = glob([ + "TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/*.storyboard", + ]), + tags = DEFAULT_TAGS + ["manual"], + deps = [ + ":Resources", + ], +) diff --git a/tensorflow/lite/experimental/swift/LICENSE b/tensorflow/lite/experimental/swift/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/tensorflow/lite/experimental/swift/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/tensorflow/lite/experimental/swift/README.md b/tensorflow/lite/experimental/swift/README.md new file mode 100644 index 0000000000000000000000000000000000000000..716ab33f8d6acaa4a4896f79c1b8e5662698c4c3 --- /dev/null +++ b/tensorflow/lite/experimental/swift/README.md @@ -0,0 +1,63 @@ +# TensorFlow Lite for Swift + +[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight +solution for Swift developers. It enables low-latency inference of on-device +machine learning models with a small binary size and fast performance supporting +hardware acceleration. + +## Getting Started + +### Bazel + +In your `BUILD` file, add the `TensorFlowLite` dependency: + +```python +swift_library( + deps = [ + "//tensorflow/lite/experimental/swift:TensorFlowLite", + ], +) +``` + +In your Swift files, import the module: + +```swift +import TensorFlowLite +``` + +If you would like to build the Swift TensorFlow Lite library using Bazel on Apple +platforms, clone or download the [TensorFlow GitHub repo](https://github.com/tensorflow/tensorflow), +then navigate to the root `tensorflow` directory and execute the `configure.py` script: + +```shell +python configure.py +``` + +Follow the prompts and when asked to configure the Bazel rules for Apple +platforms, enter `y`. + +Build the `TensorFlowLite` Swift library target: + +```shell +bazel build tensorflow/lite/experimental/swift:TensorFlowLite +``` + +Build the `TensorFlowLiteTests` target: + +```shell +bazel test tensorflow/lite/experimental/swift:TensorFlowLiteTests --swiftcopt=-enable-testing +``` + +Note that `--swiftcopt=-enable-testing` is required for optimized builds (`-c opt`). + +### Tulsi + +Open the `TensorFlowLite.tulsiproj` using the +[TulsiApp](https://github.com/bazelbuild/tulsi) +or by running the +[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh) +script from the root `tensorflow` directory: + +```shell +generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj +``` diff --git a/tensorflow/lite/experimental/swift/Sources/Interpreter.swift b/tensorflow/lite/experimental/swift/Sources/Interpreter.swift new file mode 100644 index 0000000000000000000000000000000000000000..a14b5966b1a24946137fddae0ddea16ed43ba46c --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/Interpreter.swift @@ -0,0 +1,265 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import TensorFlowLiteCAPI + +/// A TensorFlow Lite interpreter that performs inference from a given model. +public final class Interpreter { + + /// The `TFL_Interpreter` C pointer type represented as an `UnsafePointer`. + private typealias CInterpreter = OpaquePointer + + /// Total number of input tensors associated with the model. + public var inputTensorCount: Int { + return Int(TFL_InterpreterGetInputTensorCount(cInterpreter)) + } + + /// Total number of output tensors associated with the model. + public var outputTensorCount: Int { + return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter)) + } + + /// The underlying `TFL_Interpreter` C pointer. + private var cInterpreter: CInterpreter? + + /// Creates a new model interpreter instance. + /// + /// - Parameters: + /// - modelPath: Local file path to a TensorFlow Lite model. + /// - options: Custom configurations for the interpreter. The default is `nil` indicating that + /// interpreter will determine the configuration options. + /// - Throws: An error if the model could not be loaded or the interpreter could not be created. + public init(modelPath: String, options: InterpreterOptions? = nil) throws { + guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel } + + let cInterpreterOptions: OpaquePointer? = try options.map { options in + guard let cOptions = TFL_NewInterpreterOptions() else { + throw InterpreterError.failedToCreateInterpreter + } + if let threadCount = options.threadCount, threadCount > 0 { + TFL_InterpreterOptionsSetNumThreads(cOptions, Int32(threadCount)) + } + if options.isErrorLoggingEnabled { + TFL_InterpreterOptionsSetErrorReporter( + cOptions, + { (_, format, arguments) in + guard let cFormat = format, + let message = String(cFormat: cFormat, arguments: arguments) + else { + return + } + print(String(describing: InterpreterError.tensorFlowLiteError(message))) + }, + nil + ) + } + return cOptions + } + defer { TFL_DeleteInterpreterOptions(cInterpreterOptions) } + + guard let cInterpreter = TFL_NewInterpreter(model.cModel, cInterpreterOptions) else { + throw InterpreterError.failedToCreateInterpreter + } + self.cInterpreter = cInterpreter + } + + deinit { + TFL_DeleteInterpreter(cInterpreter) + } + + /// Invokes the interpreter to perform inference from the loaded graph. + /// + /// - Throws: An error if the model was not ready because tensors were not allocated. + public func invoke() throws { + guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else { + // TODO(b/117510052): Determine which error to throw. + throw InterpreterError.allocateTensorsRequired + } + } + + /// Returns the input tensor at the given index. + /// + /// - Parameters: + /// - index: The index for the input tensor. + /// - Throws: An error if the index is invalid or the tensors have not been allocated. + /// - Returns: The input tensor at the given index. + public func input(at index: Int) throws -> Tensor { + let maxIndex = inputTensorCount - 1 + guard case 0...maxIndex = index else { + throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) + } + guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)), + let bytes = TFL_TensorData(cTensor), + let nameCString = TFL_TensorName(cTensor) + else { + throw InterpreterError.allocateTensorsRequired + } + guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else { + throw InterpreterError.invalidTensorDataType + } + + let name = String(cString: nameCString) + let rank = TFL_TensorNumDims(cTensor) + let dimensions = (0.. Tensor { + let maxIndex = outputTensorCount - 1 + guard case 0...maxIndex = index else { + throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) + } + guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)), + let bytes = TFL_TensorData(cTensor), + let nameCString = TFL_TensorName(cTensor) + else { + // TODO(b/117510052): Determine which error to throw. + throw InterpreterError.invokeInterpreterRequired + } + guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else { + throw InterpreterError.invalidTensorDataType + } + + let name = String(cString: nameCString) + let rank = TFL_TensorNumDims(cTensor) + let dimensions = (0.. Tensor { + let maxIndex = inputTensorCount - 1 + guard case 0...maxIndex = index else { + throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) + } + guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)) else { + throw InterpreterError.allocateTensorsRequired + } + + let byteCount = TFL_TensorByteSize(cTensor) + guard data.count == byteCount else { + throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount) + } + + let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) } + guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor } + return try input(at: index) + } + + /// Allocates memory for all input tensors based on their `TensorShape`s. + /// + /// - Note: This is a relatively expensive operation and should only be called after creating the + /// interpreter and/or resizing any input tensors. + /// - Throws: An error if memory could not be allocated for the input tensors. + public func allocateTensors() throws { + guard TFL_InterpreterAllocateTensors(cInterpreter) == kTfLiteOk else { + throw InterpreterError.failedToAllocateTensors + } + } +} + +// MARK: - Extensions + +extension String { + /// Returns a new `String` initialized by using the given format C array as a template into which + /// the remaining argument values are substituted according to the user’s default locale. + /// + /// - Note: Returns `nil` if a new `String` could not be constructed from the given values. + /// - Parameters: + /// - cFormat: The format C array as a template for substituting values. + /// - arguments: A C pointer to a `va_list` of arguments to substitute into `cFormat`. + init?(cFormat: UnsafePointer, arguments: CVaListPointer) { + var buffer: UnsafeMutablePointer? + guard vasprintf(&buffer, cFormat, arguments) != 0, let cString = buffer else { return nil } + self.init(validatingUTF8: cString) + } +} diff --git a/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift b/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift new file mode 100644 index 0000000000000000000000000000000000000000..5de58b997a76b6bf9493525694bc9f9e4e6b6c1c --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift @@ -0,0 +1,99 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// TensorFlow Lite interpreter errors. +public enum InterpreterError: Error { + case invalidTensorIndex(index: Int, maxIndex: Int) + case invalidTensorDataCount(provided: Int, required: Int) + case invalidTensorDataType + case failedToLoadModel + case failedToCreateInterpreter + case failedToResizeInputTensor(index: Int) + case failedToCopyDataToInputTensor + case failedToAllocateTensors + case allocateTensorsRequired + case invokeInterpreterRequired + case tensorFlowLiteError(String) +} + +// MARK: - Extensions + +extension InterpreterError: LocalizedError { + /// Localized description of the interpreter error. + public var errorDescription: String? { + switch self { + case .invalidTensorIndex(let index, let maxIndex): + return "Invalid tensor index \(index), max index is \(maxIndex)." + case .invalidTensorDataCount(let providedCount, let requiredCount): + return "Provided data count \(providedCount) must match the required count \(requiredCount)." + case .invalidTensorDataType: + return "Tensor data type is unsupported or could not be determined because of a model error." + case .failedToLoadModel: + return "Failed to load the given model." + case .failedToCreateInterpreter: + return "Failed to create the interpreter." + case .failedToResizeInputTensor(let index): + return "Failed to resize input tesnor at index \(index)." + case .failedToCopyDataToInputTensor: + return "Failed to copy data to input tensor." + case .failedToAllocateTensors: + return "Failed to allocate memory for input tensors." + case .allocateTensorsRequired: + return "Must call allocateTensors()." + case .invokeInterpreterRequired: + return "Must call invoke()." + case .tensorFlowLiteError(let message): + return "TensorFlow Lite Error: \(message)" + } + } +} + +extension InterpreterError: CustomStringConvertible { + /// Textual representation of the TensorFlow Lite interpreter error. + public var description: String { + return errorDescription ?? "Unknown error." + } +} + +#if swift(>=4.2) +extension InterpreterError: Equatable {} +#else +extension InterpreterError: Equatable { + public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool { + switch (lhs, rhs) { + case (.invalidTensorDataType, .invalidTensorDataType), + (.failedToLoadModel, .failedToLoadModel), + (.failedToCreateInterpreter, .failedToCreateInterpreter), + (.failedToAllocateTensors, .failedToAllocateTensors), + (.allocateTensorsRequired, .allocateTensorsRequired), + (.invokeInterpreterRequired, .invokeInterpreterRequired): + return true + case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex), + .invalidTensorIndex(let rhsIndex, let rhsMaxIndex)): + return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex + case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount), + .invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)): + return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount + case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)): + return lhsIndex == rhsIndex + case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)): + return lhsMessage == rhsMessage + default: + return false + } + } +} +#endif // swift(>=4.2) diff --git a/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift b/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift new file mode 100644 index 0000000000000000000000000000000000000000..2365fd7ade0f9562250b239308f6a13b16c35784 --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/InterpreterOptions.swift @@ -0,0 +1,29 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// Custom configuration options for a TensorFlow Lite interpreter. +public struct InterpreterOptions: Equatable { + + /// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which + /// indicates that the `Interpreter` will decide the number of threads to use. + public var threadCount: Int? = nil + + /// Whether error logging to the console is enabled. The default is `false`. + public var isErrorLoggingEnabled = false + + /// Creates a new instance of interpreter options. + public init() {} +} diff --git a/tensorflow/lite/experimental/swift/Sources/Model.swift b/tensorflow/lite/experimental/swift/Sources/Model.swift new file mode 100644 index 0000000000000000000000000000000000000000..e8c49ff1ae10cc20d1c50b8e8340950cb1491722 --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/Model.swift @@ -0,0 +1,40 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import TensorFlowLiteCAPI + +/// A TensorFlow Lite model used by the 'Interpreter` to perform inference. +final class Model { + + /// The `TFL_Model` C pointer type represented as an `UnsafePointer`. + typealias CModel = OpaquePointer + + /// The underlying `TFL_Model` C pointer. + let cModel: CModel? + + /// Creates a new model instance. + /// + /// - Precondition: Initialization can fail if the given `filePath` is invalid. + /// - Parameters: + /// - filePath: Local file path to a TensorFlow Lite model. + init?(filePath: String) { + guard !filePath.isEmpty, let cModel = TFL_NewModelFromFile(filePath) else { return nil } + self.cModel = cModel + } + + deinit { + TFL_DeleteModel(cModel) + } +} diff --git a/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift b/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift new file mode 100644 index 0000000000000000000000000000000000000000..f36787564478115e19584b933a10fb0458e06c71 --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/QuantizationParameters.swift @@ -0,0 +1,38 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// Parameters that determine the mapping of quantized values to real values. Quantized values can +/// be mapped to float values using the following conversion: +/// `realValue = scale * (quantizedValue - zeroPoint)`. +public struct QuantizationParameters { + + /// Difference between real values corresponding to consecutive quantized values differing by 1. + /// For example, the range of quantized values for `UInt8` data type is [0, 255]. + public let scale: Float + + /// Quantized value that corresponds to the real 0 value. + public let zeroPoint: Int + + /// Creates a new quantization parameters instance. + /// + /// - Parameters: + /// - scale: Scale value for asymmetric quantization. + /// - zeroPoint: Zero point for asymmetric quantization. + init(scale: Float, zeroPoint: Int) { + self.scale = scale + self.zeroPoint = zeroPoint + } +} diff --git a/tensorflow/lite/experimental/swift/Sources/Tensor.swift b/tensorflow/lite/experimental/swift/Sources/Tensor.swift new file mode 100644 index 0000000000000000000000000000000000000000..b738d8754914e20ac4c1cb991c92b029828f66d2 --- /dev/null +++ b/tensorflow/lite/experimental/swift/Sources/Tensor.swift @@ -0,0 +1,138 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import TensorFlowLiteCAPI + +/// An input or output tensor in a TensorFlow Lite graph. +public struct Tensor { + + /// Name of the tensor. + public let name: String + + /// Data type of the tensor. + public let dataType: TensorDataType + + /// Shape of the tensor. + public let shape: TensorShape + + /// Data in the input or output tensor. + public let data: Data + + /// Quantization parameters for the tensor if using a quantized model. + public let quantizationParameters: QuantizationParameters? + + /// Creates a new input or output tensor instance. + /// + /// - Parameters: + /// - name: Name of the tensor. + /// - dataType: Data type of the tensor. + /// - data: Data in the input tensor. + /// - quantizationParameters Quantization parameters for the tensor if using a quantized model. + /// The default is `nil`. + init( + name: String, + dataType: TensorDataType, + shape: TensorShape, + data: Data, + quantizationParameters: QuantizationParameters? = nil + ) { + self.name = name + self.dataType = dataType + self.shape = shape + self.data = data + self.quantizationParameters = quantizationParameters + } +} + +/// Supported TensorFlow Lite tensor data types. +public enum TensorDataType: Equatable { + /// 32-bit single precision floating point tensor data type. + case float32 + /// 8-bit unsigned integer tensor data type. + case uInt8 + /// 16-bit signed integer tensor data type. + case int16 + /// 32-bit signed integer tensor data type. + case int32 + /// 64-bit signed integer tensor data type. + case int64 + /// Boolean tensor data type. + case bool + + /// Creates a new tensor data type from the given `TFL_Type` or `nil` if the data type is + /// unsupported or could not be determined because there was an error. + /// + /// - Parameter type: A data type supported by a tensor. + init?(type: TFL_Type) { + switch type { + case kTfLiteFloat32: + self = .float32 + case kTfLiteUInt8: + self = .uInt8 + case kTfLiteInt16: + self = .int16 + case kTfLiteInt32: + self = .int32 + case kTfLiteInt64: + self = .int64 + case kTfLiteBool: + self = .bool + case kTfLiteNoType: + fallthrough + default: + return nil + } + } +} + +/// The shape of a TensorFlow Lite tensor. +public struct TensorShape { + + /// The number of dimensions of the tensor. + public let rank: Int + + /// Array of dimensions for the tensor. + public let dimensions: [Int] + + /// Array of `Int32` dimensions for the tensor. + var int32Dimensions: [Int32] { return dimensions.map(Int32.init) } + + /// Creates a new tensor shape instance with the given array of dimensions. + /// + /// - Parameters: + /// - dimensions: Dimensions for the tensor. + public init(_ dimensions: [Int]) { + self.rank = dimensions.count + self.dimensions = dimensions + } + + /// Creates a new tensor shape instance with the given elements representing the dimensions. + /// + /// - Parameters: + /// - elements: Dimensions for the tensor. + public init(_ elements: Int...) { + self.init(elements) + } +} + +extension TensorShape: ExpressibleByArrayLiteral { + /// Creates a new tensor shape instance with the given array literal representing the dimensions. + /// + /// - Parameters: + /// - arrayLiteral: Dimensions for the tensor. + public init(arrayLiteral: Int...) { + self.init(arrayLiteral) + } +} diff --git a/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen new file mode 100644 index 0000000000000000000000000000000000000000..16bc6cbfe8f554caad2cba3cae11b364b34ed64d --- /dev/null +++ b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/Configs/TensorFlowLite.tulsigen @@ -0,0 +1,57 @@ +{ + "sourceFilters" : [ + "tensorflow/lite/experimental/c", + "tensorflow/lite/experimental/swift", + "tensorflow/lite/experimental/swift/Sources", + "tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp", + "tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj", + "tensorflow/lite/experimental/swift/Tests", + ], + "buildTargets" : [ + "//tensorflow/lite/experimental/swift:TensorFlowLite", + "//tensorflow/lite/experimental/swift:TensorFlowLiteApp", + "//tensorflow/lite/experimental/swift:TensorFlowLiteTests", + ], + "projectName" : "TensorFlowLite", + "optionSet" : { + "LaunchActionPreActionScript" : { + "p" : "$(inherited)" + }, + "BazelBuildStartupOptionsRelease" : { + "p" : "$(inherited)" + }, + "BazelBuildOptionsRelease" : { + "p" : "$(inherited)" + }, + "BazelBuildOptionsDebug" : { + "p" : "$(inherited)" + }, + "EnvironmentVariables" : { + "p" : "$(inherited)" + }, + "BuildActionPreActionScript" : { + "p" : "$(inherited)" + }, + "CommandlineArguments" : { + "p" : "$(inherited)" + }, + "TestActionPreActionScript" : { + "p" : "$(inherited)" + }, + "BazelBuildStartupOptionsDebug" : { + "p" : "$(inherited)" + }, + "BuildActionPostActionScript" : { + "p" : "$(inherited)" + }, + "TestActionPostActionScript" : { + "p" : "$(inherited)" + }, + "LaunchActionPostActionScript" : { + "p" : "$(inherited)" + } + }, + "additionalFilePaths" : [ + "tensorflow/lite/experimental/swift/BUILD" + ] +} diff --git a/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/project.tulsiconf b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/project.tulsiconf new file mode 100644 index 0000000000000000000000000000000000000000..82ac8aa38126021c176773e4093352bcbecd8603 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj/project.tulsiconf @@ -0,0 +1,14 @@ +{ + "configDefaults" : { + "optionSet" : { + "ProjectPrioritizesSwift" : { + "p" : "YES" + } + } + }, + "projectName" : "TensorFlowLite", + "packages" : [ + "tensorflow/lite/experimental/swift" + ], + "workspaceRoot" : "../../../../.." +} diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp.xcodeproj/project.pbxproj b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..fbbf9a1de2c8e82ab486b99b9e9b8c6dfe80868e --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp.xcodeproj/project.pbxproj @@ -0,0 +1,345 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */; }; + 4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B722146ED64006C3AEF /* AppDelegate.swift */; }; + 4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA72B742146ED64006C3AEF /* ViewController.swift */; }; + 4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B762146ED64006C3AEF /* Main.storyboard */; }; + 4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B792146ED66006C3AEF /* Assets.xcassets */; }; + 4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */; }; + 4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Data+TensorFlowLite.swift"; sourceTree = ""; }; + 4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TensorFlowLiteApp.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 4AA72B722146ED64006C3AEF /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = ""; }; + 4AA72B742146ED64006C3AEF /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = ""; }; + 4AA72B772146ED64006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 4AA72B792146ED66006C3AEF /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 4AA72B7C2146ED66006C3AEF /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = ""; }; + 4AA72B7E2146ED66006C3AEF /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Array+TensorFlowLite.swift"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 4AA72B6C2146ED64006C3AEF /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 4AA72B662146ED64006C3AEF = { + isa = PBXGroup; + children = ( + 4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */, + 4AA72B702146ED64006C3AEF /* Products */, + ); + sourceTree = ""; + }; + 4AA72B702146ED64006C3AEF /* Products */ = { + isa = PBXGroup; + children = ( + 4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */, + ); + name = Products; + sourceTree = ""; + }; + 4AA72B712146ED64006C3AEF /* TensorFlowLiteApp */ = { + isa = PBXGroup; + children = ( + 4AA72B722146ED64006C3AEF /* AppDelegate.swift */, + 4ADDE0CD2176600900FF07A2 /* Array+TensorFlowLite.swift */, + 4A7304B321500B8300C90B21 /* Data+TensorFlowLite.swift */, + 4AA72B742146ED64006C3AEF /* ViewController.swift */, + 4AA72B762146ED64006C3AEF /* Main.storyboard */, + 4AA72B792146ED66006C3AEF /* Assets.xcassets */, + 4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */, + 4AA72B7E2146ED66006C3AEF /* Info.plist */, + ); + path = TensorFlowLiteApp; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */ = { + isa = PBXNativeTarget; + buildConfigurationList = 4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */; + buildPhases = ( + 4AA72B6B2146ED64006C3AEF /* Sources */, + 4AA72B6C2146ED64006C3AEF /* Frameworks */, + 4AA72B6D2146ED64006C3AEF /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = TensorFlowLiteApp; + productName = TensorFlowLiteApp; + productReference = 4AA72B6F2146ED64006C3AEF /* TensorFlowLiteApp.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 4AA72B672146ED64006C3AEF /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 0940; + LastUpgradeCheck = 0940; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 4AA72B6E2146ED64006C3AEF = { + CreatedOnToolsVersion = 9.4.1; + }; + }; + }; + buildConfigurationList = 4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 4AA72B662146ED64006C3AEF; + productRefGroup = 4AA72B702146ED64006C3AEF /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 4AA72B6E2146ED64006C3AEF /* TensorFlowLiteApp */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 4AA72B6D2146ED64006C3AEF /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 4AA72B7D2146ED66006C3AEF /* LaunchScreen.storyboard in Resources */, + 4AA72B7A2146ED66006C3AEF /* Assets.xcassets in Resources */, + 4AA72B782146ED64006C3AEF /* Main.storyboard in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 4AA72B6B2146ED64006C3AEF /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 4AA72B732146ED64006C3AEF /* AppDelegate.swift in Sources */, + 4ADDE0CE2176600E00FF07A2 /* Array+TensorFlowLite.swift in Sources */, + 4A7304B421500B8400C90B21 /* Data+TensorFlowLite.swift in Sources */, + 4AA72B752146ED64006C3AEF /* ViewController.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 4AA72B762146ED64006C3AEF /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 4AA72B772146ED64006C3AEF /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; + 4AA72B7B2146ED66006C3AEF /* LaunchScreen.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 4AA72B7C2146ED66006C3AEF /* Base */, + ); + name = LaunchScreen.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 4AA72B7F2146ED66006C3AEF /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.4; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 4AA72B802146ED66006C3AEF /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.4; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 4AA72B822146ED66006C3AEF /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + INFOPLIST_FILE = TensorFlowLiteApp/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_VERSION = 4.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 4AA72B832146ED66006C3AEF /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + INFOPLIST_FILE = TensorFlowLiteApp/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.tensorflow.lite.swift.TensorFlowLite; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_VERSION = 4.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 4AA72B6A2146ED64006C3AEF /* Build configuration list for PBXProject "TensorFlowLiteApp" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 4AA72B7F2146ED66006C3AEF /* Debug */, + 4AA72B802146ED66006C3AEF /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 4AA72B812146ED66006C3AEF /* Build configuration list for PBXNativeTarget "TensorFlowLiteApp" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 4AA72B822146ED66006C3AEF /* Debug */, + 4AA72B832146ED66006C3AEF /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 4AA72B672146ED64006C3AEF /* Project object */; +} diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift new file mode 100644 index 0000000000000000000000000000000000000000..ffa90a06adb0b9f93575c8390cd30bd589e43ac7 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift @@ -0,0 +1,24 @@ +import UIKit + +@UIApplicationMain + +final class AppDelegate: UIResponder, UIApplicationDelegate { + + /// The main window of the app. + var window: UIWindow? + + func application( + _ application: UIApplication, + didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]? = nil + ) -> Bool { + return true + } +} + +// MARK: - Extensions + +#if !swift(>=4.2) +extension UIApplication { + typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey +} +#endif // !swift(>=4.2) diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift new file mode 100644 index 0000000000000000000000000000000000000000..56df1ce6597aacf307f7a89a084527ea93c303c2 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift @@ -0,0 +1,22 @@ +import Foundation + +extension Array { + /// Creates a new array from the bytes of the given unsafe data. + /// + /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit + /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in + /// the `unsafeData`'s buffer to a new array returns an unsafe copy. + /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of + /// `MemoryLayout.stride`. + /// - Parameter unsafeData: The data containing the bytes to turn into an array. + init?(unsafeData: Data) { + guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } + let elements = unsafeData.withUnsafeBytes { + UnsafeBufferPointer( + start: $0, + count: unsafeData.count / MemoryLayout.stride + ) + } + self.init(elements) + } +} diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..d8db8d65fd79fd541b2b7eba75c7378af3448f9c --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,98 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/Contents.json b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..a07a1321be2e65323fadeca51487671c88f462c8 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/Main.storyboard b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..b9e8bfb822f90ea1e1db31a21d482d6e522e374f --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/Main.storyboard @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Data+TensorFlowLite.swift b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Data+TensorFlowLite.swift new file mode 100644 index 0000000000000000000000000000000000000000..bc8a70c848390ad7ba584629563d7d75a9e32341 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Data+TensorFlowLite.swift @@ -0,0 +1,13 @@ +import Foundation + +extension Data { + /// Creates a new buffer by copying the buffer pointer of the given array. + /// + /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit + /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting + /// data from the resulting buffer has undefined behavior. + /// - Parameter array: An array with elements of type `T`. + init(copyingBufferOf array: [T]) { + self = array.withUnsafeBufferPointer(Data.init) + } +} diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..3ca3875f04e5789da9cfb34a44151cd06226a8f3 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Info.plist @@ -0,0 +1,46 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 0.0.1 + LSRequiresIPhoneOS + + NSCameraUsageDescription + NSCameraUsageDescription + NSPhotoLibraryUsageDescription + Select a photo to detect objects in. + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + + + diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/ViewController.swift b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/ViewController.swift new file mode 100644 index 0000000000000000000000000000000000000000..73c74fd19c996653d988977d551fcef683f18697 --- /dev/null +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/ViewController.swift @@ -0,0 +1,299 @@ +import TensorFlowLite +import UIKit + +class ViewController: UIViewController { + + // MARK: - Properties + + /// TensorFlowLite interpreter object for performing inference from a given model. + private var interpreter: Interpreter? + + /// Serial dispatch queue for managing `Interpreter` calls. + private let interpreterQueue = DispatchQueue( + label: Constant.dispatchQueueLabel, + qos: .userInitiated + ) + + /// The currently selected model. + private var currentModel: Model { + guard let currentModel = Model(rawValue: modelControl.selectedSegmentIndex) else { + preconditionFailure("Invalid model for selected segment index.") + } + return currentModel + } + + /// A description of the current model. + private var modelDescription: String { + guard let interpreter = interpreter else { return "" } + let inputCount = interpreter.inputTensorCount + let outputCount = interpreter.outputTensorCount + let inputTensors = (0.. String = { + guard let results = [Float32](unsafeData: outputTensor.data) else { return "No results." } + return resultsText + results.description + } + self.updateResultsText(results()) + } catch let error { + self.updateResultsText( + "Failed to invoke the interpreter with error: \(error.localizedDescription)" + ) + return + } + } + } + + private func invokeAddQuantized() { + interpreterQueue.async { + guard let interpreter = self.interpreter else { + self.updateResultsText(Constant.nilInterpreterErrorMessage) + return + } + do { + try interpreter.resizeInput(at: 0, to: [2]) + try interpreter.allocateTensors() + let input: [UInt8] = [1, 3] + let resultsText = self.modelDescription + "\n\n" + + "Performing 2 add operations on quantized input \(input.description) equals: " + self.updateResultsText(resultsText) + let data = Data(input) + try interpreter.copy(data, toInputAt: 0) + try interpreter.invoke() + let outputTensor = try interpreter.output(at: 0) + let results: () -> String = { + guard let quantizationParameters = outputTensor.quantizationParameters else { + return "No results." + } + let quantizedResults = [UInt8](outputTensor.data) + let dequantizedResults = quantizedResults.map { + quantizationParameters.scale * Float(Int($0) - quantizationParameters.zeroPoint) + } + return resultsText + quantizedResults.description + + ", dequantized results: " + dequantizedResults.description + } + self.updateResultsText(results()) + } catch let error { + self.updateResultsText( + "Failed to invoke the interpreter with error: \(error.localizedDescription)" + ) + return + } + } + } + + private func invokeMultiAdd() { + interpreterQueue.async { + guard let interpreter = self.interpreter else { + self.updateResultsText(Constant.nilInterpreterErrorMessage) + return + } + do { + let shape = TensorShape(2) + try (0.. [Float32] in + let input = [Float32(index + 1), Float32(index + 2)] + let data = Data(copyingBufferOf: input) + try interpreter.copy(data, toInputAt: index) + return input + } + let resultsText = self.modelDescription + "\n\n" + + "Performing 3 add operations on inputs \(inputs.description) equals: " + self.updateResultsText(resultsText) + try interpreter.invoke() + let results = try (0.. [Float32] in + let tensor = try interpreter.output(at: index) + return [Float32](unsafeData: tensor.data) ?? [] + } + self.updateResultsText(resultsText + results.description) + } catch let error { + self.updateResultsText( + "Failed to invoke the interpreter with error: \(error.localizedDescription)" + ) + return + } + } + } + + private func updateResultsText(_ text: String? = nil) { + safeDispatchOnMain { self.resultsTextView.text = text } + } +} + +// MARK: - Constants + +private enum Constant { + static let dispatchQueueLabel = "TensorFlowLiteInterpreterQueue" + static let nilInterpreterErrorMessage = + "Failed to invoke the interpreter because the interpreter was nil." +} + +/// Models that can be loaded by the TensorFlow Lite `Interpreter`. +private enum Model: Int, CustomStringConvertible { + /// A float model that performs two add operations on one input tensor and returns the result in + /// one output tensor. + case add = 0 + /// A quantized model that performs two add operations on one input tensor and returns the result + /// in one output tensor. + case addQuantized = 1 + /// A float model that performs three add operations on four input tensors and returns the results + /// in 2 output tensors. + case multiAdd = 2 + + var fileInfo: (name: String, extension: String) { + switch self { + case .add: + return Add.fileInfo + case .addQuantized: + return AddQuantized.fileInfo + case .multiAdd: + return MultiAdd.fileInfo + } + } + + // MARK: - CustomStringConvertible + + var description: String { + switch self { + case .add: + return Add.name + case .addQuantized: + return AddQuantized.name + case .multiAdd: + return MultiAdd.name + } + } +} + +/// Values for the `Add` model. +private enum Add { + static let name = "Add" + static let fileInfo = (name: "add", extension: "bin") +} + +/// Values for the `AddQuantized` model. +private enum AddQuantized { + static let name = "AddQuantized" + static let fileInfo = (name: "add_quantized", extension: "bin") +} + +/// Values for the `MultiAdd` model. +private enum MultiAdd { + static let name = "MultiAdd" + static let fileInfo = (name: "multi_add", extension: "bin") +} + +// MARK: - Fileprivate + +/// Safely dispatches the given block on the main queue. If the current thread is `main`, the block +/// is executed synchronously; otherwise, the block is executed asynchronously on the main thread. +fileprivate func safeDispatchOnMain(_ block: @escaping () -> Void) { + if Thread.isMainThread { block(); return } + DispatchQueue.main.async { block() } +} diff --git a/tensorflow/lite/experimental/swift/Tests/InterpreterOptionsTests.swift b/tensorflow/lite/experimental/swift/Tests/InterpreterOptionsTests.swift new file mode 100644 index 0000000000000000000000000000000000000000..54b4f59b28942fe2398aba1a19443857e9617458 --- /dev/null +++ b/tensorflow/lite/experimental/swift/Tests/InterpreterOptionsTests.swift @@ -0,0 +1,54 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import TensorFlowLite +import XCTest + +class InterpreterOptionsTests: XCTestCase { + + func testInterpreterOptions_InitWithDefaultValues() { + let options = InterpreterOptions() + XCTAssertNil(options.threadCount) + XCTAssertFalse(options.isErrorLoggingEnabled) + } + + func testInterpreterOptions_InitWithCustomValues() { + var options = InterpreterOptions() + options.threadCount = 2 + XCTAssertEqual(options.threadCount, 2) + options.isErrorLoggingEnabled = true + XCTAssertTrue(options.isErrorLoggingEnabled) + } + + func testInterpreterOptions_Equatable() { + var options1 = InterpreterOptions() + var options2 = InterpreterOptions() + XCTAssertEqual(options1, options2) + + options1.threadCount = 2 + options2.threadCount = 2 + XCTAssertEqual(options1, options2) + + options2.threadCount = 3 + XCTAssertNotEqual(options1, options2) + options2.threadCount = 2 + + options1.isErrorLoggingEnabled = true + options2.isErrorLoggingEnabled = true + XCTAssertEqual(options1, options2) + + options2.isErrorLoggingEnabled = false + XCTAssertNotEqual(options1, options2) + } +} diff --git a/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift b/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift new file mode 100644 index 0000000000000000000000000000000000000000..e98da5f951e9bc6bfebaf6a1bd76b3c8c8bb9e83 --- /dev/null +++ b/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift @@ -0,0 +1,315 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import TensorFlowLite +import XCTest + +class InterpreterTests: XCTestCase { + + var interpreter: Interpreter! + + override func setUp() { + super.setUp() + + interpreter = try! Interpreter(modelPath: AddModel.path) + } + + override func tearDown() { + interpreter = nil + + super.tearDown() + } + + func testInterpreter_InitWithModelPath() { + XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path)) + } + + func testInterpreter_Init_ThrowsFailedToLoadModel() { + XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in + self.assertEqualErrors(actual: error, expected: .failedToLoadModel) + } + } + + func testInterpreter_InitWithModelPathAndOptions() { + var options = InterpreterOptions() + options.threadCount = 2 + XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options)) + } + + func testInterpreter_InputTensorCount() { + XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount) + } + + func testInterpreter_OutputTensorCount() { + XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount) + } + + func testInterpreter_Invoke() throws { + try interpreter.allocateTensors() + XCTAssertNoThrow(try interpreter.invoke()) + } + + func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() { + XCTAssertThrowsError(try interpreter.invoke()) { error in + self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired) + } + } + + func testInterpreter_InputTensorAtIndex() throws { + try setUpAddModelInputTensor() + let inputTensor = try interpreter.input(at: AddModel.validIndex) + XCTAssertEqual(inputTensor, AddModel.inputTensor) + } + + func testInterpreter_InputTensorAtIndex_QuantizedModel() throws { + interpreter = try Interpreter(modelPath: AddQuantizedModel.path) + try setUpAddQuantizedModelInputTensor() + let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex) + XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor) + } + + func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws { + try interpreter.allocateTensors() + XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in + let maxIndex = AddModel.inputTensorCount - 1 + self.assertEqualErrors( + actual: error, + expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) + ) + } + } + + func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() { + XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in + self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired) + } + } + + func testInterpreter_OutputTensorAtIndex() throws { + try setUpAddModelInputTensor() + try interpreter.invoke() + let outputTensor = try interpreter.output(at: AddModel.validIndex) + XCTAssertEqual(outputTensor, AddModel.outputTensor) + let expectedResults = [Float32](unsafeData: outputTensor.data) + XCTAssertEqual(expectedResults, AddModel.results) + } + + func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws { + interpreter = try Interpreter(modelPath: AddQuantizedModel.path) + try setUpAddQuantizedModelInputTensor() + try interpreter.invoke() + let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex) + XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor) + let expectedResults = [UInt8](outputTensor.data) + XCTAssertEqual(expectedResults, AddQuantizedModel.results) + } + + func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws { + try interpreter.allocateTensors() + try interpreter.invoke() + XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in + let maxIndex = AddModel.outputTensorCount - 1 + self.assertEqualErrors( + actual: error, + expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) + ) + } + } + + func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() { + XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in + self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired) + } + } + + func testInterpreter_ResizeInputTensorAtIndexToShape() { + XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3])) + XCTAssertNoThrow(try interpreter.allocateTensors()) + } + + func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() { + XCTAssertThrowsError(try interpreter.resizeInput( + at: AddModel.invalidIndex, + to: [2, 2, 3] + )) { error in + let maxIndex = AddModel.inputTensorCount - 1 + self.assertEqualErrors( + actual: error, + expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) + ) + } + } + + func testInterpreter_CopyDataToInputTensorAtIndex() throws { + try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) + try interpreter.allocateTensors() + let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex) + XCTAssertEqual(inputTensor.data, AddModel.inputData) + } + + func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() { + XCTAssertThrowsError(try interpreter.copy( + AddModel.inputData, + toInputAt: AddModel.invalidIndex + )) { error in + let maxIndex = AddModel.inputTensorCount - 1 + self.assertEqualErrors( + actual: error, + expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) + ) + } + } + + func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws { + try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) + try interpreter.allocateTensors() + let invalidData = Data(count: AddModel.dataCount - 1) + XCTAssertThrowsError(try interpreter.copy( + invalidData, + toInputAt: AddModel.validIndex + )) { error in + self.assertEqualErrors( + actual: error, + expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount) + ) + } + } + + func testInterpreter_AllocateTensors() { + XCTAssertNoThrow(try interpreter.allocateTensors()) + } + + // MARK: - Private + + private func setUpAddModelInputTensor() throws { + precondition(interpreter != nil) + try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) + try interpreter.allocateTensors() + try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex) + } + + private func setUpAddQuantizedModelInputTensor() throws { + precondition(interpreter != nil) + try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape) + try interpreter.allocateTensors() + try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex) + } + + private func assertEqualErrors(actual: Error, expected: InterpreterError) { + guard let actual = actual as? InterpreterError else { + XCTFail("Actual error should be of type InterpreterError.") + return + } + XCTAssertEqual(actual, expected) + } +} + +// MARK: - Constants + +/// Values for the `add.bin` model. +private enum AddModel { + static let info = (name: "add", extension: "bin") + static let inputTensorCount = 1 + static let outputTensorCount = 1 + static let invalidIndex = 1 + static let validIndex = 0 + static let shape: TensorShape = [2] + static let dataCount = inputData.count + static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)]) + static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)]) + static let results = [Float32(3.0), Float32(9.0)] + + static let inputTensor = Tensor( + name: "input", + dataType: .float32, + shape: shape, + data: inputData + ) + static let outputTensor = Tensor( + name: "output", + dataType: .float32, + shape: shape, + data: outputData + ) + + static var path: String = { + let bundle = Bundle(for: InterpreterTests.self) + guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" } + return path + }() +} + +/// Values for the `add_quantized.bin` model. +private enum AddQuantizedModel { + static let info = (name: "add_quantized", extension: "bin") + static let inputOutputIndex = 0 + static let shape: TensorShape = [2] + static let inputData = Data([1, 3]) + static let outputData = Data([3, 9]) + static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0) + static let results: [UInt8] = [3, 9] + + static let inputTensor = Tensor( + name: "input", + dataType: .uInt8, + shape: shape, + data: inputData, + quantizationParameters: quantizationParameters + ) + static let outputTensor = Tensor( + name: "output", + dataType: .uInt8, + shape: shape, + data: outputData, + quantizationParameters: quantizationParameters + ) + + static var path: String = { + let bundle = Bundle(for: InterpreterTests.self) + guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" } + return path + }() +} + +// MARK: - Extensions + +extension Array { + /// Creates a new array from the bytes of the given unsafe data. + /// + /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of + /// `MemoryLayout.stride`. + /// - Parameter unsafeData: The data containing the bytes to turn into an array. + init?(unsafeData: Data) { + guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } + let elements = unsafeData.withUnsafeBytes { + UnsafeBufferPointer( + start: $0, + count: unsafeData.count / MemoryLayout.stride + ) + } + self.init(elements) + } +} + +extension Data { + /// Creates a new buffer by copying the buffer pointer of the given array. + /// + /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit + /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting + /// data from the resulting buffer has undefined behavior. + /// - Parameter array: An array with elements of type `T`. + init(copyingBufferOf array: [T]) { + self = array.withUnsafeBufferPointer(Data.init) + } +} diff --git a/tensorflow/lite/experimental/swift/Tests/ModelTests.swift b/tensorflow/lite/experimental/swift/Tests/ModelTests.swift new file mode 100644 index 0000000000000000000000000000000000000000..025db1890607641d49304ae22da1fc33fed084ef --- /dev/null +++ b/tensorflow/lite/experimental/swift/Tests/ModelTests.swift @@ -0,0 +1,59 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import TensorFlowLite +import XCTest + +class ModelTests: XCTestCase { + + var modelPath: String! + + override func setUp() { + super.setUp() + + let bundle = Bundle(for: type(of: self)) + guard let modelPath = bundle.path( + forResource: Constant.modelInfo.name, + ofType: Constant.modelInfo.extension) + else { + XCTFail("Failed to get the model file path.") + return + } + self.modelPath = modelPath + } + + override func tearDown() { + modelPath = nil + + super.tearDown() + } + + func testModel_InitWithFilePath() { + XCTAssertNotNil(Model(filePath: modelPath)) + } + + func testModel_InitWithEmptyFilePath_FailsInitialization() { + XCTAssertNil(Model(filePath: "")) + } + + func testModel_InitWithInvalidFilePath_FailsInitialization() { + XCTAssertNil(Model(filePath: "invalid/path")) + } +} + +// MARK: - Constants + +private enum Constant { + static let modelInfo = (name: "add", extension: "bin") +} diff --git a/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift b/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift new file mode 100644 index 0000000000000000000000000000000000000000..65648c26982daa0cab2a40d111d72e10563373cf --- /dev/null +++ b/tensorflow/lite/experimental/swift/Tests/QuantizationParametersTests.swift @@ -0,0 +1,43 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import TensorFlowLite +import XCTest + +class QuantizationParametersTests: XCTestCase { + + func testQuantizationParameters_InitWithCustomValues() { + let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1) + XCTAssertEqual(parameters.scale, 0.5) + XCTAssertEqual(parameters.zeroPoint, 1) + } + + func testQuantizationParameters_Equatable() { + let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1) + let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1) + XCTAssertEqual(parameters1, parameters2) + + let parameters3 = QuantizationParameters(scale: 0.4, zeroPoint: 1) + XCTAssertNotEqual(parameters1, parameters3) + XCTAssertNotEqual(parameters2, parameters3) + } +} + +// MARK: - Extensions + +extension QuantizationParameters: Equatable { + public static func == (lhs: QuantizationParameters, rhs: QuantizationParameters) -> Bool { + return lhs.scale == rhs.scale && lhs.zeroPoint == rhs.zeroPoint + } +} diff --git a/tensorflow/lite/experimental/swift/Tests/TensorTests.swift b/tensorflow/lite/experimental/swift/Tests/TensorTests.swift new file mode 100644 index 0000000000000000000000000000000000000000..4540043a1636f43834ec496ffef1e78444ba312b --- /dev/null +++ b/tensorflow/lite/experimental/swift/Tests/TensorTests.swift @@ -0,0 +1,83 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import TensorFlowLite +import XCTest + +class TensorTests: XCTestCase { + + // MARK: - Tensor + + func testTensor_Init() { + let name = "InputTensor" + let dataType: TensorDataType = .uInt8 + let shape = TensorShape(Constant.dimensions) + guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return } + let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1) + let inputTensor = Tensor( + name: name, + dataType: dataType, + shape: shape, + data: data, + quantizationParameters: quantizationParameters + ) + XCTAssertEqual(inputTensor.name, name) + XCTAssertEqual(inputTensor.dataType, dataType) + XCTAssertEqual(inputTensor.shape, shape) + XCTAssertEqual(inputTensor.data, data) + XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters) + } + + // MARK: - TensorShape + + func testTensorShape_InitWithArray() { + let shape = TensorShape(Constant.dimensions) + XCTAssertEqual(shape.rank, Constant.dimensions.count) + XCTAssertEqual(shape.dimensions, Constant.dimensions) + } + + func testTensorShape_InitWithElements() { + let shape = TensorShape(2, 2, 3) + XCTAssertEqual(shape.rank, Constant.dimensions.count) + XCTAssertEqual(shape.dimensions, Constant.dimensions) + } + + func testTensorShape_InitWithArrayLiteral() { + let shape: TensorShape = [2, 2, 3] + XCTAssertEqual(shape.rank, Constant.dimensions.count) + XCTAssertEqual(shape.dimensions, Constant.dimensions) + } +} + +// MARK: - Constants + +private enum Constant { + /// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]. + static let dimensions = [2, 2, 3] +} + +// MARK: - Extensions + +extension TensorShape: Equatable { + public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool { + return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions + } +} + +extension Tensor: Equatable { + public static func == (lhs: Tensor, rhs: Tensor) -> Bool { + return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape && + lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters + } +} diff --git a/tensorflow/lite/experimental/writer/BUILD b/tensorflow/lite/experimental/writer/BUILD index 57ce63636714aa616cb50e04fe2c15210cc2eb1c..9ba74d0e9114a5e47bef6c6b146f121d254e0f92 100644 --- a/tensorflow/lite/experimental/writer/BUILD +++ b/tensorflow/lite/experimental/writer/BUILD @@ -33,7 +33,6 @@ cc_library( "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", - "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs_with_reflection", ], ) @@ -48,6 +47,16 @@ cc_binary( ], ) +cc_binary( + name = "writer_test", + srcs = ["writer_test.cc"], + deps = [ + ":writer_lib", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + ], +) + cc_test( name = "writer_lib_test", size = "small", diff --git a/tensorflow/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/experimental/writer/enum_mapping.h index cb6ec3e0d7e0f1b53cc8b84e10cb1be4b1f023c0..949a255abaf63cbc1cc0b3f718f6aaca4f38fd8d 100644 --- a/tensorflow/lite/experimental/writer/enum_mapping.h +++ b/tensorflow/lite/experimental/writer/enum_mapping.h @@ -112,5 +112,29 @@ inline LSHProjectionType LSHProjectionTypeToSchema( } } +inline MirrorPadMode MirrorPaddingModeToSchema(TfLiteMirrorPaddingMode mode) { + switch (mode) { + case kTfLiteMirrorPaddingUnknown: + return MirrorPadMode_REFLECT; // TODO(aselle): consider an error + case kTfLiteMirrorPaddingReflect: + return MirrorPadMode_REFLECT; + case kTfLiteMirrorPaddingSymmetric: + return MirrorPadMode_SYMMETRIC; + } +} + +inline CombinerType CombinerTypeToSchema(TfLiteCombinerType type) { + switch (type) { + case kTfLiteCombinerTypeSum: + return CombinerType_SUM; + case kTfLiteCombinerTypeMean: + return CombinerType_MEAN; + case kTfLiteCombinerTypeSqrtn: + return CombinerType_SQRTN; + } +} + +// int + } // namespace tflite #endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index 73742494762b8af9a9a08cd24c6eae1ac25fd426..a5c5dc8709969eccb03250ff194127f47592896a 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -22,54 +22,60 @@ limitations under the License. namespace tflite { namespace { // This is generated by grepping -// cat third_party/tensorflow/lite/builtin_op_data.h -//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" -static const char* param_structs[] = {"TfLiteConvParams", - "TfLitePoolParams", - "TfLiteDepthwiseConvParams", - "TfLiteSVDFParams", - "TfLiteRNNParams", - "TfLiteSequenceRNNParams", - "TfLiteFullyConnectedParams", - "TfLiteLSHProjectionParams", - "TfLiteSoftmaxParams", - "TfLiteConcatenationParams", - "TfLiteAddParams", - "TfLiteSpaceToBatchNDParams", +// cat third_party/tensorflow/lite/c/builtin_op_data.h | grep "^} TfLite" | +// sed 's/^} \(TfLite.*\)Params;/\1Params/g' | grep -v "^}" | sed +// 's/\(.*\)/"\1",/g' | sort +static const char* param_structs[] = {"TfLiteAddParams", + "TfLiteArgMaxParams", + "TfLiteArgMinParams", "TfLiteBatchToSpaceNDParams", - "TfLiteMulParams", - "TfLiteSubParams", + "TfLiteBidirectionalSequenceLSTMParams", + "TfLiteBidirectionalSequenceRNNParams", + "TfLiteCastParams", + "TfLiteConcatenationParams", + "TfLiteConvParams", + "TfLiteDepthwiseConvParams", "TfLiteDivParams", + "TfLiteEmbeddingLookupSparseParams", + "TfLiteFakeQuantParams", + "TfLiteFullyConnectedParams", + "TfLiteGatherParams", "TfLiteL2NormParams", + "TfLiteLeakyReluParams", "TfLiteLocalResponseNormParams", + "TfLiteLSHProjectionParams", "TfLiteLSTMParams", - "TfLiteResizeBilinearParams", - "TfLiteResizeNearestNeighborParams", + "TfLiteMirrorPaddingParams", + "TfLiteMulParams", + "TfLiteOneHotParams", + "TfLitePackParams", "TfLitePadParams", "TfLitePadV2Params", + "TfLitePoolParams", + "TfLiteReducerParams", "TfLiteReshapeParams", + "TfLiteResizeBilinearParams", + "TfLiteResizeNearestNeighborParams", + "TfLiteRNNParams", + "TfLiteSequenceRNNParams", + "TfLiteShapeParams", "TfLiteSkipGramParams", + "TfLiteSoftmaxParams", + "TfLiteSpaceToBatchNDParams", "TfLiteSpaceToDepthParams", - "TfLiteCastParams", - "TfLiteEmbeddingLookupSparseParams", - "TfLiteGatherParams", - "TfLiteTransposeParams", - "TfLiteReducerParams", + "TfLiteSparseToDenseParams", "TfLiteSplitParams", "TfLiteSplitVParams", "TfLiteSqueezeParams", "TfLiteStridedSliceParams", - "TfLiteArgMaxParams", - "TfLiteArgMinParams", + "TfLiteSubParams", + "TfLiteSVDFParams", "TfLiteTransposeConvParams", - "TfLiteSparseToDenseParams", - "TfLiteShapeParams", - "TfLiteFakeQuantParams", - "TfLitePackParams", - "TfLiteOneHotParams", - "TfLiteLeakyReluParams", - "TfLiteMirrorPaddingParams", + "TfLiteTransposeParams", + "TfLiteUnidirectionalSequenceLSTMParams", "TfLiteUniqueParams", + "TfLiteUnpackParams", + "TfLiteReverseSequenceParams", nullptr}; } // namespace @@ -142,7 +148,6 @@ class OpOptionData { op_to_option_["REDUCE_MAX"] = "ReducerOptions"; op_to_option_["REDUCE_MIN"] = "ReducerOptions"; op_to_option_["REDUCE_ANY"] = "ReducerOptions"; - op_to_option_["UNPACK"] = ""; op_to_option_["SUM"] = "ReducerOptions"; op_to_option_["REDUCE_MAX"] = "ReducerOptions"; op_to_option_["REDUCE_PROD"] = "ReducerOptions"; @@ -151,33 +156,32 @@ class OpOptionData { op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions"; op_to_option_["MAX_POOL_2D"] = "Pool2DOptions"; op_to_option_["L2_NORMALIZATION"] = "L2NormOptions"; - op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; - op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - op_to_option_["MIRROR_PAD"] = ""; // TODO(karimnosseir): MirrorPadOptions. - op_to_option_["UNIQUE"] = ""; // TODO(karimnosseir): UniqueOptions. - // Manually specified mappings between ops and options (none) - op_to_option_["EMBEDDING_LOOKUP"] = - ""; // TODO(aselle): maybe something else. + op_to_option_["MAXIMUM"] = "MaximumMinimumOptions"; + op_to_option_["MINIMUM"] = "MaximumMinimumOptions"; + op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. + op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. + + // Manually specified mappings between ops to "none" options -- these are + // ops without a corresponding Options message in schema as yet. If these + // options do get assigned an Options message in future, they need to be + // updated here as well. + op_to_option_["EMBEDDING_LOOKUP"] = ""; op_to_option_["FLOOR"] = ""; - op_to_option_["HASHTABLE_LOOKUP"] = - ""; // TODO(aselle): maybe something else. + op_to_option_["CEIL"] = ""; + op_to_option_["HASHTABLE_LOOKUP"] = ""; op_to_option_["LOGISTIC"] = ""; op_to_option_["RELU"] = ""; op_to_option_["RELU_N1_TO_1"] = ""; op_to_option_["RELU6"] = ""; op_to_option_["TANH"] = ""; - op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. - op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. op_to_option_["PRELU"] = ""; - op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions - op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions op_to_option_["SIN"] = ""; op_to_option_["LOG"] = ""; op_to_option_["SQRT"] = ""; op_to_option_["RSQRT"] = ""; + op_to_option_["ELU"] = ""; + op_to_option_["REVERSE_SEQUENCE"] = ""; // TODO(aselle): These are undesirable hacks. Consider changing C structs option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; @@ -185,6 +189,7 @@ class OpOptionData { option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams"; option_to_struct_["LocalResponseNormalizationOptions"] = "TfLiteLocalResponseNormParams"; + option_to_struct_["MirrorPadOptions"] = "TfLiteMirrorPaddingParams"; // Now for every op, try to find an option. bool fatal = false; for (auto op_name : ops_) { @@ -224,13 +229,15 @@ class OpOptionData { if (!param_struct_found) { std::cerr << "Failed to get param struct for option " << option_name << std::endl; - fatal = true; } else { option_to_struct_.insert(std::make_pair(option_name, params_guess)); } } } } + if (fatal) { + exit(1); + } } private: @@ -241,16 +248,28 @@ class OpOptionData { option_to_type_function_; }; +void GenerateImportForResizeBilinearOp(FILE* fp) { + fprintf(fp, + " case BuiltinOperator_RESIZE_BILINEAR: {\n" + " const auto* params = reinterpret_cast(builtin_op_data);\n" + " auto union_type = CreateResizeBilinearOptions(*fbb, " + "params->align_corners).Union();\n" + " return std::make_pair(BuiltinOptions_ResizeBilinearOptions, " + "union_type);\n" + " }\n break;\n"); +} + void GenerateImportForOp(FILE* fp, const std::string& op_name, const std::string& option_name, const std::string& option_type, const flatbuffers::TypeTable* options, const std::string& struct_name) { - // Skip tricky ones for now - if (struct_name == "TfLiteResizeBilinearParams") return; - if (struct_name == "TfLiteSqueezeParams") return; - if (struct_name == "TfLiteEmbeddingLookupSparseParams") return; - if (struct_name == "TfLiteReshapeParams") return; + // Special-case ResizeBilinear which has some deprecated fields. + if (struct_name == "TfLiteResizeBilinearParams") { + GenerateImportForResizeBilinearOp(fp); + return; + } fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); fprintf(fp, @@ -260,6 +279,9 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name, for (size_t i = 0; i < options->num_elems; i++) { std::string elem_name = options->names[i]; + bool is_int_vector = false; + std::string vector_name = elem_name; + std::string vector_size; // TODO(aselle): Irregular naming in builtins if (elem_name == "fused_activation_function") elem_name = "activation"; @@ -271,8 +293,26 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name, elem_name = "dilation_height_factor"; else if (elem_name == "dilation_w_factor") elem_name = "dilation_width_factor"; - else if (elem_name == "new_shape") - elem_name = "shape"; + else if (elem_name == "idx_out_type") + elem_name = "index_out_type"; + + // Vector fields treated specially. + if (elem_name == "new_shape") { + is_int_vector = true; + vector_name = "shape"; + vector_size = "num_dimensions"; + } else if (elem_name == "squeeze_dims") { + is_int_vector = true; + vector_size = "num_squeeze_dims"; + } + + if (is_int_vector) { + fprintf(fp, + " auto val%zu = fbb->CreateVector(" + "std::vector(params->%s, params->%s + params->%s));\n", + i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str()); + continue; + } flatbuffers::TypeCode code = options->type_codes[i]; auto contained_type = code.sequence_ref != -1 @@ -291,6 +331,10 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name, mapper = "LSTMKernelTypeToSchema"; } else if (contained_type == LSHProjectionTypeTypeTable) { mapper = "LSHProjectionTypeToSchema"; + } else if (contained_type == MirrorPadModeTypeTable) { + mapper = "MirrorPaddingModeToSchema"; + } else if (contained_type == CombinerTypeTypeTable) { + mapper = "CombinerTypeToSchema"; } fprintf(fp, diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc index a0ce4b716d62c5a24342f5a3863e58eb203f7441..2bdc41bae84341949631f77a1be8631b007f2985 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/experimental/writer/writer_lib.cc @@ -219,6 +219,11 @@ std::vector InterpreterWriter::RemapTensorIndicesToWritten( std::vector output; output.reserve(input.size()); for (int x : input) { + // Special value representing an optional tensor which is not present. + if (x == -1) { + output.push_back(x); + continue; + } if (tensor_to_written_tensor_[x] != -1) { output.push_back(tensor_to_written_tensor_[x]); } diff --git a/tensorflow/lite/experimental/writer/writer_test.cc b/tensorflow/lite/experimental/writer/writer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc9b25db33094ed06e6ab464b8b9dcf5209f1488 --- /dev/null +++ b/tensorflow/lite/experimental/writer/writer_test.cc @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Loads the input tflite file into interpreter, serializes it back to a tflite +// buffer, and then verifies that the generated output can be loaded back into +// an interpreter and the model prepared (i.e., AllocateTensors returns ok). +// +// Usage: +// writer_test + +#include + +#include "tensorflow/lite/experimental/writer/writer_lib.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s input_file\n", argv[0]); + return 1; + } + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(argv[1]); + std::unique_ptr interpreter; + tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; + tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); + tflite::InterpreterWriter writer(interpreter.get()); + std::unique_ptr output_buffer; + size_t output_buffer_size; + writer.GetBuffer(&output_buffer, &output_buffer_size); + + // Verify the generated model. + model = tflite::FlatBufferModel::BuildFromBuffer( + reinterpret_cast(output_buffer.get()), output_buffer_size); + tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); + if (interpreter->AllocateTensors() != kTfLiteOk) { + fprintf(stderr, "AllocateTensors failed on the round-tripped model.\n"); + return 1; + } + return 0; +} diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index 9c48e1e54d153b9ff043e43f75f25cc36398bc60..82e9de3d26847494e343e4a892a4159c612d0957 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -1,116 +1,124 @@ upper_tabs: -# Tabs left of dropdown menu -- include: /_upper_tabs_left.yaml -- include: /api_docs/_upper_tabs_api.yaml -# Dropdown menu -- name: Resources - path: /resources +- name: "Install" + lower_tabs: + guides: + - include: /install/_toc.yaml + +- name: "Learn" + path: /learn/ is_default: true menu: - - include: /resources/_menu_toc.yaml + - include: /learn/_menu_toc.yaml lower_tabs: # Subsite tabs other: - - name: Guide + - name: "Guide" contents: - - title: Overview - path: /lite/overview - - title: Developer guide - path: /lite/devguide - - title: Android demo app - path: /lite/demo_android - - title: iOS demo app - path: /lite/demo_ios - - break: true - - title: TensorFlow Lite APIs - path: /lite/apis - - title: Custom operators - path: /lite/custom_operators - - title: TensorFlow Lite ops versioning - path: /lite/ops_versioning - - title: TensorFlow Lite compatibility guide - path: /lite/tf_ops_compatibility - - title: List of hosted models - path: /lite/models - - title: TensorFlow Lite for iOS - path: /lite/ios - - title: TensorFlow Lite for Raspberry Pi - path: /lite/rpi + - title: "TensorFlow Lite guide" + path: /lite/guide + + - heading: "Get started" + - title: "Overview" + path: /lite/guide/get_started + - title: "Android quickstart" + path: /lite/guide/android + - title: "iOS quickstart" + path: /lite/guide/ios + - title: "FAQ" + path: /lite/guide/faq + - title: "Roadmap" + path: /lite/guide/roadmap - - heading: TF Lite converter - - title: Overview + - heading: "Convert a model" + - title: "TensorFlow Lite converter" path: /lite/convert/ - - title: Python API guide - path: /lite/convert/python_api - - title: Command line examples + - title: "Command line examples" path: /lite/convert/cmdline_examples - - title: Command line reference + - title: "Command line reference" path: /lite/convert/cmdline_reference + - title: "Python API" + path: /lite/convert/python_api + + - heading: "Inference" + - title: "Overview" + path: /lite/guide/inference + - title: "Custom operators" + path: /lite/guide/ops_custom + - title: "Operator versions" + path: /lite/guide/ops_version + - title: "Operator compatibility" + path: /lite/guide/ops_compatibility + - title: "Select operators from TensorFlow" + path: /lite/guide/ops_select + - title: "List of hosted models" + path: /lite/guide/hosted_models - - heading: Performance - - title: Best practices + - heading: "Performance" + - title: "Best practices" path: /lite/performance/best_practices - - title: Benchmarks + - title: "Benchmarks" path: /lite/performance/benchmarks - - title: Model optimization + - title: "Model optimization" path: /lite/performance/model_optimization - - title: Post-training quantization + - title: "Post-training quantization" path: /lite/performance/post_training_quantization - - title: Post-training quantization example + - title: "Post-training quantization example" path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb status: external - - title: GPU delegate + - title: "Delegates" + path: /lite/performance/delegates + - title: "GPU delegate" path: /lite/performance/gpu - - title: Advanced GPU + - title: "Advanced GPU" path: /lite/performance/gpu_advanced - - title: TF Mobile - style: accordion - status: deprecated - section: - - title: Overview - path: /lite/tfmobile/ - - title: Building TensorFlow on Android - path: /lite/tfmobile/android_build - - title: Building TensorFlow on IOS - path: /lite/tfmobile/ios_build - - title: Integrating TensorFlow libraries - path: /lite/tfmobile/linking_libs - - title: Preparing models for mobile deployment - path: /lite/tfmobile/prepare_models - - title: Optimizing for mobile - path: /lite/tfmobile/optimizing + - heading: "Build TensorFlow Lite" + - title: "Build for iOS" + path: /lite/guide/build_ios + - title: "Build for ARM64" + path: /lite/guide/build_arm64 + - title: "Build for Raspberry Pi" + path: /lite/guide/build_rpi - # - name: Models - # contents: - # - title: Overview - # path: /lite/models/ - # - heading: Beginner - # style: divider - # - title: Image labeling - # section: - # - title: Overview - # path: /lite/models/image/label/overview - # - title: Android - # path: /lite/models/image/label/android - # - title: iOS - # path: /lite/models/image/label/ios - # - heading: Advanced - # style: divider - # - heading: Image - # - title: Image classification - # path: /lite/models/image/classification/ - # - heading: Audio - # - title: Hot word detection - # path: /lite/models/audio/hot_word/ - # - heading: Text - # - title: Text classification - # path: /lite/models/text/classification/ + - heading: "Microcontroller" + - title: "Overview" + path: /lite/guide/microcontroller + - title: "TensorFlow Codelab" + path: https://g.co/codelabs/sparkfunTF + status: external + + - name: "Examples" + contents: + - title: "Examples" + path: /lite/examples - - name: API + - name: "Models" + contents: + - title: "Overview" + path: /lite/models/ + - heading: "Image classification" + - title: "Overview" + path: /lite/models/image_classification/overview + - title: "Android" + path: /lite/models/image_classification/android + - title: "iOS" + path: /lite/models/image_classification/ios + - heading: "Other techniques" + - title: "Object detection" + path: /lite/models/object_detection/overview + - title: "Pose estimation" + path: /lite/models/pose_estimation/overview + - title: "Segmentation" + path: /lite/models/segmentation/overview + - title: "Smart reply" + path: /lite/models/smart_reply/overview + + - name: "API" skip_translation: true contents: - - title: API + - title: "API" path: /api_docs/python/tf/lite +- include: /api_docs/_upper_tabs_api.yaml +- include: /resources/_upper_tabs_resources.yaml - include: /_upper_tabs_right.yaml diff --git a/tensorflow/lite/g3doc/_index.yaml b/tensorflow/lite/g3doc/_index.yaml deleted file mode 100644 index 7153b7c6f670375df8183a9269bb7eaf096ac0c2..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/_index.yaml +++ /dev/null @@ -1,202 +0,0 @@ -project_path: /lite/_project.yaml -book_path: /lite/_book.yaml -description: -landing_page: - custom_css_path: /site-assets/css/style.css - rows: - - heading: TensorFlow Lite is for mobile and embedded devices - description: > -

- TensorFlow Lite is the official solution for running machine learning - models on mobile and embedded devices. It enables on‑device machine - learning inference with low latency and a small binary size on Android, - iOS, and other operating systems. -

- - - - classname: tfo-landing-row-heading tfo-landing-row-heading-list - heading: Many benefits - description: > - On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these: - items: - - list: - - heading: Performance - description: > - TF Lite is fast with no noticeable accuracy loss—see the metrics. - icon: - icon_name: lens - foreground: theme - - heading: Portability - description: > - Android, - iOS, and more specialized IoT devices. - icon: - icon_name: lens - foreground: theme - - list: - - heading: Low latency - description: > - Optimized float- and fixed-point CPU kernels, op‑fusing, and more. - icon: - icon_name: lens - foreground: theme - - heading: Acceleration - description: > - Integration with GPU and internal/external accelerators. - icon: - icon_name: lens - foreground: theme - - list: - - heading: Small model size - description: > - Controlled dependencies, quantization, - and op registration. - icon: - icon_name: lens - foreground: theme - - heading: Tooling - description: > - Conversion, compression, benchmarking, power-consumption, and more. - icon: - icon_name: lens - foreground: theme - - - heading: How it works - items: - - heading: Pick a model - icon: - icon_name: build - description: > - Pick a new model or retrain an existing one. - buttons: - - label: Pick - path: /lite/devguide#1_choose_a_model - classname: button button-primary tfo-button-primary - - heading: Convert - icon: - icon_name: autorenew - description: > - Convert a TensorFlow model into a compressed flat buffer with the - TensorFlow Lite Converter. - buttons: - - label: Convert - path: /lite/devguide#2_convert_the_model_format - classname: button button-primary tfo-button-primary - - heading: Deploy - icon: - icon_name: settings_cell - description: > - Take the compressed .tflite file and load it into a mobile or embedded device. - buttons: - - label: Deploy - path: /lite/devguide#3_use_the_tensorflow_lite_model_for_inference_in_a_mobile_app - classname: button button-primary tfo-button-primary - - heading: Optimize - icon: - icon_name: bolt - description: > - [optional] Quantize by converting 32-bit floats to more efficient 8-bit integers or run on GPU. - buttons: - - label: Optimize - path: /lite/devguide#4_optimize_your_model_optional - classname: button button-primary tfo-button-primary - - - heading: Build your first TensorFlow Lite app with Codelabs - background: grey - items: - - classname: tfo-landing-row-item-inset-white - heading: Get started - description: > - - - classname: tfo-landing-row-item-inset-white - heading: Share your TensorFlow Lite story - description: > - We love to hear what you're working on—it may even get highlighted on - our social media! Tell us. - - - classname: devsite-landing-row-logos tfo-landing-row-heading - heading: TensorFlow Lite users - items: - - custom_image: - path: ./images/landing-page/photos_logo.png - - custom_image: - path: ./images/landing-page/gboard_logo.png - - custom_image: - path: ./images/landing-page/gmail_logo.png - - custom_image: - path: ./images/landing-page/assistant_logo.png - - - classname: devsite-landing-row-logos - items: - - custom_image: - path: ./images/landing-page/vsco_logo.png - - custom_image: - path: ./images/landing-page/shazam_logo.png - - custom_image: - path: ./images/landing-page/nest_logo.png - - custom_image: - path: ./images/landing-page/loseit_logo.png - - - - classname: devsite-landing-row-cards - background: grey - heading: Updates - items: - - heading: "TensorFlow Lite Now Faster with Mobile GPUs (Developer Preview)" - image_path: ./images/landing-page/facial_contour_detection.png - path: https://medium.com/tensorflow/tensorflow-lite-now-faster-with-mobile-gpus-developer-preview-e15797e6dee7 - buttons: - - label: Read more - path: https://medium.com/tensorflow/tensorflow-lite-now-faster-with-mobile-gpus-developer-preview-e15797e6dee7 - - heading: "AI in motion: react in the real world" - image_path: ./images/landing-page/ai_in_motion.png - path: https://cloud.google.com/blog/products/ai-machine-learning/ai-motion-designing-simple-system-see-understand-and-react-real-world-part-ii - buttons: - - label: Read more - path: https://cloud.google.com/blog/products/ai-machine-learning/ai-motion-designing-simple-system-see-understand-and-react-real-world-part-ii - - heading: "Introducing the Model Optimization Toolkit" - image_path: /resources/images/tf-logo-card-16x9.png - path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3 - buttons: - - label: Read on TensorFlow blog - path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3 - - heading: "East Africa Cassava App" - image_path: ./images/landing-page/detect_crop_disease_in_africa.png - path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5 - buttons: - - label: Read more - path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5 - - - classname: devsite-landing-row-cards - background: grey - items: - - heading: "Using TensorFlow Lite on Android" - image_path: /resources/images/tf-logo-card-16x9.png - path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d - buttons: - - label: Read on TensorFlow blog - path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d - - heading: "TensorFlow Lite at the Dev Summit" - youtube_id: FAMfy7izB6A - buttons: - - label: Watch the video - path: https://www.youtube.com/watch?v=FAMfy7izB6A - - heading: "TensorFlow Lite on GitHub" - image_path: /resources/images/github-card-16x9.png - path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite - buttons: - - label: View on GitHub - path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite - - classname: devsite-landing-row-item-hidden diff --git a/tensorflow/lite/g3doc/_project.yaml b/tensorflow/lite/g3doc/_project.yaml index 3ce698639647d9e105b6748512314aeca148b0a0..768edeba0a12ab100da82110cd96301e8cd76b98 100644 --- a/tensorflow/lite/g3doc/_project.yaml +++ b/tensorflow/lite/g3doc/_project.yaml @@ -1,5 +1,5 @@ name: TensorFlow Lite -breadcrumb_name: TensorFlow Lite +breadcrumb_name: For Mobile & IoT home_url: /lite/ parent_project_metadata_path: /_project.yaml description: > @@ -8,3 +8,4 @@ use_site_branding: true hide_from_products_list: true content_license: cc3-apache2 buganizer_id: 316308 +include: /_project_included.yaml diff --git a/tensorflow/lite/g3doc/convert/cmdline_examples.md b/tensorflow/lite/g3doc/convert/cmdline_examples.md index 169f2d91d8a72278ff61f170f0b450885e4c2c93..139a6c9b521c060d058c5243776dcd03896ce1de 100644 --- a/tensorflow/lite/g3doc/convert/cmdline_examples.md +++ b/tensorflow/lite/g3doc/convert/cmdline_examples.md @@ -1,4 +1,4 @@ -# Converter command-line examples +# Converter command line examples This page shows how to use the TensorFlow Lite Converter in the command line. @@ -68,9 +68,9 @@ tflite_convert \ has fewer required flags than frozen graphs due to access to additional data contained within the SavedModel. The values for `--input_arrays` and `--output_arrays` are an aggregated, alphabetized list of the inputs and outputs -in the [SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within +in the [SignatureDefs](../../serving/signature_defs.md) within the -[MetaGraphDef](https://www.tensorflow.org/guide/saved_model#apis_to_build_and_load_a_savedmodel) +[MetaGraphDef](https://www.tensorflow.org/saved_model#apis_to_build_and_load_a_savedmodel) specified by `--saved_model_tag_set`. As with the GraphDef, the value for `input_shapes` is automatically determined whenever possible. diff --git a/tensorflow/lite/g3doc/convert/cmdline_reference.md b/tensorflow/lite/g3doc/convert/cmdline_reference.md index d72a46760d48dae46d63f1e914d8afda3f527e27..609ab3fdedeb92979e6fc93da60f9ed461819f7a 100644 --- a/tensorflow/lite/g3doc/convert/cmdline_reference.md +++ b/tensorflow/lite/g3doc/convert/cmdline_reference.md @@ -1,4 +1,4 @@ -# Converter command-line reference +# Converter command line reference This page is complete reference of command-line flags used by the TensorFlow Lite Converter's command line starting from TensorFlow 1.9 up until the most @@ -38,7 +38,7 @@ The following flags specify optional parameters when using SavedModels. Specifies a comma-separated set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be specified. * `--saved_model_signature_key`. Type: string. Default: - [DEFAULT_SERVING_SIGNATURE_DEF_KEY](https://www.tensorflow.org/api_docs/python/tf/saved_model/signature_constants). + `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. Specifies the key identifying the SignatureDef containing inputs and outputs. diff --git a/tensorflow/lite/g3doc/convert/index.md b/tensorflow/lite/g3doc/convert/index.md index 60fa265c295174453b1a910f5279807dd0be32cb..fac658aa6c4d9f469063cacb8ed843df9766679c 100644 --- a/tensorflow/lite/g3doc/convert/index.md +++ b/tensorflow/lite/g3doc/convert/index.md @@ -1,25 +1,37 @@ -# TensorFlow Lite Converter +# TensorFlow Lite converter -The TensorFlow Lite Converter takes a TensorFlow graph file and creates a graph -file used by the TensorFlow Lite interpreter. +TensorFlow Lite uses the optimized +[FlatBuffer](https://google.github.io/flatbuffers/) format to represent graphs. +Therefore, a TensorFlow model +([protocol buffer](https://developers.google.com/protocol-buffers/)) needs to be +converted into a `FlatBuffer` file before deploying to clients. ## From model training to device deployment -After a TensorFlow model is trained, the TensorFlow Lite converter uses that -model to generate a TensorFlow Lite -[FlatBuffer](https://google.github.io/flatbuffers/) file (`.tflite`). The -converter supports as input: -[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), -frozen graphs (models generated by -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), -and `tf.keras` HDF5 models. The TensorFlow Lite `FlatBuffer` file is deployed to -a client device (generally a mobile or embedded device), and the TensorFlow Lite -interpreter uses the compressed model for on-device inference. This conversion -process is shown in the diagram below: +The TensorFlow Lite converter generates a TensorFlow Lite +[FlatBuffer](https://google.github.io/flatbuffers/) file (`.tflite`) from a +TensorFlow model. + +The converter supports the following input formats: + +* [SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators) +* Frozen `GraphDef`: Models generated by + [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). +* `tf.keras` HDF5 models. +* Any model taken from a `tf.Session` (Python API only). + +The TensorFlow Lite `FlatBuffer` file is then deployed to a client device +(generally a mobile or embedded device), and the TensorFlow Lite interpreter +uses the compressed model for on-device inference. This conversion process is +shown in the diagram below: ![TFLite converter workflow](../images/convert/workflow.svg) -The TensorFlow Lite Converter can be used either from [Python](python_api.md) or -from the [command line](cmdline_examples.md). This allows you to integrate the -conversion step into the model design workflow, ensuring the model is easy to -convert to a mobile inference graph. +## Options + +The TensorFlow Lite Converter can be used from either of these two options: + +* [Python](python_api.md) (**Preferred**): Using the Python API makes it + easier to convert models as part of a model development pipeline, and helps + mitigate [compatibility](../tf_ops_compatibility.md) issues early on. +* [Command line](cmdline_examples.md) diff --git a/tensorflow/lite/g3doc/convert/quantization.md b/tensorflow/lite/g3doc/convert/quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..895f3e637e7527d35e656d9ab137dacc2a3f2709 --- /dev/null +++ b/tensorflow/lite/g3doc/convert/quantization.md @@ -0,0 +1,47 @@ +# Converting Quantized Models + +This page provides information for how to convert quantized TensorFlow Lite +models. For more details, please see the +[model optimization](../performance/model_optimization.md). + +# Post-training: Quantizing models for CPU model size + +The simplest way to create a small model is to quantize the weights to 8 bits +and quantize the inputs/activations "on-the-fly", during inference. This +has latency benefits, but prioritizes size reduction. + +During conversion, set the `optimizations` flag to optimize for size: + +``` +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) +converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] +tflite_quant_model = converter.convert() +``` + +# During training: Quantizing models for integer-only execution. + +Quantizing models for integer-only execution gets a model with even faster +latency, smaller size, and integer-only accelerators compatible model. +Currently, this requires training a model with +["fake-quantization" nodes](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize). + +Convert the graph: + +``` +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) +converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 +input_arrays = converter.get_input_arrays() +converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev +tflite_model = converter.convert() +``` + +For fully integer models, the inputs are uint8. The `mean` and `std_dev values` +specify how those uint8 values map to the float input values used while training +the model. + +`mean` is the integer value from 0 to 255 that maps to floating point 0.0f. +`std_dev` is 255 / (float_max - float_min) + +For most users, we recommend using post-training quantization. We are working on +new tools for post-training and during training quantization that we hope will +simplify generating quantized models. diff --git a/tensorflow/lite/g3doc/demo_ios.md b/tensorflow/lite/g3doc/demo_ios.md deleted file mode 100644 index f4b481dc6192db703dea4161ed28e2fd63812ebf..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/demo_ios.md +++ /dev/null @@ -1,76 +0,0 @@ - -# iOS Demo App - -The TensorFlow Lite demo is a camera app that continuously classifies whatever -it sees from your device's back camera, using a quantized MobileNet model. These -instructions walk you through building and running the demo on an iOS device. - -## Prerequisites - -* You must have [Xcode](https://developer.apple.com/xcode/) installed and have - a valid Apple Developer ID, and have an iOS device set up and linked to your - developer account with all of the appropriate certificates. For these - instructions, we assume that you have already been able to build and deploy - an app to an iOS device with your current developer environment. - -* The demo app requires a camera and must be executed on a real iOS device. - You can build it and run with the iPhone Simulator but it won't have any - camera information to classify. - -* You don't need to build the entire TensorFlow library to run the demo, but - you will need to clone the TensorFlow repository if you haven't already: - - git clone https://github.com/tensorflow/tensorflow - cd tensorflow - -* You'll also need the Xcode command-line tools: - - xcode-select --install - - If this is a new install, you will need to run the Xcode application once to - agree to the license before continuing. - -## Building the iOS Demo App - -1. Install CocoaPods if you don't have it: - - sudo gem install cocoapods - -2. Download the model files used by the demo app (this is done from inside the - cloned directory): - - sh tensorflow/lite/examples/ios/download_models.sh - -3. Install the pod to generate the workspace file: - - cd tensorflow/lite/examples/ios/camera - pod install - - If you have installed this pod before and that command doesn't work, try - - pod repo update - - At the end of this step you should have a file called - `tflite_camera_example.xcworkspace`. - -4. Open the project in Xcode by typing this on the command line: - - open tflite_camera_example.xcworkspace - - This launches Xcode if it isn't open already and opens the - `tflite_camera_example` project. - -5. Under `Project navigator -> tflite_camera_example -> Targets -> - tflite_camera_example -> General` change the bundle identifier by - pre-pending your name: - - ![pre-pend your name to the bundle identifier](images/ios/bundle_identifier.png) - -6. Build and run the app in Xcode. - - Note that as mentioned earlier, you must already have a device set up and - linked to your Apple Developer account in order to deploy the app on a - device. - -You'll have to grant permissions for the app to use the device's camera. Point -the camera at various objects and enjoy seeing how the model classifies things! diff --git a/tensorflow/lite/g3doc/demo_android.md b/tensorflow/lite/g3doc/guide/android.md similarity index 99% rename from tensorflow/lite/g3doc/demo_android.md rename to tensorflow/lite/g3doc/guide/android.md index 772598d5cfd36a388d253becd7fc3026f31375c9..4b2f38a5d32ffe30996d2149c81c74be70b7798a 100644 --- a/tensorflow/lite/g3doc/demo_android.md +++ b/tensorflow/lite/g3doc/guide/android.md @@ -1,5 +1,4 @@ - -# Android Demo App +# Android quickstart An example Android application using TensorFLow Lite is available [on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo). diff --git a/tensorflow/lite/g3doc/guide/build_arm64.md b/tensorflow/lite/g3doc/guide/build_arm64.md new file mode 100644 index 0000000000000000000000000000000000000000..0daa45abd02ea0e5f3a122a818c80ca16e03d796 --- /dev/null +++ b/tensorflow/lite/g3doc/guide/build_arm64.md @@ -0,0 +1,62 @@ +# Build TensorFlow Lite for ARM64 boards + +## Cross compiling + +### Installing the toolchain + +```bash +sudo apt-get update +sudo apt-get install crossbuild-essential-arm64 +``` + +> If you are using Docker, you may not use `sudo`. + +### Building + +Clone this Tensorflow repository. Run this script at the root of the repository +to download all the dependencies: + +> The Tensorflow repository is in `/tensorflow` if you are using +> `tensorflow/tensorflow:nightly-devel` docker image, just try it. + +```bash +./tensorflow/lite/tools/make/download_dependencies.sh +``` + +Note that you only need to do this once. + +Compile: + +```bash +./tensorflow/lite/tools/make/build_aarch64_lib.sh +``` + +This should compile a static library in: +`tensorflow/lite/gen/gen/aarch64_armv8-a/lib/libtensorflow-lite.a`. + +## Native compiling + +These steps were tested on HardKernel Odroid C2, gcc version 5.4.0. + +Log in to your board, install the toolchain. + +```bash +sudo apt-get install build-essential +``` + +First, clone the TensorFlow repository. Run this at the root of the repository: + +```bash +./tensorflow/lite/tools/make/download_dependencies.sh +``` + +Note that you only need to do this once. + +Compile: + +```bash +./tensorflow/lite/tools/make/build_aarch64_lib.sh +``` + +This should compile a static library in: +`tensorflow/lite/gen/gen/aarch64_armv8-a/lib/libtensorflow-lite.a`. diff --git a/tensorflow/lite/g3doc/ios.md b/tensorflow/lite/g3doc/guide/build_ios.md similarity index 98% rename from tensorflow/lite/g3doc/ios.md rename to tensorflow/lite/g3doc/guide/build_ios.md index c195b6abf4f76f88d1f60b192cd19165aefe9a11..40f2ac2fdfdac4ef7c9fb958bb125afca51ab148 100644 --- a/tensorflow/lite/g3doc/ios.md +++ b/tensorflow/lite/g3doc/guide/build_ios.md @@ -3,7 +3,7 @@ This document describes how to build TensorFlow Lite iOS library. If you just want to use it, the easiest way is using the TensorFlow Lite CocoaPod releases. -See [TensorFlow Lite iOS Demo](demo_ios.md) for examples. +See [TensorFlow Lite iOS Demo](ios.md) for examples. ## Building diff --git a/tensorflow/lite/g3doc/rpi.md b/tensorflow/lite/g3doc/guide/build_rpi.md similarity index 97% rename from tensorflow/lite/g3doc/rpi.md rename to tensorflow/lite/g3doc/guide/build_rpi.md index 708d9e328cbdfffb491d487e4592d789b4fd06af..cb0cabc2d3e5d7ef3100a74e8dcb82db214f7287 100644 --- a/tensorflow/lite/g3doc/rpi.md +++ b/tensorflow/lite/g3doc/guide/build_rpi.md @@ -1,4 +1,4 @@ -# TensorFlow Lite for Raspberry Pi +# Build TensorFlow Lite for Raspberry Pi ## Cross compiling diff --git a/tensorflow/lite/g3doc/guide/faq.md b/tensorflow/lite/g3doc/guide/faq.md new file mode 100644 index 0000000000000000000000000000000000000000..a0e4d09ef1e6f3bb1214d6a7130f672e53f32396 --- /dev/null +++ b/tensorflow/lite/g3doc/guide/faq.md @@ -0,0 +1,135 @@ +# Frequently Asked Questions + +If you don't find an answer to your question here, please look through our +detailed documentation for the topic or file a +[GitHub issue](https://github.com/tensorflow/tensorflow/issues). + +## Model Conversion + +#### What formats are supported for conversion from TensorFlow to TensorFlow Lite? + +The TensorFlow Lite converter supports the following formats: + +* SavedModels: + [TFLiteConverter.from_saved_model](../convert/python_api.md#exporting_a_savedmodel_) +* Frozen GraphDefs generated by + [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py): + [TFLiteConverter.from_frozen_graph](../convert/python_api.md#exporting_a_graphdef_from_file_) +* tf.keras HDF5 models: + [TFLiteConverter.from_keras_model_file](../convert/python_api.md#exporting_a_tfkeras_file_) +* tf.Session: + [TFLiteConverter.from_session](../convert/python_api.md#exporting_a_graphdef_from_tfsession_) + +The recommended approach is to integrate the +[Python converter](../convert/python_api.md) into your model pipeline in order to +detect compatibility issues early on. + +#### Why doesn't my model convert? + +Since the number of TensorFlow Lite operations is smaller than TensorFlow's, +some inference models may not be able to convert. For unimplemented operations, +take a look at the question on +[missing operators](faq.md#why-are-some-operations-not-implemented-in-tensorflow-lite). +Unsupported operators include embeddings and LSTM/RNNs. For conversion issues +not related to missing operations, search our +[GitHub issues](https://github.com/tensorflow/tensorflow/issues?q=label%3Acomp%3Alite+) +or file a [new one](https://github.com/tensorflow/tensorflow/issues). + +#### How do I determine the inputs/outputs for GraphDef protocol buffer? + +The easiest way to inspect a graph from a `.pb` file is to use the +[summarize_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs) +tool. + +If that approach yields an error, you can visualize the GraphDef with +[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) and +look for the inputs and outputs in the graph. To visualize a `.pb` file, use the +[`import_pb_to_tensorboard.py`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/import_pb_to_tensorboard.py) +script like below: + +``` +python import_pb_to_tensorboard.py --model_dir --log_dir +``` + +#### How do I inspect a `.tflite` file? + +TensorFlow Lite models can be visualized using the +[visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py) +script in our repository. + +* [Clone the TensorFlow repository](https://www.tensorflow.org/install/source) +* Run the `visualize.py` script with bazel: + +``` +bazel run //tensorflow/lite/tools:visualize model.tflite visualized_model.html +``` + +## Models & Operations + +#### Why are some operations not implemented in TensorFlow Lite? + +In order to keep TensorFlow Lite lightweight, only certain operations were used +in the converter. The [Compatibility Guide](ops_compatibility.md) provides a +list of operations currently supported by TensorFlow Lite. + +If you don’t see a specific operation (or an equivalent) listed, it's likely +that it has not been prioritized. The team tracks requests for new operations on +GitHub [issue #21526](https://github.com/tensorflow/tensorflow/issues/21526). +Leave a comment if your request hasn’t already been mentioned. + +In the meanwhile, you could try implementing a +[custom operator](ops_custom.md) or using a different model that only +contains supported operators. If binary size is not a constraint, try using +TensorFlow Lite with [select TensorFlow ops](ops_select.md). + +#### How do I test that a TensorFlow Lite model behaves the same as the original TensorFlow model? + +The best way to test the behavior of a TensorFlow Lite model is to use our API +with test data and compare the outputs to TensorFlow for the same inputs. Take a +look at our [Python Interpreter example](../convert/python_api.md) that generates +random data to feed to the interpreter. + +## Optimization + +#### How do I reduce the size of my converted TensorFlow Lite model? + +[Post-training quantization](../performance/post_training_quantization.md) can be +used during conversion to TensorFlow Lite to reduce the size of the model. +Post-training quantization quantizes weights to 8-bits of precision from +floating-point and dequantizes them during runtime to perform floating point +computations. However, note that this could have some accuracy implications. + +If retraining the model is an option, consider +[Quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize). +However, note that quantization-aware training is only available for a subset of +convolutional neural network architectures. + +For a deeper understanding of different optimization methods, look at +[Model optimization](../performance/model_optimization.md). + +#### How do I optimize TensorFlow Lite performance for my machine learning task? + +The high-level process to optimize TensorFlow Lite performance looks something +like this: + +* *Make sure that you have the right model for the task.* For image + classification, check out our [list of hosted models](hosted_models.md). +* *Tweak the number of threads.* Many TensorFlow Lite operators support + multi-threaded kernels. You can use `SetNumThreads()` in the + [C++ API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h#L345) + to do this. However, increasing threads results in performance variability + depending on the environment. +* *Use Hardware Accelerators.* TensorFlow Lite supports model acceleration for + specific hardware using delegates. For example, to use Android’s Neural + Networks API, call + [`UseNNAPI`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h#L343) + on the interpreter. Or take a look at our + [GPU delegate tutorial](../performance/gpu.md). +* *(Advanced) Profile Model.* The Tensorflow Lite + [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark) + has a built-in profiler that can show per-operator statistics. If you know + how you can optimize an operator’s performance for your specific platform, + you can implement a [custom operator](ops_custom.md). + +For a more in-depth discussion on how to optimize performance, take a look at +[Best Practices](../performance/best_practices.md). diff --git a/tensorflow/lite/g3doc/devguide.md b/tensorflow/lite/g3doc/guide/get_started.md similarity index 85% rename from tensorflow/lite/g3doc/devguide.md rename to tensorflow/lite/g3doc/guide/get_started.md index cbad036407fabea9d49910e22b4c968470566211..daa551ff7d65a14010eea74076f93a8ecccb78f5 100644 --- a/tensorflow/lite/g3doc/devguide.md +++ b/tensorflow/lite/g3doc/guide/get_started.md @@ -1,4 +1,4 @@ -# TF Lite Developer Guide +# Get started with TensorFlow Lite Using a TensorFlow Lite model in your mobile app requires multiple considerations: you must choose a pre-trained or custom model, convert the model @@ -35,7 +35,7 @@ by suggesting contextually relevant messages. The model is built specifically fo memory constrained devices, such as watches and phones, and has been successfully used in Smart Replies on Android Wear. Currently, this model is Android-specific. -These pre-trained models are [available for download](models.md). +These pre-trained models are [available for download](hosted_models.md). ### Re-train Inception-V3 or MobileNet for a custom data set @@ -54,7 +54,7 @@ both floating point and quantized inference. ### Train a custom model A developer may choose to train a custom model using Tensorflow (see the -[TensorFlow tutorials](../tutorials/) for examples of building and training +[TensorFlow tutorials](https://www.tensorflow.org/tutorials/) for examples of building and training models). If you have already written a model, the first step is to export this to a `tf.GraphDef` file. This is required because some formats do not store the model structure outside the code, and we must communicate with other parts of @@ -63,24 +63,24 @@ the framework. See to create file for the custom model. TensorFlow Lite currently supports a subset of TensorFlow operators. Refer to -the [TensorFlow Lite & TensorFlow Compatibility Guide](tf_ops_compatibility.md) +the [TensorFlow Lite & TensorFlow Compatibility Guide](ops_compatibility.md) for supported operators and their usage. This set of operators will continue to grow in future Tensorflow Lite releases. ## 2. Convert the model format -The [TensorFlow Lite Converter](convert/index.md) accepts the following file +The [TensorFlow Lite Converter](../convert/index.md) accepts the following file formats: * `SavedModel` — A `GraphDef` and checkpoint with a signature that labels input and output arguments to a model. See the documentation for converting - SavedModels using [Python](convert/python_api.md#basic_savedmodel) or using - the [command line](convert/cmdline_examples.md#savedmodel). + SavedModels using [Python](../convert/python_api.md#basic_savedmodel) or using + the [command line](../convert/cmdline_examples.md#savedmodel). * `tf.keras` - A HDF5 file containing a model with weights and input and output arguments generated by `tf.Keras`. See the documentation for converting HDF5 models using - [Python](convert/python_api.md#basic_keras_file) or using the - [command line](convert/cmdline_examples.md#keras). + [Python](../convert/python_api.md#basic_keras_file) or using the + [command line](../convert/cmdline_examples.md#keras). * `frozen tf.GraphDef` — A subclass of `tf.GraphDef` that does not contain variables. A `GraphDef` can be converted to a `frozen GraphDef` by taking a checkpoint and a `GraphDef`, and converting each variable into a constant @@ -92,7 +92,7 @@ formats: TensorFlow models may be saved as a .pb or .pbtxt `tf.GraphDef` file. In order to convert the `tf.GraphDef` file to TensorFlow Lite, the model must first be -frozen. This process invovles several file formats including the `frozen +frozen. This process involves several file formats including the `frozen GraphDef`: * `tf.GraphDef` (.pb or .pbtxt) — A protobuf that represents the TensorFlow @@ -154,19 +154,19 @@ the arguments for specifying the output nodes for inference in the ### Full converter reference -The [TensorFlow Lite Converter](convert/index.md) can be -[Python](convert/python_api.md) or from the -[command line](convert/cmdline_examples.md). This allows you to integrate the +The [TensorFlow Lite Converter](../convert/index.md) can be +[Python](../convert/python_api.md) or from the +[command line](../convert/cmdline_examples.md). This allows you to integrate the conversion step into the model design workflow, ensuring the model is easy to convert to a mobile inference graph. ### Ops compatibility -Refer to the [ops compatibility guide](tf_ops_compatibility.md) for +Refer to the [ops compatibility guide](ops_compatibility.md) for troubleshooting help, and if that doesn't help, please [file an issue](https://github.com/tensorflow/tensorflow/issues). -### Graph vizualization tool +### Graph Visualization tool The [development repo](https://github.com/tensorflow/tensorflow) contains a tool to visualize TensorFlow Lite models after conversion. To build the @@ -195,15 +195,15 @@ The open source Android demo app uses the JNI interface and is available [on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo/app). You can also download a [prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). -See the Android demo guide for details. +See the Android demo guide for details. -The Android mobile guide has instructions for +The Android mobile guide has instructions for installing TensorFlow on Android and setting up `bazel` and Android Studio. ### iOS To integrate a TensorFlow model in an iOS app, see the -[TensorFlow Lite for iOS](ios.md) guide and iOS demo +[TensorFlow Lite for iOS](ios.md) guide and iOS demo guide. #### Core ML support @@ -215,11 +215,14 @@ trained Tensorflow models to the devices. To use the converter, refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). -### Raspberry Pi +### ARM32 and ARM64 Linux Compile Tensorflow Lite for a Raspberry Pi by following the -[RPi build instructions](rpi.md) This compiles a static library file (`.a`) used -to build your app. There are plans for Python bindings and a demo app. +[RPi build instructions](build_rpi.md) Compile Tensorflow Lite for a generic aarch64 +board such as Odroid C2, Pine64, NanoPi, and others by following the +[ARM64 Linux build instructions](build_arm64.md) This compiles a static +library file (`.a`) used to build your app. There are plans for Python bindings +and a demo app. ## 4. Optimize your model (optional) @@ -250,7 +253,9 @@ tflite_quantized_model=converter.convert() open(“quantized_model.tflite”, “wb”).write(tflite_quantized_model) ``` -Read the full documentation [here](performance/post_training_quantization) and see a tutorial [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb). +Read the full documentation [here](../performance/post_training_quantization.md) +and see a tutorial +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb). ### GPU Run on GPU GPUs are designed to have high throughput for massively @@ -263,4 +268,4 @@ Another benefit with GPU inference is its power efficiency. GPUs carry out the computations in a very efficient and optimized manner, so that they consume less power and generate less heat than when the same task is run on CPUs. -Read the tutorial [here](performance/gpu) and full documentation [here](performance/gpu_advanced). +Read the tutorial [here](../performance/gpu.md) and full documentation [here](../performance/gpu_advanced.md). diff --git a/tensorflow/lite/g3doc/models.md b/tensorflow/lite/g3doc/guide/hosted_models.md similarity index 64% rename from tensorflow/lite/g3doc/models.md rename to tensorflow/lite/g3doc/guide/hosted_models.md index 62b3f17c79aa3688011a1452da18e098008f414e..69f196782eac1de7eff1c326693f93fbcb601b94 100644 --- a/tensorflow/lite/g3doc/models.md +++ b/tensorflow/lite/g3doc/guide/hosted_models.md @@ -1,64 +1,27 @@ +# Hosted models -# List of Hosted Models - -# AutoML mobile image classification models (Float Models) - -Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^ -------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------: -MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms -MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms -MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms -MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms -MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms -MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms -MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms -MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms - - -^ Performance numbers are generated on Pixel-1 using single thread large BIG core. - - -## Image classification (Float Models) - -Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance ---------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: -DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms -SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms -NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms -NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms -ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms -Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms -Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms -Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms -Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms -Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms -Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms -Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms -Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms -Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms -Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms -Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms -Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms -Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms -Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms -Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms -Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms -Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms -Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms -Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms -Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms | - -^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph. - -^^ The performance numbers are generated in the benchmark on Pixel-2 using -single thread large core. - -^^ Accuracy numbers were computed using the -[TFLite accuracy tool](../tools/accuracy/ilsvrc) . - -## Image classification (Quantized Models) - -Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance +The following is an incomplete list of pre-trained models optimized to work with +TensorFlow Lite. + +To get started choosing a model, visit Models. + +Note: The best model for a given application depends on your requirements. For +example, some applications might benefit from higher accuracy, while others +require a small model size. You should test your application with a variety of +models to find the optimal balance between size, performance, and accuracy. + +## Image classification + +For more information about image classification, see +Image classification. + +### Quantized models + +Quantized image +classification models offer the smallest model size and fastest performance, at +the expense of accuracy. + +Model name | Paper and model | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance --------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 42.8% | 68.1% | 5.5 ms @@ -82,9 +45,104 @@ Inception_V2_quant | [paper](https://arxiv.org/abs/1512.00567), [tflite Inception_V3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.7% | 637 ms Inception_V4_quant | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](http://download.tensorflow.org/models/inception_v4_299_quant_20181026.tgz) | 41 Mb | 79.5% | 93.9% | 1250.8 ms -## Other models +Note: The model files include both TF Lite FlatBuffer and Tensorflow frozen +Graph. + +Note: Performance numbers were benchmarked on Pixel-2 using single thread large +core. Accuracy numbers were computed using the +[TFLite accuracy tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/accuracy/ilsvrc). + +### Floating point models + +Floating point models offer the best accuracy, at the expense of model size and +performance. GPU acceleration requires the +use of floating point models. + +Model name | Paper and model | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance | Tensorflow performance +--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: | ---------------------: +DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms +SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms +NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms +NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms +ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms +Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms +Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms +Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms +Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms +Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms +Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms +Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms +Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms +Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms +Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms +Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms +Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms +Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms +Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms +Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms +Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms +Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms +Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms +Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms +Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms | + +### AutoML mobile models + +The following image classification models were created using +Cloud AutoML. + +Model Name | Paper and model | Model size | Top-1 accuracy | Top-5 accuracy | TF Lite performance +---------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: +MnasNet_0.50_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms +MnasNet_0.75_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms +MnasNet_1.0_96 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms +MnasNet_1.0_128 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms +MnasNet_1.0_160 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms +MnasNet_1.0_192 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms +MnasNet_1.0_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms +MnasNet_1.3_224 | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms + +Note: Performance numbers were benchmarked on Pixel-1 using single thread large +BIG core. + +## Object detection + +For more information about object detection, see +Object detection. + +The object detection model we currently host is +**coco_ssd_mobilenet_v1_1.0_quant_2018_06_29**. + +Download +model and labels + +## Pose estimation + +For more information about pose estimation, see +Pose estimation. + +The pose estimation model we currently host is +**multi_person_mobilenet_v1_075_float**. + +Download +model + +## Image segmentation + +For more information about image segmentation, see +Segmentation. + +The image segmentation model we currently host is **deeplabv3_257_mv_gpu**. + +Download +model + +## Smart reply + +For more information about smart reply, see +Smart reply. + +The smart reply model we currently host is **smartreply_1.0_2017_11_01**. -Model | TF Lite FlatBuffer ------------------------ | :----------------: -[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), -[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) +Download +model diff --git a/tensorflow/lite/g3doc/overview.md b/tensorflow/lite/g3doc/guide/index.md similarity index 98% rename from tensorflow/lite/g3doc/overview.md rename to tensorflow/lite/g3doc/guide/index.md index 2d747a9b59f734a007ef54d13223aed22f38cb1d..288f7a07576165959f79b6796ec52286fdc60123 100644 --- a/tensorflow/lite/g3doc/overview.md +++ b/tensorflow/lite/g3doc/guide/index.md @@ -1,5 +1,5 @@ -# Introduction to TensorFlow Lite +# TensorFlow Lite guide TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables on-device machine learning inference with low latency and a @@ -118,7 +118,7 @@ TensorFlow Lite provides: to all first-party and third-party apps. Also see the complete list of - [TensorFlow Lite's supported models](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md), + [TensorFlow Lite's supported models](hosted_models.md), including the model sizes, performance numbers, and downloadable model files. - Quantized versions of the MobileNet model, which runs faster than the diff --git a/tensorflow/lite/g3doc/apis.md b/tensorflow/lite/g3doc/guide/inference.md similarity index 74% rename from tensorflow/lite/g3doc/apis.md rename to tensorflow/lite/g3doc/guide/inference.md index 1a05142bc44b824e090fd6eb513360837eac2c69..b0107ece0b1d137ec75dc871bff4284d55cbf2cc 100644 --- a/tensorflow/lite/g3doc/apis.md +++ b/tensorflow/lite/g3doc/guide/inference.md @@ -1,77 +1,82 @@ -# TensorFlow Lite APIs +# TensorFlow Lite inference -TensorFlow Lite provides programming APIs in C++ and Java, and in both cases -the API design reflects a preference for performance over ease of use. -TensorFlow Lite is designed for fast inference on small devices so it should be -no surprise that the APIs try to avoid unnecessary copies at the expense of -convenience. Similarly, consistency with TensorFlow APIs was not an explicit -goal and some variance is to be expected. +[TOC] -There is also a [Python API for TensorFlow Lite](g3doc/convert/python_api.md). +## Overview -## C++ +TensorFlow Lite inference is the process of executing a TensorFlow Lite +model on-device and extracting meaningful results from it. Inference is the +final step in using the model on-device in the +[architecture](index.md#tensorflow_lite_architecture). -In order to run the inference model in TensorFlow Lite, one has to load the -model into a `FlatBufferModel` object which then can be executed by an -`Interpreter`. The `FlatBufferModel` needs to remain valid for the whole -lifetime of the `Interpreter`, and a single `FlatBufferModel` can be -simultaneously used by more than one `Interpreter`. In concrete terms, the -`FlatBufferModel` object must be created before any `Interpreter` objects that -use it, and must be kept around until they have all been destroyed. +Inference for TensorFlow Lite models is run through an interpreter. This +document outlines the various APIs for the interpreter along with the +[supported platforms](#supported-platforms). -The simplest usage of TensorFlow Lite will look like this: +### Important Concepts -```c++ -tflite::FlatBufferModel model(path_to_model); -tflite::ops::builtin::BuiltinOpResolver resolver; -std::unique_ptr interpreter; -tflite::InterpreterBuilder(*model, resolver)(&interpreter); -// Resize input tensors, if desired. -interpreter->AllocateTensors(); -float* input = interpreter->typed_input_tensor(0); -// Fill `input`. -interpreter->Invoke(); -float* output = interpreter->typed_output_tensor(0); -``` -### Data Alignment +TensorFlow Lite inference on device typically follows the following steps. -TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended -that all data provided to TensorFlow Lite be aligned that way. +1. **Loading a Model** -### Error Reporting + The user loads the `.tflite` model into memory which contains the model's + execution graph. -In many places TensorFlow Lite returns status information through -`TfLiteStatus` objects: +1. **Transforming Data** + Input data acquired by the user generally may not match the input data format + expected by the model. For eg., a user may need to resize an image or change + the image format to be used by the model. -```c++ -typedef enum { - kTfLiteOk = 0, - kTfLiteError = 1 -} TfLiteStatus; +1. **Running Inference** -``` + This step involves using the API to execute the model. It involves a few + steps such as building the interpreter, and allocating tensors as explained + in detail in [Running a Model](#running_a_model). -Failures can be easily verified with: +1. **Interpreting Output** -```c++ -if (status != kTfLiteOk) { - // ... error handling here ... -} -``` + The user retrieves results from model inference and interprets the tensors in + a meaningful way to be used in the application. -In order to obtain detailed error information an ErrorReporter must be -provided: + For example, a model may only return a list of probabilities. It is up to the + application developer to meaningully map them to relevant categories and + present it to their user. -```c++ -class ErrorReporter { - virtual int Report(const char* format, va_list args) = 0; -}; -``` +### Supported Platforms +TensorFlow inference APIs are provided for most common mobile/embedded platforms +such as Android, iOS and Linux. + +#### Android +On Android, TensorFlow Lite inference can be performed using either Java or C++ +APIs. The Java APIs provide convenience and can be used directly within your +Android Activity classes. The C++ APIs on the other hand may offer more +flexibility and speed, but may require writing JNI wrappers to move data between +Java and C++ layers. You can find an example [here](android.md). + +#### iOS +TensorFlow Lite provides Swift/Objective C++ APIs for inference on iOS. An +example can be found [here](ios.md). + +#### Linux +On Linux platforms such as [Raspberry Pi](build_rpi.md), TensorFlow Lite C++ +and Python APIs can be used to run inference. -The `DefaultErrorReporter` takes care of reporting to `stderr`. + +## API Guides + +TensorFlow Lite provides programming APIs in C++, Java and Python, with +experimental bindings for several other languages (C, Swift, Objective-C). In +most cases, the API design reflects a preference for performance over ease of +use. TensorFlow Lite is designed for fast inference on small devices so it +should be no surprise that the APIs try to avoid unnecessary copies at the +expense of convenience. Similarly, consistency with TensorFlow APIs was not an +explicit goal and some variance is to be expected. + +There is also a [Python API for TensorFlow Lite](../convert/python_api.md). ### Loading a Model +#### C++ The `FlatBufferModel` class encapsulates a model and can be built in a couple of slightly different ways depending on where the model is stored: @@ -92,11 +97,42 @@ class FlatBufferModel { }; ``` +```c++ +tflite::FlatBufferModel model(path_to_model); +``` + Note that if TensorFlow Lite detects the presence of Android's NNAPI it will automatically try to use shared memory to store the FlatBufferModel. -### Running a Model +#### Java + +TensorFlow Lite's Java API supports on-device inference and is provided as an +Android Studio Library that allows loading models, feeding inputs, and +retrieving inference outputs. + +The `Interpreter` class drives model inference with TensorFlow Lite. In +most of the cases, this is the only class an app developer will need. + +The `Interpreter` can be initialized with a model file using the constructor: +```java +public Interpreter(@NotNull File modelFile); +``` + +or with a `MappedByteBuffer`: + +```java +public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer); +``` + +In both cases a valid TensorFlow Lite model must be provided or an +`IllegalArgumentException` with be thrown. If a `MappedByteBuffer` is used to +initialize an Interpreter, it should remain unchanged for the whole lifetime of +the `Interpreter`. + +### Running a Model {#running_a_model} + +#### C++ Running a model involves a few simple steps: * Build an `Interpreter` based on an existing `FlatBufferModel` @@ -106,7 +142,7 @@ Running a model involves a few simple steps: * Read output tensor values The important parts of public interface of the `Interpreter` are provided -below. It should be noted that: +below. It should be noted that: * Tensors are represented by integers, in order to avoid string comparisons (and any fixed dependency on string libraries). @@ -114,41 +150,129 @@ below. It should be noted that: * Memory allocation for input and output tensors must be triggered by calling AllocateTensors() right after resizing tensors. +In order to run the inference model in TensorFlow Lite, one has to load the +model into a `FlatBufferModel` object which then can be executed by an +`Interpreter`. The `FlatBufferModel` needs to remain valid for the whole +lifetime of the `Interpreter`, and a single `FlatBufferModel` can be +simultaneously used by more than one `Interpreter`. In concrete terms, the +`FlatBufferModel` object must be created before any `Interpreter` objects that +use it, and must be kept around until they have all been destroyed. + +The simplest usage of TensorFlow Lite will look like this: + ```c++ -class Interpreter { - Interpreter(ErrorReporter* error_reporter); +tflite::FlatBufferModel model(path_to_model); - // Read only access to list of inputs. - const std::vector& inputs() const; +tflite::ops::builtin::BuiltinOpResolver resolver; +std::unique_ptr interpreter; +tflite::InterpreterBuilder(*model, resolver)(&interpreter); - // Read only access to list of outputs. - const std::vector& outputs() const; +// Resize input tensors, if desired. +interpreter->AllocateTensors(); + +float* input = interpreter->typed_input_tensor(0); +// Fill `input`. - // Change the dimensionality of a given tensor. - TfLiteStatus ResizeInputTensor(int tensor_index, - const std::vector& dims); +interpreter->Invoke(); - // Returns status of success or failure. - TfLiteStatus AllocateTensors(); +float* output = interpreter->typed_output_tensor(0); +``` - // Return a pointer into the data of a given input tensor. - template - T* typed_input_tensor(int index) { - return typed_tensor(inputs_[index]); - } +#### Java - // Return a pointer into the data of a given output tensor. - template - T* typed_output_tensor(int index) { - return typed_tensor(outputs_[index]); - } +The simplest usage of Tensorflow Lite Java API looks like this: - // Execute the model, populating output tensors. - TfLiteStatus Invoke(); -}; +```java +try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { + interpreter.run(input, output); +} ``` -### Writing Custom Operators +If a model takes only one input and returns only one output, the following will +trigger an inference run: + +```java +interpreter.run(input, output); +``` + +For models with multiple inputs, or multiple outputs, use: + +```java +interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); +``` + +where each entry in `inputs` corresponds to an input tensor and +`map_of_indices_to_outputs` maps indices of output tensors to the corresponding +output data. In both cases the tensor indices should correspond to the values +given to the +[TensorFlow Lite Optimized Converter](../convert/cmdline_examples.md) when the +model was created. Be aware that the order of tensors in `input` must match the +order given to the `TensorFlow Lite Optimized Converter`. + +The Java API also provides convenient functions for app developers to get the +index of any model input or output using a tensor name: + +```java +public int getInputIndex(String tensorName); +public int getOutputIndex(String tensorName); +``` + +If tensorName is not a valid name in model, an `IllegalArgumentException` will +be thrown. + +##### Releasing Resources After Use + +An `Interpreter` owns resources. To avoid memory leak, the resources must be +released after use by: + +```java +interpreter.close(); +``` + +##### Supported Data Types + +To use TensorFlow Lite, the data types of the input and output tensors must be +one of the following primitive types: + +* `float` +* `int` +* `long` +* `byte` + +`String` types are also supported, but they are encoded differently than the +primitive types. In particular, the shape of a string Tensor dictates the number +and arrangement of strings in the Tensor, with each element itself being a +variable length string. In this sense, the (byte) size of the Tensor cannot be +computed from the shape and type alone, and consequently strings cannot be +provided as a single, flat `ByteBuffer` argument. + +If other data types, including boxed types like `Integer` and `Float`, are used, +an `IllegalArgumentException` will be thrown. + +##### Inputs + +Each input should be an array or multi-dimensional array of the supported +primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is +an array or multi-dimensional array, the associated input tensor will be +implicitly resized to the array's dimensions at inference time. If the input is +a ByteBuffer, the caller should first manually resize the associated input +tensor (via `Interpreter.resizeInput()`) before running inference. + +When using 'ByteBuffer', prefer using direct byte buffers, as this allows the +`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte +buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a +model inference, it must remain unchanged until the model inference is finished. + +##### Outputs + +Each output should be an array or multi-dimensional array of the supported +primitive types, or a ByteBuffer of the appropriate size. Note that some models +have dynamic outputs, where the shape of output tensors can vary depending on +the input. There's no straightforward way of handling this with the existing +Java inference API, but planned extensions will make this possible. + + +## Writing Custom Operators All TensorFlow Lite operators (both custom and builtin) are defined using a simple pure-C interface that consists of four functions: @@ -251,127 +375,3 @@ be code-generated based on a given subset of ops, possibly only the ones contained in a given model. This is the equivalent of TensorFlow's selective registration (and a simple version of it is available in the `tools` directory). - -## Java - -TensorFlow Lite's Java API supports on-device inference and is provided as an -Android Studio Library that allows loading models, feeding inputs, and -retrieving inference outputs. - -The simplest usage of Tensorflow Lite Java API looks like this: - -```java -try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { - interpreter.run(input, output); -} -``` - -### Loading a Model - -The `Interpreter.java` class drives model inference with TensorFlow Lite. In -most of the cases, this is the only class an app developer will need. - -#### Initializing an `Interpreter` With a Model File - -The `Interpreter` can be initialized with a model file using the constructor: - -```java -public Interpreter(@NotNull File modelFile); -``` - -or with a `MappedByteBuffer`: - -```java -public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer); -``` - -In both cases a valid TensorFlow Lite must be provided or an -`IllegalArgumentException` with be thrown. If a `MappedByteBuffer` is used to -initialize an Interpreter, it should remain unchanged for the whole lifetime of -the `Interpreter`. - -### Running a Model - -#### Supported Data Types - -To use TensorFlow Lite, the data types of the input and output tensors must be -one of the following primitive types: - -* `float` -* `int` -* `long` -* `byte` - -`String` types are also supported, but they are encoded differently than the -primitive types. In particular, the shape of a string Tensor dictates the number -and arrangement of strings in the Tensor, with each element itself being a -variable length string. In this sense, the (byte) size of the Tensor cannot be -computed from the shape and type alone, and consequently strings cannot be -provided as a single, flat `ByteBuffer` argument. - -If other data types, including boxed types like `Integer` and `Float`, are used, -an `IllegalArgumentException` will be thrown. - -#### Inputs - -Each input should be an array or multi-dimensional array of the supported -primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is -an array or multi-dimensional array, the associated input tensor will be -implicitly resized to the array's dimensions at inference time. If the input is -a ByteBuffer, the caller should first manually resize the associated input -tensor (via `Interpreter.resizeInput()`) before running inference. - -When using 'ByteBuffer', prefer using direct byte buffers, as this allows the -`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte -buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a -model inference, it must remain unchanged until the model inference is finished. - -#### Outputs - -Each output should be an array or multi-dimensional array of the supported -primitive types, or a ByteBuffer of the appropriate size. Note that some models -have dynamic outputs, where the shape of output tensors can vary depending on -the input. There's no straightforward way of handling this with the existing -Java inference API, but planned extensions will make this possible. - -#### Running Model Inference - -If a model takes only one input and returns only one output, the following will -trigger an inference run: - -```java -interpreter.run(input, output); -``` - -For models with multiple inputs, or multiple outputs, use: - -```java -interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); -``` - -where each entry in `inputs` corresponds to an input tensor and -`map_of_indices_to_outputs` maps indices of output tensors to the corresponding -output data. In both cases the tensor indices should correspond to the values -given to the [TensorFlow Lite Optimized Converter](convert/cmdline_examples.md) -when the model was created. Be aware that the order of tensors in `input` must -match the order given to the `TensorFlow Lite Optimized Converter`. - -The Java API also provides convenient functions for app developers to get the -index of any model input or output using a tensor name: - -```java -public int getInputIndex(String tensorName); -public int getOutputIndex(String tensorName); -``` - -If tensorName is not a valid name in model, an `IllegalArgumentException` will -be thrown. - -### Releasing Resources After Use - -An `Interpreter` owns resources. To avoid memory leak, the resources must be -released after use by: - -```java -interpreter.close(); -``` diff --git a/tensorflow/lite/g3doc/guide/ios.md b/tensorflow/lite/g3doc/guide/ios.md new file mode 100644 index 0000000000000000000000000000000000000000..3565ce71df3095d24352f564bef2714a3582adae --- /dev/null +++ b/tensorflow/lite/g3doc/guide/ios.md @@ -0,0 +1,229 @@ +# iOS quickstart + +This tutorial provides a simple iOS mobile application to classify images using +the iOS device camera. In this tutorial, you will download the demo application +from the Tensorflow repository, build it on your computer, and install it on +your iOS Device. You will also learn how to customize the application to suit +your requirements. + +## Prerequisites + +* You must have [Xcode](https://developer.apple.com/xcode/) installed and have + a valid Apple Developer ID, and have an iOS device set up and linked to your + developer account with all of the appropriate certificates. For these + instructions, we assume that you have already been able to build and deploy + an app to an iOS device with your current developer environment. + +* The demo app requires a camera and must be executed on a real iOS device. + You can build it and run with the iPhone Simulator but it won't have any + camera information to classify. + +* You don't need to build the entire TensorFlow library to run the demo, but + you will need to clone the TensorFlow repository if you haven't already: + + git clone https://github.com/tensorflow/tensorflow + cd tensorflow + +* You'll also need the Xcode command-line tools: + + xcode-select --install + + If this is a new install, you will need to run the Xcode application once to + agree to the license before continuing. + +* Install CocoaPods if you don't have it: + + sudo gem install cocoapods + +### Step 1. Clone the TensorFlow source code + +First, we clone the GitHub repository on the computer in a folder to get the +demo application. + +``` +git clone https://github.com/tensorflow/tensorflow +``` + +### Step 2. Download required dependencies + +Execute the shell script to download the model files used by the demo app (this +is done from inside the cloned directory): + +``` + tensorflow/lite/examples/ios/download_models.sh +``` + +Run the following command to install TensorFlow Lite pod: + +``` + cd tensorflow/lite/examples/ios/camera + pod install +``` + +If you have installed this pod before and that command doesn't work, try + +``` + pod repo update +``` + +### Step 3. Build the XCode project + +Open the `tflite_camera_example.xcworkspace` project file generated in the last +step: + +``` + open tflite_camera_example.xcworkspace +``` + +Under `Project navigator -> tflite_camera_example -> Targets -> +tflite_camera_example -> General` change the bundle identifier by pre-pending +your name: + +![pre-pend your name to the bundle identifier](../images/ios/bundle_identifier.png) + +Plug in your iOS device. Note the app must be executed with a real device with +camera. Select the iOS device from the drop-down menu. + +![Device selection](../images/ios/device_selection.png) + +Click the "Run" button to build and run the app + +![Build and execute](../images/ios/build_and_execute.png) + +Note that as mentioned earlier, you must already have a device set up and linked +to your Apple Developer account in order to deploy the app on a device. + +You'll have to grant permissions for the app to use the device's camera. Point +the camera at various objects and enjoy seeing how the model classifies things! + +## Understanding iOS App Code + +### Get camera input + +The main logic of this app is in the Objective C++ source file +`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. + +The `setupAVCapture` method constructs a `AVCaptureSession` and set itself as a +delegate. The `captureOutput:didOutputSampleBuffer:fromConnection:` method is +called for every captured frame. It calls `runModelOnFrame` to run the model for +every frame. + +### Create an interpreter + +To create the interpreter, we need to load the model file. The following code +will load a model and create an interpreter. + +``` +model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); +``` + +Behind the scenes, the model is loaded as a memory-mapped file. It offers faster +load times and reduce the dirty pages in memory. + +Construct a `BuiltinOpResolver` to use the TensorFlow Lite buildin ops. Then, +create the interpreter object using `InterpreterBuilder` that takes the model +file as argument as shown below. + +``` +tflite::ops::builtin::BuiltinOpResolver resolver; +tflite::InterpreterBuilder(*model, resolver)(&interpreter); +``` + +### Obtain the input buffer + +By default, the app uses quantized model since it's smaller and faster. The +buffer is a raw pointer to an array of 8 bit unsigned integers (`uint8_t`). The +following code obtains the input buffer from the interpreter: + +``` +// Get the index of first input tensor. +int input_tensor_index = interpreter->inputs()[0]; +// Get the pointer to the input buffer. +uint8_t* buffer = interpreter->typed_tensor(input_tensor_index); +``` + +Throughout this document, it's assumed a quantized model is used. + +### Pre-process of bitmap image + +The MobileNet model we're using takes 224x224x3 inputs, where the dimensions are +width, height, and colors (RGB). The images returned from `AVCaptureSession` is +bigger, and has 4 color channels (RGBA). + +Many image classification models (like MobileNet) take fixe-sized inputs. It's +required to scale or crop the image before feeding it into the model, and change +the channels from RGBA to RGB. + +The code to pre-process the images is in `ProcessInputWithQuantizedModel` +function in +`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. It's a +simple implementation for nearest neighbor color sampling, and it only copies +the first 3 bytes for each pixel. + +``` +void ProcessInputWithQuantizedModel( + uint8_t* input, uint8_t* output, int image_width, int image_height, int image_channels) { + for (int y = 0; y < wanted_input_height; ++y) { + uint8_t* out_row = output + (y * wanted_input_width * wanted_input_channels); + for (int x = 0; x < wanted_input_width; ++x) { + const int in_x = (y * image_width) / wanted_input_width; + const int in_y = (x * image_height) / wanted_input_height; + uint8_t* in_pixel = input + (in_y * image_width * image_channels) + (in_x * image_channels); + uint8_t* out_pixel = out_row + (x * wanted_input_channels); + for (int c = 0; c < wanted_input_channels; ++c) { + out_pixel[c] = in_pixel[c]; + } + } + } +} +``` + +Note the code is preprocessing and preparing the model input from the camera +data. Therefore the first parameter `input` should be the camera buffer. The +second parameter `output` should be the buffer of model input. + +### Run inference and obtain output buffer + +After preprocessing and filling the data into the input buffer of the +interpreter, it's really easy to run the interpreter: + +``` +if (interpreter->Invoke() != kTfLiteOk) { + NSLog("Failed to invoke!"); +} +``` + +The result is stored in the output tensor buffer of the interpreter. The +following code obtains the pointer to the buffer: + +``` +// Get the index of first output tensor. +const int output_tensor_index = interpreter->outputs()[0]; +// Get the pointer to the output buffer. +uint8_t* buffer = interpreter->typed_tensor(output_tensor_index); +``` + +### Post-process values + +The output buffer contains an array of `uint8_t`, and the value range is 0-255. +We need to convert the value to float to get the probabilities with value range +0.0-1.0. The formula of the quantization value mapping is: + + float_value = (quantized_value - zero_point) * scale + +The following code converts quantized values back to float values, using the +quantizaiton parameters in tensors: + +``` +uint8_t* quantized_output = interpreter->typed_output_tensor(0); +int32_t zero_point = input_tensor->params.zero_point; +float scale = input_tensor->params.scale; +float output[output_size]; +for (int i = 0; i < output_size; ++i) { + output[i] = (quantized_output[i] - zero_point) * scale; +} +``` + +Finally, we find the best set of classifications by storing them in a priority +queue based on their confidence scores. See the `GetTopN` function in +`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. diff --git a/tensorflow/lite/g3doc/guide/microcontroller.md b/tensorflow/lite/g3doc/guide/microcontroller.md new file mode 100644 index 0000000000000000000000000000000000000000..635132004ee45fd9f566a970641bc624b85351bb --- /dev/null +++ b/tensorflow/lite/g3doc/guide/microcontroller.md @@ -0,0 +1,21 @@ +# Microcontrollers + +## Overview + +Microcontrollers are compact integrated circuits with very limited resources. Currently, they only perform simple functions. + +With the onset of TensorFlow Lite, hence smaller binary sizes, these devices will be able to support machine learning applications, opening the industry up to a myriad of use cases. + +## Getting started + +Note: This is an experimental release aimed at microcontrollers and other devices with only kilobytes of memory. It doesn't require any operating system support, any standard C or C++ libraries, or dynamic memory allocation, so it's designed to be portable even to 'bare metal' systems. + +One of the challenges of embedded software development is that there are a lot of different architectures, devices, operating systems, and build systems. We aim to support as many of the popular combinations as we can and make it as easy as possible to add support for others. + +Read more about [how to get started](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro#getting-started). + +## Goals + +The design goals are to make the framework readable, easy to modify, well-tested, easy to integrate, and compatible (e.g. consistent file schema, interpreter, API, kernel interface). + +Read more about [goals and tradeoffs](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro#goals). diff --git a/tensorflow/lite/g3doc/tf_ops_compatibility.md b/tensorflow/lite/g3doc/guide/ops_compatibility.md similarity index 90% rename from tensorflow/lite/g3doc/tf_ops_compatibility.md rename to tensorflow/lite/g3doc/guide/ops_compatibility.md index 8b85c08a32588e210889b8fe2c8e4e2a11223295..000cca5873c5b299ce144aa95b5067ed8f7df64e 100644 --- a/tensorflow/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/lite/g3doc/guide/ops_compatibility.md @@ -1,4 +1,4 @@ -# TensorFlow Lite & TensorFlow Compatibility Guide +# TensorFlow Lite and TensorFlow operator compatibility TensorFlow Lite supports a number of TensorFlow operations used in common inference models. As they are processed by the TensorFlow Lite Optimizing @@ -8,7 +8,9 @@ operations are mapped to their TensorFlow Lite counterparts. Since the set of TensorFlow Lite operations is smaller than TensorFlow's, not every model is convertible. Even for supported operations, very specific usage patterns are sometimes expected, for performance reasons. We expect to expand -the set of supported operations in future TensorFlow Lite releases. +the set of supported operations in future TensorFlow Lite releases. Additional +ops can be included by [using select TensorFlow ops](ops_select.md), at the cost +of binary size. The best way to understand how to build a TensorFlow model that can be used with TensorFlow Lite is to carefully consider how operations are converted and @@ -17,15 +19,15 @@ optimized, along with the limitations imposed by this process. ## Supported Types Most TensorFlow Lite operations target both floating-point (float32) and -quantized (uint8) inference, but usually there is little or no support for other -types like tf.float16 and strings. +quantized (uint8, int8) inference, but many ops do not yet for other types like +tf.float16 and strings. Apart from using different version of the operations, the other difference between floating-point and quantized models lies in the way they are converted. -Quantized conversion expect the models to be annotated with "fake quantization" -nodes that record the dynamic range of the tensors. Without that information TF -Lite is not able to accurately quantize a model, which means that proper -quantized training is necessary before conversion. +Quantized conversion requires dynamic range information for tensors. This +requires "fake-quantization" during model training, getting range information +via a calibration data set, or doing "on-the-fly" range estimation. See +[quantization](../performance/model_optimization.md). ## Data Format and Broadcasting @@ -89,7 +91,8 @@ be simply removed from the graph (tf.identity), replaced by tensors some supported operations may sometimes be removed through one of these processes. -Here is a list of TensorFlow operations that are usually removed from the graph: +Here is a non-exhaustive list of TensorFlow operations that are usually removed +from the graph: * [tf.add](https://www.tensorflow.org/api_docs/python/tf/add) * [tf.check_numerics](https://www.tensorflow.org/api_docs/python/tf/check_numerics) @@ -165,6 +168,17 @@ Options { } ``` +**ADD_N** + +``` +Inputs { + 0-N: any number of tensors (must have same size and shape) +} +Outputs { + 0: elementwise sum of the input tensors +} +``` + **ARG_MAX** ``` @@ -291,6 +305,17 @@ Options { } ``` +**ELU** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a tensor equivalent to exp(features) - 1 if < 0, features otherwise. +} +``` + **EQUAL** ``` @@ -362,6 +387,17 @@ Outputs { } ``` +**CEIL** + +``` +inputs { + 0: tensor +} +outputs: { + 0: result of computing element-wise ceil of the input tensor +} +``` + **FULLY_CONNECTED** ``` @@ -392,6 +428,18 @@ Outputs { } ``` +**GATHER_ND** + +``` +Inputs { + 0: params tensor + 1: indices tensor +} +Outputs { + 0: a tensor with same type as the params tensor. +} +``` + **GREATER** ``` @@ -463,20 +511,6 @@ Options { } ``` -**LEAKY_RELU** - -``` -Inputs { - 0: a tensor -} -Outputs { - 0: a tensor equivalent to max(input, input * alpha) -} -Options { - alpha -} -``` - **LESS** ``` @@ -691,6 +725,17 @@ Options { } ``` +**RANK** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a 0-D int32 Tensor representing the rank of input +} +``` + **RELU** ``` @@ -747,7 +792,7 @@ Inputs { 1: a 1D tensor with 2 elements } Outputs { - 0: A tensor of type `tensor 0` resized according to `tensor 1` heigh/width values + 0: A tensor of type `tensor 0` resized according to `tensor 1` height/width values using nearest neighbors interpolation. } Options { @@ -766,6 +811,25 @@ Outputs { } ``` +**REVERSE_SEQUENCE** + +``` +Inputs { + 0: a tensor + 1: a 1-D tensor which specifies the length of sequence to be reversed in each + dim +} +Outputs { + 0: a tensor with the same shape as the input tensor +} +Options { + seq_dim: a 0-D int tensor (scalar). The dimension which is partially + reversed. + batch_dim: a 0-D int tensor (scalar). Defaults to 0. The dimension along + which reversal is performed. +} +``` + **SHAPE** ``` @@ -978,6 +1042,22 @@ Outputs { } ``` +**WHERE** + +``` +Inputs { + 0: A tensor of type bool. + 1: A tensor which may have the same shape as condition. If condition is rank + 1, x may have higher rank, but its first dimension must match the size of + condition. + 2: A tensor with the same shape and type as x. +} +Outputs { + 0: A tensor with the same type and shape as x, y if they are non-None, or + a tensor with shape (num_true, dim_size(condition)). +} +``` + **ZEROS_LIKE** ``` diff --git a/tensorflow/lite/g3doc/custom_operators.md b/tensorflow/lite/g3doc/guide/ops_custom.md similarity index 57% rename from tensorflow/lite/g3doc/custom_operators.md rename to tensorflow/lite/g3doc/guide/ops_custom.md index 2d80668f37d645054596e1150f1eee6249122f75..1365f449b9c1ca9b86de8051c8866c4cd6c467d2 100644 --- a/tensorflow/lite/g3doc/custom_operators.md +++ b/tensorflow/lite/g3doc/guide/ops_custom.md @@ -1,13 +1,37 @@ +# Custom operators -# How to use custom operators +TensorFlow Lite currently supports a subset of TensorFlow operators. It supports +the use of user-provided implementations (as known as custom implementations) if +the model contains an operator that is not supported. Providing custom kernels +is also a way of evaluating a series of TensorFlow operations as a single fused +TensorFlow Lite operations. -TensorFlow Lite currently supports a subset of TensorFlow operators. However, it -does support the use of user-provided implementations (as known as custom -implementations) if the model contains an operator that is not supported. +Using custom operators consists of three steps. -Let’s walk through this via an example. Assume we are using the `Sin` operator -and that we are building a very simple model for a function `y = sin(x + -offset)`, where `offset` is trainable. +* Making sure the TensorFlow Graph Def or SavedModel refers to the correctly + named TensorFlow Lite operator. + +* Registering a custom kernel with TensorFlow Lite so that the runtime knows + how to map your operator and parameters in your graph to executable C/C++ + code. + +* Testing and profiling your operator correctness and performance, + respectively. If you wish to test just your custom operator it is best to + create a model with just your custom operator and using the + [benchmark_model](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/benchmark/benchmark_model_test.cc) + proggram + +Below we describe a complete example of defining Sin and some links to existing +conversion process involving custom operators. + +## Making a custom operator for Sin + +Let’s walk through this an example of supporting a TensorFlow operator that +TensorFlow Lite does not have. Assume we are using the `Sin` operator and that +we are building a very simple model for a function `y = sin(x + offset)`, where +`offset` is trainable. + +### Generating the model from TensorFlow The code to train the TensorFlow model will be something like: @@ -30,6 +54,8 @@ Didn't find custom op for name 'Sin' Registration failed. ``` +### Defining the kernel in the TensorFlow Lite runtime + All we need to do to use the op in TensorFlow Lite is define two functions (`Prepare` and `Eval`), and construct a `TfLiteRegistration`. This code would look something like this: @@ -79,19 +105,34 @@ TfLiteRegistration* Register_SIN() { } ``` -When initializing the OpResolver, add the custom op into the resolver, this will -register the operator with Tensorflow Lite so that TensorFlow Lite can use the -new implementation. +When initializing the `OpResolver`, add the custom op into the resolver, this +will register the operator with Tensorflow Lite so that TensorFlow Lite can use +the new implementation. Note that the last two arguments in TfLiteRegistration +correspond to the `SinPrepare` and `SinEval()` functions you defined for the +custom op. If you used two functions to initialize variables used in the op and +free up space: `Init()` and `Free()`, then they would be added to the first two +arguments of TfLiteRegistration; they are set to nullptr in this example. ```cpp tflite::ops::builtin::BuiltinOpResolver builtins; builtins.AddCustom("Sin", Register_SIN()); ``` +If you want to make your custom operators in Java, you would currently need to +build your own custom JNI layer and compile your own AAR +[in this jni code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc). +Similarly, if you wish to make these operators available in Python you can place +your registrations in the +[Python wrapper code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc). + Note that a similar process as above can be followed for supporting for a set of -operations instead of a single operator. +operations instead of a single operator. Just add as many `AddCustom` operators +as you need. In addition, `BuiltinOpResolver` also allows you to override +implementations of builtins by using the `AddBuiltin`. + +## Best Practices -## Best Practices for writing custom operators +### Writing TensorFlow Lite kernels best practices 1. Optimize memory allocations and de-allocations cautiously. It is more efficient to allocate memory in Prepare() instead of Invoke(), and allocate @@ -135,13 +176,35 @@ operations instead of a single operator. Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e., these should be done before any resources are allocated that will leak. -## Special TF Graph Attributes +### Conversion best practices + +The example above was easy to convert since it was a builtin operator in +TensorFlow. If you are defining a new operator that fuses many operators or you +have complicated shapes or types, you might need to provide more information and +use graph transformations to rewrite an existing graph to use your operator +instead of the builtin TensorFlow one. + +#### Converting TensorFlow models to convert graphs + +In TensorFlow you can use the `tf.lite.OpHint` class to encapsulate groups of +operators when you create a TensorFlow graph. This allows you then to extract a +graph def that has references to those operators. This is currently experimental +and should only be used by advanced users. There is a full example of how to use +this in the +[OpHint code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/op_hint.py). + +In addition, you can also use a manual graph substitution approach to rewrite +Tensorflow graphs. There is an example of how this is done in single shot object +based detection models +[export script](https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py). + +### TF Graph Attributes When `tflite_convert` converts a TensorFlow graph into TFLite format, it makes -some assumption about custom operations that might be not correct. In this case, +some assumption about custom operations that might not be correct. In this case, the generated graph may not execute. -It is possible to add aditional information about your custom op output to TF +It is possible to add additional information about your custom op output to TF graph before it is converted. The following attributes are supported: - **_output_quantized** a boolean attribute, true if the operation outputs are @@ -149,7 +212,7 @@ graph before it is converted. The following attributes are supported: - **_output_types** a list of types for output tensors - **_output_shapes** a list of shapes for output tensors -### Setting the Attributes +#### Setting the Attributes This is an example how the attributes can be set: diff --git a/tensorflow/lite/g3doc/using_select_tf_ops.md b/tensorflow/lite/g3doc/guide/ops_select.md similarity index 92% rename from tensorflow/lite/g3doc/using_select_tf_ops.md rename to tensorflow/lite/g3doc/guide/ops_select.md index 269774a4b10648f92aab5ee6bf5ae3687c263f75..51320140e1c085f3b44238a3586efe92819dc170 100644 --- a/tensorflow/lite/g3doc/using_select_tf_ops.md +++ b/tensorflow/lite/g3doc/guide/ops_select.md @@ -1,4 +1,6 @@ -# [Experimental] Using TensorFlow Lite with select TensorFlow ops +# Select TensorFlow operators to use in TensorFlow Lite + +Caution: This feature is experimental. The TensorFlow Lite builtin op library has grown rapidly, and will continue to grow, but there remains a long tail of TensorFlow ops that are not yet natively @@ -13,7 +15,7 @@ please send feedback about models that work and issues you are facing to tflite@tensorflow.org. TensorFlow Lite will continue to have -[TensorFlow Lite builtin ops](tf_ops_compatibility.md) optimized for mobile and +[TensorFlow Lite builtin ops](ops_compatibility.md) optimized for mobile and embedded devices. However, TensorFlow Lite models can now use a subset of TensorFlow ops when TFLite builtin ops are not sufficient. @@ -32,7 +34,7 @@ choice. It also discusses some [known limitations](#known-limitations), the To convert a TensorFlow model to a TensorFlow Lite model with TensorFlow ops, use the `target_ops` argument in the -[TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/). The +[TensorFlow Lite converter](../convert/). The following values are valid options for `target_ops`: * `TFLITE_BUILTINS` - Converts models using TensorFlow Lite builtin ops. @@ -49,8 +51,7 @@ partially supported by TensorFlow Lite, and one would like to avoid those limitations. The following example shows how to use `target_ops` in the -[`TFLiteConverter`](https://www.tensorflow.org/lite/convert/python_api) Python -API. +[`TFLiteConverter`](./convert/python_api.md) Python API. ``` import tensorflow as tf @@ -63,7 +64,7 @@ open("converted_model.tflite", "wb").write(tflite_model) ``` The following example shows how to use `target_ops` in the -[`tflite_convert`](https://www.tensorflow.org/lite/convert/cmdline_examples) +[`tflite_convert`](../convert/cmdline_examples.md) command line tool. ``` @@ -96,7 +97,7 @@ includes the necessary library of TensorFlow ops. ### Android AAR A new Android AAR target with select TensorFlow ops has been added for -convenience. Assuming a working TensorFlow Lite +convenience. Assuming a working TensorFlow Lite build environment, build the Android AAR with select TensorFlow ops as follows: @@ -151,8 +152,8 @@ TensorFlow Lite XCode project with support for select TensorFlow ops has been added to `tensorflow/lite/examples/ios/camera/tflite_camera_example_with_select_tf_ops.xcodeproj`. -To use this feature in a your own project, either clone the example project or -set the project settings for a new or existing project to the following: +To use this feature in your own project, either clone the example project or set +the project settings for a new or existing project to the following: * In Build Phases -> Link Binary With Libraries, add the static libraries under `tensorflow/contrib/makefile/gen/lib/` directory: @@ -197,9 +198,7 @@ Python support is actively under development. When using a mixture of both builtin and select TensorFlow ops, all of the same TensorFlow Lite optimizations and optimized builtin kernels will be be available -and usable with the converted model. For the TensorFlow ops, performance should -generally be comparable to that of -[TensorFlow Mobile](https://www.tensorflow.org/lite/tfmobile/). +and usable with the converted model. The following table describes the average time taken to run inference on MobileNet on a Pixel 2. The listed times are an average of 100 runs. These diff --git a/tensorflow/lite/g3doc/ops_versioning.md b/tensorflow/lite/g3doc/guide/ops_version.md similarity index 99% rename from tensorflow/lite/g3doc/ops_versioning.md rename to tensorflow/lite/g3doc/guide/ops_version.md index 0d571ce54779547a5e3457b089b791abca858930..9418ce4e92af3f43378181eaa836461edfed987a 100644 --- a/tensorflow/lite/g3doc/ops_versioning.md +++ b/tensorflow/lite/g3doc/guide/ops_version.md @@ -1,5 +1,4 @@ - -# TensorFlow Lite Ops Versioning +# TensorFlow Lite operator versions This document describes TensorFlow Lite's op versioning schema. Op versioning enables developers to add new functionalities and parameters into diff --git a/tensorflow/lite/g3doc/guide/roadmap.md b/tensorflow/lite/g3doc/guide/roadmap.md new file mode 100644 index 0000000000000000000000000000000000000000..a93591813c3e5e58b7d65b27f2d3f40b0e9e723a --- /dev/null +++ b/tensorflow/lite/g3doc/guide/roadmap.md @@ -0,0 +1,77 @@ +# TensorFlow Lite 2019 Roadmap + +**Updated: March 6th, 2019** + +The following represents a high level overview of our 2019 plan. You should be +conscious that this roadmap may change at anytime relative to a range of factors +and the order below does not reflect any type of priority. As a matter of +principle, we typically prioritize issues that the majority of our users are +asking for and so this list fundamentally reflects that. + +We break our roadmap into four key segments: usability, performance, +optimization and portability. We strongly encourage you to comment on our +roadmap and provide us feedback in the TF Lite discussion groups and forums. + +## Usability + +* **More ops coverage** + * Prioritize many more ops based on user feedback +* **Op versioning & signatures** + * Op kernels will get version numbers + * Op kernels will be identifiable by signature +* **New Convertor** + * Implementing a new TensorFlow Lite convertor that will better handle + graph conversion (i.e. control flow, conditionals etc) and replace TOCO +* **Continue to improve TF Select Ops** + * Support more types of conversion utilizing TF Selects such as hash + tables, strings etc. + * Support smaller binary size when using select TF ops via op stripping +* **LSTM / RNN support** + * Add full support of conversion for LSTMs and RNNs +* **Graph Visualization Tooling** + * Provide enhanced graph visualization tooling +* **Pre-and-post processing support** + * Add more support for pre-and-post processing of inference +* **Control Flow & Training on-device** + * Add support for control flow related ops + * Add support for training on-device +* **New APIs** + * New C API as core for language bindings and most clients + * Objective-C API for iOS + * SWIFT API for iOS + * Updated Java API for Android + * C# Unity language bindings +* **Add more Models** + * Add more models to the support section of the site + +## Performance + +* **More hardware delegates** + * Add support for more hardware delegates +* **Support NN API** + * Continually support and improve support for NN API +* **Framework Extensibility** + * Enable simplistic overwriting of CPU kernels with customized optimized + versions +* **GPU Delegate** + * Continue to extend the total support ops for OpenGL and Metal ops + * Open-source +* **Improve TFLite CPU performance** + * Optimizations for float and quantized models + +## Optimization + +* **Model Optimization Toolkit** + * Post training quantization + hybrid kernels + * Post Training quantization + fixed-point kernels + * Training with quantization +* **More support for more techniques** + * RNN Support + * Sparsity/Pruning + * Lower bit-width support + +## Portability + +* **Microcontroller Support** + * Add support for a range of 8-bit, 16-bit and 32-bit MCU architecture use + cases for Speech and Image Classification diff --git a/tensorflow/lite/g3doc/images/ios/build_and_execute.png b/tensorflow/lite/g3doc/images/ios/build_and_execute.png new file mode 100644 index 0000000000000000000000000000000000000000..a305350cb7b413b70d2095cfd8fea6d396c63695 Binary files /dev/null and b/tensorflow/lite/g3doc/images/ios/build_and_execute.png differ diff --git a/tensorflow/lite/g3doc/images/ios/device_selection.png b/tensorflow/lite/g3doc/images/ios/device_selection.png new file mode 100644 index 0000000000000000000000000000000000000000..1565fa0f2b65d48f4308ba50caacf745125c7431 Binary files /dev/null and b/tensorflow/lite/g3doc/images/ios/device_selection.png differ diff --git a/tensorflow/lite/g3doc/images/performance/tflite_delegate_graph_1.png b/tensorflow/lite/g3doc/images/performance/tflite_delegate_graph_1.png new file mode 100644 index 0000000000000000000000000000000000000000..29b883a40f3a18f3db183887dda253a5a86d7c13 Binary files /dev/null and b/tensorflow/lite/g3doc/images/performance/tflite_delegate_graph_1.png differ diff --git a/tensorflow/lite/g3doc/images/performance/tflite_delegate_graph_2.png b/tensorflow/lite/g3doc/images/performance/tflite_delegate_graph_2.png new file mode 100644 index 0000000000000000000000000000000000000000..197fa216e29a7671b35dc57fbd517c4c2b543784 Binary files /dev/null and b/tensorflow/lite/g3doc/images/performance/tflite_delegate_graph_2.png differ diff --git a/tensorflow/lite/g3doc/models/_index.yaml b/tensorflow/lite/g3doc/models/_index.yaml deleted file mode 100644 index f4d8bc40a9325b12734022e005996e13dba0a0d6..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/models/_index.yaml +++ /dev/null @@ -1,125 +0,0 @@ -project_path: /lite/_project.yaml -book_path: /lite/_book.yaml -description: -landing_page: - body_class: tfo-hide-page-nav - custom_css_path: /site-assets/css/style.css - show_side_navs: true - rows: - - # Hero - - classname: > - devsite-landing-row-50 - devsite-landing-row-large-headings - devsite-landing-row-no-image-background - foreground: theme - items: - - heading: Models marketplace - description: > - The TensorFlow Lite models marketplace, your neighborhood model shoppe. - image_path: /resources/images/tflite-card-16x9.png - - # Features - - background: grey - items: - - heading: Optimized for mobile - description: > - Machine learning can make your apps more engaging, personalized, and - helpful, and provides solutions that are optimized to run on-device. - - heading: Built with Google expertise - description: > - Models offer the technologies that have long powered Google's own - experiences on mobile. - - heading: Approachable and comprehensive - description: > - Use out-of-the-box solutions (base APIs) or custom models, running - on-device or in the Cloud, depending on your specific needs. - - # Beginner models - - classname: devsite-landing-row-100 - heading: "Build machine learning into your apps" - items: - - heading: > - Image labeling - description: > - Identify objects, locations, activities, animal species, products, and - more - icon: - path: ../images/landing-page/assistant_logo.png - path: /lite/image/labeling/ - - heading: > - Text recognition (OCR) - description: > - Recognize and extract text from images - icon: - path: ../images/landing-page/assistant_logo.png - path: /lite/image/labeling/ - - heading: > - Face detection - description: > - Detect faces and facial landmarks - icon: - path: ../images/landing-page/assistant_logo.png - path: /lite/image/labeling/ - - - items: - - heading: > - Barcode scanning - description: > - Scan and process barcodes - icon: - path: ../images/landing-page/assistant_logo.png - path: /lite/image/labeling/ - - heading: > - Landmark detection - description: > - Identify popular landmarks in an image - icon: - path: ../images/landing-page/assistant_logo.png - path: /lite/image/labeling/ - - heading: > - Smart reply - description: > - Provide suggested text snippet that fits context - icon: - path: ../images/landing-page/assistant_logo.png - path: /lite/image/labeling/ - - # Custom models - - classname: > - devsite-landing-row-no-image-background - devsite-landing-row-50 - devsite-landing-row-large-headings - foreground: theme - background: grey - items: - - heading: Custom models - description: > -

If models don’t cover your use cases, you can always - bring your own existing TensorFlow Lite models. Just upload your model, - and we’ll take care of hosting and serving it to your app.

- -

Models acts as an API layer to your custom model, making it easy to - run and use. In addition to deploying your models, we are releasing an - experimental model compression flow that aims to reduce model size (up - to orders of magnitudes) while maintaining similar accuracy. Sign up at - g.co/firebase/signup

- -

And if you’re new to machine learning and want more information on - custom models for mobile, you can learn more about TensorFlow - Lite.

- image_path: /resources/images/tflite-card-16x9.png - image_left: true - - classname: devsite-landing-row-large-headings - foreground: theme - items: - - heading: Just the beginning - description: > - Our ultimate goal is to reduce idea–to–implementation cycles and make AI - an essential and intuitive part of a developer's toolkit. We will do so - by continuing to add new Base APIs that leverage Google’s machine - learning expertise. Base APIs will ultimately cover significantly more - use cases in the vision, speech, and text fields. We will also continue - to simplify use of custom models, adding tools to deploy, compress, and - create them. diff --git a/tensorflow/lite/g3doc/models/image/label/ios.md b/tensorflow/lite/g3doc/models/image/label/ios.md deleted file mode 100644 index 904c6450ac7272e67c1982b56099b608b91e2237..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/models/image/label/ios.md +++ /dev/null @@ -1,3 +0,0 @@ -# iOS - -lorem diff --git a/tensorflow/lite/g3doc/models/image/label/overview.md b/tensorflow/lite/g3doc/models/image/label/overview.md deleted file mode 100644 index b3d9133bb2123012f2ddd2db768347305d224744..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/models/image/label/overview.md +++ /dev/null @@ -1,8 +0,0 @@ -# Overview - -Image labeling gives you insight into the content of images. When you use the -API, you get a list of the entities that were recognized: people, things, -places, activities, and so on. Each label found comes with a score that -indicates the confidence the ML model has in its relevance. With this -information, you can perform tasks such as automatic metadata generation -and content moderation. diff --git a/tensorflow/lite/g3doc/models/image/label/android.md b/tensorflow/lite/g3doc/models/image_classification/android.md similarity index 85% rename from tensorflow/lite/g3doc/models/image/label/android.md rename to tensorflow/lite/g3doc/models/image_classification/android.md index c755328ac059013d2d45bbeb3c67516dafbb0ff1..61606096f77ce810c2b1a686cd05599b35200f57 100644 --- a/tensorflow/lite/g3doc/models/image/label/android.md +++ b/tensorflow/lite/g3doc/models/image_classification/android.md @@ -3,21 +3,20 @@ This tutorial provides a simple Android mobile application to classify images using the Android device camera. In this tutorial, you will download the demo application from the Tensorflow repository, build it on your computer, and -install it on your Android Device. You will also learn how to customize the +install it on your Android device. You will also learn how to customize the application to suit your requirements. ### Prerequisites * Android Studio 3.2 (installed on a Linux, Mac or Windows machine) -* Android Device +* Android device * USB cable (to connect Android device to your computer) ### Step 1. Clone the TensorFlow source code -First, we clone the GitHub repository on the computer in a folder to get the -demo application. +Clone the GitHub repository to your computer to get the demo application. ``` @@ -29,21 +28,21 @@ Open the TensorFlow source code in Android Studio. To do this, open Android Studio and select `Open an existing project` setting the folder to `tensorflow/lite/examples/android` -![Step 1](images/classifydemo_img1.png) + This folder contains the demo application for image classification, object detection, and speech hotword detection. ### Step 2. Build the Android Studio project -In this step, Select `Build -> Make Project` and check that the project builds +Select `Build -> Make Project` and check that the project builds successfully. You will need Android SDK configured in the settings. You'll need at least SDK version 23. The gradle file will prompt you to download any missing libraries. -![Step 2](images/classifydemo_img4.png) + -![Step 2a](images/classifydemo_img2.png) + #### TensorFlow Lite AAR from JCenter: @@ -53,25 +52,25 @@ build. If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is undefined for type Interpreter), there has likely been a backwards compatible change to the API. You will need to pull new app code -that's compatible with the nightly build by running git pull. +that's compatible with the nightly build by running `git pull`. -### Step 3. Install and Run the app +### Step 3. Install and run the app -Connect the Android device to the computer, and be sure to approve any ADB +Connect the Android device to the computer and be sure to approve any ADB permission prompts that appear on your phone. Select `Run -> Run app.` Select -the deployment target in the connected devices to the device on which app will +the deployment target in the connected devices to the device on which the app will be installed. This will install the app on the device. -![Step 3](images/classifydemo_img5.png) + -![Step 3a](images/classifydemo_img6.png) + -![Step 3b](images/classifydemo_img7.png) + -![Step 3c](images/classifydemo_img8.png) + -To test the app, open the app named `TFL Classify` on the device. When you run -the app first time, the app will request permission to access the camera. +To test the app, open the app called `TFL Classify` on your device. When you run +the app the first time, the app will request permission to access the camera. Re-installing the app may require you to uninstall the previous installations. ## Understanding Android App Code @@ -85,7 +84,7 @@ This file depends on `AndroidManifest.xml` in the folder `tensorflow/tensorflow/lite/examples/android/app/src/main` to set the camera orientation. -### Pre-process of bitmap image +### Pre-process bitmap image The mobile application code that pre-processes the images and runs inference is in @@ -102,7 +101,7 @@ DIM_PIXEL_SIZE); c.imgData.order(ByteOrder.nativeOrder()); ``` -While running the application, we preprocess the incoming bitmap images from the +While running the application, we pre-process the incoming bitmap images from the camera to a Bytebuffer. Since this model is quantized 8-bit, we will put a single byte for each channel. `imgData` will contain an encoded `Color` for each pixel in ARGB format, so we need to mask the least significant 8 bits to get diff --git a/tensorflow/lite/g3doc/models/image_classification/images/android_banana.png b/tensorflow/lite/g3doc/models/image_classification/images/android_banana.png new file mode 100644 index 0000000000000000000000000000000000000000..a25dffe3a070a8af509efab823193de3d2f80f49 Binary files /dev/null and b/tensorflow/lite/g3doc/models/image_classification/images/android_banana.png differ diff --git a/tensorflow/lite/g3doc/models/image_classification/images/build_and_execute.png b/tensorflow/lite/g3doc/models/image_classification/images/build_and_execute.png new file mode 100644 index 0000000000000000000000000000000000000000..a305350cb7b413b70d2095cfd8fea6d396c63695 Binary files /dev/null and b/tensorflow/lite/g3doc/models/image_classification/images/build_and_execute.png differ diff --git a/tensorflow/lite/g3doc/models/image_classification/images/bundle_identifier.png b/tensorflow/lite/g3doc/models/image_classification/images/bundle_identifier.png new file mode 100644 index 0000000000000000000000000000000000000000..398763916b353e61f236392e2b8898aad2aafe8e Binary files /dev/null and b/tensorflow/lite/g3doc/models/image_classification/images/bundle_identifier.png differ diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img1.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img1.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img1.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img1.png diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img2.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img2.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img2.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img2.png diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img4.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img4.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img4.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img4.png diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img5.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img5.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img5.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img5.png diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img6.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img6.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img6.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img6.png diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img7.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img7.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img7.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img7.png diff --git a/tensorflow/lite/g3doc/models/image/label/images/classifydemo_img8.png b/tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img8.png similarity index 100% rename from tensorflow/lite/g3doc/models/image/label/images/classifydemo_img8.png rename to tensorflow/lite/g3doc/models/image_classification/images/classifydemo_img8.png diff --git a/tensorflow/lite/g3doc/models/image_classification/images/device_selection.png b/tensorflow/lite/g3doc/models/image_classification/images/device_selection.png new file mode 100644 index 0000000000000000000000000000000000000000..1565fa0f2b65d48f4308ba50caacf745125c7431 Binary files /dev/null and b/tensorflow/lite/g3doc/models/image_classification/images/device_selection.png differ diff --git a/tensorflow/lite/g3doc/models/image_classification/images/dog.png b/tensorflow/lite/g3doc/models/image_classification/images/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..65c6eb551468be3b53dc805009688c7b5808c660 Binary files /dev/null and b/tensorflow/lite/g3doc/models/image_classification/images/dog.png differ diff --git a/tensorflow/lite/g3doc/models/image_classification/ios.md b/tensorflow/lite/g3doc/models/image_classification/ios.md new file mode 100644 index 0000000000000000000000000000000000000000..63e3abd779355b842964ae8836f24a1cd7a8832f --- /dev/null +++ b/tensorflow/lite/g3doc/models/image_classification/ios.md @@ -0,0 +1,229 @@ +# TensorFlow Lite iOS Image Classifier App Example + +This tutorial provides a simple iOS mobile application to classify images using +the iOS device camera. In this tutorial, you will download the demo application +from the Tensorflow repository, build it on your computer, and install it on +your iOS Device. You will also learn how to customize the application to suit +your needs. + +## Prerequisites + +* You must have [Xcode](https://developer.apple.com/xcode/) installed and have + a valid Apple Developer ID, and have an iOS device set up and linked to your + developer account with all of the appropriate certificates. For these + instructions, we assume that you have already been able to build and deploy + an app to an iOS device with your current developer environment. + +* The demo app requires a camera and must be executed on a real iOS device. + You can build it and run with the iPhone Simulator but it won't have any + camera information to classify. + +* You don't need to build the entire TensorFlow library to run the demo, but + you will need to clone the TensorFlow repository if you haven't already: + + git clone https://github.com/tensorflow/tensorflow + cd tensorflow + +* You'll also need the Xcode command-line tools: + + xcode-select --install + + If this is a new install, you will need to run the Xcode application once to + agree to the license before continuing. + +* Install CocoaPods if you don't have it: + + sudo gem install cocoapods + +### Step 1. Clone the TensorFlow source code + +lone the GitHub repository onto your computer to get the +demo application. + +``` +git clone https://github.com/tensorflow/tensorflow +``` + +### Step 2. Download required dependencies + +Execute the shell script to download the model files used by the demo app (this +is done from inside the cloned directory): + +``` + tensorflow/lite/examples/ios/download_models.sh +``` + +Run the following command to install TensorFlow Lite pod: + +``` + cd tensorflow/lite/examples/ios/camera + pod install +``` + +If you have installed this pod before and that command doesn't work, try + +``` + pod repo update +``` + +### Step 3. Build the XCode project + +Open the `tflite_camera_example.xcworkspace` project file generated in the last +step: + +``` + open tflite_camera_example.xcworkspace +``` + +Under `Project navigator -> tflite_camera_example -> Targets -> +tflite_camera_example -> General` change the bundle identifier by pre-pending +your name: + +![pre-pend your name to the bundle identifier](images/bundle_identifier.png) + +Plug in your iOS device. Note that the app must be executed with a real device with +a camera. Select the iOS device from the drop-down menu. + +![Device selection](images/device_selection.png) + +Click the "Run" button to build and run the app + +![Build and execute](images/build_and_execute.png) + +Note that, as mentioned earlier, you must already have a device set up and linked +to your Apple Developer account in order to deploy the app onto a device. + +You'll have to grant permissions for the app to use the device's camera. Point +the camera at various objects and enjoy seeing how the model classifies things! + +## Understanding iOS App Code + +### Get camera input + +The main logic of this app is in the Objective C++ source file +`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. + +The `setupAVCapture` method constructs a `AVCaptureSession` and set itself as a +delegate. The `captureOutput:didOutputSampleBuffer:fromConnection:` method is +called for every captured frame. It calls `runModelOnFrame` to run the model for +every frame. + +### Create an interpreter + +To create the interpreter, we need to load the model file. The following code +will load a model and create an interpreter. + +``` +model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); +``` + +Behind the scenes, the model is loaded as a memory-mapped file. It offers faster +load times and reduce the dirty pages in memory. + +Construct a `BuiltinOpResolver` to use the TensorFliw Lite buildin ops. Then, +create the interpreter object using `InterpreterBuilder` that takes the model +file as argument as shown below. + +``` +tflite::ops::builtin::BuiltinOpResolver resolver; +tflite::InterpreterBuilder(*model, resolver)(&interpreter); +``` + +### Obtain the input buffer + +By default, the app uses a quantized model since it's smaller and faster. The +buffer is a raw pointer to an array of 8 bit unsigned integers (`uint8_t`). The +following code obtains the input buffer from the interpreter: + +``` +// Get the index of first input tensor. +int input_tensor_index = interpreter->inputs()[0]; +// Get the pointer to the input buffer. +uint8_t* buffer = interpreter->typed_tensor(input_tensor_index); +``` + +Throughout this document, it's assumed that a quantized model is used. + +### Pre-process bitmap image + +The MobileNet model that we're using takes 224x224x3 inputs, where the dimensions are +width, height, and colors (RGB). The images returned from `AVCaptureSession` is +bigger and has 4 color channels (RGBA). + +Many image classification models (like MobileNet) take fixe-sized inputs. It's +required to scale or crop the image before feeding it into the model and change +the channels from RGBA to RGB. + +The code to pre-process the images is in `ProcessInputWithQuantizedModel` +function in +`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. It's a +simple implementation for nearest neighbor color sampling and it only copies +the first 3 bytes for each pixel. + +``` +void ProcessInputWithQuantizedModel( + uint8_t* input, uint8_t* output, int image_width, int image_height, int image_channels) { + for (int y = 0; y < wanted_input_height; ++y) { + uint8_t* out_row = output + (y * wanted_input_width * wanted_input_channels); + for (int x = 0; x < wanted_input_width; ++x) { + const int in_x = (y * image_width) / wanted_input_width; + const int in_y = (x * image_height) / wanted_input_height; + uint8_t* in_pixel = input + (in_y * image_width * image_channels) + (in_x * image_channels); + uint8_t* out_pixel = out_row + (x * wanted_input_channels); + for (int c = 0; c < wanted_input_channels; ++c) { + out_pixel[c] = in_pixel[c]; + } + } + } +} +``` + +Note that the code pre-processes and prepares the model input from the camera +data. Therefore, the first parameter `input` should be the camera buffer. The +second parameter `output` should be the buffer of model input. + +### Run inference and obtain output buffer + +After pre-processing and filling the data into the input buffer of the +interpreter, it's really easy to run the interpreter: + +``` +if (interpreter->Invoke() != kTfLiteOk) { + NSLog("Failed to invoke!"); +} +``` + +The result is stored in the output tensor buffer of the interpreter. The +following code obtains the pointer to the buffer: + +``` +// Get the index of first output tensor. +const int output_tensor_index = interpreter->outputs()[0]; +// Get the pointer to the output buffer. +uint8_t* buffer = interpreter->typed_tensor(output_tensor_index); +``` + +### Post-process values + +The output buffer contains an array of `uint8_t`, and the value range is from 0-255. +We need to convert the value to float to get the probabilities with a value range from +0.0-1.0. The formula of the quantization value mapping is: + + float_value = (quantized_value - zero_point) * scale + +The following code converts quantized values back to float values, using the +quantizaiton parameters in tensors: + +``` +uint8_t* quantized_output = interpreter->typed_output_tensor(0); +int32_t zero_point = input_tensor->params.zero_point; +float scale = input_tensor->params.scale; +float output[output_size]; +for (int i = 0; i < output_size; ++i) { + output[i] = (quantized_output[i] - zero_point) * scale; +} +``` + +Finally, we find the best set of classifications by storing them in a priority +queue based on their confidence scores. See the `GetTopN` function in +`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. diff --git a/tensorflow/lite/g3doc/models/image_classification/overview.md b/tensorflow/lite/g3doc/models/image_classification/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..9ddbaf43ef092456b23d65684e1fdd7609c58472 --- /dev/null +++ b/tensorflow/lite/g3doc/models/image_classification/overview.md @@ -0,0 +1,285 @@ +# Image classification + + + +Use a pre-trained and optimized model to identify hundreds of classes of +objects, including people, activities, animals, plants, and places. + +## Get started + +If you are unfamiliar with the concept of image classification, you should start +by reading What is image +classification? + +If you understand image classification, you’re new to TensorFlow Lite, and +you’re working with Android or iOS, we recommend following the corresponding +tutorial that will walk you through our sample code. + +Android +iOS + +We also provide example applications you can +use to get started. + +If you are using a platform other than Android or iOS, or you are already +familiar with the TensorFlow Lite APIs, you can +download our starter image classification model and the accompanying labels. + +Download +starter model and labels + +Once you have the starter model running on your target device, you can +experiment with different models to find the optimal balance between +performance, accuracy, and model size. For guidance, see +Choose a different model. + +If you are using a platform other than Android or iOS, or you are already +familiar with the TensorFlow Lite APIs, you can +download our starter image classification model and the accompanying labels. + +Download +starter model and labels + +### Example applications + +We have example applications for image classification for both Android and iOS. + +Android +example +iOS +example + +The following screenshot shows the Android image classification example: + +Screenshot of Android example + +## What is image classification? + +A common use of machine learning is to identify what an image represents. For +example, we might want to know what type of animal appears in the following +photograph. + +dog + +The task of predicting what an image represents is called _image +classification_. An image classification model is trained to recognize various +classes of images. For example, a model might be trained to recognize photos +representing three different types of animals: rabbits, hamsters, and dogs. + +When we subsequently provide a new image as input to the model, it will output +the probabilities of the image representing each of the types of animal it was +trained on. An example output might be as follows: + + + + + + + + + + + + + + + + + + + + + + +
Animal typeProbability
Rabbit0.07
Hamster0.02
Dog0.91
+ +Based on the output, we can see that the classification model has predicted that +the image has a high probability of representing a dog. + +Note: Image classification can only tell you the probability that an image +represents one or more of the classes that the model was trained on. It cannot +tell you the position or identity of objects within the image. If you need to +identify objects and their positions within images, you should use an +object detection model. + +### Training, labels, and inference + +During training, an image classification model is fed images and their +associated _labels_. Each label is the name of a distinct concept, or class, +that the model will learn to recognize. + +Given sufficient training data (often hundreds or thousands of images per +label), an image classification model can learn to predict whether new images +belong to any of the classes it has been trained on. This process of prediction +is called _inference_. + +To perform inference, an image is passed as input to a model. The model will +then output an array of probabilities between 0 and 1. With our example model, +this process might look like the following: + + + + + + +
dog[0.07, 0.02, 0.91]
+ +Each number in the output corresponds to a label in our training data. +Associating our output with the three labels the model was trained on, we can +see the model has predicted a high probability that the image represents a dog. + + + + + + + + + + + + + + + + + + + + + + +
LabelProbability
rabbit0.07
hamster0.02
dog0.91
+ +You might notice that the sum of all the probabilities (for rabbit, hamster, and +dog) is equal to 1. This is a common type of output for models with multiple +classes (see +Softmax +for more information). + +### Ambiguous results + +Since the probabilities will always sum to 1, if the image is not confidently +recognized as belonging to any of the classes the model was trained on you may +see the probability distributed throughout the labels without any one value +being significantly larger. + +For example, the following might indicate an ambiguous result: + + + + + + + + + + + + + + + + + + + + + + +
LabelProbability
rabbit0.31
hamster0.35
dog0.34
+ +### Uses and limitations + +The image classification models that we provide are useful for single-label +classification, which means predicting which single label the image is most +likely to represent. They are trained to recognize 1000 classes of image. For a +full list of classes, see the labels file in the +model +zip. + +If you want to train a model to recognize new classes, see +Customize model. + +For the following use cases, you should use a different type of model: + +
    +
  • Predicting the type and position of one or more objects within an image (see object detection)
  • +
  • Predicting the composition of an image, for example subject versus background (see segmentation)
  • +
+ +Once you have the starter model running on your target device, you can +experiment with different models to find the optimal balance between +performance, accuracy, and model size. For guidance, see +Choose a different model. + +## Choose a different model + +There are a large number of image classification models available on our +List of hosted models. You should aim to choose the +optimal model for your application based on performance, accuracy and model +size. There are trade-offs between each of them. + +### Performance + +We measure performance in terms of the amount of time it takes for a model to +run inference on a given piece of hardware. The less time, the faster the model. + +The performance you require depends on your application. Performance can be +important for applications like real-time video, where it may be important to +analyze each frame in the time before the next frame is drawn (e.g. inference +must be faster than 33ms to perform real-time inference on a 30fps video +stream). + +Our quantized Mobilenet models’ performance ranges from 3.7ms to 80.3 ms. + +### Accuracy + +We measure accuracy in terms of how often the model correctly classifies an +image. For example, a model with a stated accuracy of 60% can be expected to +classify an image correctly an average of 60% of the time. + +Our list of hosted models provides Top-1 and Top-5 +accuracy statistics. Top-1 refers to how often the correct label appears as the +label with the highest probability in the model’s output. Top-5 refers to how +often the correct label appears in the top 5 highest probabilities in the +model’s output. + +Our quantized Mobilenet models’ Top-5 accuracy ranges from 64.4 to 89.9%. + +### Size + +The size of a model on-disk varies with its performance and accuracy. Size may +be important for mobile development (where it might impact app download sizes) +or when working with hardware (where available storage might be limited). + +Our quantized Mobilenet models’ size ranges from 0.5 to 3.4 Mb. + +### Architecture + +There are several different architectures of models available on +List of hosted models, indicated by the model’s name. +For example, you can choose between Mobilenet, Inception, and others. + +The architecture of a model impacts its performance, accuracy, and size. All of +our hosted models are trained on the same data, meaning you can use the provided +statistics to compare them and choose which is optimal for your application. + +Note: The image classification models we provide accept varying sizes of input. For some models, this is indicated in the filename. For example, the Mobilenet_V1_1.0_224 model accepts an input of 224x224 pixels.

All of the models require three color channels per pixel (red, green, and blue). Quantized models require 1 byte per channel, and float models require 4 bytes per channel.

Our Android and iOS code samples demonstrate how to process full-sized camera images into the required format for each model. + +## Customize model + +The pre-trained models we provide are trained to recognize 1000 classes of +image. For a full list of classes, see the labels file in the +model +zip. + +You can use a technique known as _transfer learning_ to re-train a model to +recognize classes not in the original set. For example, you could re-train the +model to distinguish between different species of tree, despite there being no +trees in the original training data. To do this, you will need a set of training +images for each of the new labels you wish to train. + +Learn how to perform transfer learning in the +TensorFlow +for Poets codelab. diff --git a/tensorflow/lite/g3doc/models/images/audio.png b/tensorflow/lite/g3doc/models/images/audio.png new file mode 100644 index 0000000000000000000000000000000000000000..ce6b25c442016a21600eb8249eafa55bacbba4e9 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/audio.png differ diff --git a/tensorflow/lite/g3doc/models/images/blank.png b/tensorflow/lite/g3doc/models/images/blank.png new file mode 100644 index 0000000000000000000000000000000000000000..d099da5da07271410883554e07e37765ca048590 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/blank.png differ diff --git a/tensorflow/lite/g3doc/models/images/camera.png b/tensorflow/lite/g3doc/models/images/camera.png new file mode 100644 index 0000000000000000000000000000000000000000..95a9218d47864aba12255bd32b67bb74b0d6704e Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/camera.png differ diff --git a/tensorflow/lite/g3doc/models/images/detection.png b/tensorflow/lite/g3doc/models/images/detection.png new file mode 100644 index 0000000000000000000000000000000000000000..30e10f59cd53af21fe9b6a86aa5b45ca07131b1b Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/detection.png differ diff --git a/tensorflow/lite/g3doc/models/images/image.png b/tensorflow/lite/g3doc/models/images/image.png new file mode 100644 index 0000000000000000000000000000000000000000..e72aac9b25eec69e8c0252f441d125340b88cab5 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/image.png differ diff --git a/tensorflow/lite/g3doc/models/images/object.png b/tensorflow/lite/g3doc/models/images/object.png new file mode 100644 index 0000000000000000000000000000000000000000..aa8ed428ed15e7b166bdde560669563a224e6f6c Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/object.png differ diff --git a/tensorflow/lite/g3doc/models/images/output_stride.png b/tensorflow/lite/g3doc/models/images/output_stride.png new file mode 100644 index 0000000000000000000000000000000000000000..5d4663f8675eef733e18b2a5cb05670cd40d8293 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/output_stride.png differ diff --git a/tensorflow/lite/g3doc/models/images/pose.png b/tensorflow/lite/g3doc/models/images/pose.png new file mode 100644 index 0000000000000000000000000000000000000000..f071d789963d0f48efb5ba20633391403f75ddf8 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/pose.png differ diff --git a/tensorflow/lite/g3doc/models/images/segmentation.png b/tensorflow/lite/g3doc/models/images/segmentation.png new file mode 100644 index 0000000000000000000000000000000000000000..8c61330687cc9a388a443bc6b771027d15b66d98 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/segmentation.png differ diff --git a/tensorflow/lite/g3doc/models/images/sentiment.png b/tensorflow/lite/g3doc/models/images/sentiment.png new file mode 100644 index 0000000000000000000000000000000000000000..2ba494fcb6e62a90015d2aead4779fcacab70529 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/sentiment.png differ diff --git a/tensorflow/lite/g3doc/models/images/smart_reply.png b/tensorflow/lite/g3doc/models/images/smart_reply.png new file mode 100644 index 0000000000000000000000000000000000000000..802cc80feebe2a46b059b23d52ccf794701e4d99 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/smart_reply.png differ diff --git a/tensorflow/lite/g3doc/models/images/tabular.png b/tensorflow/lite/g3doc/models/images/tabular.png new file mode 100644 index 0000000000000000000000000000000000000000..2eac8f4c4ac74029c755a207b1f8a25592f468ac Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/tabular.png differ diff --git a/tensorflow/lite/g3doc/models/images/text.png b/tensorflow/lite/g3doc/models/images/text.png new file mode 100644 index 0000000000000000000000000000000000000000..227594f07e3d38fd4110249eb2c4c6541fb89baa Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/text.png differ diff --git a/tensorflow/lite/g3doc/models/images/tflite_models.png b/tensorflow/lite/g3doc/models/images/tflite_models.png new file mode 100644 index 0000000000000000000000000000000000000000..f60cd26a3177f95e40875ed92aa4a30c59a7623f Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/tflite_models.png differ diff --git a/tensorflow/lite/g3doc/models/images/video.png b/tensorflow/lite/g3doc/models/images/video.png new file mode 100644 index 0000000000000000000000000000000000000000..88b3b7d3c76840625abec821220413a03d384a45 Binary files /dev/null and b/tensorflow/lite/g3doc/models/images/video.png differ diff --git a/tensorflow/lite/g3doc/models/object_detection/images/android_apple_banana.png b/tensorflow/lite/g3doc/models/object_detection/images/android_apple_banana.png new file mode 100644 index 0000000000000000000000000000000000000000..f7a9fe5af89f89772248e8ea8d89904719d310dd Binary files /dev/null and b/tensorflow/lite/g3doc/models/object_detection/images/android_apple_banana.png differ diff --git a/tensorflow/lite/g3doc/models/object_detection/images/false_positive.png b/tensorflow/lite/g3doc/models/object_detection/images/false_positive.png new file mode 100644 index 0000000000000000000000000000000000000000..39d2103a3a8749233ebfab9fff3fea533f0f00fd Binary files /dev/null and b/tensorflow/lite/g3doc/models/object_detection/images/false_positive.png differ diff --git a/tensorflow/lite/g3doc/models/object_detection/overview.md b/tensorflow/lite/g3doc/models/object_detection/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..ffa6381ef3df28d7b035d585a2813496ea2ea7e0 --- /dev/null +++ b/tensorflow/lite/g3doc/models/object_detection/overview.md @@ -0,0 +1,270 @@ +# Object detection + + + +Detect multiple objects within an image, with bounding boxes. Recognize 80 +different classes of objects. + +## Get started + +If you are new to TensorFlow Lite and are working with Android or iOS, we +recommend exploring the following example applications that can help you get +started. + +Android +example +iOS +example + +If you are using a platform other than Android or iOS, or you are already +familiar with the TensorFlow Lite APIs, you can +download our starter object detection model and the accompanying labels. + +Download +starter model and labels + +For more information about the starter model, see +Starter model. + +## What is object detection? + +Given an image or a video stream, an object detection model can identify which +of a known set of objects might be present and provide information about their +positions within the image. + +For example, this screenshot of our example +application shows how two objects have been recognized and their positions +annotated: + +Screenshot of Android example + +An object detection model is trained to detect the presence and location of +multiple classes of objects. For example, a model might be trained with images +that contain various pieces of fruit, along with a _label_ that specifies the +class of fruit they represent (e.g. an apple, a banana, or a strawberry), and +data specifying where each object appears in the image. + +When we subsequently provide an image to the model, it will output a list of the +objects it detects, the location of a bounding box that contains each object, +and a score that indicates the confidence that detection was correct. + +### Model output + +Imagine a model has been trained to detect apples, bananas, and strawberries. +When we pass it an image, it will output a set number of detection results - in +this example, 5. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ClassScoreLocation
Apple0.92[18, 21, 57, 63]
Banana0.88[100, 30, 180, 150]
Strawberry0.87[7, 82, 89, 163]
Banana0.23[42, 66, 57, 83]
Apple0.11[6, 42, 31, 58]
+ +### Confidence score + +To interpret these results, we can look at the score and the location for each +detected object. The score is a number between 0 and 1 that indicates confidence +that the object was genuinely detected. The closer the number is to 1, the more +confident the model is. + +Depending on your application, you can decide a cut-off threshold below which +you will discard detection results. For our example, we might decide a sensible +cut-off is a score of 0.5 (meaning a 50% probability that the detection is +valid). In that case, we would ignore the last two objects in the array, because +those confidence scores are below 0.5: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ClassScoreLocation
Apple0.92[18, 21, 57, 63]
Banana0.88[100, 30, 180, 150]
Strawberry0.87[7, 82, 89, 163]
Banana0.23[42, 66, 57, 83]
Apple0.11[6, 42, 31, 58]
+ +The cut-off you use should be based on whether you are more comfortable with +false positives (objects that are wrongly identified, or areas of the image that +are erroneously identified as objects when they are not), or false negatives +(genuine objects that are missed because their confidence was low). + +For example, in the following image, a pear (which is not an object that the +model was trained to detect) was misidentified as a "person". This is an example +of a false positive that could be ignored by selecting an appropriate cut-off. +In this case, a cut-off of 0.6 (or 60%) would comfortably exclude the false +positive. + +Screenshot of Android example showing a false positive + +### Location + +For each detected object, the model will return an array of four numbers +representing a bounding rectangle that surrounds its position. For the starter +model we provide, the numbers are ordered as follows: + + + + + + + + + + + + +
[top,left,bottom,right]
+ +The top value represents the distance of the rectangle’s top edge from the top +of the image, in pixels. The left value represents the left edge’s distance from +the left of the input image. The other values represent the bottom and right +edges in a similar manner. + +Note: Object detection models accept input images of a specific size. This is likely to be different from the size of the raw image captured by your device’s camera, and you will have to write code to crop and scale your raw image to fit the model’s input size (there are examples of this in our example applications).

The pixel values output by the model refer to the position in the cropped and scaled image, so you must scale them to fit the raw image in order to interpret them correctly. + +## Starter model + +We recommend starting with this pre-trained quantized COCO SSD MobileNet v1 +model. + +Download +starter model and labels + +### Uses and limitations + +The object detection model we provide can identify and locate up to 10 objects +in an image. It is trained to recognize 80 classes of object. For a full list of +classes, see the labels file in the +model +zip. + +If you want to train a model to recognize new classes, see +Customize model. + +For the following use cases, you should use a different type of model: + +
    +
  • Predicting which single label the image most likely represents (see image classification)
  • +
  • Predicting the composition of an image, for example subject versus background (see segmentation)
  • +
+ +### Input + +The model takes an image as input. The expected image is 300x300 pixels, with +three channels (red, blue, and green) per pixel. This should be fed to the model +as a flattened buffer of 270,000 byte values (300x300x3). Since the model is +quantized, each +value should be a single byte representing a value between 0 and 255. + +### Output + +The model outputs four arrays, mapped to the indices 0-4. Arrays 0, 1, and 2 +describe 10 detected objects, with one element in each array corresponding to +each object. There will always be 10 objects detected. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
IndexNameDescription
0LocationsMultidimensional array of [10][4] floating point values between 0 and 1, the inner arrays representing bounding boxes in the form [top, left, bottom, right]
1ClassesArray of 10 integers (output as floating point values) each indicating the index of a class label from the labels file
2ScoresArray of 10 floating point values between 0 and 1 representing probability that a class was detected
3Number and detectionsArray of length 1 containing a floating point value expressing the total number of detection results
+ +## Customize model + +The pre-trained models we provide are trained to detect 80 classes of object. +For a full list of classes, see the labels file in the +model +zip. + +You can use a technique known as transfer learning to re-train a model to +recognize classes not in the original set. For example, you could re-train the +model to detect multiple types of vegetable, despite there only being one +vegetable in the original training data. To do this, you will need a set of +training images for each of the new labels you wish to train. + +Learn how to perform transfer learning in +Training +and serving a real-time mobile object detector in 30 minutes. diff --git a/tensorflow/lite/g3doc/models/pose_estimation/overview.md b/tensorflow/lite/g3doc/models/pose_estimation/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..981a2553f701ac0302e880dd5445defa14bece37 --- /dev/null +++ b/tensorflow/lite/g3doc/models/pose_estimation/overview.md @@ -0,0 +1,153 @@ +# Pose estimation + + + +## Get started + +_PoseNet_ is a vision model that can be used to estimate the pose of a person in +an image or video by estimating where key body joints are. + +Download +starter model + +Android and iOS end-to-end tutorials are coming soon. In the meantime, if you +want to experiment this on a web browser, check out the +TensorFlow.js +GitHub repository. + +## How it works + +Pose estimation refers to computer vision techniques that detect human figures +in images and videos, so that one could determine, for example, where someone’s +elbow shows up in an image. + +To be clear, this technology is not recognizing who is in an image. The +algorithm is simply estimating where key body joints are. + +The key points detected are indexed by "Part ID", with a confidence score +between 0.0 and 1.0, 1.0 being the highest. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
IdPart
0nose
1leftEye
2rightEye
3leftEar
4rightEar
5leftShoulder
6rightShoulder
7leftElbow
8rightElbow
9leftWrist
10rightWrist
11leftHip
12rightHip
13leftKnee
14rightKnee
15leftAnkle
16rightAnkle
+ +## Example output + +Animation showing pose estimation + +## How it performs + +Performance varies based on your device and output stride (heatmaps and offset +vectors). The PoseNet model is image size invariant, which means it can predict +pose positions in the same scale as the original image regardless of whether the +image is downscaled. This means PoseNet can be configured to have a higher +accuracy at the expense of performance. + +The output stride determines how much we’re scaling down the output relative to +the input image size. It affects the size of the layers and the model outputs. +The higher the output stride, the smaller the resolution of layers in the +network and the outputs, and correspondingly their accuracy. In this +implementation, the output stride can have values of 8, 16, or 32. In other +words, an output stride of 32 will result in the fastest performance but lowest +accuracy, while 8 will result in the highest accuracy but slowest performance. +We recommend starting with 16. + +The following image shows how the output stride determines how much we’re +scaling down the output relative to the input image size. A higher output stride +is faster but results in lower accuracy. + +Output stride and heatmap resolution + +## Read more about pose estimation + + + +### Use cases + + diff --git a/tensorflow/lite/g3doc/models/segmentation/images/segmentation.gif b/tensorflow/lite/g3doc/models/segmentation/images/segmentation.gif new file mode 100644 index 0000000000000000000000000000000000000000..e664adf700b396b9cd06d48378b782efbbca4282 Binary files /dev/null and b/tensorflow/lite/g3doc/models/segmentation/images/segmentation.gif differ diff --git a/tensorflow/lite/g3doc/models/segmentation/overview.md b/tensorflow/lite/g3doc/models/segmentation/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..0bd268ada1fcefbdad39c9951c0d471e32b16b03 --- /dev/null +++ b/tensorflow/lite/g3doc/models/segmentation/overview.md @@ -0,0 +1,43 @@ +# Segmentation + + + +## Get started + +_DeepLab_ is a state-of-art deep learning model for semantic image segmentation, +where the goal is to assign semantic labels (e.g. person, dog, cat) to every +pixel in the input image. + +Download +starter model + +## How it works + +Semantic image segmentation predicts whether each pixel of an image is +associated with a certain class. This is in contrast to +object detection, which detects +objects in rectangular regions, and +image classification, which +classifies the overall image. + +The current implementation includes the following features: +
    +
  1. DeepLabv1: We use atrous convolution to explicitly control the resolution at which feature responses are computed within Deep Convolutional Neural Networks.
  2. +
  3. DeepLabv2: We use atrous spatial pyramid pooling (ASPP) to robustly segment objects at multiple scales with filters at multiple sampling rates and effective fields-of-views.
  4. +
  5. DeepLabv3: We augment the ASPP module with image-level feature [5, 6] to capture longer range information. We also include batch normalization [7] parameters to facilitate the training. In particular, we applying atrous convolution to extract output features at different output strides during training and evaluation, which efficiently enables training BN at output stride = 16 and attains a high performance at output stride = 8 during evaluation.
  6. +
  7. DeepLabv3+: We extend DeepLabv3 to include a simple yet effective decoder module to refine the segmentation results especially along object boundaries. Furthermore, in this encoder-decoder structure one can arbitrarily control the resolution of extracted encoder features by atrous convolution to trade-off precision and runtime.
  8. +
+ +## Example output + +The model will create a mask over the target objects with high accuracy. + +Animation showing image segmentation + +## Read more about segmentation + + diff --git a/tensorflow/lite/g3doc/models/smart_reply/images/smart_reply.gif b/tensorflow/lite/g3doc/models/smart_reply/images/smart_reply.gif new file mode 100644 index 0000000000000000000000000000000000000000..4a61691fd8714102409d290e7f6d6e361d9cbf13 Binary files /dev/null and b/tensorflow/lite/g3doc/models/smart_reply/images/smart_reply.gif differ diff --git a/tensorflow/lite/g3doc/models/smart_reply/overview.md b/tensorflow/lite/g3doc/models/smart_reply/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..20c359ec9ff9c79d48df8f2af381d98e27a5cc84 --- /dev/null +++ b/tensorflow/lite/g3doc/models/smart_reply/overview.md @@ -0,0 +1,52 @@ +# Smart reply + + + +## Get started + +Our smart reply model generates reply suggestions based on chat messages. The +suggestions are intended to be contextually relevant, one-touch responses that +help the user to easily reply to an incoming message. + +Download +starter model and labels + +### Sample application + +We have provided a pre-built APK that demonstrates the smart reply model on +Android. + +Go to the +GitHub +page for instructions and list of supported ops and functionalities. + +## How it works + +The model generates reply suggestions to conversational chat messages. + +The on-device model comes with several benefits. It is: +
    +
  • Fast: The model resides on the device and does not require internet connectivity. Thus, inference is very fast and has an average latency of only a few milliseconds.
  • +
  • Resource efficient: The model has a small memory footprint on the device.
  • +
  • Privacy-friendly: User data never leaves the device.
  • +
+ +## Example output + +Animation showing smart reply + +## Read more about this + + + +## Users + + diff --git a/tensorflow/lite/g3doc/performance/benchmarks.md b/tensorflow/lite/g3doc/performance/benchmarks.md index 5a1e5586beecad4876c9d0390a0fa31e78705195..a51fdb4080718d7d9dac9a01ee4935a1631933d6 100644 --- a/tensorflow/lite/g3doc/performance/benchmarks.md +++ b/tensorflow/lite/g3doc/performance/benchmarks.md @@ -1,5 +1,4 @@ - -# Performance +# Performance benchmarks This document lists TensorFlow Lite performance benchmarks when running well known models on some Android and iOS devices. @@ -42,7 +41,7 @@ Pixel xl | 0c | Model Name Device - Mean inference time (std dev) + Mean inference time @@ -50,66 +49,66 @@ Pixel xl | 0c | Mobilenet_1.0_224(float) Pixel 2 - 166.5 ms (2.6 ms) + 123.3 ms - Pixel xl - 122.9 ms (1.8 ms) + Pixel XL + 113.3 ms Mobilenet_1.0_224 (quant) Pixel 2 - 69.5 ms (0.9 ms) + 65.4 ms - Pixel xl - 78.9 ms (2.2 ms) + Pixel XL + 74.6 ms NASNet mobile Pixel 2 - 273.8 ms (3.5 ms) + 273.8 ms - Pixel xl - 210.8 ms (4.2 ms) + Pixel XL + 210.8 ms SqueezeNet Pixel 2 - 234.0 ms (2.1 ms) + 234.0 ms - Pixel xl - 158.0 ms (2.1 ms) + Pixel XL + 158.0 ms Inception_ResNet_V2 Pixel 2 - 2846.0 ms (15.0 ms) + 2846.0 ms - Pixel xl - 1973.0 ms (15.0 ms) + Pixel XL + 1973.0 ms Inception_V4 Pixel 2 - 3180.0 ms (11.7 ms) + 3180.0 ms - Pixel xl - 2262.0 ms (21.0 ms) + Pixel XL + 2262.0 ms @@ -126,7 +125,7 @@ modified to set `num_threads` to 1. Model Name Device - Mean inference time (std dev) + Mean inference time @@ -134,41 +133,41 @@ modified to set `num_threads` to 1. Mobilenet_1.0_224(float) iPhone 8 - 32.2 ms (0.8 ms) + 32.2 ms Mobilenet_1.0_224 (quant) iPhone 8 - 24.4 ms (0.8 ms) + 24.4 ms NASNet mobile iPhone 8 - 60.3 ms (0.6 ms) + 60.3 ms SqueezeNet iPhone 8 - 44.3 (0.7 ms) + 44.3 Inception_ResNet_V2 iPhone 8 - 562.4 ms (18.2 ms) + 562.4 ms Inception_V4 iPhone 8 - 661.0 ms (29.2 ms) + 661.0 ms diff --git a/tensorflow/lite/g3doc/performance/best_practices.md b/tensorflow/lite/g3doc/performance/best_practices.md index 5f41a7027538f571601c85a0a367208200155dd6..45aa17f07a9dad1cd4eb140f45abbf21e3d30fe7 100644 --- a/tensorflow/lite/g3doc/performance/best_practices.md +++ b/tensorflow/lite/g3doc/performance/best_practices.md @@ -1,19 +1,27 @@ # Performance best practices -Mobile and embedded devices have limited computational resources and it is +Mobile and embedded devices have limited computational resources, so it is important to keep your application resource efficient. We have compiled a list of best practices and strategies that you can use to optimize your model and application when using TensorFlow Lite. ## Choose the best model for the task -Depending on the task you will need to make a tradeoff between model complexity and size. If your task requires high accuracy then you may need a large and complex model. Some tasks may work with a less precise model, for these tasks it is better to use a smaller but less precise model. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. For example, graphs below show accuracy and latency tradeoff for some common image classification models. -![accuracy vs model size](../images/performance/model_size_vs_accuracy.png "Accuracy vs Model size") +Depending on the task, you will need to make a tradeoff between model complexity +and size. If your task requires high accuracy, then you may need a large and +complex model. For tasks that require less precision, it is better to use a +smaller model because they not only use less disk space and memory, but they are +also generally faster and more energy efficient. For example, graphs below show +accuracy and latency tradeoffs for some common image classification models. +![Graph of model size vs accuracy](../images/performance/model_size_vs_accuracy.png "Model Size vs Accuracy") -![latency vs model size](../images/performance/model_size_vs_latency.png "Latency vs Model size") +![Graph of model size vs latency](../images/performance/model_size_vs_latency.png "Model Size vs Latency") -One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. TensorFlow Lite [models page](../models.md) lists several other models that have been optimized specifically for mobile and embedded devices. +One example of models optimized for mobile devices are +[MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile +vision applications. [Hosted models](../models/hosted.md) lists several other +models that have been optimized specifically for mobile and embedded devices. You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for [image classification](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and @@ -24,33 +32,58 @@ You can retrain the listed models on your own dataset by using transfer learning Once you have selected a candidate model that is right for your task, it is a good practice to profile and benchmark your model. TensorFlow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. ## Profile and optimize operators in the graph -If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator. - This scenario should be rare as TensorFlow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](../custom_operators.md). -## Quantize your model -If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](model_optimization.md) for details about optimizing your model. +If a particular operator appears frequently in the model and, based on +profiling, you find that the operator consumes the most amount of time, you can +look into optimizing that operator. This scenario should be rare as TensorFlow +Lite has optimized versions for most operators. However, you may be able to +write a faster version of a custom op if you know the constraints in which the +operator is executed. Check out our +[custom operator documentation](../custom_operators.md). + +## Optimize your model + +Model compression aims to create smaller models that are generally faster and +more energy efficient, so that they can be deployed on mobile devices. + +### Quantization + +If your model uses floating-point weights or activations, then it may be +possible to reduce the size of model up to ~4x by using quantization, which +effectively turns the float weights to 8-bit. There are two flavors of +quantization: [post-training quantization](post_training_quantization.md) and +[quantized training](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/README.md){:.external}. +The former does not require model re-training, but, in rare cases, may have +accuracy loss. When accuracy loss is beyond acceptable thresholds, quantized +training should be used instead. + +We strongly recommend running benchmarks to make sure that the accuracy is not +impacted during model compression. Check out our +[model optimization toolkit](model_optimization.md) for details. ## Tweak the number of threads TensorFlow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing -the number of threads will however make your model use more resources and power. -For some applications latency may be more important than energy efficiency. You -can increase the number of threads by setting the number of -[interpreter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h#L333) -threads. Multi-threaded execution however comes at the cost of increased -performance variability depending on what else is been executed concurrently. -This is particularly the case for mobile apps. For example, isolated tests may -show 2x speed up vs single-threaded but if another app is executing at the same -time may result in worst performance than single-threaded. +the number of threads will, however, make your model use more resources and +power. + +For some applications, latency may be more important than energy efficiency. You +can increase the number of threads by setting the number of interpreter +[threads](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h#L346). +Multi-threaded execution, however, comes at the cost of increased performance +variability depending on what else is executed concurrently. This is +particularly the case for mobile apps. For example, isolated tests may show 2x +speed-up vs single-threaded, but, if another app is executing at the same time, +it may result in worse performance than single-threaded. ## Eliminate redundant copies -If your application is not careful, there can be redundant copies when feeding -the input to the model and reading output from the model. Make sure to eliminate -redundant copies. If you are using higher level APIs like Java API, make sure to -carefully check the documentation for performance caveats. For example, the Java -API is a lot faster if ByteBuffers are used as +If your application is not carefully designed, there can be redundant copies +when feeding the input to and reading the output from the model. Make sure to +eliminate redundant copies. If you are using higher level APIs, like Java, make +sure to carefully check the documentation for performance caveats. For example, +the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L175). ## Profile your application with platform specific tools @@ -58,23 +91,23 @@ Platform specific tools like [Android profiler](https://developer.android.com/st ## Evaluate whether your model benefits from using hardware accelerators available on the device -TensorFlow Lite has added been new ways to accelerate models with faster -hardware like GPUs, DSPs, and neural accelerators. Typically, these accelerators -are exposed through *delegate* submodules that take over parts of the +TensorFlow Lite has added new ways to accelerate models with faster hardware +like GPUs, DSPs, and neural accelerators. Typically, these accelerators are +exposed through [delegate](delegates.md) submodules that take over parts of the interpreter execution. TensorFlow Lite can use delegates by: * Using Android's [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/). You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable the Neural Networks API, call - [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h#L330) + [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h#L343) on the interpreter instance. -* A binary-only GPU delegate has been released for Android and iOS—using +* A binary-only GPU delegate has been released for Android and iOS, using OpenGL and Metal, respectively. To try them out, see the [GPU delegate tutorial](gpu.md) and [documentation](gpu_advanced.md). * It is possible to create your own delegate if you have access to - non-standard hardware. View the NN API delegate in the source code as an - example. + non-standard hardware. See [TensorFlow Lite delegates](delegates.md) for + more information. Be aware that some accelerators work better for different types of models. It is important to benchmark each delegate to see if it is a good choice for your diff --git a/tensorflow/lite/g3doc/performance/delegates.md b/tensorflow/lite/g3doc/performance/delegates.md new file mode 100644 index 0000000000000000000000000000000000000000..d5abfb034386435a4f2ab9c4f4ba816417d7d7b8 --- /dev/null +++ b/tensorflow/lite/g3doc/performance/delegates.md @@ -0,0 +1,206 @@ +## TensorFlow Lite delegates + +_Note: Delegate API is still experimental and is subject to change._ + + +## What is a TensorFlow Lite delegate? + +A TensorFlow Lite delegate is a way to delegate part or all of graph execution to another executor. + + +## Why should I use delegates? + +Running inference on compute-heavy machine learning models on mobile devices is resource demanding due to the devices' limited processing and power. + +Instead of relying on the CPU, some devices have hardware accelerators, such as GPU or DSP, that allows for better performance and higher energy efficiency. + + +## Using the experimental GPU delegate + +TensorFlow Lite provides an experimental GPU delegate that can be used to accelerate models on devices that have a GPU available. + +For an overview of the experimental GPU delegate, see [TensorFlow Lite on GPU](https://www.tensorflow.org/lite/performance/gpu_advanced). For step-by-step tutorials on using the GPU delegate with Android and iOS, see [TensorFlow Lite GPU Delegate Tutorial](https://www.tensorflow.org/lite/performance/gpu). + + +## How do delegates work? + +Let's say we have a simple model graph such as the following: + +![Original graph](../images/performance/tflite_delegate_graph_1.png "Original Graph") + +If a delegate was provided for specific operations, then TensorFlow Lite will split the graph into multiple subgraphs where each subgraph will be handled by a delegate. + +Let's assume that there is a delegate "MyDelegate," which has a faster implementation for Conv2D and Mean operations. The resulting main graph will be updated to look like below. + +![Graph with delegate](../images/performance/tflite_delegate_graph_2.png "Graph with delegate") + +Each subgraph that is handled by a delegate will be replaced with a node that evaluates the subgraph on its invoked call. + +Depending on the model, the final graph can end up with one node, which means that all of the graphs were delegated or multiple nodes handled the subgraphs. In general, you don't want to have multiple subgraphs handled by the delegate, since each time you switch from delegate to the main graph, there is an overhead for passing the results from the subgraph to the main graph. It's not always safe to share memory. + + +## How to add a delegate + +_Note that the API used below is experimental and is subject to change._ + +Based on the previous section, to add a delegate, we need to do the following: + + + +1. Define a kernel node that is responsible for evaluating the delegate subgraph +1. Create an instance of [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api_internal.h#L545), which is responsible for registering the kernel node and claiming the nodes that the delegate can execute + +To see it in code, let's define a delegate and call it "MyDelegate," which can execute Conv2D and Mean operations faster. + +``` +// This is where the execution of the operations or whole graph happens. +// The class below has an empty implementation just as a guideline +// on the structure. +class MyDelegate { + public: + // Returns true if my delegate can handle this type of op. + static bool SupportedOp(const TfLiteRegistration* registration) { + switch (registration->builtin_code) { + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinMean: + return true; + default: + return false; + } + } + + // Any initialization code needed + bool Init() {} + // Any preparation work needed (e.g. allocate buffers) + bool Prepare(TfLiteContext* context, TfLiteNode* node) {} + // Actual running of the delegate subgraph. + bool Invoke(TfLiteContext* context, TfLiteNode* node) {} + // ... Add any other methods needed. +}; + +// Create the TfLiteRegistration for the Kernel node which will replace +// the subrgaph in the main TfLite graph. +TfLiteRegistration GetMyDelegateNodeRegistration() { + // This is the registration for the Delegate Node that gets added to + // the TFLite graph instead of the subGraph it replaces. + // It is treated as a an OP node. But in our case + // Init will initialize the delegate + // Invoke will run the delegate graph. + // Prepare for preparing the delegate. + // Free for any cleaning needed by the delegate. + TfLiteRegistration kernel_registration; + kernel_registration.builtin_code = kTfLiteBuiltinDelegate; + kernel_registration.custom_name = "MyDelegate"; + kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }; + kernel_registration.init = [](TfLiteContext* context, const char* buffer, + size_t) -> void* { + // In the node init phase, initialize MyDelegate instance + const TfLiteDelegateParams* delegate_params = + reinterpret_cast(buffer); + MyDelegate* my_delegate = new MyDelegate; + if (!my_delegate->Init(context, params)) { + return nullptr; + } + return my_delegate; + }; + kernel_registration.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + MyDelegate* kernel = reinterpret_cast(node->user_data); + return kernel->Invoke(context, node); + }; + kernel_registration.prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + MyDelegate* kernel = reinterpret_cast(node->user_data); + return kernel->Prepare(context, node); + }; + + return kernel_registration; +} + +// TfLiteDelegate methods + +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Claim all nodes that can be evaluated by the delegate and ask the + // framework to update the graph with delegate kernel instead. + // Reserve 1 element, since we need first element to be size. + std::vector supported_nodes(1); + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + TfLiteNode* node; + TfLiteRegistration* registration; + for (int node_index : TfLiteIntArrayView(plan)) { + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + if (MyDelegate::SupportedOp(registration)) { + supported_nodes.push_back(node_index); + } + } + // Set first element to the number of nodes to replace. + supported_nodes[0] = supported_nodes.size() - 1; + TfLiteRegistration my_delegate_kernel_registration = + GetMyDelegateNodeRegistration(); + + // This call split the graphs into subgraphs, for subgraphs that can be + // handled by the delegate, it will replace it with a + // 'my_delegate_kernel_registration' + return context->ReplaceNodeSubsetsWithDelegateKernels( + context, my_delegate_kernel_registration, + reinterpret_cast(supported_nodes.data()), delegate); +} + +void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { + // Do any cleanups. +} + +TfLiteStatus CopyToBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + // Copies data from tensor to delegate buffer if needed. + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + // Copies the data from delegate buffer into the tensor raw memory. + return kTfLiteOk; +} + +// Caller takes ownership of the returned pointer. +TfLiteDelegate* CreateMyDelegate() { + TfLiteDelegate* delegate = new TfLiteDelegate; + + delegate->data_ = nullptr; + delegate->flags = kTfLiteDelegateFlagsNone; + delegate->Prepare = &DelegatePrepare; + // This cannot be null. + delegate->CopyFromBufferHandle = &CopyFromBufferHandle; + // This can be null. + delegate->CopyToBufferHandle = &CopyToBufferHandle; + // This can be null. + delegate->FreeBufferHandle = &FreeBufferHandle; + + return delegate; +} + + +// To add the delegate you need to call + +auto* my_delegate = CreateMyDelegate(); +if (interpreter->ModifyGraphWithDelegate(my_delegate) != + kTfLiteOk) { + // Handle error +} else { + interpreter->Invoke(); +} +... +// Don't forget to delete your delegate +delete my_delegate; + + +``` diff --git a/tensorflow/lite/g3doc/performance/gpu.md b/tensorflow/lite/g3doc/performance/gpu.md index c7389226123746180c8c5e6020431ffe579112a7..3b2cca9cf3211abfd4deb6bc59fb57f0225677d3 100644 --- a/tensorflow/lite/g3doc/performance/gpu.md +++ b/tensorflow/lite/g3doc/performance/gpu.md @@ -1,4 +1,4 @@ -# TensorFlow Lite GPU Delegate Tutorial +# TensorFlow Lite GPU delegate [TensorFlow Lite](https://www.tensorflow.org/lite) supports several hardware accelerators. This document describes how to preview the experimental GPU backend using the diff --git a/tensorflow/lite/g3doc/performance/model_optimization.md b/tensorflow/lite/g3doc/performance/model_optimization.md index 2eb432c008168794c3722fcd4f9ab6df0771e48f..dcfc2bed3a0e55ee0d7a1af643eb500faf3960ba 100644 --- a/tensorflow/lite/g3doc/performance/model_optimization.md +++ b/tensorflow/lite/g3doc/performance/model_optimization.md @@ -1,12 +1,14 @@ # Model optimization -Inference efficiency is a critical issue when deploying machine learning -models to mobile devices. Where the computational demand for *training* -grows with the number of models trained on different architectures, the -computational demand for *inference* grows in proportion to the number of -users. The *Tensorflow Model Optimization Toolkit* minimizes the complexity -of inference—the model size, the latency and power consumption. +The *Tensorflow Model Optimization Toolkit* minimizes the complexity +of optimizing inference. Inference efficiency +is a critical issue when deploying machine learning +models to mobile devices because of the model size, latency, and power consumption. +Computational demand for *training* +grows with the number of models trained on different architectures, whereas the +computational demand for *inference* grows in proportion to the number of +users. ## Use cases @@ -23,9 +25,11 @@ Model optimization is useful for: Model optimization uses multiple techniques: -* Reduced parameter count, for example, pruning and structured pruning. -* Reduced representational precision, for example, quantization. -* Update the original model topology to a more efficient one, with reduced parameters or faster execution, for example, tensor decomposition methods and distillation. +* Reduce parameter count with pruning and structured pruning. +* Reduce representational precision with quantization. +* Update the original model topology to a more efficient one with reduced parameters or faster execution. For example, tensor decomposition methods and distillation. + +We support quantization, and are working to add support for other techniques. ## Model quantization @@ -34,19 +38,17 @@ representations of weights and, optionally, activations for both storage and computation. Quantization provides several benefits: * Support on existing CPU platforms. -* Quantizing activations reduces memory access costs for reading and storing intermediate activations. +* Quantization of activations reduces memory access costs for reading and storing intermediate activations. * Many CPU and hardware accelerator implementations provide SIMD instruction capabilities, which are especially beneficial for quantization. TensorFlow Lite provides several levels of support for quantization. -[Post-training quantization](post_training_quantization.md) quantizes weights and activations post training and is very easy to use. -[Quantization-aware training](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/README.md){:.external} allows for training networks that can be quantized with minimal accuracy drop and is only available -for a subset of convolutional neural network architectures. - +* [Post-training quantization](post_training_quantization.md) quantizes weights and activations post training and is very easy to use. +* [Quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize){:.external} allows for training networks that can be quantized with minimal accuracy drop and is only available for a subset of convolutional neural network architectures. ### Latency and accuracy results -Below are the results of the latency and accuracy of post-training quantization and +Below are the latency and accuracy results for post-training quantization and quantization-aware training on a few models. All latency numbers are measured on Pixel 2 devices using a single big core. As the toolkit improves, so will the numbers here: @@ -62,13 +64,12 @@ Pixel 2 devices using a single big core. As the toolkit improves, so will t Latency (Quantization Aware Training) (ms) Size (Original) (MB) Size (Optimized) (MB) - - Mobilenet-v1-1-2240.7090.6570.70 - 18014580.216.94.3 + Mobilenet-v1-1-2240.7090.6570.70 + 1241126416.94.3 Mobilenet-v2-1-2240.7190.6370.709 - 11712180.3143.6 + 899854143.6 Inception_v30.780.7720.775 - 1585118763795.723.9 + 113084554395.723.9 Resnet_v2_1010.7700.768N/A 39732868N/A178.344.9 @@ -79,10 +80,10 @@ Pixel 2 devices using a single big core. As the toolkit improves, so will t ## Choice of quantization tool -As a starting point, check if the models in the TensorFlow Lite model repository can work for -your application. If not, we recommend that users start with the post-training quantization tool +As a starting point, check if the models in [hosted models](../guide/hosted_models.md) can work for +your application. If not, we recommend that users start with the [post-training quantization tool](post_training_quantization.md) since this is broadly applicable and does not require training data. For cases where the accuracy -and latency targets are not met, or hardware accelerator support is important, quantization-aware -training is the better option. +and latency targets are not met, or hardware accelerator support is important, [quantization-aware +training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize){:.external} is the better option. Note: Quantization-aware training supports a subset of convolutional neural network architectures. diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md index cf4d70b2deb3370d0acdde1fcaa8d7fce0cf3bf2..5fdf77b40112d2ba67db0b6f319f07279291da8f 100644 --- a/tensorflow/lite/g3doc/performance/post_training_quantization.md +++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md @@ -1,26 +1,24 @@ # Post-training quantization -Post-training quantization is a general technique to reduce the model size while also +Post-training quantization is a general technique to reduce model size while also providing up to 3x lower latency with little degradation in model accuracy. Post-training -quantization quantizes weights to 8-bits of precision from floating-point. This technique -is enabled as an option in [TensorFlow Lite model converter](../convert): +quantization quantizes weights from floating point to 8-bits of precision. This technique +is enabled as an option in the [TensorFlow Lite converter](../convert/): ``` import tensorflow as tf -converter = tf.lite.TocoConverter.from_saved_model(saved_model_dir) -converter.post_training_quantize = True -tflite_quantized_model = converter.convert() -open("quantized_model.tflite", "wb").write(tflite_quantized_model) - +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) +converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] +tflite_quant_model = converter.convert() ``` -At inference, weights are converted from 8-bits of precision to floating-point and -computed using floating point kernels. This conversion is done once and cached to reduce latency. +At inference, weights are converted from 8-bits of precision to floating point and +computed using floating-point kernels. This conversion is done once and cached to reduce latency. To further improve latency, hybrid operators dynamically quantize activations to 8-bits and perform computations with 8-bit weights and activations. This optimization provides latencies close to fully fixed-point inference. However, the outputs are still stored using -floating-point, so the speedup with hybrid ops is less than a full fixed-point computation. +floating point, so that the speedup with hybrid ops is less than a full fixed-point computation. Hybrid ops are available for the most compute-intensive operators in a network: * [tf.contrib.layers.fully_connected](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/fully_connected) @@ -31,46 +29,55 @@ Hybrid ops are available for the most compute-intensive operators in a network: * [tf.nn.dynamic_rnn for LSTM and BasicRNN Cell types](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn) -Since weights are quantized post-training, there could be an accuracy loss, particularly for +Since weights are quantized post training, there could be an accuracy loss, particularly for smaller networks. Pre-trained fully quantized models are provided for specific networks in -the [TensorFlow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md#image-classification-quantized-models){:.external}. It is important to check the accuracy of the quantized model to verify that any degradation +the [TensorFlow Lite model repository](../models/). It is important to check the accuracy of the quantized model to verify that any degradation in accuracy is within acceptable limits. There is a tool to evaluate [TensorFlow Lite model accuracy](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/accuracy/README.md){:.external}. -If the accuracy drop is too high, consider using [quantization aware training](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/README.md){:.external}. +If the accuracy drop is too high, consider using [quantization aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize){:.external}. ### Representation for quantized tensors TensorFlow approaches the conversion of floating-point arrays of numbers into 8-bit representations as a compression problem. Since the weights and activation tensors in trained neural network models tend to have values that are distributed -across comparatively small ranges (for example, -15 to +15 for weights or -500 to -1000 for image model activations). And since neural nets tend to be robust -handling noise, the error introduced by quantizing to a small set of values -maintains the precision of the overall results within an acceptable threshold. A -chosen representation must perform fast calculations, especially the large matrix -multiplications that comprise the bulk of the computations while running a model. +across comparatively small ranges (e.g. -15 to +15 for weights or -500 to +1000 for image model activations). + +Since neural networks tend to be robust at handling noise, the error introduced +by quantizing to a small set of values maintains the precision of the overall +results within an acceptable threshold. A chosen representation must perform +fast calculations, especially with large matrix multiplications that comprise +the bulk of the computations while running a model. This is represented with two floats that store the overall minimum and maximum values corresponding to the lowest and highest quantized value. Each entry in the quantized array represents a float value in that range, distributed linearly -between the minimum and maximum. For example, with a minimum of -10.0 and maximum -of 30.0f, and an 8-bit array, the quantized values represent the following: +between the minimum and maximum. + +With our post-training quantization tooling, we use symmetric quantization for +our weights, meaning we expand the represented range and force the min and max +to be the negative of each other. + +For example, with an overall minimum of -10.0 and a maximum +of 30.0f, we instead represent a minimum of -30.0 and maximum of 30.0f. In an +8-bit array, the quantized values would be represented as follows:
- - - + + + +
QuantizedFloat
0-10.0
12810.0
25530.0
-42-10.0
00
12730.0
-12730.0 (this value does not ever show up)
- Table 2: Example quantized value range + Table 2: Quantized value range example
The advantages of this representation format are: * It efficiently represents an arbitrary magnitude of ranges. -* The values don't have to be symmetrical. -* The format represents both signed and unsigned values. * The linear spread makes multiplications straightforward. +* A symmetric range for weights enables downstream hardware optimizations. diff --git a/tensorflow/lite/g3doc/r2/convert/concrete_function.md b/tensorflow/lite/g3doc/r2/convert/concrete_function.md new file mode 100644 index 0000000000000000000000000000000000000000..c17981353d97e24fb215ba024d203efaaf909eb9 --- /dev/null +++ b/tensorflow/lite/g3doc/r2/convert/concrete_function.md @@ -0,0 +1,208 @@ +# Generating a concrete function + +In order to convert TensorFlow 2.0 models to TensorFlow Lite, the model needs to +be exported as a concrete function. This document outlines what a concrete +function is and how to generate one for an existing model. + +[TOC] + +## Background + +In TensorFlow 2.0, eager execution is on by default. TensorFlow's eager +execution is an imperative programming environment that evaluates operations +immediately, without building graphs. Operations return concrete values instead +of constructing a computational graph to run later. A detailed guide on eager +execution is available +[here](https://github.com/tensorflow/docs/blob/master/site/en/r2/guide/eager.ipynb). + +While running imperatively with eager execution makes development and debugging +more interactive, it doesn't allow for deploying on-device. The `tf.function` +API makes it possible to save models as graphs, which is required to run +TensorFlow Lite in 2.0. All operations wrapped in the `tf.function` decorator +can be exported as a graph which can then be converted to the TensorFlow Lite +FlatBuffer format. + +## Terminology + +The following terminology is used in this document: + +* **Signature** - The inputs and outputs for a set of operations. +* **Concrete function** - Graph with a single signature. +* **Polymorphic function** - Python callable that encapsulates several + concrete function graphs behind one API. + +## Methodology + +This section describes how to export a concrete function. + +### Annotate functions with `tf.function` + +Annotating a function with `tf.function` generates a *polymorphic function* +containing those operations. All operations that are not annotated with +`tf.function` will be evaluated with eager execution. The examples below show +how to use `tf.function`. + +```python +@tf.function +def pow(x): + return x ** 2 +``` + +```python +tf.function(lambda x : x ** 2) +``` + +### Create an object to save + +The `tf.function` can be optionally stored as part of a `tf.Module` object. +Variables should only be defined once within the `tf.Module`. The examples below +show two different approaches for creating a class that derives `Checkpoint`. + +```python +class BasicModel(tf.Module): + + def __init__(self): + self.const = None + + @tf.function + def pow(self, x): + if self.const is None: + self.const = tf.Variable(2.) + return x ** self.const + +root = BasicModel() +``` + +```python +root = tf.Module() +root.const = tf.Variable(2.) +root.pow = tf.function(lambda x : x ** root.const) +``` + +### Exporting the concrete function + +The concrete function defines a graph that can be converted to TensorFlow Lite +model or be exported to a SavedModel. In order to export a concrete function +from the polymorphic function, the signature needs to be defined. The signature +can be defined the following ways: + +* Define `input_signature` parameter in `tf.function`. +* Pass in `tf.TensorSpec` into `get_concrete_function`: e.g. + `tf.TensorSpec(shape=[1], dtype=tf.float32)`. +* Pass in a sample input tensor into `get_concrete_function`: e.g. + `tf.constant(1., shape=[1])`. + +The follow example shows how to define the `input_signature` parameter for +`tf.function`. + +```python +class BasicModel(tf.Module): + + def __init__(self): + self.const = None + + @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)]) + def pow(self, x): + if self.const is None: + self.const = tf.Variable(2.) + return x ** self.const + +# Create the tf.Module object. +root = BasicModel() + +# Get the concrete function. +concrete_func = root.pow.get_concrete_function() +``` + +The example below passes in a sample input tensor into `get_concrete_function`. + +```python +# Create the tf.Module object. +root = tf.Module() +root.const = tf.Variable(2.) +root.pow = tf.function(lambda x : x ** root.const) + +# Get the concrete function. +input_data = tf.constant(1., shape=[1]) +concrete_func = root.pow.get_concrete_function(input_data) +``` + +## Example program + +```python +import tensorflow as tf + +# Initialize the tf.Module object. +root = tf.Module() + +# Instantiate the variable once. +root.var = None + +# Define a function so that the operations aren't computed in advance. +@tf.function +def exported_function(x): + # Each variable can only be defined once. The variable can be defined within + # the function but needs to contain a reference outside of the function. + if root.var is None: + root.var = tf.Variable(tf.random.uniform([2, 2])) + root.const = tf.constant([[37.0, -23.0], [1.0, 4.0]]) + root.mult = tf.matmul(root.const, root.var) + return root.mult * x + +# Save the function as part of the tf.Module object. +root.func = exported_function + +# Get the concrete function. +concrete_func = root.func.get_concrete_function( + tf.TensorSpec([1, 1], tf.float32)) +``` + +## Common Questions + +### How do I save a concrete function as a SavedModel? + +Users who want to save their TensorFlow model before converting it to TensorFlow +Lite should save it as a SavedModel. After getting the concrete function, call +`tf.saved_model.save` to save the model. The example above can be saved using +the following instruction. + +```python +tf.saved_model.save(root, export_dir, concrete_func) +``` + +Reference the +[SavedModel guide](https://github.com/tensorflow/docs/blob/master/site/en/r2/guide/saved_model.ipynb) +for detailed instructions on using SavedModels. + +### How do I get a concrete function from the SavedModel? + +Each concrete function within a SavedModel can be identified by a signature key. +The default signature key is `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. +The example below shows how to get the concrete function from a model. + +```python +model = tf.saved_model.load(export_dir) +concrete_func = model.signatures[ + tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] +``` + +### How do I get a concrete function for a `tf.Keras` model? + +There are two approaches that you can use: + +1. Save the model as a SavedModel. A concrete function will be generated during + the saving process, which can be accessed upon loading the model. +2. Annotate the model with `tf.function` as seen below. + +```python +model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])]) +model.compile(optimizer='sgd', loss='mean_squared_error') +model.fit(x=[-1, 0, 1, 2, 3, 4], y=[-3, -1, 1, 3, 5, 7], epochs=50) + +# Get the concrete function from the Keras model. +run_model = tf.function(lambda x : model(x)) + +# Save the concrete function. +concrete_func = run_model.get_concrete_function( + tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)) +``` diff --git a/tensorflow/lite/g3doc/r2/convert/index.md b/tensorflow/lite/g3doc/r2/convert/index.md new file mode 100644 index 0000000000000000000000000000000000000000..f1e763e027c8caf7c34c01e991ff4d93076aff32 --- /dev/null +++ b/tensorflow/lite/g3doc/r2/convert/index.md @@ -0,0 +1,24 @@ +# TensorFlow Lite converter + +The TensorFlow Lite converter takes a TensorFlow model represented as a +[concrete function](concrete_function.md), and generates a TensorFlow Lite +[`FlatBuffer`](https://google.github.io/flatbuffers/) file (`.tflite`). + +Note: This page contains documentation on the converter API for TensorFlow 2.0. +The API for TensorFlow 1.X is available +[here](https://www.tensorflow.org/lite/convert/). + +## Device deployment + +The TensorFlow Lite `FlatBuffer` file is then deployed to a client device (e.g. +mobile, embedded) and run locally using the TensorFlow Lite interpreter. This +conversion process is shown in the diagram below: + +![TFLite converter workflow](../images/convert/workflow.svg) + +## Converting models + +The TensorFlow Lite converter can be used from the [Python API](python_api.md). +Using the Python API makes it easier to convert models as part of a model +development pipeline and helps mitigate +[compatibility](../guide/ops_compatibility.md) issues early on. diff --git a/tensorflow/lite/g3doc/r2/convert/python_api.md b/tensorflow/lite/g3doc/r2/convert/python_api.md new file mode 100644 index 0000000000000000000000000000000000000000..882346fb8cf7fc0db7bbd44e1d8d6a2c95ac5f59 --- /dev/null +++ b/tensorflow/lite/g3doc/r2/convert/python_api.md @@ -0,0 +1,252 @@ +# Converter Python API guide + +This page provides examples on how to use the +[TensorFlow Lite converter](index.md) using the Python API in TensorFlow 2.0. + +[TOC] + +## Python API + +The Python API for converting TensorFlow models to TensorFlow Lite in TensorFlow +2.0 is +[`tf.lite.TFLiteConverter.from_concrete_function()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/lite/TFLiteConverter). +Documentation on concrete functions is available [here](concrete_function.md). + +This document contains [example usages](#examples) of the API, a detailed list +of [changes in the API between 1.X and 2.0](#differences), and +[instructions](#versioning) on running the different versions of TensorFlow. + +## Examples + +### Exporting a concrete function + +The following example shows how to convert a TensorFlow concrete function into a +TensorFlow Lite `FlatBuffer`. + +```python +import tensorflow as tf + +# Construct a basic model. +root = tf.train.Checkpoint() +root.v1 = tf.Variable(3.) +root.v2 = tf.Variable(2.) +root.f = tf.function(lambda x: root.v1 * root.v2 * x) + +# Create the concrete function. +input_data = tf.constant(1., shape=[1, 1]) +concrete_func = root.f.get_concrete_function(input_data) + +# Convert the model. +converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func) +tflite_model = converter.convert() +``` + +### Exporting a SavedModel + +The following example shows how to convert a SavedModel into a TensorFlow Lite +`FlatBuffer`. + +Note: Due to a known issue with preserving input shapes with SavedModels, +`set_shape` needs to be called for all input tensors. + +```python +import tensorflow as tf + +# Construct a basic model. +root = tf.train.Checkpoint() +root.v1 = tf.Variable(3.) +root.v2 = tf.Variable(2.) +root.f = tf.function(lambda x: root.v1 * root.v2 * x) + +# Save the model. +export_dir = "/tmp/test_saved_model" +input_data = tf.constant(1., shape=[1, 1]) +to_save = root.f.get_concrete_function(input_data) +tf.saved_model.save(root, export_dir, to_save) + +# Load model and get the concrete function. +model = tf.saved_model.load(export_dir) +concrete_func = model.signatures[ + tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + +# Set the shape manually. +concrete_func.inputs[0].set_shape(input_data.shape) + +# Convert the model. +converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func) +tflite_model = converter.convert() +``` + +### Exporting a Keras model + +The following example shows how to convert a `tf.keras` model into a TensorFlow +Lite `FlatBuffer`. + +```python +import tensorflow as tf + +# Create a simple Keras model. +x = [-1, 0, 1, 2, 3, 4] +y = [-3, -1, 1, 3, 5, 7] + +model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) +model.compile(optimizer='sgd', loss='mean_squared_error') +model.fit(x, y, epochs=50) + +# Get the concrete function from the Keras model. +run_model = tf.function(lambda x : model(x)) +concrete_func = run_model.get_concrete_function( + tf.TensorSpec([None, 1], tf.float32)) + +# Convert the model. +converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func) +tflite_model = converter.convert() +``` + +### End-to-end MobileNet conversion + +The following example shows how to convert and run inference on a pre-trained +`tf.Keras` MobileNet model to TensorFlow Lite. In order to load the model from +file, use `model_path` instead of `model_content`. + +```python +import numpy as np +import tensorflow as tf + +# Load the MobileNet tf.keras model. +model = tf.keras.applications.MobileNetV2( + weights="imagenet", input_shape=(224, 224, 3)) + +# Save and load the model to generate the concrete function to export. +export_dir = "/tmp/test_model/mobilenet" +tf.saved_model.save(model, export_dir) +model = tf.saved_model.load(export_dir) +concrete_func = model.signatures[ + tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + +# Set the shape manually. +concrete_func.inputs[0].set_shape([1, 224, 224, 3]) + +# Convert the model. +converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func) +tflite_model = converter.convert() + +# Load TFLite model and allocate tensors. +interpreter = tf.lite.Interpreter(model_content=tflite_model) +interpreter.allocate_tensors() + +# Get input and output tensors. +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +# Test model on random input data. +input_shape = input_details[0]['shape'] +input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +interpreter.set_tensor(input_details[0]['index'], input_data) + +interpreter.invoke() +output_data = interpreter.get_tensor(output_details[0]['index']) +print(output_data) +``` + +## Summary of changes in `TFLiteConverter` between 1.X and 2.0 + +The following section summarizes the changes in `TFLiteConverter` from 1.X to +2.0. If any of the changes raise concerns, please file a +[GitHub issue](https://github.com/tensorflow/tensorflow/issues). + +### Supported formats + +`TFLiteConverter` in 2.0 supports SavedModels and Keras model files generated in +both 1.X and 2.0. However, the conversion process no longer supports frozen +`GraphDefs` generated in 1.X. Users who want to convert frozen `GraphDefs` to +TensorFlow Lite should use `tensorflow.compat.v1`. + +### Quantization-aware training + +The following attributes and methods associated with +[quantization-aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize) +have been removed from `TFLiteConverter` in TensorFlow 2.0: + +* `inference_type` +* `inference_input_type` +* `quantized_input_stats` +* `default_ranges_stats` +* `reorder_across_fake_quant` +* `change_concat_input_ranges` +* `post_training_quantize` - Deprecated in the 1.X API +* `get_input_arrays()` + +The rewriter function that supports quantization-aware training does not support +models generated by TensorFlow 2.0. Additionally, TensorFlow Lite’s quantization +API is being reworked and streamlined in a direction that supports +quantization-aware training through the Keras API. These attributes will be +removed in the 2.0 API until the new quantization API is launched. Users who +want to convert models generated by the rewriter function can use +`tensorflow.compat.v1`. + +### Changes to attributes + +The `target_ops` attribute has become an attribute of `TargetSpec` and renamed +to `supported_ops` in line with future additions to the optimization framework. + +Additionally, the following attributes have been removed: + +* `drop_control_dependency` (default: `True`) - Control flow is currently not + supported by TFLite so it is always `True`. +* _Graph visualization_ - The recommended approach for visualizing a + TensorFlow Lite graph in TensorFlow 2.0 will be to use + [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py). + Unlike GraphViz, it enables users to visualize the graph after post training + quantization has occurred. The following attributes related to graph + visualization will be removed: + * `output_format` + * `dump_graphviz_dir` + * `dump_graphviz_video` + +### Deprecated APIs + +The following methods that were previously deprecated in 1.X will no longer be +exported in 2.0: + +* `lite.toco_convert` +* `lite.TocoConverter` + +## Installing TensorFlow + +### Installing the TensorFlow 2.0 nightly + +The TensorFlow 2.0 nightly can be installed using the following command: + +``` +pip install tf-nightly-2.0-preview +``` + +### Using TensorFlow 2.0 from a 1.X installation + +TensorFlow 2.0 can be enabled from recent 1.X installations using the following +code snippet. + +```python +import tensorflow.compat.v2 as tf + +tf.enable_v2_behavior() +``` + +### Using TensorFlow 1.X from a 2.0 installation + +TensorFlow 1.X can be enabled from 2.0 installation. This can be useful if you +are using features that are no longer supported in 2.0. + +```python +import tensorflow.compat.v1 as tf +``` + +### Build from source code + +In order to run the latest version of the TensorFlow Lite Converter Python API, +either install the nightly build with +[pip](https://www.tensorflow.org/install/pip) (recommended) or +[Docker](https://www.tensorflow.org/install/docker), or +[build the pip package from source](https://www.tensorflow.org/install/source). diff --git a/tensorflow/lite/g3doc/r2/images/convert/workflow.svg b/tensorflow/lite/g3doc/r2/images/convert/workflow.svg new file mode 100644 index 0000000000000000000000000000000000000000..2d8339f35f3292d964367a1f788187e81178e44d --- /dev/null +++ b/tensorflow/lite/g3doc/r2/images/convert/workflow.svg @@ -0,0 +1 @@ + diff --git a/tensorflow/lite/g3doc/tfmobile/android_build.md b/tensorflow/lite/g3doc/tfmobile/android_build.md deleted file mode 100644 index f8c0243298e435382a7514e72ada89880fb00c1c..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/tfmobile/android_build.md +++ /dev/null @@ -1,195 +0,0 @@ -# Building TensorFlow on Android - -Warning: We expect to deprecate TensorFlow Mobile in early 2019 - -
-

- TensorFlow Lite is our main mobile and embedded offering. We are - working hard to close the feature gap between TensorFlow Mobile and - TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We - will give ample notice to our users when we get to that point and will - provide help and support to ensure easy migrations. -

-

- In the meantime, please use TensorFlow Lite. If you have a feature request, - such as a missing op, please post to our GitHub. -

-
- -To get you started working with TensorFlow on Android, we'll walk through two -ways to build our TensorFlow mobile demos and deploying them on an Android -device. The first is Android Studio, which lets you build and deploy in an -IDE. The second is building with Bazel and deploying with ADB on the command -line. - -Why choose one or the other of these methods? - -The simplest way to use TensorFlow on Android is to use Android Studio. If you -aren't planning to customize your TensorFlow build at all, or if you want to use -Android Studio's editor and other features to build an app and just want to add -TensorFlow to it, we recommend using Android Studio. - -If you are using custom ops, or have some other reason to build TensorFlow from -scratch, scroll down and see our instructions -for [building the demo with Bazel](#build_the_demo_using_bazel). - -## Build the demo using Android Studio - -**Prerequisites** - -If you haven't already, do the following two things: - -- Install [Android Studio](https://developer.android.com/studio/index.html), - following the instructions on their website. - -- Clone the TensorFlow repository from GitHub: - - git clone https://github.com/tensorflow/tensorflow - -**Building** - -1. Open Android Studio, and from the Welcome screen, select **Open an existing - Android Studio project**. - -2. From the **Open File or Project** window that appears, navigate to and select - the `tensorflow/examples/android` directory from wherever you cloned the - TensorFlow GitHub repo. Click OK. - - If it asks you to do a Gradle Sync, click OK. - - You may also need to install various platforms and tools, if you get - errors like "Failed to find target with hash string 'android-23' and similar. - -3. Open the `build.gradle` file (you can go to **1:Project** in the side panel - and find it under the **Gradle Scripts** zippy under **Android**). Look for - the `nativeBuildSystem` variable and set it to `none` if it isn't already: - - // set to 'bazel', 'cmake', 'makefile', 'none' - def nativeBuildSystem = 'none' - -4. Click the *Run* button (the green arrow) or select *Run > Run 'android'* from the - top menu. You may need to rebuild the project using *Build > Rebuild Project*. - - If it asks you to use Instant Run, click **Proceed Without Instant Run**. - - Also, you need to have an Android device plugged in with developer options - enabled at this - point. See [here](https://developer.android.com/studio/run/device.html) for - more details on setting up developer devices. - -This installs three apps on your phone that are all part of the TensorFlow -Demo. See [Android Sample Apps](#android_sample_apps) for more information about -them. - -## Adding TensorFlow to your apps using Android Studio - -To add TensorFlow to your own apps on Android, the simplest way is to add the -following lines to your Gradle build file: - - allprojects { - repositories { - jcenter() - } - } - - dependencies { - implementation 'org.tensorflow:tensorflow-android:+' - } - -This automatically downloads the latest stable version of TensorFlow as an AAR -and installs it in your project. - -## Build the demo using Bazel - -Another way to use TensorFlow on Android is to build an APK -using [Bazel](https://bazel.build/) and load it onto your device -using [ADB](https://developer.android.com/studio/command-line/adb.html). This -requires some knowledge of build systems and Android developer tools, but we'll -guide you through the basics here. - -- First, follow our instructions for - installing from sources. - This will also guide you through installing Bazel and cloning the - TensorFlow code. - -- Download the Android [SDK](https://developer.android.com/studio/index.html) - and [NDK](https://developer.android.com/ndk/downloads/index.html) if you do - not already have them. You need at least version 12b of the NDK, and 23 of the - SDK. - -- In your copy of the TensorFlow source, update the - [WORKSPACE](https://github.com/tensorflow/tensorflow/blob/master/WORKSPACE) - file with the location of your SDK and NDK, where it says <PATH_TO_NDK> - and <PATH_TO_SDK>. - -- Run Bazel to build the demo APK: - - bazel build -c opt //tensorflow/examples/android:tensorflow_demo - -- Use [ADB](https://developer.android.com/studio/command-line/adb.html#move) to - install the APK onto your device: - - adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk - -Note: In general when compiling for Android with Bazel you need -`--config=android` on the Bazel command line, though in this case this -particular example is Android-only, so you don't need it here. - -This installs three apps on your phone that are all part of the TensorFlow -Demo. See [Android Sample Apps](#android_sample_apps) for more information about -them. - -## Android Sample Apps - -The -[Android example code](https://www.tensorflow.org/code/tensorflow/examples/android/) is -a single project that builds and installs three sample apps which all use the -same underlying code. The sample apps all take video input from a phone's -camera: - -- **TF Classify** uses the Inception v3 model to label the objects it’s pointed - at with classes from Imagenet. There are only 1,000 categories in Imagenet, - which misses most everyday objects and includes many things you’re unlikely to - encounter often in real life, so the results can often be quite amusing. For - example there’s no ‘person’ category, so instead it will often guess things it - does know that are often associated with pictures of people, like a seat belt - or an oxygen mask. If you do want to customize this example to recognize - objects you care about, you can use - the - [TensorFlow for Poets codelab](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) as - an example for how to train a model based on your own data. - -- **TF Detect** uses a multibox model to try to draw bounding boxes around the - locations of people in the camera. These boxes are annotated with the - confidence for each detection result. Results will not be perfect, as this - kind of object detection is still an active research topic. The demo also - includes optical tracking for when objects move between frames, which runs - more frequently than the TensorFlow inference. This improves the user - experience since the apparent frame rate is faster, but it also gives the - ability to estimate which boxes refer to the same object between frames, which - is important for counting objects over time. - -- **TF Stylize** implements a real-time style transfer algorithm on the camera - feed. You can select which styles to use and mix between them using the - palette at the bottom of the screen, and also switch out the resolution of the - processing to go higher or lower rez. - -When you build and install the demo, you'll see three app icons on your phone, -one for each of the demos. Tapping on them should open up the app and let you -explore what they do. You can enable profiling statistics on-screen by tapping -the volume up button while they’re running. - -### Android Inference Library - -Because Android apps need to be written in Java, and core TensorFlow is in C++, -TensorFlow has a JNI library to interface between the two. Its interface is aimed -only at inference, so it provides the ability to load a graph, set up inputs, -and run the model to calculate particular outputs. You can see the full -documentation for the minimal set of methods in -[TensorFlowInferenceInterface.java](https://www.tensorflow.org/code/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java) - -The demos applications use this interface, so they’re a good place to look for -example usage. You can download prebuilt binary jars -at -[ci.tensorflow.org](https://ci.tensorflow.org/view/Nightly/job/nightly-android/). diff --git a/tensorflow/lite/g3doc/tfmobile/index.md b/tensorflow/lite/g3doc/tfmobile/index.md deleted file mode 100644 index 15f0fd396134e40e89266182cb308080d9d250cb..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/tfmobile/index.md +++ /dev/null @@ -1,298 +0,0 @@ -# Overview - -Warning: We expect to deprecate TensorFlow Mobile in early 2019 - -
-

- TensorFlow Lite is our main mobile and embedded offering. We are - working hard to close the feature gap between TensorFlow Mobile and - TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We - will give ample notice to our users when we get to that point and will - provide help and support to ensure easy migrations. -

-

- In the meantime, please use TensorFlow Lite. If you have a feature request, - such as a missing op, please post to our GitHub. -

-
- -TensorFlow was designed to be a good deep learning solution for mobile -platforms. Currently we have two solutions for deploying machine learning -applications on mobile and embedded devices: TensorFlow for Mobile and -TensorFlow Lite. - -## TensorFlow Lite versus TensorFlow Mobile - -Here are a few of the differences between the two: - -- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps - developed with TensorFlow Lite will have a smaller binary size, fewer - dependencies, and better performance. - -- TensorFlow Lite is in developer preview, so not all use cases are covered yet. - We expect you to use TensorFlow Mobile to cover production cases. - -- TensorFlow Lite supports only a limited set of operators, so not all models - will work on it by default. TensorFlow for Mobile has a fuller set of - supported functionality. - -TensorFlow Lite provides better performance and a small binary size on mobile -platforms as well as the ability to leverage hardware acceleration if available -on their platforms. In addition, it has many fewer dependencies so it can be -built and hosted on simpler, more constrained device scenarios. TensorFlow Lite -also allows targeting accelerators through the [Neural Networks -API](https://developer.android.com/ndk/guides/neuralnetworks/index.html). - -TensorFlow Lite currently has coverage for a limited set of operators. While -TensorFlow for Mobile supports only a constrained set of ops by default, in -principle if you use an arbitrary operator in TensorFlow, it can be customized -to build that kernel. Thus use cases which are not currently supported by -TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite -evolves, it will gain additional operators, and the decision will be easier to -make. - - -## Introduction to TensorFlow Mobile - -TensorFlow was designed from the ground up to be a good deep learning solution -for mobile platforms like Android and iOS. This mobile guide should help you -understand how machine learning can work on mobile platforms and how to -integrate TensorFlow into your mobile apps effectively and efficiently. - -## About this Guide - -This guide is aimed at developers who have a TensorFlow model that’s -successfully working in a desktop environment, who want to integrate it into -a mobile application, and cannot use TensorFlow Lite. Here are the -main challenges you’ll face during that process: - -- Understanding how to use Tensorflow for mobile. -- Building TensorFlow for your platform. -- Integrating the TensorFlow library into your application. -- Preparing your model file for mobile deployment. -- Optimizing for latency, RAM usage, model file size, and binary size. - -## Common use cases for mobile machine learning - -**Why run TensorFlow on mobile?** - -Traditionally, deep learning has been associated with data centers and giant -clusters of high-powered GPU machines. However, it can be very expensive and -time-consuming to send all of the data a device has access to across a network -connection. Running on mobile makes it possible to deliver very interactive -applications in a way that’s not possible when you have to wait for a network -round trip. - -Here are some common use cases for on-device deep learning: - -### Speech Recognition - -There are a lot of interesting applications that can be built with a -speech-driven interface, and many of these require on-device processing. Most of -the time a user isn’t giving commands, and so streaming audio continuously to a -remote server would be a waste of bandwidth, since it would mostly be silence or -background noises. To solve this problem it’s common to have a small neural -network running on-device -[listening out for a particular keyword](../tutorials/sequences/audio_recognition). -Once that keyword has been spotted, the rest of the -conversation can be transmitted over to the server for further processing if -more computing power is needed. - -### Image Recognition - -It can be very useful for a mobile app to be able to make sense of a camera -image. If your users are taking photos, recognizing what’s in them can help your -camera apps apply appropriate filters, or label the photos so they’re easily -findable. It’s important for embedded applications too, since you can use image -sensors to detect all sorts of interesting conditions, whether it’s spotting -endangered animals in the wild -or -[reporting how late your train is running](https://svds.com/tensorflow-image-recognition-raspberry-pi/). - -TensorFlow comes with several examples of recognizing the types of objects -inside images along with a variety of different pre-trained models, and they can -all be run on mobile devices. You can try out -our -[Tensorflow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) and -[Tensorflow for Poets 2: Optimize for Mobile](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/index.html#0) codelabs to -see how to take a pretrained model and run some very fast and lightweight -training to teach it to recognize specific objects, and then optimize it to -run on mobile. - -### Object Localization - -Sometimes it’s important to know where objects are in an image as well as what -they are. There are lots of augmented reality use cases that could benefit a -mobile app, such as guiding users to the right component when offering them -help fixing their wireless network or providing informative overlays on top of -landscape features. Embedded applications often need to count objects that are -passing by them, whether it’s pests in a field of crops, or people, cars and -bikes going past a street lamp. - -TensorFlow offers a pretrained model for drawing bounding boxes around people -detected in images, together with tracking code to follow them over time. The -tracking is especially important for applications where you’re trying to count -how many objects are present over time, since it gives you a good idea when a -new object enters or leaves the scene. We have some sample code for this -available for Android [on -GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android), -and also a [more general object detection -model](https://github.com/tensorflow/models/tree/master/research/object_detection/README.md) -available as well. - -### Gesture Recognition - -It can be useful to be able to control applications with hand or other -gestures, either recognized from images or through analyzing accelerometer -sensor data. Creating those models is beyond the scope of this guide, but -TensorFlow is an effective way of deploying them. - -### Optical Character Recognition - -Google Translate’s live camera view is a great example of how effective -interactive on-device detection of text can be. - -
- -
- -There are multiple steps involved in recognizing text in images. You first have -to identify the areas where the text is present, which is a variation on the -object localization problem, and can be solved with similar techniques. Once you -have an area of text, you then need to interpret it as letters, and then use a -language model to help guess what words they represent. The simplest way to -estimate what letters are present is to segment the line of text into individual -letters, and then apply a simple neural network to the bounding box of each. You -can get good results with the kind of models used for MNIST, which you can find -in TensorFlow’s tutorials, though you may want a higher-resolution input. A -more advanced alternative is to use an LSTM model to process a whole line of -text at once, with the model itself handling the segmentation into different -characters. - -### Translation - -Translating from one language to another quickly and accurately, even if you -don’t have a network connection, is an important use case. Deep networks are -very effective at this sort of task, and you can find descriptions of a lot of -different models in the literature. Often these are sequence-to-sequence -recurrent models where you’re able to run a single graph to do the whole -translation, without needing to run separate parsing stages. - -### Text Classification - -If you want to suggest relevant prompts to users based on what they’re typing or -reading, it can be very useful to understand the meaning of the text. This is -where text classification comes in. Text classification is an umbrella term -that covers everything from sentiment analysis to topic discovery. You’re likely -to have your own categories or labels that you want to apply, so the best place -to start is with an example -like -[Skip-Thoughts](https://github.com/tensorflow/models/tree/master/research/skip_thoughts/), -and then train on your own examples. - -### Voice Synthesis - -A synthesized voice can be a great way of giving users feedback or aiding -accessibility, and recent advances such as -[WaveNet](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) show -that deep learning can offer very natural-sounding speech. - -## Mobile machine learning and the cloud - -These examples of use cases give an idea of how on-device networks can -complement cloud services. Cloud has a great deal of computing power in a -controlled environment, but running on devices can offer higher interactivity. -In situations where the cloud is unavailable, or your cloud capacity is limited, -you can provide an offline experience, or reduce cloud workload by processing -easy cases on device. - -Doing on-device computation can also signal when it's time to switch to working -on the cloud. A good example of this is hotword detection in speech. Since -devices are able to constantly listen out for the keywords, this then triggers a -lot of traffic to cloud-based speech recognition once one is recognized. Without -the on-device component, the whole application wouldn’t be feasible, and this -pattern exists across several other applications as well. Recognizing that some -sensor input is interesting enough for further processing makes a lot of -interesting products possible. - -## What hardware and software should you have? - -TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all -supported operating systems and instructions to install TensorFlow, see -Installing Tensorflow. - -Note that some of the sample code we provide for mobile TensorFlow requires you -to compile TensorFlow from source, so you’ll need more than just `pip install` -to work through all the sample code. - -To try out the mobile examples, you’ll need a device set up for development, -using -either [Android Studio](https://developer.android.com/studio/install.html), -or [XCode](https://developer.apple.com/xcode/) if you're developing for iOS. - -## What should you do before you get started? - -Before thinking about how to get your solution on mobile: - -1. Determine whether your problem is solvable by mobile machine learning -2. Create a labelled dataset to define your problem -3. Pick an effective model for the problem - -We'll discuss these in more detail below. - -### Is your problem solvable by mobile machine learning? - -Once you have an idea of the problem you want to solve, you need to make a plan -of how to build your solution. The most important first step is making sure that -your problem is actually solvable, and the best way to do that is to mock it up -using humans in the loop. - -For example, if you want to drive a robot toy car using voice commands, try -recording some audio from the device and listen back to it to see if you can -make sense of what’s being said. Often you’ll find there are problems in the -capture process, such as the motor drowning out speech or not being able to hear -at a distance, and you should tackle these problems before investing in the -modeling process. - -Another example would be giving photos taken from your app to people see if they -can classify what’s in them, in the way you’re looking for. If they can’t do -that (for example, trying to estimate calories in food from photos may be -impossible because all white soups look the same), then you’ll need to redesign -your experience to cope with that. A good rule of thumb is that if a human can’t -handle the task then it will be difficult to train a computer to do better. - -### Create a labelled dataset - -After you’ve solved any fundamental issues with your use case, you need to -create a labeled dataset to define what problem you’re trying to solve. This -step is extremely important, more than picking which model to use. You want it -to be as representative as possible of your actual use case, since the model -will only be effective at the task you teach it. It’s also worth investing in -tools to make labeling the data as efficient and accurate as possible. For -example, if you’re able to switch from having to click a button on a web -interface to simple keyboard shortcuts, you may be able to speed up the -generation process a lot. You should also start by doing the initial labeling -yourself, so you can learn about the difficulties and likely errors, and -possibly change your labeling or data capture process to avoid them. Once you -and your team are able to consistently label examples (that is once you -generally agree on the same labels for most examples), you can then try and -capture your knowledge in a manual and teach external raters how to run the same -process. - -### Pick an effective model - -The next step is to pick an effective model to use. You might be able to avoid -training a model from scratch if someone else has already implemented a model -similar to what you need; we have a repository of models implemented in -TensorFlow [on GitHub](https://github.com/tensorflow/models) that you can look -through. Lean towards the simplest model you can find, and try to get started as -soon as you have even a small amount of labelled data, since you’ll get the best -results when you’re able to iterate quickly. The shorter the time it takes to -try training a model and running it in its real application, the better overall -results you’ll see. It’s common for an algorithm to get great training accuracy -numbers but then fail to be useful within a real application because there’s a -mismatch between the dataset and real usage. Prototype end-to-end usage as soon -as possible to create a consistent user experience. diff --git a/tensorflow/lite/g3doc/tfmobile/ios_build.md b/tensorflow/lite/g3doc/tfmobile/ios_build.md deleted file mode 100644 index d922907cdc5fe5ccec8864b456586fce0293a0af..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/tfmobile/ios_build.md +++ /dev/null @@ -1,124 +0,0 @@ -# Building TensorFlow on iOS - -Warning: We expect to deprecate TensorFlow Mobile in early 2019 - -
-

- TensorFlow Lite is our main mobile and embedded offering. We are - working hard to close the feature gap between TensorFlow Mobile and - TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We - will give ample notice to our users when we get to that point and will - provide help and support to ensure easy migrations. -

-

- In the meantime, please use TensorFlow Lite. If you have a feature request, - such as a missing op, please post to our GitHub. -

-
- -## Using CocoaPods - -The simplest way to get started with TensorFlow on iOS is using the CocoaPods -package management system. You can add the `TensorFlow-experimental` pod to your -Podfile, which installs a universal binary framework. This makes it easy to get -started but has the disadvantage of being hard to customize, which is important -in case you want to shrink your binary size. If you do need the ability to -customize your libraries, see later sections on how to do that. - -## Creating your own app - -If you'd like to add TensorFlow capabilities to your own app, do the following: - -- Create your own app or load your already-created app in XCode. - -- Add a file named Podfile at the project root directory with the following content: - - target 'YourProjectName' - pod 'TensorFlow-experimental' - -- Run `pod install` to download and install the `TensorFlow-experimental` pod. - -- Open `YourProjectName.xcworkspace` and add your code. - -- In your app's **Build Settings**, make sure to add `$(inherited)` to the - **Other Linker Flags**, and **Header Search Paths** sections. - -## Running the Samples - -You'll need Xcode 7.3 or later to run our iOS samples. - -There are currently three examples: simple, benchmark, and camera. For now, you -can download the sample code by cloning the main tensorflow repository (we are -planning to make the samples available as a separate repository later). - -From the root of the tensorflow folder, download [Inception -v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip), -and extract the label and graph files into the data folders inside both the -simple and camera examples using these steps: - - mkdir -p ~/graphs - curl -o ~/graphs/inception5h.zip \ - https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \ - && unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h - cp ~/graphs/inception5h/* tensorflow/examples/ios/benchmark/data/ - cp ~/graphs/inception5h/* tensorflow/examples/ios/camera/data/ - cp ~/graphs/inception5h/* tensorflow/examples/ios/simple/data/ - -Change into one of the sample directories, download the -[Tensorflow-experimental](https://cocoapods.org/pods/TensorFlow-experimental) -pod, and open the Xcode workspace. Note that installing the pod can take a long -time since it is big (~450MB). If you want to run the simple example, then: - - cd tensorflow/examples/ios/simple - pod install - open tf_simple_example.xcworkspace # note .xcworkspace, not .xcodeproj - # this is created by pod install - -Run the simple app in the XCode simulator. You should see a single-screen app -with a **Run Model** button. Tap that, and you should see some debug output -appear below indicating that the example Grace Hopper image in directory data -has been analyzed, with a military uniform recognized. - -Run the other samples using the same process. The camera example requires a real -device connected. Once you build and run that, you should get a live camera view -that you can point at objects to get real-time recognition results. - -### iOS Example details - -There are three demo applications for iOS, all defined in Xcode projects inside -[tensorflow/examples/ios](https://www.tensorflow.org/code/tensorflow/examples/ios/). - -- **Simple**: This is a minimal example showing how to load and run a TensorFlow - model in as few lines as possible. It just consists of a single view with a - button that executes the model loading and inference when its pressed. - -- **Camera**: This is very similar to the Android TF Classify demo. It loads - Inception v3 and outputs its best label estimate for what’s in the live camera - view. As with the Android version, you can train your own custom model using - TensorFlow for Poets and drop it into this example with minimal code changes. - -- **Benchmark**: is quite close to Simple, but it runs the graph repeatedly and - outputs similar statistics to the benchmark tool on Android. - - -### Troubleshooting - -- Make sure you use the TensorFlow-experimental pod (and not TensorFlow). - -- The TensorFlow-experimental pod is current about ~450MB. The reason it is so - big is because we are bundling multiple platforms, and the pod includes all - TensorFlow functionality (e.g. operations). The final app size after build is - substantially smaller though (~25MB). Working with the complete pod is - convenient during development, but see below section on how you can build your - own custom TensorFlow library to reduce the size. - -## Building the TensorFlow iOS libraries from source - -While Cocoapods is the quickest and easiest way of getting started, you sometimes -need more flexibility to determine which parts of TensorFlow your app should be -shipped with. For such cases, you can build the iOS libraries from the -sources. [This -guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/ios#building-the-tensorflow-ios-libraries-from-source) -contains detailed instructions on how to do that. - diff --git a/tensorflow/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/lite/g3doc/tfmobile/linking_libs.md deleted file mode 100644 index fd0e322c93493ed835ae7ec9766a708885c6ac88..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/tfmobile/linking_libs.md +++ /dev/null @@ -1,270 +0,0 @@ -# Integrating TensorFlow libraries - -Warning: We expect to deprecate TensorFlow Mobile in early 2019 - -
-

- TensorFlow Lite is our main mobile and embedded offering. We are - working hard to close the feature gap between TensorFlow Mobile and - TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We - will give ample notice to our users when we get to that point and will - provide help and support to ensure easy migrations. -

-

- In the meantime, please use TensorFlow Lite. If you have a feature request, - such as a missing op, please post to our GitHub. -

-
- -Once you have made some progress on a model that addresses the problem you’re -trying to solve, it’s important to test it out inside your application -immediately. There are often unexpected differences between your training data -and what users actually encounter in the real world, and getting a clear picture -of the gap as soon as possible improves the product experience. - -This page talks about how to integrate the TensorFlow libraries into your own -mobile applications, once you have already successfully built and deployed the -TensorFlow mobile demo apps. - -## Linking the library - -After you've managed to build the examples, you'll probably want to call -TensorFlow from one of your existing applications. The very easiest way to do -this is to use the Pod installation steps described in -Building TensorFlow on iOS, but if you want to build -TensorFlow from source (for example to customize which operators are included) -you'll need to break out TensorFlow as a framework, include the right header -files, and link against the built libraries and dependencies. - -### Android - -For Android, you just need to link in a Java library contained in a JAR file -called `libandroid_tensorflow_inference_java.jar`. There are three ways to -include this functionality in your program: - -1. Include the jcenter AAR which contains it, as in this - [example app](https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tfmobile/build.gradle#L59-L65) - -2. Download the nightly precompiled version from -[ci.tensorflow.org](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/). - -3. Build the JAR file yourself using the instructions [in our Android GitHub repo](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android) - -### iOS - -Pulling in the TensorFlow libraries on iOS is a little more complicated. Here is -a checklist of what you’ll need to do to your iOS app: - -- Link against tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a, usually - by adding `-L/your/path/tensorflow/contrib/makefile/gen/lib/` and - `-ltensorflow-core` to your linker flags. - -- Link against the generated protobuf libraries by adding - `-L/your/path/tensorflow/contrib/makefile/gen/protobuf_ios/lib` and - `-lprotobuf` and `-lprotobuf-lite` to your command line. - -- For the include paths, you need the root of your TensorFlow source folder as - the first entry, followed by - `tensorflow/contrib/makefile/downloads/protobuf/src`, - `tensorflow/contrib/makefile/downloads`, - `tensorflow/contrib/makefile/downloads/eigen`, and - `tensorflow/contrib/makefile/gen/proto`. - -- Make sure your binary is built with `-force_load` (or the equivalent on your - platform), aimed at the TensorFlow library to ensure that it’s linked - correctly. More detail on why this is necessary can be found in the next - section, [Global constructor magic](#global_constructor_magic). On Linux-like - platforms, you’ll need different flags, more like - `-Wl,--allow-multiple-definition -Wl,--whole-archive`. - -You’ll also need to link in the Accelerator framework, since this is used to -speed up some of the operations. - -## Global constructor magic - -One of the subtlest problems you may run up against is the “No session factory -registered for the given session options” error when trying to call TensorFlow -from your own application. To understand why this is happening and how to fix -it, you need to know a bit about the architecture of TensorFlow. - -The framework is designed to be very modular, with a thin core and a large -number of specific objects that are independent and can be mixed and matched as -needed. To enable this, the coding pattern in C++ had to let modules easily -notify the framework about the services they offer, without requiring a central -list that has to be updated separately from each implementation. It also had to -allow separate libraries to add their own implementations without needing a -recompile of the core. - -To achieve this capability, TensorFlow uses a registration pattern in a lot of -places. In the code, it looks like this: - -``` -class MulKernel : OpKernel { - Status Compute(OpKernelContext* context) { … } -}; -REGISTER_KERNEL(MulKernel, “Mul”); -``` - -This would be in a standalone `.cc` file linked into your application, either -as part of the main set of kernels or as a separate custom library. The magic -part is that the `REGISTER_KERNEL()` macro is able to inform the core of -TensorFlow that it has an implementation of the Mul operation, so that it can be -called in any graphs that require it. - -From a programming point of view, this setup is very convenient. The -implementation and registration code live in the same file, and adding new -implementations is as simple as compiling and linking it in. The difficult part -comes from the way that the `REGISTER_KERNEL()` macro is implemented. C++ -doesn’t offer a good mechanism for doing this sort of registration, so we have -to resort to some tricky code. Under the hood, the macro is implemented so that -it produces something like this: - -``` -class RegisterMul { - public: - RegisterMul() { - global_kernel_registry()->Register(“Mul”, [](){ - return new MulKernel() - }); - } -}; -RegisterMul g_register_mul; -``` - -This sets up a class `RegisterMul` with a constructor that tells the global -kernel registry what function to call when somebody asks it how to create a -“Mul” kernel. Then there’s a global object of that class, and so the constructor -should be called at the start of any program. - -While this may sound sensible, the unfortunate part is that the global object -that’s defined is not used by any other code, so linkers not designed with this -in mind will decide that it can be deleted. As a result, the constructor is -never called, and the class is never registered. All sorts of modules use this -pattern in TensorFlow, and it happens that `Session` implementations are the -first to be looked for when the code is run, which is why it shows up as the -characteristic error when this problem occurs. - -The solution is to force the linker to not strip any code from the library, even -if it believes it’s unused. On iOS, this step can be accomplished with the -`-force_load` flag, specifying a library path, and on Linux you need -`--whole-archive`. These persuade the linker to not be as aggressive about -stripping, and should retain the globals. - -The actual implementation of the various `REGISTER_*` macros is a bit more -complicated in practice, but they all suffer the same underlying problem. If -you’re interested in how they work, [op_kernel.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_kernel.h#L1091) -is a good place to start investigating. - -## Protobuf problems - -TensorFlow relies on -the [Protocol Buffer](https://developers.google.com/protocol-buffers/) library, -commonly known as protobuf. This library takes definitions of data structures -and produces serialization and access code for them in a variety of -languages. The tricky part is that this generated code needs to be linked -against shared libraries for the exact same version of the framework that was -used for the generator. This can be an issue when `protoc`, the tool used to -generate the code, is from a different version of protobuf than the libraries in -the standard linking and include paths. For example, you might be using a copy -of `protoc` that was built locally in `~/projects/protobuf-3.0.1.a`, but you have -libraries installed at `/usr/local/lib` and `/usr/local/include` that are from -3.0.0. - -The symptoms of this issue are errors during the compilation or linking phases -with protobufs. Usually, the build tools take care of this, but if you’re using -the makefile, make sure you’re building the protobuf library locally and using -it, as shown in [this Makefile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/Makefile#L18). - -Another situation that can cause problems is when protobuf headers and source -files need to be generated as part of the build process. This process makes -building more complex, since the first phase has to be a pass over the protobuf -definitions to create all the needed code files, and only after that can you go -ahead and do a build of the library code. - -### Multiple versions of protobufs in the same app - -Protobufs generate headers that are needed as part of the C++ interface to the -overall TensorFlow library. This complicates using the library as a standalone -framework. - -If your application is already using version 1 of the protocol buffers library, -you may have trouble integrating TensorFlow because it requires version 2. If -you just try to link both versions into the same binary, you’ll see linking -errors because some of the symbols clash. To solve this particular problem, we -have an experimental script at [rename_protobuf.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/rename_protobuf.sh). - -You need to run this as part of the makefile build, after you’ve downloaded all -the dependencies: - -``` -tensorflow/contrib/makefile/download_dependencies.sh -tensorflow/contrib/makefile/rename_protobuf.sh -``` - -## Calling the TensorFlow API - -Once you have the framework available, you then need to call into it. The usual -pattern is that you first load your model, which represents a preset set of -numeric computations, and then you run inputs through that model (for example, -images from a camera) and receive outputs (for example, predicted labels). - -On Android, we provide the Java Inference Library that is focused on just this -use case, while on iOS and Raspberry Pi you call directly into the C++ API. - -### Android - -Here’s what a typical Inference Library sequence looks like on Android: - -``` -// Load the model from disk. -TensorFlowInferenceInterface inferenceInterface = -new TensorFlowInferenceInterface(assetManager, modelFilename); - -// Copy the input data into TensorFlow. -inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); - -// Run the inference call. -inferenceInterface.run(outputNames, logStats); - -// Copy the output Tensor back into the output array. -inferenceInterface.fetch(outputName, outputs); -``` - -You can find the source of this code in the [Android examples](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java#L107). - -### iOS and Raspberry Pi - -Here’s the equivalent code for iOS and Raspberry Pi: - -``` -// Load the model. -PortableReadFileToProto(file_path, &tensorflow_graph); - -// Create a session from the model. -tensorflow::Status s = session->Create(tensorflow_graph); -if (!s.ok()) { - LOG(FATAL) << "Could not create TensorFlow Graph: " << s; -} - -// Run the model. -std::string input_layer = "input"; -std::string output_layer = "output"; -std::vector outputs; -tensorflow::Status run_status = session->Run({\{input_layer, image_tensor}}, - {output_layer}, {}, &outputs); -if (!run_status.ok()) { - LOG(FATAL) << "Running model failed: " << run_status; -} - -// Access the output data. -tensorflow::Tensor* output = &outputs[0]; -``` - -This is all based on the -[iOS sample code](https://www.tensorflow.org/code/tensorflow/examples/ios/simple/RunModelViewController.mm), -but there’s nothing iOS-specific; the same code should be usable on any platform -that supports C++. - -You can also find specific examples for Raspberry Pi -[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/pi_examples/label_image/label_image.cc). diff --git a/tensorflow/lite/g3doc/tfmobile/optimizing.md b/tensorflow/lite/g3doc/tfmobile/optimizing.md deleted file mode 100644 index 59ff8e774c6c63a01668aee7d6caeea01171468d..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/tfmobile/optimizing.md +++ /dev/null @@ -1,518 +0,0 @@ -# Optimizing for mobile - -Warning: We expect to deprecate TensorFlow Mobile in early 2019 - -
-

- TensorFlow Lite is our main mobile and embedded offering. We are - working hard to close the feature gap between TensorFlow Mobile and - TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We - will give ample notice to our users when we get to that point and will - provide help and support to ensure easy migrations. -

-

- In the meantime, please use TensorFlow Lite. If you have a feature request, - such as a missing op, please post to our GitHub. -

-
- -There are some special issues that you have to deal with when you’re trying to -ship on mobile or embedded devices, and you’ll need to think about these as -you’re developing your model. - -These issues are: - -- Model and Binary Size -- App speed and model loading speed -- Performance and threading - -We'll discuss a few of these below. - -## What are the minimum device requirements for TensorFlow? - -You need at least one megabyte of program memory and several megabytes of RAM to -run the base TensorFlow runtime, so it’s not suitable for DSPs or -microcontrollers. Other than those, the biggest constraint is usually the -calculation speed of the device, and whether you can run the model you need for -your application with a low enough latency. You can use the benchmarking tools -in [How to Profile your Model](#how_to_profile_your_model) to get an idea of how -many FLOPs are required for a model, and then use that to make rule-of-thumb -estimates of how fast they will run on different devices. For example, a modern -smartphone might be able to run 10 GFLOPs per second, so the best you could hope -for from a 5 GFLOP model is two frames per second, though you may do worse -depending on what the exact computation patterns are. - -This model dependence means that it’s possible to run TensorFlow even on very -old or constrained phones, as long as you optimize your network to fit within -the latency budget and possibly within limited RAM too. For memory usage, you -mostly need to make sure that the intermediate buffers that TensorFlow creates -aren’t too large, which you can examine in the benchmark output too. - -## Speed - -One of the highest priorities of most model deployments is figuring out how to -run the inference fast enough to give a good user experience. The first place to -start is by looking at the total number of floating point operations that are -required to execute the graph. You can get a very rough estimate of this by -using the `benchmark_model` tool: - - bazel build -c opt tensorflow/tools/benchmark:benchmark_model && \ - bazel-bin/tensorflow/tools/benchmark/benchmark_model \ - --graph=/tmp/inception_graph.pb --input_layer="Mul:0" \ - --input_layer_shape="1,299,299,3" --input_layer_type="float" \ - --output_layer="softmax:0" --show_run_order=false --show_time=false \ - --show_memory=false --show_summary=true --show_flops=true --logtostderr - -This should show you an estimate of how many operations are needed to run the -graph. You can then use that information to figure out how feasible your model -is to run on the devices you’re targeting. For an example, a high-end phone from -2016 might be able to do 20 billion FLOPs per second, so the best speed you -could hope for from a model that requires 10 billion FLOPs is around 500ms. On a -device like the Raspberry Pi 3 that can do about 5 billion FLOPs, you may only -get one inference every two seconds. - -Having this estimate helps you plan for what you’ll be able to realistically -achieve on a device. If the model is using too many ops, then there are a lot of -opportunities to optimize the architecture to reduce that number. - -Advanced techniques include [SqueezeNet](https://arxiv.org/abs/1602.07360) -and [MobileNet](https://arxiv.org/abs/1704.04861), which are architectures -designed to produce models for mobile -- lean and fast but with a small accuracy -cost. You can also just look at alternative models, even older ones, which may -be smaller. For example, Inception v1 only has around 7 million parameters, -compared to Inception v3’s 24 million, and requires only 3 billion FLOPs rather -than 9 billion for v3. - -## Model Size - -Models that run on a device need to be stored somewhere on the device, and very -large neural networks can be hundreds of megabytes. Most users are reluctant to -download very large app bundles from app stores, so you want to make your model -as small as possible. Furthermore, smaller neural networks can persist in and -out of a mobile device's memory faster. - -To understand how large your network will be on disk, start by looking at the -size on disk of your `GraphDef` file after you’ve run `freeze_graph` and -`strip_unused_nodes` on it (see Preparing models for -more details on these tools), since then it should only contain -inference-related nodes. To double-check that your results are as expected, run -the `summarize_graph` tool to see how many parameters are in constants: - - bazel build tensorflow/tools/graph_transforms:summarize_graph && \ - bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ - --in_graph=/tmp/tensorflow_inception_graph.pb - -That command should give you output that looks something like this: - - No inputs spotted. - Found 1 possible outputs: (name=softmax, op=Softmax) - Found 23885411 (23.89M) const parameters, 0 (0) variable parameters, - and 99 control_edges - Op types used: 489 Const, 99 CheckNumerics, 99 Identity, 94 - BatchNormWithGlobalNormalization, 94 Conv2D, 94 Relu, 11 Concat, 9 AvgPool, - 5 MaxPool, 1 Sub, 1 Softmax, 1 ResizeBilinear, 1 Reshape, 1 Mul, 1 MatMul, - 1 ExpandDims, 1 DecodeJpeg, 1 Cast, 1 BiasAdd - -The important part for our current purposes is the number of const -parameters. In most models these will be stored as 32-bit floats to start, so if -you multiply the number of const parameters by four, you should get something -that’s close to the size of the file on disk. You can often get away with only -eight-bits per parameter with very little loss of accuracy in the final result, -so if your file size is too large you can try using -quantize_weights -to transform the parameters down. - - bazel build tensorflow/tools/graph_transforms:transform_graph && \ - bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ - --in_graph=/tmp/tensorflow_inception_optimized.pb \ - --out_graph=/tmp/tensorflow_inception_quantized.pb \ - --inputs='Mul:0' --outputs='softmax:0' --transforms='quantize_weights' - -If you look at the resulting file size, you should see that it’s about a quarter -of the original at 23MB. - -Another transform is `round_weights`, which doesn't make the file smaller, but it -makes the file compressible to about the same size as when `quantize_weights` is -used. This is particularly useful for mobile development, taking advantage of -the fact that app bundles are compressed before they’re downloaded by consumers. - -The original file does not compress well with standard algorithms, because the -bit patterns of even very similar numbers can be very different. The -`round_weights` transform keeps the weight parameters stored as floats, but -rounds them to a set number of step values. This means there are a lot more -repeated byte patterns in the stored model, and so compression can often bring -the size down dramatically, in many cases to near the size it would be if they -were stored as eight bit. - -Another advantage of `round_weights` is that the framework doesn’t have to -allocate a temporary buffer to unpack the parameters into, as we have to when -we just use `quantize_weights`. This saves a little bit of latency (though the -results should be cached so it’s only costly on the first run) and makes it -possible to use memory mapping, as described later. - -## Binary Size - -One of the biggest differences between mobile and server development is the -importance of binary size. On desktop machines it’s not unusual to have -executables that are hundreds of megabytes on disk, but for mobile and embedded -apps it’s vital to keep the binary as small as possible so that user downloads -are easy. As mentioned above, TensorFlow only includes a subset of op -implementations by default, but this still results in a 12 MB final -executable. To reduce this, you can set up the library to only include the -implementations of the ops that you actually need, based on automatically -analyzing your model. To use it: - -- Run `tools/print_required_ops/print_selective_registration_header.py` on your - model to produce a header file that only enables the ops it uses. - -- Place the `ops_to_register.h` file somewhere that the compiler can find - it. This can be in the root of your TensorFlow source folder. - -- Build TensorFlow with `SELECTIVE_REGISTRATION` defined, for example by passing - in `--copts=”-DSELECTIVE_REGISTRATION”` to your Bazel build command. - -This process recompiles the library so that only the needed ops and types are -included, which can dramatically reduce the executable size. For example, with -Inception v3, the new size is only 1.5MB. - -## How to Profile your Model - -Once you have an idea of what your device's peak performance range is, it’s -worth looking at its actual current performance. Using a standalone TensorFlow -benchmark, rather than running it inside a larger app, helps isolate just the -Tensorflow contribution to the -latency. The -[tensorflow/tools/benchmark](https://www.tensorflow.org/code/tensorflow/tools/benchmark/) tool -is designed to help you do this. To run it on Inception v3 on your desktop -machine, build this benchmark model: - - bazel build -c opt tensorflow/tools/benchmark:benchmark_model && \ - bazel-bin/tensorflow/tools/benchmark/benchmark_model \ - --graph=/tmp/tensorflow_inception_graph.pb --input_layer="Mul" \ - --input_layer_shape="1,299,299,3" --input_layer_type="float" \ - --output_layer="softmax:0" --show_run_order=false --show_time=false \ - --show_memory=false --show_summary=true --show_flops=true --logtostderr - -You should see output that looks something like this: - -
-============================== Top by Computation Time ==============================
-[node
- type]  [start]  [first] [avg ms]     [%]  [cdf%]  [mem KB]  [Name]
-Conv2D   22.859   14.212   13.700  4.972%  4.972%  3871.488  conv_4/Conv2D
-Conv2D    8.116    8.964   11.315  4.106%  9.078%  5531.904  conv_2/Conv2D
-Conv2D   62.066   16.504    7.274  2.640% 11.717%   443.904  mixed_3/conv/Conv2D
-Conv2D    2.530    6.226    4.939  1.792% 13.510%  2765.952  conv_1/Conv2D
-Conv2D   55.585    4.605    4.665  1.693% 15.203%   313.600  mixed_2/tower/conv_1/Conv2D
-Conv2D  127.114    5.469    4.630  1.680% 16.883%    81.920  mixed_10/conv/Conv2D
-Conv2D   47.391    6.994    4.588  1.665% 18.548%   313.600  mixed_1/tower/conv_1/Conv2D
-Conv2D   39.463    7.878    4.336  1.574% 20.122%   313.600  mixed/tower/conv_1/Conv2D
-Conv2D  127.113    4.192    3.894  1.413% 21.535%   114.688  mixed_10/tower_1/conv/Conv2D
-Conv2D   70.188    5.205    3.626  1.316% 22.850%   221.952  mixed_4/conv/Conv2D
-
-============================== Summary by node type ==============================
-[Node type]  [count]  [avg ms]    [avg %]    [cdf %]  [mem KB]
-Conv2D            94   244.899    88.952%    88.952% 35869.953
-BiasAdd           95     9.664     3.510%    92.462% 35873.984
-AvgPool            9     7.990     2.902%    95.364%  7493.504
-Relu              94     5.727     2.080%    97.444% 35869.953
-MaxPool            5     3.485     1.266%    98.710%  3358.848
-Const            192     1.727     0.627%    99.337%     0.000
-Concat            11     1.081     0.393%    99.730%  9892.096
-MatMul             1     0.665     0.242%    99.971%     4.032
-Softmax            1     0.040     0.015%    99.986%     4.032
-<>                 1     0.032     0.012%    99.997%     0.000
-Reshape            1     0.007     0.003%   100.000%     0.000
-
-Timings (microseconds): count=50 first=330849 curr=274803 min=232354 max=415352 avg=275563 std=44193
-Memory (bytes): count=50 curr=128366400(all same)
-514 nodes defined 504 nodes observed
-
- -This is the summary view, which is enabled by the show_summary flag. To -interpret it, the first table is a list of the nodes that took the most time, in -order by how long they took. From left to right, the columns are: - -- Node type, what kind of operation this was. - -- Start time of the op, showing where it falls in the sequence of operations. - -- First time in milliseconds. This is how long the operation took on the first - run of the benchmark, since by default 20 runs are executed to get more - reliable statistics. The first time is useful to spot which ops are doing - expensive calculations on the first run, and then caching the results. - -- Average time for the operation across all runs, in milliseconds. - -- What percentage of the total time for one run the op took. This is useful to - understand where the hotspots are. - -- The cumulative total time of this and the previous ops in the table. This is - handy for understanding what the distribution of work is across the layers, to - see if just a few of the nodes are taking up most of the time. - -- The amount of memory consumed by outputs of this type of op. - -- Name of the node. - -The second table is similar, but instead of breaking down the timings by -particular named nodes, it groups them by the kind of op. This is very useful to -understand which op implementations you might want to optimize or eliminate from -your graph. The table is arranged with the most costly operations at the start, -and only shows the top ten entries, with a placeholder for other nodes. The -columns from left to right are: - -- Type of the nodes being analyzed. - -- Accumulated average time taken by all nodes of this type, in milliseconds. - -- What percentage of the total time was taken by this type of operation. - -- Cumulative time taken by this and op types higher in the table, so you can - understand the distribution of the workload. - -- How much memory the outputs of this op type took up. - -Both of these tables are set up so that you can easily copy and paste their -results into spreadsheet documents, since they are output with tabs as -separators between the columns. The summary by node type can be the most useful -when looking for optimization opportunities, since it’s a pointer to the code -that’s taking the most time. In this case, you can see that the Conv2D ops are -almost 90% of the execution time. This is a sign that the graph is pretty -optimal, since convolutions and matrix multiplies are expected to be the bulk of -a neural network’s computing workload. - -As a rule of thumb, it’s more worrying if you see a lot of other operations -taking up more than a small fraction of the time. For neural networks, the ops -that don’t involve large matrix multiplications should usually be dwarfed by the -ones that do, so if you see a lot of time going into those it’s a sign that -either your network is non-optimally constructed, or the code implementing those -ops is not as optimized as it could -be. [Performance bugs](https://github.com/tensorflow/tensorflow/issues) or -patches are always welcome if you do encounter this situation, especially if -they include an attached model exhibiting this behavior and the command line -used to run the benchmark tool on it. - -The run above was on your desktop, but the tool also works on Android, which is -where it’s most useful for mobile development. Here’s an example command line to -run it on a 64-bit ARM device: - - bazel build -c opt --config=android_arm64 \ - tensorflow/tools/benchmark:benchmark_model - adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp - adb push /tmp/tensorflow_inception_graph.pb /data/local/tmp/ - adb shell '/data/local/tmp/benchmark_model \ - --graph=/data/local/tmp/tensorflow_inception_graph.pb --input_layer="Mul" \ - --input_layer_shape="1,299,299,3" --input_layer_type="float" \ - --output_layer="softmax:0" --show_run_order=false --show_time=false \ - --show_memory=false --show_summary=true' - -You can interpret the results in exactly the same way as the desktop version -above. If you have any trouble figuring out what the right input and output -names and types are, take a look at the -Preparing models -page for details about detecting these for your model, and look at the -`summarize_graph` tool which may give you -helpful information. - -There isn’t good support for command line tools on iOS, so instead there’s a -separate example -at -[tensorflow/examples/ios/benchmark](https://www.tensorflow.org/code/tensorflow/examples/ios/benchmark) that -packages the same functionality inside a standalone app. This outputs the -statistics to both the screen of the device and the debug log. If you want -on-screen statistics for the Android example apps, you can turn them on by -pressing the volume-up button. - -## Profiling within your own app - -The output you see from the benchmark tool is generated from modules that are -included as part of the standard TensorFlow runtime, which means you have access -to them within your own applications too. You can see an example of how to do -that [here](https://www.tensorflow.org/code/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm?l=139). - -The basic steps are: - -1. Create a StatSummarizer object: - - tensorflow::StatSummarizer stat_summarizer(tensorflow_graph); - -2. Set up the options: - - tensorflow::RunOptions run_options; - run_options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); - tensorflow::RunMetadata run_metadata; - -3. Run the graph: - - run_status = session->Run(run_options, inputs, output_layer_names, {}, - output_layers, &run_metadata); - -4. Calculate the results and print them out: - - assert(run_metadata.has_step_stats()); - const tensorflow::StepStats& step_stats = run_metadata.step_stats(); - stat_summarizer->ProcessStepStats(step_stats); - stat_summarizer->PrintStepStats(); - -## Visualizing Models - -The most effective way to speed up your code is by altering your model so it -does less work. To do that, you need to understand what your model is doing, and -visualizing it is a good first step. To get a high-level overview of your graph, -use [TensorBoard](https://github.com/tensorflow/tensorboard). - -## Threading - -The desktop version of TensorFlow has a sophisticated threading model, and will -try to run multiple operations in parallel if it can. In our terminology this is -called “inter-op parallelism” (though to avoid confusion with “intra-op”, you -could think of it as “between-op” instead), and can be set by specifying -`inter_op_parallelism_threads` in the session options. - -By default, mobile devices run operations serially; that is, -`inter_op_parallelism_threads` is set to 1. Mobile processors usually have few -cores and a small cache, so running multiple operations accessing disjoint parts -of memory usually doesn’t help performance. “Intra-op parallelism” (or -“within-op”) can be very helpful though, especially for computation-bound -operations like convolutions where different threads can feed off the same small -set of memory. - -On mobile, how many threads an op will use is set to the number of cores by -default, or 2 when the number of cores can't be determined. You can override the -default number of threads that ops are using by setting -`intra_op_parallelism_threads` in the session options. It’s a good idea to -reduce the default if your app has its own threads doing heavy processing, so -that they don’t interfere with each other. - -To see more details on session options, look at [ConfigProto](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). - -## Retrain with mobile data - -The biggest cause of accuracy problems when running models on mobile apps is -unrepresentative training data. For example, most of the Imagenet photos are -well-framed so that the object is in the center of the picture, well-lit, and -shot with a normal lens. Photos from mobile devices are often poorly framed, -badly lit, and can have fisheye distortions, especially selfies. - -The solution is to expand your training set with data actually captured from -your application. This step can involve extra work, since you’ll have to label -the examples yourself, but even if you just use it to expand your original -training data, it can help the training set dramatically. Improving the training -set by doing this, and by fixing other quality issues like duplicates or badly -labeled examples is the single best way to improve accuracy. It’s usually a -bigger help than altering your model architecture or using different techniques. - -## Reducing model loading time and/or memory footprint - -Most operating systems allow you to load a file using memory mapping, rather -than going through the usual I/O APIs. Instead of allocating an area of memory -on the heap and then copying bytes from disk into it, you simply tell the -operating system to make the entire contents of a file appear directly in -memory. This has several advantages: - -* Speeds loading -* Reduces paging (increases performance) -* Does not count towards RAM budget for your app - -TensorFlow has support for memory mapping the weights that form the bulk of most -model files. Because of limitations in the `ProtoBuf` serialization format, we -have to make a few changes to our model loading and processing code. The -way memory mapping works is that we have a single file where the first part is a -normal `GraphDef` serialized into the protocol buffer wire format, but then the -weights are appended in a form that can be directly mapped. - -To create this file, run the -`tensorflow/contrib/util:convert_graphdef_memmapped_format` tool. This takes in -a `GraphDef` file that’s been run through `freeze_graph` and converts it to the -format that has the weights appended at the end. Since that file’s no longer a -standard `GraphDef` protobuf, you then need to make some changes to the loading -code. You can see an example of this in -the -[iOS Camera demo app](https://www.tensorflow.org/code/tensorflow/examples/ios/camera/tensorflow_utils.mm?l=147), -in the `LoadMemoryMappedModel()` function. - -The same code (with the Objective C calls for getting the filenames substituted) -can be used on other platforms too. Because we’re using memory mapping, we need -to start by creating a special TensorFlow environment object that’s set up with -the file we’ll be using: - - std::unique_ptr memmapped_env; - memmapped_env->reset( - new tensorflow::MemmappedEnv(tensorflow::Env::Default())); - tensorflow::Status mmap_status = - (memmapped_env->get())->InitializeFromFile(file_path); - -You then need to pass in this environment to subsequent calls, like this one for -loading the graph: - - tensorflow::GraphDef tensorflow_graph; - tensorflow::Status load_graph_status = ReadBinaryProto( - memmapped_env->get(), - tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, - &tensorflow_graph); - -You also need to create the session with a pointer to the environment you’ve -created: - - tensorflow::SessionOptions options; - options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_opt_level(::tensorflow::OptimizerOptions::L0); - options.env = memmapped_env->get(); - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - -One thing to notice here is that we’re also disabling automatic optimizations, -since in some cases these will fold constant sub-trees, and so create copies of -tensor values that we don’t want and use up more RAM. - -Once you’ve gone through these steps, you can use the session and graph as -normal, and you should see a reduction in loading time and memory usage. - -## Protecting model files from easy copying - -By default, your models will be stored in the standard serialized protobuf -format on disk. In theory this means that anybody can copy your model, which you -may not want. However, in practice, most models are so application-specific and -obfuscated by optimizations that the risk is similar to that of competitors -disassembling and reusing your code, but if you do want to make it tougher for -casual users to access your files it is possible to take some basic steps. - -Most of our examples use -the -[ReadBinaryProto()](https://www.tensorflow.org/code/tensorflow/core/platform/env.cc?q=core/platform/env.cc&l=409) convenience -call to load a `GraphDef` from disk. This does require an unencrypted protobuf on -disk. Luckily though, the implementation of the call is pretty straightforward -and it should be easy to write an equivalent that can decrypt in memory. Here's -some code that shows how you can read and decrypt a protobuf using your own -decryption routine: - - Status ReadEncryptedProto(Env* env, const string& fname, - ::tensorflow::protobuf::MessageLite* proto) { - string data; - TF_RETURN_IF_ERROR(ReadFileToString(env, fname, &data)); - - DecryptData(&data); // Your own function here. - - if (!proto->ParseFromString(&data)) { - TF_RETURN_IF_ERROR(stream->status()); - return errors::DataLoss("Can't parse ", fname, " as binary proto"); - } - return Status::OK(); - } - -To use this you’d need to define the DecryptData() function yourself. It could -be as simple as something like: - - void DecryptData(string* data) { - for (int i = 0; i < data.size(); ++i) { - data[i] = data[i] ^ 0x23; - } - } - -You may want something more complex, but exactly what you’ll need is outside the -current scope here. diff --git a/tensorflow/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/lite/g3doc/tfmobile/prepare_models.md deleted file mode 100644 index 1d373251ddf3ba6a0119bd57bf14caf100ef371a..0000000000000000000000000000000000000000 --- a/tensorflow/lite/g3doc/tfmobile/prepare_models.md +++ /dev/null @@ -1,318 +0,0 @@ -# Preparing models for mobile deployment - -Warning: We expect to deprecate TensorFlow Mobile in early 2019 - -
-

- TensorFlow Lite is our main mobile and embedded offering. We are - working hard to close the feature gap between TensorFlow Mobile and - TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We - will give ample notice to our users when we get to that point and will - provide help and support to ensure easy migrations. -

-

- In the meantime, please use TensorFlow Lite. If you have a feature request, - such as a missing op, please post to our GitHub. -

-
- -The requirements for storing model information during training are very -different from when you want to release it as part of a mobile app. This section -covers the tools involved in converting from a training model to something -releasable in production. - -## What is up with all the different saved file formats? - -You may find yourself getting very confused by all the different ways that -TensorFlow can save out graphs. To help, here’s a rundown of some of the -different components, and what they are used for. The objects are mostly defined -and serialized as protocol buffers: - -- [NodeDef](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto): - Defines a single operation in a model. It has a unique name, a list of the - names of other nodes it pulls inputs from, the operation type it implements - (for example `Add`, or `Mul`), and any attributes that are needed to control - that operation. This is the basic unit of computation for TensorFlow, and all - work is done by iterating through a network of these nodes, applying each one - in turn. One particular operation type that’s worth knowing about is `Const`, - since this holds information about a constant. This may be a single, scalar - number or string, but it can also hold an entire multi-dimensional tensor - array. The values for a `Const` are stored inside the `NodeDef`, and so large - constants can take up a lot of room when serialized. - -- [Checkpoint](https://www.tensorflow.org/code/tensorflow/core/util/tensor_bundle/tensor_bundle.h). Another - way of storing values for a model is by using `Variable` ops. Unlike `Const` - ops, these don’t store their content as part of the `NodeDef`, so they take up - very little space within the `GraphDef` file. Instead their values are held in - RAM while a computation is running, and then saved out to disk as checkpoint - files periodically. This typically happens as a neural network is being - trained and weights are updated, so it’s a time-critical operation, and it may - happen in a distributed fashion across many workers, so the file format has to - be both fast and flexible. They are stored as multiple checkpoint files, - together with metadata files that describe what’s contained within the - checkpoints. When you’re referring to a checkpoint in the API (for example - when passing a filename in as a command line argument), you’ll use the common - prefix for a set of related files. If you had these files: - - /tmp/model/model-chkpt-1000.data-00000-of-00002 - /tmp/model/model-chkpt-1000.data-00001-of-00002 - /tmp/model/model-chkpt-1000.index - /tmp/model/model-chkpt-1000.meta - - You would refer to them as `/tmp/model/chkpt-1000`. - -- [GraphDef](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto): - Has a list of `NodeDefs`, which together define the computational graph to - execute. During training, some of these nodes will be `Variables`, and so if - you want to have a complete graph you can run, including the weights, you’ll - need to call a restore operation to pull those values from - checkpoints. Because checkpoint loading has to be flexible to deal with all of - the training requirements, this can be tricky to implement on mobile and - embedded devices, especially those with no proper file system available like - iOS. This is where - the - [`freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py) script - comes in handy. As mentioned above, `Const` ops store their values as part of - the `NodeDef`, so if all the `Variable` weights are converted to `Const` nodes, - then we only need a single `GraphDef` file to hold the model architecture and - the weights. Freezing the graph handles the process of loading the - checkpoints, and then converts all Variables to Consts. You can then load the - resulting file in a single call, without having to restore variable values - from checkpoints. One thing to watch out for with `GraphDef` files is that - sometimes they’re stored in text format for easy inspection. These versions - usually have a ‘.pbtxt’ filename suffix, whereas the binary files end with - ‘.pb’. - -- [FunctionDefLibrary](https://www.tensorflow.org/code/tensorflow/core/framework/function.proto): - This appears in `GraphDef`, and is effectively a set of sub-graphs, each with - information about their input and output nodes. Each sub-graph can then be - used as an op in the main graph, allowing easy instantiation of different - nodes, in a similar way to how functions encapsulate code in other languages. - -- [MetaGraphDef](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto): - A plain `GraphDef` only has information about the network of computations, but - doesn’t have any extra information about the model or how it can be - used. `MetaGraphDef` contains a `GraphDef` defining the computation part of - the model, but also includes information like ‘signatures’, which are - suggestions about which inputs and outputs you may want to call the model - with, data on how and where any checkpoint files are saved, and convenience - tags for grouping ops together for ease of use. - -- [SavedModel](https://www.tensorflow.org/code/tensorflow/core/protobuf/saved_model.proto): - It’s common to want to have different versions of a graph that rely on a - common set of variable checkpoints. For example, you might need a GPU and a - CPU version of the same graph, but keep the same weights for both. You might - also need some extra files (like label names) as part of your - model. The - [SavedModel](https://www.tensorflow.org/code/tensorflow/python/saved_model/README.md) format - addresses these needs by letting you save multiple versions of the same graph - without duplicating variables, and also storing asset files in the same - bundle. Under the hood, it uses `MetaGraphDef` and checkpoint files, along - with extra metadata files. It’s the format that you’ll want to use if you’re - deploying a web API using TensorFlow Serving, for example. - -## How do you get a model you can use on mobile? - -In most situations, training a model with TensorFlow will give you a folder -containing a `GraphDef` file (usually ending with the `.pb` or `.pbtxt` extension) and -a set of checkpoint files. What you need for mobile or embedded deployment is a -single `GraphDef` file that’s been ‘frozen’, or had its variables converted into -inline constants so everything’s in one file. To handle the conversion, you’ll -need the `freeze_graph.py` script, that’s held in -[`tensorflow/python/tools/freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py). You’ll run it like this: - - bazel build tensorflow/python/tools:freeze_graph - bazel-bin/tensorflow/python/tools/freeze_graph \ - --input_graph=/tmp/model/my_graph.pb \ - --input_checkpoint=/tmp/model/model.ckpt-1000 \ - --output_graph=/tmp/frozen_graph.pb \ - --output_node_names=output_node \ - -The `input_graph` argument should point to the `GraphDef` file that holds your -model architecture. It’s possible that your `GraphDef` has been stored in a text -format on disk, in which case it’s likely to end in `.pbtxt` instead of `.pb`, -and you should add an extra `--input_binary=false` flag to the command. - -The `input_checkpoint` should be the most recent saved checkpoint. As mentioned -in the checkpoint section, you need to give the common prefix to the set of -checkpoints here, rather than a full filename. - -`output_graph` defines where the resulting frozen `GraphDef` will be -saved. Because it’s likely to contain a lot of weight values that take up a -large amount of space in text format, it’s always saved as a binary protobuf. - -`output_node_names` is a list of the names of the nodes that you want to extract -the results of your graph from. This is needed because the freezing process -needs to understand which parts of the graph are actually needed, and which are -artifacts of the training process, like summarization ops. Only ops that -contribute to calculating the given output nodes will be kept. If you know how -your graph is going to be used, these should just be the names of the nodes you -pass into `Session::Run()` as your fetch targets. The easiest way to find the -node names is to inspect the Node objects while building your graph in python. -Inspecting your graph in TensorBoard is another simple way. You can get some -suggestions on likely outputs by running the [`summarize_graph` tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs). - -Because the output format for TensorFlow has changed over time, there are a -variety of other less commonly used flags available too, like `input_saver`, but -hopefully you shouldn’t need these on graphs trained with modern versions of the -framework. - -## Using the Graph Transform Tool - -A lot of the things you need to do to efficiently run a model on device are -available through the [Graph Transform -Tool](https://www.tensorflow.org/code/tensorflow/tools/graph_transforms/README.md). This -command-line tool takes an input `GraphDef` file, applies the set of rewriting -rules you request, and then writes out the result as a `GraphDef`. See the -documentation for more information on how to build and run this tool. - -### Removing training-only nodes - -TensorFlow `GraphDefs` produced by the training code contain all of the -computation that’s needed for back-propagation and updates of weights, as well -as the queuing and decoding of inputs, and the saving out of checkpoints. All of -these nodes are no longer needed during inference, and some of the operations -like checkpoint saving aren’t even supported on mobile platforms. To create a -model file that you can load on devices you need to delete those unneeded -operations by running the `strip_unused_nodes` rule in the Graph Transform Tool. - -The trickiest part of this process is figuring out the names of the nodes you -want to use as inputs and outputs during inference. You'll need these anyway -once you start to run inference, but you also need them here so that the -transform can calculate which nodes are not needed on the inference-only -path. These may not be obvious from the training code. The easiest way to -determine the node name is to explore the graph with TensorBoard. - -Remember that mobile applications typically gather their data from sensors and -have it as arrays in memory, whereas training typically involves loading and -decoding representations of the data stored on disk. In the case of Inception v3 -for example, there’s a `DecodeJpeg` op at the start of the graph that’s designed -to take JPEG-encoded data from a file retrieved from disk and turn it into an -arbitrary-sized image. After that there’s a `BilinearResize` op to scale it to -the expected size, followed by a couple of other ops that convert the byte data -into float and scale the value magnitudes it in the way the rest of the graph -expects. A typical mobile app will skip most of these steps because it’s getting -its input directly from a live camera, so the input node you will actually -supply will be the output of the `Mul` node in this case. - - - -You’ll need to do a similar process of inspection to figure out the correct -output nodes. - -If you’ve just been given a frozen `GraphDef` file, and are not sure about the -contents, try using the `summarize_graph` tool to print out information -about the inputs and outputs it finds from the graph structure. Here’s an -example with the original Inception v3 file: - - bazel run tensorflow/tools/graph_transforms:summarize_graph -- - --in_graph=tensorflow_inception_graph.pb - -Once you have an idea of what the input and output nodes are, you can feed them -into the graph transform tool as the `--input_names` and `--output_names` -arguments, and call the `strip_unused_nodes` transform, like this: - - bazel run tensorflow/tools/graph_transforms:transform_graph -- - --in_graph=tensorflow_inception_graph.pb - --out_graph=optimized_inception_graph.pb --inputs='Mul' --outputs='softmax' - --transforms=' - strip_unused_nodes(type=float, shape="1,299,299,3") - fold_constants(ignore_errors=true) - fold_batch_norms - fold_old_batch_norms' - -One thing to look out for here is that you need to specify the size and type -that you want your inputs to be. This is because any values that you’re going to -be passing in as inputs to inference need to be fed to special `Placeholder` op -nodes, and the transform may need to create them if they don’t already exist. In -the case of Inception v3 for example, a `Placeholder` node replaces the old -`Mul` node that used to output the resized and rescaled image array, since we’re -going to be doing that processing ourselves before we call TensorFlow. It keeps -the original name though, which is why we always feed in inputs to `Mul` when we -run a session with our modified Inception graph. - -After you’ve run this process, you’ll have a graph that only contains the actual -nodes you need to run your prediction process. This is the point where it -becomes useful to run metrics on the graph, so it’s worth running -`summarize_graph` again to understand what’s in your model. - -## What ops should you include on mobile? - -There are hundreds of operations available in TensorFlow, and each one has -multiple implementations for different data types. On mobile platforms, the size -of the executable binary that’s produced after compilation is important, because -app download bundles need to be as small as possible for the best user -experience. If all of the ops and data types are compiled into the TensorFlow -library then the total size of the compiled library can be tens of megabytes, so -by default only a subset of ops and data types are included. - -That means that if you load a model file that’s been trained on a desktop -machine, you may see the error “No OpKernel was registered to support Op” when -you load it on mobile. The first thing to try is to make sure you’ve stripped -out any training-only nodes, since the error will occur at load time even if the -op is never executed. If you’re still hitting the same problem once that’s done, -you’ll need to look at adding the op to your built library. - -The criteria for including ops and types fall into several categories: - -- Are they only useful in back-propagation, for gradients? Since mobile is - focused on inference, we don’t include these. - -- Are they useful mainly for other training needs, such as checkpoint saving? - These we leave out. - -- Do they rely on frameworks that aren’t always available on mobile, such as - libjpeg? To avoid extra dependencies we don’t include ops like `DecodeJpeg`. - -- Are there types that aren’t commonly used? We don’t include boolean variants - of ops for example, since we don’t see much use of them in typical inference - graphs. - -These ops are trimmed by default to optimize for inference on mobile, but it is -possible to alter some build files to change the default. After alternating the -build files, you will need to recompile TensorFlow. See below for more details -on how to do this, and also see optimizing binary size -for more on reducing your binary size. - -### Locate the implementation - -Operations are broken into two parts. The first is the op definition, which -declares the signature of the operation, which inputs, outputs, and attributes -it has. These take up very little space, and so all are included by default. The -implementations of the op computations are done in kernels, which live in the -`tensorflow/core/kernels` folder. You need to compile the C++ file containing -the kernel implementation of the op you need into the library. To figure out -which file that is, you can search for the operation name in the source -files. - -[Here’s an example search in github](https://github.com/search?utf8=%E2%9C%93&q=repo%3Atensorflow%2Ftensorflow+extension%3Acc+path%3Atensorflow%2Fcore%2Fkernels+REGISTER+Mul&type=Code&ref=searchresults). - -You’ll see that this search is looking for the `Mul` op implementation, and it -finds it in `tensorflow/core/kernels/cwise_op_mul_1.cc`. You need to look for -macros beginning with `REGISTER`, with the op name you care about as one of the -string arguments. - -In this case, the implementations are actually broken up across multiple `.cc` -files, so you’d need to include all of them in your build. If you’re more -comfortable using the command line for code search, here’s a grep command that -also locates the right files if you run it from the root of your TensorFlow -repository: - -`grep 'REGISTER.*"Mul"' tensorflow/core/kernels/*.cc` - -### Add the implementation to the build - -If you’re using Bazel, and building for Android, you’ll want to add the files -you’ve found to -the -[`android_extended_ops_group1`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3565) or -[`android_extended_ops_group2`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3632) targets. You -may also need to include any .cc files they depend on in there. If the build -complains about missing header files, add the .h’s that are needed into -the -[`android_extended_ops`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3525) target. - -If you’re using a makefile targeting iOS, Raspberry Pi, etc, go to -[`tensorflow/contrib/makefile/tf_op_files.txt`](https://www.tensorflow.org/code/tensorflow/contrib/makefile/tf_op_files.txt) and -add the right implementation files there. diff --git a/tensorflow/lite/graph_info.cc b/tensorflow/lite/graph_info.cc index 1cec0d0c290679c7755cbf84858317489c0ba159..a9091924c064341316d788704daa643bc8e247b5 100644 --- a/tensorflow/lite/graph_info.cc +++ b/tensorflow/lite/graph_info.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/graph_info.h" #include +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -94,6 +95,10 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { // been identified. for (int output_index : info_->outputs()) { int output_epoch = tensor_epochs_[output_index]; + if (output_epoch == kEpochAlwaysReady) { + // This happens when an input of subgraph is also an output of subgraph. + continue; + } NodeSubset& output_subset = (*node_subsets_)[output_epoch]; output_subset.output_tensors.push_back(output_index); } @@ -138,7 +143,8 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { // See if all dependencies of this node are already assigned to a // node sub set. for (int input_tensor_index : TfLiteIntArrayView(node.inputs)) { - if (tensor_epochs_[input_tensor_index] == kEpochNotReady) { + if (input_tensor_index != kOptionalTensor && + tensor_epochs_[input_tensor_index] == kEpochNotReady) { return false; } } @@ -162,6 +168,9 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { // Look at our inputs one more time to update that tensor's // epochs' outputs for (int input_tensor_index : TfLiteIntArrayView(node.inputs)) { + if (input_tensor_index == kOptionalTensor) { + continue; + } int input_epoch = tensor_epochs_[input_tensor_index]; int node_epoch = current_epoch; if (input_epoch != node_epoch) { diff --git a/tensorflow/lite/graph_info_test.cc b/tensorflow/lite/graph_info_test.cc index 4d8bbdc0eef49b3f79b3c74c1d07fd86467e1d65..b72728a9a9c94c4ee7312e5ff6f17e7b4d3b8a95 100644 --- a/tensorflow/lite/graph_info_test.cc +++ b/tensorflow/lite/graph_info_test.cc @@ -101,7 +101,7 @@ void CheckPartitionSubgraphs( } // Test an empty trivial graph with no partitions. -TEST(PartitionTest, Nodes0_PartitionNodes0) { +TEST(PartitionTest, Nodes0PartitionNodes0) { SimpleTestGraph graph; std::vector nodes_to_partition = {}; std::vector generated_subgraphs; @@ -109,6 +109,20 @@ TEST(PartitionTest, Nodes0_PartitionNodes0) { CheckPartitionSubgraphs(generated_subgraphs, {}); } +// Test a trivial graph with no node and only 1 tensor. +// The tensor is input & output of the graph at the same time. +// Note: This is a regression test to ensure the partitioning logic +// handles this case without crashing. +TEST(PartitionTest, Nodes0PartitionNodes0Tensors1) { + SimpleTestGraph graph; + graph.AddTensors(1); + graph.SetInputsAndOutputs({0}, {0}); + std::vector nodes_to_partition = {}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + CheckPartitionSubgraphs(generated_subgraphs, {}); +} + // Test a 1 node graph with no partitions. // Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[] // Output: [kTfNoPartition, tensor(0) -> node(0) -> tensor(1)] diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index 60fa2130fabaa692d23c109f42fa8883f6e8de19..6888183d14ce3375443cc4b50ba00f339da6530e 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include // NOLINT(build/c++11): only using std::call_once, not mutex. #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/memory_planner.h" +#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/nnapi_delegate.h" #include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -32,9 +34,36 @@ limitations under the License. namespace tflite { +namespace { + +// Gets the current TfLiteQuantization from the legacy fLiteQuantizationParams. +TfLiteQuantization GetQuantizationFromLegacy( + const TfLiteQuantizationParams& legacy_quantization) { + TfLiteQuantization quantization; + quantization.type = kTfLiteAffineQuantization; + auto* affine_quantization = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + affine_quantization->scale = TfLiteFloatArrayCreate(1); + affine_quantization->zero_point = TfLiteIntArrayCreate(1); + affine_quantization->scale->data[0] = legacy_quantization.scale; + affine_quantization->zero_point->data[0] = legacy_quantization.zero_point; + quantization.params = affine_quantization; + + return quantization; +} + +} // namespace + Interpreter::Interpreter(ErrorReporter* error_reporter) : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { + // Only log initialization once per-process to avoid log spam. + static std::once_flag init_log_once_flag; + std::call_once(init_log_once_flag, []() { + // TODO(b/128420794): Include the TFLite runtime version in the log. + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime."); + }); + // There's always at least 1 subgraph which is the primary subgraph. AddSubgraphs(1); context_ = primary_subgraph().context(); @@ -71,7 +100,7 @@ TfLiteStatus Interpreter::AllocateTensors() { } void Interpreter::ReserveNodes(int count) { - primary_subgraph().nodes_and_registration().reserve(count); + primary_subgraph().ReserveNodes(count); } void Interpreter::AddSubgraphs(int subgraphs_to_add, @@ -123,24 +152,49 @@ TfLiteStatus Interpreter::ResetVariableTensors() { return primary_subgraph().ResetVariableTensors(); } +TfLiteStatus Interpreter::SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantization quantization, + const char* buffer, size_t bytes, const Allocation* allocation) { + return primary_subgraph().SetTensorParametersReadOnly( + tensor_index, type, name, dims.size(), dims.data(), quantization, buffer, + bytes, allocation); +} + +TfLiteStatus Interpreter::SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantization quantization, + bool is_variable) { + return primary_subgraph().SetTensorParametersReadWrite( + tensor_index, type, name, dims.size(), dims.data(), quantization, + is_variable); +} + TfLiteStatus Interpreter::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation) { - return primary_subgraph().SetTensorParametersReadOnly( - tensor_index, type, name, rank, dims, quantization, buffer, bytes, - allocation); + TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization); + if (primary_subgraph().SetTensorParametersReadOnly( + tensor_index, type, name, rank, dims, new_quantization, buffer, bytes, + allocation) != kTfLiteOk) { + TfLiteQuantizationFree(&new_quantization); + return kTfLiteError; + } + return kTfLiteOk; } -// Set description of inputs/outputs/data/fptrs for node `node_index`. -// This variant assumes an external buffer has been allocated of size -// bytes. The lifetime of buffer must be ensured to be greater or equal -// to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { - return primary_subgraph().SetTensorParametersReadWrite( - tensor_index, type, name, rank, dims, quantization, is_variable); + TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization); + if (primary_subgraph().SetTensorParametersReadWrite( + tensor_index, type, name, rank, dims, new_quantization, + is_variable) != kTfLiteOk) { + TfLiteQuantizationFree(&new_quantization); + return kTfLiteError; + } + return kTfLiteOk; } TfLiteStatus Interpreter::SetExecutionPlan(const std::vector& new_plan) { @@ -178,7 +232,10 @@ void Interpreter::SetCancellationFunction(void* data, } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { - return primary_subgraph().ModifyGraphWithDelegate(delegate); + for (auto& subgraph : subgraphs_) { + TF_LITE_ENSURE_OK(context_, subgraph->ModifyGraphWithDelegate(delegate)); + } + return kTfLiteOk; } TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 754439c9667980454d5ee4ef61892a4869cd95be..806b66c12a0bf119985927e4e937c71fc6fed487 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -160,6 +160,12 @@ class Interpreter { // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. + TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantization quantization, + const char* buffer, size_t bytes, const Allocation* allocation = nullptr); + + // Legacy. Deprecated in favor of above. inline TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantizationParams quantization, @@ -179,6 +185,13 @@ class Interpreter { // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. + TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, + const char* name, + const std::vector& dims, + TfLiteQuantization quantization, + bool is_variable = false); + + // Legacy. Deprecated in favor of above. inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantizationParams quantization, diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index 1ee993c6f9c65289c18f73e7974c5ff7df713d5b..f67733f89f258d931dfae68c9d5e9e83a4c522ce 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/interpreter.h" +#include #include #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -51,13 +52,25 @@ TfLiteRegistration* Register_NEG(); } // namespace ops namespace { +using ::testing::IsEmpty; + // Make an interpreter that has no tensors and no nodes TEST(BasicInterpreter, ZeroInterpreter) { + testing::internal::CaptureStderr(); + Interpreter interpreter; + EXPECT_THAT(testing::internal::GetCapturedStderr(), + testing::HasSubstr("INFO: Initialized TensorFlow Lite runtime")); + interpreter.SetInputs({}); interpreter.SetOutputs({}); ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Creating a new interpreter should not redundantly log runtime init. + testing::internal::CaptureStderr(); + Interpreter interpreter2; + EXPECT_THAT(testing::internal::GetCapturedStderr(), IsEmpty()); } // Test various error conditions. @@ -73,8 +86,9 @@ TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) { int tensor_index; ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk); constexpr int kTensorSize = 16; + TfLiteQuantizationParams quant; interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "", - {kTensorSize}, {}, true); + {kTensorSize}, quant, true); interpreter.SetVariables({tensor_index}); ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); TfLiteTensor* tensor = interpreter.tensor(tensor_index); @@ -170,6 +184,53 @@ TEST(BasicInterpreter, CheckAllocate) { } } +TEST(BasicInterpreter, CheckQuantization) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteType tensor_type = kTfLiteInt8; + const uint8_t int8s[] = {3, 4}; + float scale = 0.5f; + int32_t zero_point = 12; + + TfLiteQuantization rw_quantization; + rw_quantization.type = kTfLiteAffineQuantization; + auto* rw_affine_quantization = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + rw_affine_quantization->scale = TfLiteFloatArrayCreate(1); + rw_affine_quantization->zero_point = TfLiteIntArrayCreate(1); + rw_affine_quantization->scale->data[0] = scale; + rw_affine_quantization->zero_point->data[0] = zero_point; + rw_quantization.params = rw_affine_quantization; + + TfLiteQuantization ro_quantization; + ro_quantization.type = kTfLiteAffineQuantization; + auto* ro_affine_quantization = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + ro_affine_quantization->scale = TfLiteFloatArrayCreate(1); + ro_affine_quantization->zero_point = TfLiteIntArrayCreate(1); + ro_affine_quantization->scale->data[0] = scale; + ro_affine_quantization->zero_point->data[0] = zero_point; + ro_quantization.params = ro_affine_quantization; + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, tensor_type, "", {3}, + rw_quantization), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadOnly( + 1, tensor_type, "", {2}, ro_quantization, + reinterpret_cast(int8s), 2), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // Check that the legacy scale and zero_point are set correctly. + ASSERT_EQ(interpreter.tensor(0)->params.scale, scale); + ASSERT_EQ(interpreter.tensor(0)->params.zero_point, zero_point); + ASSERT_EQ(interpreter.tensor(0)->quantization.type, rw_quantization.type); + ASSERT_EQ(interpreter.tensor(1)->params.scale, scale); + ASSERT_EQ(interpreter.tensor(1)->params.zero_point, zero_point); + ASSERT_EQ(interpreter.tensor(1)->quantization.type, ro_quantization.type); +} + TEST(BasicInterpreter, CheckResize) { const float floats[] = {-3., -4.}; const int32_t int32s[] = {-3, -4}; diff --git a/tensorflow/lite/java/AndroidManifest.xml b/tensorflow/lite/java/AndroidManifest.xml index b91c6d149a213926be90b9b131bd632d4f79a0fc..a76a727ec75d231a506b4ef693b3dcd681515b1a 100644 --- a/tensorflow/lite/java/AndroidManifest.xml +++ b/tensorflow/lite/java/AndroidManifest.xml @@ -3,7 +3,6 @@ package="org.tensorflow.lite"> diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index a539a0cf771a898ec8d9eefb8d58590ccbeb7e22..8983079a31d7d99dbd666387c0a2c0ded63747e8 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -90,6 +90,9 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], javacopts = JAVACOPTS, + tags = [ + "no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ ":tensorflowlitelib", @@ -103,6 +106,9 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], javacopts = JAVACOPTS, + tags = [ + "no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ ":tensorflowlitelib", @@ -126,6 +132,9 @@ java_test( "src/testdata/with_custom_op.lite", ], javacopts = JAVACOPTS, + tags = [ + "no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ ":tensorflowlitelib", @@ -186,6 +195,9 @@ java_test( "src/testdata/add.bin", ], javacopts = JAVACOPTS, + tags = [ + "no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], test_class = "org.tensorflow.lite.TensorTest", deps = [ ":tensorflowlitelib", @@ -199,6 +211,9 @@ filegroup( srcs = select({ "//conditions:default": [":libtensorflowlite_jni.so"], }), + tags = [ + "no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 814d236872caff05e9fbd4dc5aa4a9a995eb586b..c6f315b545bbe8196999df07c6a4bcdfdaafa2d5 100644 --- a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -45,7 +45,6 @@ import android.os.Bundle; import android.os.Handler; import android.os.HandlerThread; import android.support.annotation.NonNull; -import android.support.v13.app.FragmentCompat; import android.support.v4.content.ContextCompat; import android.text.SpannableString; import android.text.SpannableStringBuilder; @@ -62,6 +61,7 @@ import android.widget.ListView; import android.widget.NumberPicker; import android.widget.TextView; import android.widget.Toast; +import android.support.v13.app.FragmentCompat; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -71,7 +71,6 @@ import java.util.List; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; - /** Basic fragments for the Camera. */ public class Camera2BasicFragment extends Fragment implements FragmentCompat.OnRequestPermissionsResultCallback { diff --git a/tensorflow/lite/java/jni/BUILD b/tensorflow/lite/java/jni/BUILD index ce17ac4fa0d37cb0b790617c4258ea469d14a664..3121cda7fe65a245a544fc8ec74c617f91166177 100644 --- a/tensorflow/lite/java/jni/BUILD +++ b/tensorflow/lite/java/jni/BUILD @@ -39,7 +39,7 @@ genrule( genrule( name = "copy_jni_md_h", srcs = select({ - "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//tensorflow:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], }), outs = ["jni_md.h"], diff --git a/tensorflow/lite/java/ovic/BUILD b/tensorflow/lite/java/ovic/BUILD index 774320871eec9afb2fae31824dc021fb7d338e1e..b00c9cd05809c9a694f32a25ae4fde3c33d40a88 100644 --- a/tensorflow/lite/java/ovic/BUILD +++ b/tensorflow/lite/java/ovic/BUILD @@ -19,7 +19,10 @@ java_test( "//tensorflow/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, - tags = ["no_oss"], + tags = [ + "no_mac", + "no_oss", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], test_class = "org.tensorflow.ovic.OvicClassifierTest", visibility = ["//visibility:public"], deps = [ @@ -87,7 +90,10 @@ java_test( "//tensorflow/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, - tags = ["no_oss"], + tags = [ + "no_mac", + "no_oss", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac. + ], test_class = "org.tensorflow.ovic.OvicDetectorTest", visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 725bb326ba1d6a9d9c206cd4fb01bdf687b0a79c..16cca45f388953e0616dd0b1b4c24114e5a6108b 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -107,7 +107,7 @@ public final class Tensor { throw new IllegalArgumentException( "Null inputs are allowed only if the Tensor is bound to a buffer handle."); } - throwExceptionIfTypeIsIncompatible(src); + throwIfDataIsIncompatible(src); if (isByteBuffer(src)) { ByteBuffer srcBuffer = (ByteBuffer) src; // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller @@ -138,7 +138,7 @@ public final class Tensor { throw new IllegalArgumentException( "Null outputs are allowed only if the Tensor is bound to a buffer handle."); } - throwExceptionIfTypeIsIncompatible(dst); + throwIfDataIsIncompatible(dst); if (dst instanceof ByteBuffer) { ByteBuffer dstByteBuffer = (ByteBuffer) dst; dstByteBuffer.put(buffer()); @@ -159,6 +159,7 @@ public final class Tensor { if (isByteBuffer(input)) { return null; } + throwIfTypeIsIncompatible(input); int[] inputShape = computeShapeOf(input); if (Arrays.equals(shapeCopy, inputShape)) { return null; @@ -243,16 +244,14 @@ public final class Tensor { } } - private void throwExceptionIfTypeIsIncompatible(Object o) { + private void throwIfDataIsIncompatible(Object o) { + throwIfTypeIsIncompatible(o); + throwIfShapeIsIncompatible(o); + } + + private void throwIfTypeIsIncompatible(Object o) { + // ByteBuffer payloads can map to any type, so exempt it from the check. if (isByteBuffer(o)) { - ByteBuffer oBuffer = (ByteBuffer) o; - if (oBuffer.capacity() != numBytes()) { - throw new IllegalArgumentException( - String.format( - "Cannot convert between a TensorFlowLite buffer with %d bytes and a " - + "ByteBuffer with %d bytes.", - numBytes(), oBuffer.capacity())); - } return; } DataType oType = dataTypeOf(o); @@ -263,7 +262,20 @@ public final class Tensor { + "object of type %s (which is compatible with the TensorFlowLite type %s).", dtype, o.getClass().getName(), oType)); } + } + private void throwIfShapeIsIncompatible(Object o) { + if (isByteBuffer(o)) { + ByteBuffer oBuffer = (ByteBuffer) o; + if (oBuffer.capacity() != numBytes()) { + throw new IllegalArgumentException( + String.format( + "Cannot convert between a TensorFlowLite buffer with %d bytes and a " + + "ByteBuffer with %d bytes.", + numBytes(), oBuffer.capacity())); + } + return; + } int[] oShape = computeShapeOf(o); if (!Arrays.equals(oShape, shapeCopy)) { throw new IllegalArgumentException( diff --git a/tensorflow/lite/java/src/main/native/exception_jni.cc b/tensorflow/lite/java/src/main/native/exception_jni.cc index 5406c7197f0c6ba6fd17c3472a365ef2d56d07a4..74217d6b682c72917ed00d4ef8173ab0c0af0605 100644 --- a/tensorflow/lite/java/src/main/native/exception_jni.cc +++ b/tensorflow/lite/java/src/main/native/exception_jni.cc @@ -31,12 +31,14 @@ void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { va_start(args, fmt); const size_t max_msg_len = 512; auto* message = static_cast(malloc(max_msg_len)); - if (vsnprintf(message, max_msg_len, fmt, args) >= 0) { + if (message && (vsnprintf(message, max_msg_len, fmt, args) >= 0)) { env->ThrowNew(env->FindClass(clazz), message); } else { env->ThrowNew(env->FindClass(clazz), ""); } - free(message); + if (message) { + free(message); + } va_end(args); } diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index c5496e3a21e7f5d27c36d92e49dd6c8e622b0070..ff3325633c1d71a950682764b6d1576e3b75ed6a 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.fail; import java.io.File; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.FloatBuffer; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Files; @@ -247,6 +248,18 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testRunWithUnsupportedInputType() { + FloatBuffer floatBuffer = FloatBuffer.allocate(10); + float[][][][] parsedOutputs = new float[2][8][8][3]; + try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { + interpreter.run(floatBuffer, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("DataType error: cannot resolve DataType of"); + } + } + @Test public void testRunWithWrongOutputType() { Interpreter interpreter = new Interpreter(MODEL_FILE); diff --git a/tensorflow/lite/kernels/Android.bp b/tensorflow/lite/kernels/Android.bp index f0c0d12834934ba1d81e3f854f21f5178e33a0dc..b9da5fb594ce543978fa5f10f2532d71ee00da6f 100644 --- a/tensorflow/lite/kernels/Android.bp +++ b/tensorflow/lite/kernels/Android.bp @@ -28,6 +28,8 @@ cc_library_static { ], cflags: [ "-Wno-extern-c-compat", + "-Wno-sign-compare", + "-Wno-unused-function", ] } @@ -37,12 +39,14 @@ cc_library_static { srcs: [ "activations.cc", "add.cc", + "add_n.cc", "arg_min_max.cc", "basic_rnn.cc", "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", "cast.cc", + "ceil.cc", "comparisons.cc", "concatenation.cc", "conv.cc", @@ -62,10 +66,11 @@ cc_library_static { "floor_mod.cc", "fully_connected.cc", "gather.cc", + "gather_nd.cc", "hashtable_lookup.cc", + "if.cc", "kernel_util.cc", "l2norm.cc", - "layer_norm_lstm.cc", "local_response_norm.cc", "logical.cc", "lsh_projection.cc", @@ -81,9 +86,11 @@ cc_library_static { "pooling.cc", "pow.cc", "range.cc", + "rank.cc", "reduce.cc", - "relu1.cc", "register.cc", + "reverse.cc", + "reverse_sequence.cc", "reshape.cc", "resize_bilinear.cc", "resize_nearest_neighbor.cc", @@ -109,6 +116,8 @@ cc_library_static { "unidirectional_sequence_rnn.cc", "unique.cc", "unpack.cc", + "where.cc", + "while.cc", "zeros_like.cc", "internal/kernel_utils.cc", "internal/tensor_utils.cc", @@ -127,10 +136,13 @@ cc_library_static { "-Wno-array-bounds", "-Wno-extern-c-compat", "-Wno-invalid-partial-specialization", + "-Wno-ignored-attributes", "-Wno-missing-field-initializers", "-Wno-sign-compare", "-Wno-unused-local-typedef", + "-Wno-unused-function", "-Wno-unused-variable", + "-Wno-unused-private-field", "-Wno-mismatched-tags", ], } diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 7a4b6b8644be52274f298f6a23c55d677fcfdd35..4f04445d972eb6b1025e49d3c3148be7b736fa74 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android") +load("//tensorflow:tensorflow.bzl", "tf_opts_nortti_if_android") # Suppress warnings that are introduced by Eigen Tensor. EXTRA_EIGEN_COPTS = select({ @@ -21,7 +21,7 @@ EXTRA_EIGEN_COPTS = select({ "//conditions:default": ["-Wno-error=reorder"], }) -tf_cc_test( +cc_test( name = "optional_tensor_test", size = "small", srcs = ["optional_tensor_test.cc"], @@ -46,6 +46,7 @@ cc_library( "//tensorflow/lite:string_util", "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/testing:util", + "//tensorflow/lite/tools/optimize:quantization_utils", "@com_google_googletest//:gtest", ], ) @@ -110,12 +111,12 @@ cc_library( ], deps = [ "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:round", - "//tensorflow/lite/kernels/internal:types", ], ) -tf_cc_test( +cc_test( name = "kernel_util_test", size = "small", srcs = ["kernel_util_test.cc"], @@ -126,7 +127,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "test_util_test", size = "small", srcs = ["test_util_test.cc"], @@ -151,6 +152,7 @@ cc_library( srcs = [ "activations.cc", "add.cc", + "add_n.cc", "arg_min_max.cc", "audio_spectrogram.cc", "basic_rnn.cc", @@ -158,6 +160,7 @@ cc_library( "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", "cast.cc", + "ceil.cc", "comparisons.cc", "concatenation.cc", "conv.cc", @@ -177,9 +180,10 @@ cc_library( "floor_mod.cc", "fully_connected.cc", "gather.cc", + "gather_nd.cc", "hashtable_lookup.cc", + "if.cc", "l2norm.cc", - "layer_norm_lstm.cc", "local_response_norm.cc", "logical.cc", "lsh_projection.cc", @@ -195,18 +199,19 @@ cc_library( "pooling.cc", "pow.cc", "range.cc", + "rank.cc", "reduce.cc", - "relu1.cc", "reshape.cc", "resize_bilinear.cc", "resize_nearest_neighbor.cc", + "reverse.cc", + "reverse_sequence.cc", "select.cc", "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", - "sparse_output_fully_connected.cc", "sparse_to_dense.cc", "split.cc", "split_v.cc", @@ -223,6 +228,8 @@ cc_library( "unidirectional_sequence_rnn.cc", "unique.cc", "unpack.cc", + "where.cc", + "while.cc", "zeros_like.cc", ], hdrs = [ @@ -285,14 +292,13 @@ cc_library( srcs = ["register_ref.cc"], hdrs = ["register_ref.h"], deps = [ - ":builtin_op_kernels", "//tensorflow/lite:framework", "//tensorflow/lite:util", "//tensorflow/lite/c:c_api_internal", ], ) -tf_cc_test( +cc_test( name = "audio_spectrogram_test", size = "small", srcs = ["audio_spectrogram_test.cc"], @@ -305,7 +311,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "mfcc_test", size = "small", srcs = ["mfcc_test.cc"], @@ -318,7 +324,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "detection_postprocess_test", size = "small", srcs = ["detection_postprocess_test.cc"], @@ -331,37 +337,22 @@ tf_cc_test( ], ) -tf_cc_test( - name = "relu1_test", - size = "small", - srcs = ["relu1_test.cc"], - deps = [ - ":builtin_ops", - "//tensorflow/lite:framework", - "//tensorflow/lite/kernels:test_util", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - -tf_cc_test( - name = "sparse_output_fully_connected_test", +cc_test( + name = "activations_test", size = "small", - srcs = ["sparse_output_fully_connected_test.cc"], + srcs = ["activations_test.cc"], deps = [ ":builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite/kernels:test_util", - "//tensorflow/lite/kernels/internal:types", "@com_google_googletest//:gtest", - "@flatbuffers", ], ) -tf_cc_test( - name = "activations_test", +cc_test( + name = "add_test", size = "small", - srcs = ["activations_test.cc"], + srcs = ["add_test.cc"], deps = [ ":builtin_ops", "//tensorflow/lite:framework", @@ -370,19 +361,19 @@ tf_cc_test( ], ) -tf_cc_test( - name = "add_test", +cc_test( + name = "add_n_test", size = "small", - srcs = ["add_test.cc"], + srcs = ["add_n_test.cc"], deps = [ ":builtin_ops", + ":test_util", "//tensorflow/lite:framework", - "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) -tf_cc_test( +cc_test( name = "arg_min_max_test", size = "small", srcs = ["arg_min_max_test.cc"], @@ -394,7 +385,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "div_test", size = "small", srcs = ["div_test.cc"], @@ -406,7 +397,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "sub_test", size = "small", srcs = ["sub_test.cc"], @@ -418,7 +409,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "transpose_test", size = "small", srcs = ["transpose_test.cc"], @@ -432,7 +423,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "space_to_batch_nd_test", size = "small", srcs = ["space_to_batch_nd_test.cc"], @@ -444,7 +435,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "batch_to_space_nd_test", size = "small", srcs = ["batch_to_space_nd_test.cc"], @@ -456,7 +447,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "cast_test", size = "small", srcs = ["cast_test.cc"], @@ -468,7 +459,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "concatenation_test", size = "small", srcs = ["concatenation_test.cc"], @@ -480,7 +471,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "conv_test", size = "small", srcs = ["conv_test.cc"], @@ -493,7 +484,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "depthwise_conv_test", size = "small", srcs = ["depthwise_conv_test.cc"], @@ -506,7 +497,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "dequantize_test", size = "small", srcs = ["dequantize_test.cc"], @@ -520,7 +511,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "basic_rnn_test", size = "small", srcs = ["basic_rnn_test.cc"], @@ -532,7 +523,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "bidirectional_sequence_lstm_test", size = "small", srcs = ["bidirectional_sequence_lstm_test.cc"], @@ -545,7 +536,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "floor_test", size = "small", srcs = ["floor_test.cc"], @@ -557,7 +548,22 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( + name = "ceil_test", + size = "small", + srcs = ["ceil_test.cc"], + tags = [ + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( name = "elementwise_test", size = "small", srcs = ["elementwise_test.cc"], @@ -569,7 +575,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "unidirectional_sequence_lstm_test", size = "small", srcs = ["unidirectional_sequence_lstm_test.cc"], @@ -581,7 +587,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "bidirectional_sequence_rnn_test", size = "small", srcs = ["bidirectional_sequence_rnn_test.cc"], @@ -593,7 +599,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "unidirectional_sequence_rnn_test", size = "small", srcs = ["unidirectional_sequence_rnn_test.cc"], @@ -605,7 +611,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "l2norm_test", size = "small", srcs = ["l2norm_test.cc"], @@ -617,7 +623,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "exp_test", size = "small", srcs = ["exp_test.cc"], @@ -629,7 +635,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "fake_quant_test", size = "small", srcs = ["fake_quant_test.cc"], @@ -641,7 +647,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "maximum_minimum_test", size = "small", srcs = ["maximum_minimum_test.cc"], @@ -653,7 +659,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "reduce_test", size = "small", srcs = ["reduce_test.cc"], @@ -665,7 +671,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "mul_test", size = "small", srcs = ["mul_test.cc"], @@ -677,7 +683,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "pad_test", size = "small", srcs = ["pad_test.cc"], @@ -689,7 +695,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "reshape_test", size = "small", srcs = ["reshape_test.cc"], @@ -701,7 +707,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "gather_test", size = "small", srcs = ["gather_test.cc"], @@ -714,7 +720,20 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( + name = "gather_nd_test", + size = "small", + srcs = ["gather_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( name = "topk_v2_test", size = "small", srcs = ["topk_v2_test.cc"], @@ -727,7 +746,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "resize_bilinear_test", size = "small", srcs = ["resize_bilinear_test.cc"], @@ -739,7 +758,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "resize_nearest_neighbor_test", size = "small", srcs = ["resize_nearest_neighbor_test.cc"], @@ -751,7 +770,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "svdf_test", size = "small", srcs = ["svdf_test.cc"], @@ -763,7 +782,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "embedding_lookup_test", size = "small", srcs = ["embedding_lookup_test.cc"], @@ -775,7 +794,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "embedding_lookup_sparse_test", size = "small", srcs = ["embedding_lookup_sparse_test.cc"], @@ -787,7 +806,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "fully_connected_test", size = "small", srcs = ["fully_connected_test.cc"], @@ -801,7 +820,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "local_response_norm_test", size = "small", srcs = ["local_response_norm_test.cc"], @@ -813,7 +832,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "pooling_test", size = "small", srcs = ["pooling_test.cc"], @@ -825,7 +844,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "softmax_test", size = "small", srcs = ["softmax_test.cc"], @@ -838,7 +857,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "log_softmax_test", size = "small", srcs = ["log_softmax_test.cc"], @@ -851,7 +870,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "lsh_projection_test", size = "small", srcs = ["lsh_projection_test.cc"], @@ -863,7 +882,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "hashtable_lookup_test", size = "small", srcs = ["hashtable_lookup_test.cc"], @@ -876,20 +895,7 @@ tf_cc_test( ], ) -tf_cc_test( - name = "layer_norm_lstm_test", - size = "small", - srcs = ["layer_norm_lstm_test.cc"], - deps = [ - ":builtin_ops", - "//tensorflow/lite:framework", - "//tensorflow/lite/kernels:test_util", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - -tf_cc_test( +cc_test( name = "lstm_test", size = "small", srcs = ["lstm_test.cc"], @@ -901,7 +907,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "skip_gram_test", size = "small", srcs = ["skip_gram_test.cc"], @@ -914,7 +920,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "space_to_depth_test", size = "small", srcs = ["space_to_depth_test.cc"], @@ -926,7 +932,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "split_test", size = "small", srcs = ["split_test.cc"], @@ -938,7 +944,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "split_v_test", size = "small", srcs = ["split_v_test.cc"], @@ -950,7 +956,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "squeeze_test", size = "small", srcs = ["squeeze_test.cc"], @@ -962,7 +968,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "strided_slice_test", size = "small", srcs = ["strided_slice_test.cc"], @@ -974,7 +980,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "tile_test", size = "small", srcs = ["tile_test.cc"], @@ -987,7 +993,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "comparisons_test", size = "small", srcs = [ @@ -1001,7 +1007,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "neg_test", size = "small", srcs = ["neg_test.cc"], @@ -1013,7 +1019,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "select_test", size = "small", srcs = [ @@ -1027,7 +1033,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "slice_test", size = "small", srcs = [ @@ -1041,7 +1047,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "transpose_conv_test", size = "small", srcs = ["transpose_conv_test.cc"], @@ -1054,7 +1060,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "expand_dims_test", size = "small", srcs = ["expand_dims_test.cc"], @@ -1067,7 +1073,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "sparse_to_dense_test", size = "small", srcs = ["sparse_to_dense_test.cc"], @@ -1080,7 +1086,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "shape_test", size = "small", srcs = ["shape_test.cc"], @@ -1093,7 +1099,20 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( + name = "rank_test", + size = "small", + srcs = ["rank_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( name = "pow_test", size = "small", srcs = ["pow_test.cc"], @@ -1106,7 +1125,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "pack_test", size = "small", srcs = ["pack_test.cc"], @@ -1119,7 +1138,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "one_hot_test", size = "small", srcs = ["one_hot_test.cc"], @@ -1131,7 +1150,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "logical_test", size = "small", srcs = ["logical_test.cc"], @@ -1144,7 +1163,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "unpack_test", size = "small", srcs = ["unpack_test.cc"], @@ -1157,7 +1176,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "floor_div_test", size = "small", srcs = ["floor_div_test.cc"], @@ -1170,7 +1189,20 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( + name = "where_test", + size = "small", + srcs = ["where_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( name = "zeros_like_test", size = "small", srcs = ["zeros_like_test.cc"], @@ -1183,7 +1215,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "floor_mod_test", size = "small", srcs = ["floor_mod_test.cc"], @@ -1196,7 +1228,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "range_test", size = "small", srcs = ["range_test.cc"], @@ -1209,19 +1241,54 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "squared_difference_test", size = "small", srcs = ["squared_difference_test.cc"], deps = [ ":builtin_ops", + "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:framework", "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) -tf_cc_test( +cc_test( + name = "if_test", + size = "small", + srcs = ["if_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + ":kernel_util", + ":subgraph_test_util", + ":test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +cc_test( + name = "while_test", + size = "small", + srcs = ["while_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + ":kernel_util", + ":subgraph_test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +cc_test( name = "fill_test", size = "small", srcs = ["fill_test.cc"], @@ -1233,7 +1300,7 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "unique_test", srcs = ["unique_test.cc"], deps = [ @@ -1244,6 +1311,18 @@ tf_cc_test( ], ) +cc_test( + name = "reverse_test", + size = "small", + srcs = ["reverse_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( @@ -1258,7 +1337,7 @@ filegroup( tflite_portable_test_suite() -tf_cc_test( +cc_test( name = "mirror_pad_test", srcs = ["mirror_pad_test.cc"], deps = [ @@ -1268,3 +1347,43 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "subgraph_test_util", + testonly = 1, + srcs = ["subgraph_test_util.cc"], + hdrs = ["subgraph_test_util.h"], + deps = [ + ":builtin_ops", + ":kernel_util", + ":test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +cc_test( + name = "subgraph_test_util_test", + size = "small", + srcs = ["subgraph_test_util_test.cc"], + deps = [ + ":subgraph_test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "reverse_sequence_test", + size = "small", + srcs = ["reverse_sequence_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 4463a6c5a65bf848ad68635717750d3a214dd0a0..930eabaeccfde5c9fce824a58d28d14783dde419 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -23,7 +23,10 @@ limitations under the License. #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -60,9 +63,9 @@ namespace { TfLiteStatus CheckOutputQuantParams(TfLiteContext* context, const TfLiteTensor* input, const TfLiteTensor* output) { + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - TF_LITE_ENSURE(context, output->params.scale == 1. / 256); } else { TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); } @@ -118,7 +121,7 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - if (input->type == kTfLiteUInt8) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { static constexpr int kInputIntegerBits = 4; const double input_real_multiplier = @@ -177,8 +180,15 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + std::numeric_limits::min()); + } + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + std::numeric_limits::min()); + } TF_LITE_ENSURE(context, output->params.scale == 1. / 256); static constexpr int kInputIntegerBits = 4; @@ -261,8 +271,13 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + } + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127); + } TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256); static const double kBeta = 1.0; @@ -353,6 +368,24 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { } } +namespace { +template +void QuantizedRelu6(const TfLiteTensor* input, TfLiteTensor* output) { + ActivationParams params; + params.activation_type = FusedActivationFunctionType::kRelu6; + params.quantized_activation_min = + std::max(static_cast(std::numeric_limits::min()), + output->params.zero_point + + static_cast(roundf(0.f / output->params.scale))); + params.quantized_activation_max = + std::min(static_cast(std::numeric_limits::max()), + output->params.zero_point + + static_cast(roundf(6.f / output->params.scale))); + optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); +} +} // namespace + TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); @@ -365,23 +398,16 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); return kTfLiteOk; } break; - case kTfLiteUInt8: { - ActivationParams params; - params.activation_type = FusedActivationFunctionType::kRelu6; - params.quantized_activation_min = std::max( - 0, output->params.zero_point + - static_cast(roundf(0.f / output->params.scale))); - params.quantized_activation_max = std::min( - 255, output->params.zero_point + - static_cast(roundf(6.f / output->params.scale))); - optimized_ops::ReluX(params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(output), - GetTensorData(output)); + case kTfLiteUInt8: + QuantizedRelu6(input, output); + return kTfLiteOk; + case kTfLiteInt8: { + QuantizedRelu6(input, output); return kTfLiteOk; } break; default: context->ReportError( - context, "Only float32 and uint8 supported currently, got %s.", + context, "Only float32, uint8 and int8 supported currently, got %s.", TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -436,6 +462,16 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { } return kTfLiteOk; } break; + case kTfLiteInt8: { + const auto input_shape = GetTensorShape(input); + const auto output_shape = GetTensorShape(output); + const int size = MatchingFlatSize(input_shape, output_shape); + reference_integer_ops::Tanh( + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, size, + GetTensorData(input), GetTensorData(output)); + return kTfLiteOk; + } break; default: context->ReportError(context, "Only float32 supported currently, got %s.", TfLiteTypeGetName(input->type)); @@ -493,6 +529,15 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { } break; } + case kTfLiteInt8: { + const int input_size = + MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)); + reference_integer_ops::Logistic( + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, input_size, + GetTensorData(input), GetTensorData(output)); + break; + } default: context->ReportError(context, "Only float32 supported currently, got %s.", TfLiteTypeGetName(input->type)); @@ -815,6 +860,21 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } return kTfLiteOk; } + case kTfLiteInt8: { + const auto input_shape = GetTensorShape(input); + const auto output_shape = GetTensorShape(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + reference_integer_ops::LogSoftmax( + data->input_multiplier, data->input_left_shift, + data->reverse_scaling_divisor, data->reverse_scaling_right_shift, + data->diff_min, outer_size, depth, GetTensorData(input), + GetTensorData(output)); + return kTfLiteOk; + } default: context->ReportError(context, "Only float32 supported currently., got %s", TfLiteTypeGetName(input->type)); @@ -884,8 +944,31 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + optimized_ops::Elu(GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + } break; + default: + context->ReportError(context, "Only float32 supported currently, got %s.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } +} + } // namespace activations +TfLiteRegistration* Register_ELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::EluEval}; + return &r; +} + TfLiteRegistration* Register_RELU() { static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, activations::GenericPrepare, diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index 5e3c56ed5bf7092581fbbced6d3735958c19580c..25b17a9678728f0ee82ccf22e2a5b63eee2c3537 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -32,6 +32,8 @@ class BaseActivationsOpModel : public SingleOpModel { input_ = AddInput(input); if (input.type == TensorType_UINT8) { output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else if (input.type == TensorType_INT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256, -128}); } else { output_ = AddOutput({input.type, {}}); } @@ -116,6 +118,20 @@ class QuantizedActivationsOpModel : public BaseActivationsOpModel { } }; +TEST(FloatActivationsOpTest, Elu) { + FloatActivationsOpModel m(BuiltinOperator_ELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, -4, // + 3, -2, 10, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.0, -0.997521, 2.0, -0.981684, // + 3.0, -0.864665, 10.0, -0.0951626, // + }))); +} + TEST(FloatActivationsOpTest, Relu) { FloatActivationsOpModel m(BuiltinOperator_RELU, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); @@ -172,7 +188,7 @@ TEST(FloatActivationsOpTest, Tanh) { }))); } -TEST(QuantizedActivationsOpTest, Relu6) { +TEST(QuantizedActivationsOpTest, Relu6Uint8) { const float kMin = -1; const float kMax = 127.f / 128.f; QuantizedActivationsOpModel m( @@ -195,7 +211,29 @@ TEST(QuantizedActivationsOpTest, Relu6) { ElementsAreArray({128, 128, 160, 192, 176, 128, 224, 144})); } -TEST(QuantizedActivationsOpTest, Tanh) { +TEST(QuantizedActivationsOpTest, Relu6Int8) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + QuantizedActivationsOpModel m( + BuiltinOperator_RELU6, + /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0, 0, 2, 4, // + 3, 0, 6, 1, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 32, 64, 48, 0, 96, 16})); +} + +TEST(QuantizedActivationsOpTest, TanhUint8) { const float kMin = -1; const float kMax = 127.f / 128.f; QuantizedActivationsOpModel m( @@ -218,6 +256,29 @@ TEST(QuantizedActivationsOpTest, Tanh) { ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225})); } +TEST(QuantizedActivationsOpTest, TanhInt8) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + QuantizedActivationsOpModel m( + BuiltinOperator_TANH, + /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT8, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + -4, -2, 8, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.0, -0.999987, 0.964027, 0.999329, // + -0.999329, -0.96402, 0.99999, 0.76159, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, -128, 123, 127, -128, -123, 127, 97})); +} + TEST(QuantizedActivationsOpTest, TanhInt16) { const float kMin = -1; const float kMax = 32767.f / 32768.f; @@ -253,7 +314,7 @@ TEST(FloatActivationsOpTest, Sigmoid) { }))); } -TEST(QuantizedActivationsOpTest, Sigmoid) { +TEST(QuantizedActivationsOpTest, SigmoidUint8) { QuantizedActivationsOpModel m( BuiltinOperator_LOGISTIC, /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); @@ -273,6 +334,26 @@ TEST(QuantizedActivationsOpTest, Sigmoid) { ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); } +TEST(QuantizedActivationsOpTest, SigmoidInt8) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_INT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, -127, 99, 123, 116, -99, 127, 60})); +} + TEST(QuantizedActivationsOpTest, SigmoidInt16) { const float kMin = -1; const float kMax = 32767.f / 32768.f; @@ -705,7 +786,7 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } -TEST(QuantizedActivationsOpTest, LogSoftmax) { +TEST(QuantizedActivationsOpTest, LogSoftmaxUint8) { const float kLogSoftmaxQuantizedTolerance = 16 / 256.0; QuantizedActivationsOpModel m( BuiltinOperator_LOG_SOFTMAX, @@ -727,6 +808,30 @@ TEST(QuantizedActivationsOpTest, LogSoftmax) { ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111})); } +TEST(QuantizedActivationsOpTest, LogSoftmaxInt8) { + const float kLogSoftmaxQuantizedTolerance = 0.06355; + QuantizedActivationsOpModel m( + BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_INT8, {2, 4}, -10, 10}, + /*output=*/{TensorType_INT8, {}, 0, 0, 16. / 256, 127}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -4.14297, -10.14297, -2.14297, -.142971, // + -7.00104, -12.00104, -.00104087, -9.00104, // + }, + kLogSoftmaxQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 61, -36, 93, 125, // + 15, -65, 127, -16, // + })); +} + // A base class of PRelu op model. It provides the constructor for // FloatPReluOpModel and QuantizedPReluOpModel. class BasePReluOpModel : public SingleOpModel { diff --git a/tensorflow/lite/kernels/add.cc b/tensorflow/lite/kernels/add.cc index 9867cc53b342d9fddda81db270c223de2ecda14f..4cfe435e9e2c4dab7253c4be4fffcb991cea4abd 100644 --- a/tensorflow/lite/kernels/add.cc +++ b/tensorflow/lite/kernels/add.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" @@ -92,7 +93,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } - if (output->type == kTfLiteUInt8) { + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { // 8bit -> 8bit general quantized path, with general rescalings data->input1_offset = -input1->params.zero_point; data->input2_offset = -input2->params.zero_point; @@ -117,10 +118,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { QuantizeMultiplierSmallerThanOneExp( real_output_multiplier, &data->output_multiplier, &data->output_shift); - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); - + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } else { + CalculateActivationRangeInt8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } } else if (output->type == kTfLiteInt16) { // 16bit -> 16bit special quantized path, supporting only a rather // narrow case of quantization parameters: zero_points must all be 0 @@ -219,7 +225,7 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { - if (output->type == kTfLiteUInt8) { + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { tflite::ArithmeticParams op_params; op_params.left_shift = data->left_shift; op_params.input1_offset = data->input1_offset; @@ -235,25 +241,33 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, data->output_activation_max, &op_params); bool need_broadcast = optimized_ops::ProcessBroadcastShapes( GetTensorShape(input1), GetTensorShape(input2), &op_params); -#define TF_LITE_ADD(type, opname) \ - type::opname(op_params, GetTensorShape(input1), \ - GetTensorData(input1), GetTensorShape(input2), \ - GetTensorData(input2), GetTensorShape(output), \ - GetTensorData(output)); - if (kernel_type == kReference) { +#define TF_LITE_ADD(type, opname, dtype) \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)); + if (output->type == kTfLiteInt8) { if (need_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow); + TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t); } else { - TF_LITE_ADD(reference_ops, Add); + TF_LITE_ADD(reference_integer_ops, Add, int8_t); } } else { - if (op_params.broadcast_category == - BroadcastableOpCategory::kGenericBroadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow); - } else if (need_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAddFivefold); + if (kernel_type == kReference) { + if (need_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t); + } else { + TF_LITE_ADD(reference_ops, Add, uint8_t); + } } else { - TF_LITE_ADD(optimized_ops, Add); + if (op_params.broadcast_category == + BroadcastableOpCategory::kGenericBroadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, uint8_t); + } else if (need_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, uint8_t); + } else { + TF_LITE_ADD(optimized_ops, Add, uint8_t); + } } } #undef TF_LITE_ADD @@ -292,7 +306,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { EvalAdd(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || + output->type == kTfLiteInt16) { TF_LITE_ENSURE_OK(context, EvalAddQuantized(context, node, params, data, input1, input2, output)); diff --git a/tensorflow/lite/kernels/add_n.cc b/tensorflow/lite/kernels/add_n.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e9b2ea24afcd3eff107b110e7d5bb6226d95d3a --- /dev/null +++ b/tensorflow/lite/kernels/add_n.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace add_n { + +constexpr int kInputTensor1 = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + int num_inputs = NumInputs(node); + TF_LITE_ENSURE(context, num_inputs >= 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + output->type = input1->type; + + // Check that all input tensors have the same shape and type. + for (int i = kInputTensor1 + 1; i < num_inputs; ++i) { + const TfLiteTensor* input = GetInput(context, node, i); + TF_LITE_ENSURE(context, HaveSameShapes(input1, input)); + TF_LITE_ENSURE_EQ(context, input1->type, input->type); + } + + // Use the first input node's dimension to be the dimension of the output + // node. + TfLiteIntArray* input1_dims = input1->dims; + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims); + return context->ResizeTensor(context, output, output_dims); +} + +template +void EvalAddN(TfLiteContext* context, TfLiteNode* node) { + // TODO(haoliang): Initialize all_inputs only once during init. + VectorOfTensors all_inputs(*context, *node->inputs); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int num_inputs = NumInputs(node); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + reference_ops::AddN(GetTensorShape(input1), num_inputs, all_inputs.data(), + GetTensorData(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + if (output->type == kTfLiteFloat32) { + EvalAddN(context, node); + } else if (output->type == kTfLiteInt32) { + EvalAddN(context, node); + } else { + context->ReportError(context, + "AddN only supports FLOAT32|INT32 now, got %s.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace add_n + +TfLiteRegistration* Register_ADD_N() { + static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr, + add_n::Prepare, add_n::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/add_n_test.cc b/tensorflow/lite/kernels/add_n_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee9477d2ff13c4e4f4e2da815d8f5660ab5b6c4e --- /dev/null +++ b/tensorflow/lite/kernels/add_n_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseAddNOpModel : public SingleOpModel { + public: + BaseAddNOpModel(const std::vector& inputs, + const TensorData& output) { + int num_inputs = inputs.size(); + std::vector> input_shapes; + + for (int i = 0; i < num_inputs; ++i) { + inputs_.push_back(AddInput(inputs[i])); + input_shapes.push_back(GetShape(inputs_[i])); + } + + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions, + CreateAddNOptions(builder_).Union()); + BuildInterpreter(input_shapes); + } + + int input(int i) { return inputs_[i]; } + + protected: + std::vector inputs_; + int output_; +}; + +class FloatAddNOpModel : public BaseAddNOpModel { + public: + using BaseAddNOpModel::BaseAddNOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class IntegerAddNOpModel : public BaseAddNOpModel { + public: + using BaseAddNOpModel::BaseAddNOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(FloatAddNOpModel, AddMultipleTensors) { + FloatAddNOpModel m({{TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}}, + {TensorType_FLOAT32, {}}); + m.PopulateTensor(m.input(0), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input(1), {0.1, 0.2, 0.3, 0.5}); + m.PopulateTensor(m.input(2), {0.5, 0.1, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.4, 0.5, 1.1, 1.5})); +} + +TEST(IntegerAddNOpModel, AddMultipleTensors) { + IntegerAddNOpModel m({{TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}}, + {TensorType_INT32, {}}); + m.PopulateTensor(m.input(0), {-20, 2, 7, 8}); + m.PopulateTensor(m.input(1), {1, 2, 3, 5}); + m.PopulateTensor(m.input(2), {10, -5, 1, -2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-9, -1, 11, 11})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/add_test.cc b/tensorflow/lite/kernels/add_test.cc index 16045d457238e482bd7aad1077d0344632a7550b..2904f4a11a947264cb12fc2e8c0a7822df24c678 100644 --- a/tensorflow/lite/kernels/add_test.cc +++ b/tensorflow/lite/kernels/add_test.cc @@ -63,9 +63,10 @@ class QuantizedAddOpModel : public BaseAddOpModel { public: using BaseAddOpModel::BaseAddOpModel; + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); } std::vector GetDequantizedOutputInt16() { @@ -74,17 +75,15 @@ class QuantizedAddOpModel : public BaseAddOpModel { } }; -// for quantized Add, the error shouldn't exceed 2*step +// for quantized Add, the error shouldn't exceed step float GetTolerance(float min, float max) { float kQuantizedStep = (max - min) / 255.0; - float kQuantizedTolerance = 2.0 * kQuantizedStep; - return kQuantizedTolerance; + return kQuantizedStep; } float GetToleranceInt16(float min, float max) { float kQuantizedStep = (max - min) / 32767.f; - float kQuantizedTolerance = 2.0 * kQuantizedStep; - return kQuantizedTolerance; + return kQuantizedStep; } TEST(FloatAddOpModel, NoActivation) { @@ -191,7 +190,8 @@ TEST(IntegerAddOpModel, WithBroadcast) { } } -TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { +template +void QuantizedTestsNoActivation() { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; @@ -200,19 +200,28 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { std::vector> results = { {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { - QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}, + QuantizedAddOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {}, -1.0, 1.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), inputs1[i]); - m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - results[i], kQuantizedTolerance))) + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) << "With test number " << i; } } +TEST(QuantizedAddOpModel, QuantizedTestsNoActivationUInt8) { + QuantizedTestsNoActivation(); +} + +TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt8) { + QuantizedTestsNoActivation(); +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) { const float kMin = -1.f; const float kMax = 32767.f / 32768.f; @@ -238,7 +247,8 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) { } } -TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { +template +void QuantizedTestsActivationRELU_N1_TO_1() { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; @@ -247,53 +257,74 @@ TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { std::vector> results = {{-0.2, 0.6, 1.0, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { - QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}, + QuantizedAddOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {}, -1.0, 1.0}, ActivationFunctionType_RELU_N1_TO_1); - m.QuantizeAndPopulate(m.input1(), inputs1[i]); - m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - results[i], kQuantizedTolerance))) + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) << "With test number " << i; } } -TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1UInt8) { + QuantizedTestsActivationRELU_N1_TO_1(); +} + +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1Int8) { + QuantizedTestsActivationRELU_N1_TO_1(); +} + +template +void QuantizedVariousInputShapes() { float kQuantizedTolerance = GetTolerance(-3.0, 3.0); std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, {}, -3.0, 3.0}, + QuantizedAddOpModel m({tensor_type, test_shapes[i], -3.0, 3.0}, + {tensor_type, test_shapes[i], -3.0, 3.0}, + {tensor_type, {}, -3.0, 3.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); - m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); + m.QuantizeAndPopulate(m.input1(), + {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), + {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1}, kQuantizedTolerance))) << "With shape number " << i; } } -TEST(QuantizedAddOpModel, QuantizedWithScalarBroadcast) { +TEST(QuantizedAddOpModel, QuantizedVariousInputShapesUInt8) { + QuantizedVariousInputShapes(); +} + +TEST(QuantizedAddOpModel, QuantizedVariousInputShapesInt8) { + QuantizedVariousInputShapes(); +} + +template +void QuantizedWithScalarBroadcast() { float kQuantizedTolerance = GetTolerance(-3.f, 3.f); std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedAddOpModel model_fixture( - {TensorType_UINT8, test_shapes[i], -3.f, 3.f}, - {TensorType_UINT8, {}, -3.f, 3.f}, {TensorType_UINT8, {}, -3.f, 3.f}, - ActivationFunctionType_NONE); - model_fixture.QuantizeAndPopulate( + {tensor_type, test_shapes[i], -3.f, 3.f}, {tensor_type, {}, -3.f, 3.f}, + {tensor_type, {}, -3.f, 3.f}, ActivationFunctionType_NONE); + model_fixture.QuantizeAndPopulate( model_fixture.input1(), {-2.0f, 0.2f, 0.7f, 0.8f, 1.1f, 2.0f}); - model_fixture.QuantizeAndPopulate(model_fixture.input2(), {0.1f}); + model_fixture.QuantizeAndPopulate(model_fixture.input2(), + {0.1f}); model_fixture.Invoke(); EXPECT_THAT( - model_fixture.GetDequantizedOutput(), + model_fixture.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-1.9f, 0.3f, 0.8f, 0.9f, 1.2f, 2.1f}, kQuantizedTolerance))) << "With shape number " << i; @@ -301,22 +332,31 @@ TEST(QuantizedAddOpModel, QuantizedWithScalarBroadcast) { // Re-run with exchanged inputs. for (int i = 0; i < test_shapes.size(); ++i) { QuantizedAddOpModel model_fixture( - {TensorType_UINT8, {}, -3.f, 3.f}, - {TensorType_UINT8, test_shapes[i], -3.f, 3.f}, - {TensorType_UINT8, {}, -3.f, 3.f}, ActivationFunctionType_NONE); - model_fixture.QuantizeAndPopulate(model_fixture.input1(), {0.1f}); - model_fixture.QuantizeAndPopulate( + {tensor_type, {}, -3.f, 3.f}, {tensor_type, test_shapes[i], -3.f, 3.f}, + {tensor_type, {}, -3.f, 3.f}, ActivationFunctionType_NONE); + model_fixture.QuantizeAndPopulate(model_fixture.input1(), + {0.1f}); + model_fixture.QuantizeAndPopulate( model_fixture.input2(), {-2.0f, 0.2f, 0.7f, 0.8f, 1.1f, 2.0f}); model_fixture.Invoke(); EXPECT_THAT( - model_fixture.GetDequantizedOutput(), + model_fixture.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-1.9f, 0.3f, 0.8f, 0.9f, 1.2f, 2.1f}, kQuantizedTolerance))) << "With shape number " << i; } } -TEST(QuantizedAddOpModel, QuantizedWithMixedBroadcast) { +TEST(QuantizedAddOpModel, QuantizedWithScalarBroadcastUInt8) { + QuantizedWithScalarBroadcast(); +} + +TEST(QuantizedAddOpModel, QuantizedWithScalarBroadcastInt8) { + QuantizedWithScalarBroadcast(); +} + +template +void QuantizedWithMixedBroadcast() { float kQuantizedTolerance = GetTolerance(-3.f, 3.f); const std::vector base_shape = {2, 3, 1, 2}; std::vector> test_shapes = { @@ -335,40 +375,48 @@ TEST(QuantizedAddOpModel, QuantizedWithMixedBroadcast) { {-0.1f, 2.5f, 1.2f, 0.8f, 0.4f, -1.5f, 1.7f, 3.0f, -0.6f, 1.0f, 1.6f, -1.3f}}; for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedAddOpModel model_fixture( - {TensorType_UINT8, base_shape, -3.f, 3.f}, - {TensorType_UINT8, test_shapes[i], -3.f, 3.f}, - {TensorType_UINT8, {}, -3.f, 3.f}, ActivationFunctionType_NONE); - model_fixture.QuantizeAndPopulate( + QuantizedAddOpModel model_fixture({tensor_type, base_shape, -3.f, 3.f}, + {tensor_type, test_shapes[i], -3.f, 3.f}, + {tensor_type, {}, -3.f, 3.f}, + ActivationFunctionType_NONE); + model_fixture.QuantizeAndPopulate( model_fixture.input1(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f, 2.8f, -1.6f, 0.0f, 0.7f, -2.2f}); - model_fixture.QuantizeAndPopulate( + model_fixture.QuantizeAndPopulate( model_fixture.input2(), {0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f}); model_fixture.Invoke(); EXPECT_THAT( - model_fixture.GetDequantizedOutput(), + model_fixture.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(test_outputs[i], kQuantizedTolerance))) << "With shape number " << i; } // Re-run with exchanged inputs. for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedAddOpModel model_fixture( - {TensorType_UINT8, test_shapes[i], -3.f, 3.f}, - {TensorType_UINT8, base_shape, -3.f, 3.f}, - {TensorType_UINT8, {}, -3.f, 3.f}, ActivationFunctionType_NONE); - model_fixture.QuantizeAndPopulate( + QuantizedAddOpModel model_fixture({tensor_type, test_shapes[i], -3.f, 3.f}, + {tensor_type, base_shape, -3.f, 3.f}, + {tensor_type, {}, -3.f, 3.f}, + ActivationFunctionType_NONE); + model_fixture.QuantizeAndPopulate( model_fixture.input1(), {0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f}); - model_fixture.QuantizeAndPopulate( + model_fixture.QuantizeAndPopulate( model_fixture.input2(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f, 2.8f, -1.6f, 0.0f, 0.7f, -2.2f}); model_fixture.Invoke(); EXPECT_THAT( - model_fixture.GetDequantizedOutput(), + model_fixture.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(test_outputs[i], kQuantizedTolerance))) << "With shape number " << i; } } +TEST(QuantizedAddOpModel, QuantizedWithMixedBroadcastUInt8) { + QuantizedWithMixedBroadcast(); +} + +TEST(QuantizedAddOpModel, QuantizedWithMixedBroadcastInt8) { + QuantizedWithMixedBroadcast(); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/lite/kernels/arg_min_max.cc b/tensorflow/lite/kernels/arg_min_max.cc index eea2de27f74af8bf73df92c28ed6042e4d8fa4ff..e5223badc407059511f06cd538b6057c1e276966 100644 --- a/tensorflow/lite/kernels/arg_min_max.cc +++ b/tensorflow/lite/kernels/arg_min_max.cc @@ -36,9 +36,15 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input, axis_value += NumDimensions(input); } - // Copy the input dimensions to output except make the axis dimension 1. - TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); - output_dims->data[axis_value] = 1; + // Copy the input dimensions to output except the axis dimension. + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(NumDimensions(input) - 1); + int j = 0; + for (int i = 0; i < NumDimensions(input); ++i) { + if (i != axis_value) { + output_dims->data[j] = SizeOfDimension(input, i); + ++j; + } + } return context->ResizeTensor(context, output, output_dims); } @@ -74,13 +80,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: case kTfLiteUInt8: + case kTfLiteInt8: case kTfLiteInt32: break; default: context->ReportError( context, - "Unkonwn input type: %d, only float32 and int types are supported", + "Unknown input type: %d, only float32 and int types are supported", input->type); return kTfLiteError; } @@ -129,6 +136,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { case kTfLiteUInt8: TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t); break; + case kTfLiteInt8: + TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t); + break; case kTfLiteInt32: TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t); break; @@ -144,6 +154,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { case kTfLiteUInt8: TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t); break; + case kTfLiteInt8: + TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t); + break; case kTfLiteInt32: TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t); break; diff --git a/tensorflow/lite/kernels/arg_min_max_test.cc b/tensorflow/lite/kernels/arg_min_max_test.cc index dcdff74cc6f376b3418b64c025e8eb4a36c429a0..01ea923f26d3ca32ec109a61d0484b0ecbd30c93 100644 --- a/tensorflow/lite/kernels/arg_min_max_test.cc +++ b/tensorflow/lite/kernels/arg_min_max_test.cc @@ -83,7 +83,29 @@ TEST(ArgMaxOpTest, GetMaxArgFloat) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1})); +} + +TEST(ArgMaxOpTest, GetMaxArgUInt8) { + ArgMaxOpModel model({1, 1, 1, 4}, TensorType_UINT8, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor(model.input(), {1, 9, 7, 3}); + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1})); +} + +TEST(ArgMaxOpTest, GetMaxArgInt8) { + ArgMaxOpModel model({1, 1, 1, 4}, TensorType_INT8, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor(model.input(), {-1, -9, 7, 3}); + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1})); } TEST(ArgMaxOpTest, GetMaxArgInt) { @@ -94,7 +116,7 @@ TEST(ArgMaxOpTest, GetMaxArgInt) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1})); } TEST(ArgMaxOpTest, GetMaxArgMulDimensions) { @@ -105,7 +127,7 @@ TEST(ArgMaxOpTest, GetMaxArgMulDimensions) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 1})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2})); } TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) { @@ -116,7 +138,7 @@ TEST(ArgMaxOpTest, GetMaxArgNegativeAxis) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1, 0, 0})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4})); } TEST(ArgMaxOpTest, GetMaxArgOutput64) { @@ -127,7 +149,7 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 1})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2})); } TEST(ArgMinOpTest, GetMinArgFloat) { @@ -138,7 +160,7 @@ TEST(ArgMinOpTest, GetMinArgFloat) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1})); } TEST(ArgMinOpTest, GetMinArgInt) { @@ -149,7 +171,7 @@ TEST(ArgMinOpTest, GetMinArgInt) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1})); } TEST(ArgMinOpTest, GetMinArgMulDimensions) { @@ -160,7 +182,7 @@ TEST(ArgMinOpTest, GetMinArgMulDimensions) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2})); } TEST(ArgMinOpTest, GetMinArgNegativeAxis) { @@ -171,7 +193,7 @@ TEST(ArgMinOpTest, GetMinArgNegativeAxis) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 1})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 4})); } TEST(ArgMinOpTest, GetMinArgOutput64) { @@ -182,7 +204,7 @@ TEST(ArgMinOpTest, GetMinArgOutput64) { model.Invoke(); EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2})); } } // namespace diff --git a/tensorflow/lite/kernels/batch_to_space_nd.cc b/tensorflow/lite/kernels/batch_to_space_nd.cc index 34fdf34f70c9660266e23260bd5a6b645a3c5ccb..ce85aeddedcebdecf4d2944bade2ed5f823b0592 100644 --- a/tensorflow/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/lite/kernels/batch_to_space_nd.cc @@ -148,6 +148,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t); } break; + case kTfLiteInt8: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t); + } + break; case kTfLiteInt32: if (kernel_type == kReference) { TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t); diff --git a/tensorflow/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/lite/kernels/batch_to_space_nd_test.cc index f33089559992c1a6a6fa34161122c43b7954fbdb..bd806b55ca48424e143a77d1f95640365af5fe77 100644 --- a/tensorflow/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/lite/kernels/batch_to_space_nd_test.cc @@ -26,8 +26,9 @@ using ::testing::ElementsAreArray; class BatchToSpaceNDOpModel : public SingleOpModel { public: - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } void SetBlockShape(std::initializer_list data) { @@ -38,7 +39,10 @@ class BatchToSpaceNDOpModel : public SingleOpModel { PopulateTensor(crops_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } std::vector GetOutputShape() { return GetTensorShape(output_); } protected: @@ -58,11 +62,12 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { public: BatchToSpaceNDOpConstModel(std::initializer_list input_shape, std::initializer_list block_shape, - std::initializer_list crops) { - input_ = AddInput(TensorType_FLOAT32); + std::initializer_list crops, + const TensorType& type = TensorType_FLOAT32) { + input_ = AddInput(type); block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); crops_ = AddConstInput(TensorType_INT32, crops, {2, 2}); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(type); SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, BuiltinOptions_BatchToSpaceNDOptions, @@ -81,11 +86,12 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { // m.Invoke(); class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { public: - BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); + BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape, + const TensorType& type = TensorType_FLOAT32) { + input_ = AddInput(type); block_shape_ = AddInput(TensorType_INT32); crops_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(type); SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, BuiltinOptions_BatchToSpaceNDOptions, @@ -96,22 +102,47 @@ class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { TEST(BatchToSpaceNDOpTest, SimpleConstTest) { BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, SimpleConstTestInt8) { + BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}, + TensorType_INT8); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, - 4, 8, 11, 15, 12, 16})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16})); } TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, SimpleDynamicTestInt8) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}, TensorType_INT8); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetCrops({0, 0, 0, 0}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, - 4, 8, 11, 15, 12, 16})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16})); } #ifdef GTEST_HAS_DEATH_TEST @@ -127,7 +158,7 @@ TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) { TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) { BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetCrops({0, 0, -1, 0}); EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true."); diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index e807626108c2d45071adab62416c9c31f04d5a9a..0adf574bb0641b2ddd2774f1563a92a66023f7a2 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -53,7 +53,10 @@ constexpr int kBwWeightsTensor = 5; constexpr int kBwRecurrentWeightsTensor = 6; constexpr int kBwBiasTensor = 7; constexpr int kBwHiddenStateTensor = 8; -// Auxiliary inputs. +// Used as auxiliary input and weights when stacking for +// tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input +// to the backward cell when stacking for tf.nn.static_bidirectional_rnn case +// (without cross links). constexpr int kAuxInputTensor = 9; // Optional. constexpr int kFwAuxWeightsTensor = 10; // Optional. constexpr int kBwAuxWeightsTensor = 11; // Optional. @@ -113,13 +116,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bw_aux_input_weights = GetOptionalInputTensor(context, node, kBwAuxWeightsTensor); - const bool aux_inputs_all_or_none = - ((aux_input != nullptr) && (fw_aux_input_weights != nullptr) && + const bool aux_inputs_weights_or_none = + ((fw_aux_input_weights != nullptr) && (bw_aux_input_weights != nullptr)) || - ((aux_input == nullptr) && (fw_aux_input_weights == nullptr) && - (bw_aux_input_weights == nullptr)); - TF_LITE_ENSURE(context, aux_inputs_all_or_none); - const bool has_aux_input = (aux_input != nullptr); + ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr)); + TF_LITE_ENSURE(context, aux_inputs_weights_or_none); + const bool has_aux_input = (fw_aux_input_weights != nullptr); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -277,16 +279,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteStatus EvalFloat( - const TfLiteTensor* input, const TfLiteTensor* fw_input_weights, - const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias, - const TfLiteTensor* bw_input_weights, - const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, - const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights, - const TfLiteTensor* bw_aux_input_weights, - const TfLiteBidirectionalSequenceRNNParams* params, - TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, - TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { +TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input, + const TfLiteTensor* fw_input_weights, + const TfLiteTensor* fw_recurrent_weights, + const TfLiteTensor* fw_bias, + const TfLiteTensor* bw_input_weights, + const TfLiteTensor* bw_recurrent_weights, + const TfLiteTensor* bw_bias, + const TfLiteTensor* aux_input, + const TfLiteTensor* fw_aux_input_weights, + const TfLiteTensor* bw_aux_input_weights, + const TfLiteBidirectionalSequenceRNNParams* params, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -339,7 +344,7 @@ TfLiteStatus EvalFloat( float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = - input->data.f + s * input_size * batch_size; + bw_input->data.f + s * input_size * batch_size; const float* aux_input_ptr_batch = (aux_input != nullptr) ? aux_input->data.f + s * input_size * batch_size @@ -407,7 +412,8 @@ TfLiteStatus EvalFloat( } TfLiteStatus EvalHybrid( - const TfLiteTensor* input, const TfLiteTensor* fw_input_weights, + const TfLiteTensor* input, const TfLiteTensor* bw_input, + const TfLiteTensor* fw_input_weights, const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias, const TfLiteTensor* bw_input_weights, const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, @@ -504,7 +510,7 @@ TfLiteStatus EvalHybrid( float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = - input->data.f + s * input_size * batch_size; + bw_input->data.f + s * input_size * batch_size; const float* aux_input_ptr_batch = (aux_input != nullptr) ? aux_input->data.f + s * input_size * batch_size @@ -616,13 +622,35 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ? nullptr : GetOutput(context, node, kBwOutputTensor); + const bool has_previous_bw_output = (aux_input != nullptr); + const bool use_aux_input = (fw_aux_input_weights != nullptr); + + // We want to cover the following cases: + // + // If not stacking (not connected after other bidi lstms): + // both fw & bw will just use `input`; aux_input will be null. + // + // If stacking with cross_links, TensorFlow equivalent + // (tf.contrib.rnn.stack_bidirectional_rnn): + // both fw & bw will use `input`, but aux_input will be none null. + // Note, this time, whether connected after other bidi lstms both works. + // + // If stacking without cross_links, but connected after other bidi lstms, + // TensorFlow equivalent (tf.nn.static_bidirectional_rnn): + // fw will use `input`, bw will use aux_input, and the `real aux_input` + // will be null. + + const bool non_stacking_mode = !use_aux_input && has_previous_bw_output; + const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input; + const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input; + switch (fw_input_weights->type) { case kTfLiteFloat32: - return EvalFloat(input, fw_input_weights, fw_recurrent_weights, fw_bias, - bw_input_weights, bw_recurrent_weights, bw_bias, - aux_input, fw_aux_input_weights, bw_aux_input_weights, - params, fw_hidden_state, fw_output, bw_hidden_state, - bw_output); + return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights, + fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias, + real_aux_input, fw_aux_input_weights, + bw_aux_input_weights, params, fw_hidden_state, fw_output, + bw_hidden_state, bw_output); case kTfLiteUInt8: case kTfLiteInt8: { TfLiteTensor* input_quantized = @@ -634,17 +662,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); TfLiteTensor* aux_input_quantized = - (aux_input != nullptr) - ? GetTemporary(context, node, kAuxInputQuantized) - : nullptr; - - return EvalHybrid(input, fw_input_weights, fw_recurrent_weights, fw_bias, - bw_input_weights, bw_recurrent_weights, bw_bias, - aux_input, fw_aux_input_weights, bw_aux_input_weights, - params, scaling_factors, input_quantized, - aux_input_quantized, fw_hidden_state_quantized, - fw_hidden_state, fw_output, bw_hidden_state_quantized, - bw_hidden_state, bw_output); + use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) + : nullptr; + + return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights, + fw_bias, bw_input_weights, bw_recurrent_weights, + bw_bias, real_aux_input, fw_aux_input_weights, + bw_aux_input_weights, params, scaling_factors, + input_quantized, aux_input_quantized, + fw_hidden_state_quantized, fw_hidden_state, fw_output, + bw_hidden_state_quantized, bw_hidden_state, bw_output); } default: context->ReportError(context, "Type not currently supported."); diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc index 5bad8e02c29608fa058d0d1104acbf09626f1b66..9b61f8238b558042e7a957d09dac162d8ea6450b 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -654,8 +654,8 @@ const std::initializer_list recurrent_weights = { class BidirectionalRNNOpModel : public SingleOpModel { public: BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, - int bw_units, int input_size, bool time_major, - bool merge_outputs) + int bw_units, int input_size, bool use_aux_input, + bool time_major, bool merge_outputs) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), @@ -671,7 +671,13 @@ class BidirectionalRNNOpModel : public SingleOpModel { bw_bias_ = AddInput(TensorType_FLOAT32); bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); - aux_input_ = AddNullInput(); + int aux_input_size = 0; + if (use_aux_input) { + aux_input_ = AddInput(TensorType_FLOAT32); + aux_input_size = input_size_; + } else { + aux_input_ = AddNullInput(); + } aux_fw_weights_ = AddNullInput(); aux_bw_weights_ = AddNullInput(); @@ -691,18 +697,18 @@ class BidirectionalRNNOpModel : public SingleOpModel { : std::vector({batches_, sequence_len_, input_size_}); BuildInterpreter({ - input_shape, // input - {fw_units_, input_size_}, // fw_weights - {fw_units_, fw_units_}, // fw_recurrent_weights - {fw_units_}, // fw_bias - {batches_, fw_units_}, // fw_hidden_state - {bw_units_, input_size_}, // bw_weights - {bw_units_, bw_units_}, // bw_recurrent_weights - {bw_units_}, // bw_bias - {batches_, bw_units_}, // bw_hidden_state - {batches_, sequence_len_, 0}, // aux_input - {fw_units_, 0}, // aux_fw_weights - {bw_units_, 0}, // aux_bw_weights + input_shape, // input + {fw_units_, input_size_}, // fw_weights + {fw_units_, fw_units_}, // fw_recurrent_weights + {fw_units_}, // fw_bias + {batches_, fw_units_}, // fw_hidden_state + {bw_units_, input_size_}, // bw_weights + {bw_units_, bw_units_}, // bw_recurrent_weights + {bw_units_}, // bw_bias + {batches_, bw_units_}, // bw_hidden_state + {batches_, sequence_len_, aux_input_size}, // aux_input + {fw_units_, 0}, // aux_fw_weights + {bw_units_, 0}, // aux_bw_weights }); } @@ -738,6 +744,10 @@ class BidirectionalRNNOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } + void SetAuxInput(int offset, float* begin, float* end) { + PopulateTensor(aux_input_, offset, begin, end); + } + std::vector GetFwOutput() { return ExtractVector(fw_output_); } std::vector GetBwOutput() { return ExtractVector(bw_output_); } @@ -775,7 +785,8 @@ class BidirectionalRNNOpModel : public SingleOpModel { TEST(BidirectionalRNNOpTest, BlackBoxTest) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*time_major=*/false, + /*input_size=*/8, /*use_aux_input=*/false, + /*time_major=*/false, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); @@ -813,7 +824,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*time_major=*/true, + /*input_size=*/8, /*use_aux_input=*/false, + /*time_major=*/true, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); @@ -822,7 +834,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - // const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); // Insert the inputs in time_major format. The batch_major format is: // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as: // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15]. @@ -850,7 +861,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*time_major=*/false, + /*input_size=*/8, /*use_aux_input=*/false, + /*time_major=*/false, /*merge_outputs=*/true); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); @@ -888,7 +900,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajorMergeOutputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*time_major=*/true, + /*input_size=*/8, /*use_aux_input=*/false, + /*time_major=*/true, /*merge_outputs=*/true); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); @@ -932,7 +945,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajorMergeOutputs) { TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*time_major=*/false, + /*input_size=*/8, /*use_aux_input=*/false, + /*time_major=*/false, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); @@ -979,7 +993,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { TEST(BidirectionalRNNOpTest, EndToEndTest) { BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*time_major=*/false, + /*input_size=*/8, /*use_aux_input=*/false, + /*time_major=*/false, /*merge_outputs=*/false); const int output_size = 4; float dnn_weights[] = { @@ -1046,6 +1061,137 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) { } } +// Same as BlackBox test, but has aux input. +TEST(BidirectionalRNNOpTest, BlackBoxTestAuxInput) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*use_aux_input=*/true, + /*time_major=*/true, + /*merge_outputs=*/false); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + // Insert the inputs in time_major format. The batch_major format is: + // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as: + // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + // Also make aux input the same as input. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector fw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units(); + float* golden_fw_end = golden_fw_start + rnn.num_fw_units(); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + } + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); +} + +// Same as previous test, but has aux input is all zeros. +TEST(BidirectionalRNNOpTest, BlackBoxTestAuxInputZeros) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*use_aux_input=*/true, + /*time_major=*/true, + /*merge_outputs=*/false); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + // Initialize bw inputs with zeros. + std::vector bw_inputs(rnn.sequence_len(), 0); + + // Insert the inputs in time_major format. The batch_major format is: + // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as: + // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + // Also make aux input the same as input. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetAuxInput(2 * i * rnn.input_size(), &bw_inputs[0], + &bw_inputs[bw_inputs.size() - 1]); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), &bw_inputs[0], + &bw_inputs[bw_inputs.size() - 1]); + } + + rnn.Invoke(); + + std::vector fw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units(); + float* golden_fw_end = golden_fw_start + rnn.num_fw_units(); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + } + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); +} + +// Same as previous test, but has input is all zeros, and aux input is the real +// input. This is testing the bw path is functional. +TEST(BidirectionalRNNOpTest, BlackBoxTestAuxInputInputZeros) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*use_aux_input=*/true, + /*time_major=*/true, + /*merge_outputs=*/false); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + // Initialize bw inputs with zeros. + std::vector fw_inputs(rnn.sequence_len(), 0); + + // Insert the inputs in time_major format. The batch_major format is: + // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as: + // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + // Also make aux input the same as input. + rnn.SetAuxInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput(2 * i * rnn.input_size(), &fw_inputs[0], + &fw_inputs[fw_inputs.size() - 1]); + rnn.SetAuxInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), &fw_inputs[0], + &fw_inputs[fw_inputs.size() - 1]); + } + + rnn.Invoke(); + + std::vector bw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_bw_start = rnn_golden_bw_output + i * rnn.num_fw_units(); + float* golden_bw_end = golden_bw_start + rnn.num_fw_units(); + bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); + bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); + } + EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/relu1.cc b/tensorflow/lite/kernels/ceil.cc similarity index 55% rename from tensorflow/lite/kernels/relu1.cc rename to tensorflow/lite/kernels/ceil.cc index 5a55631405b6b32a602cfe21ba863d0dc92213ea..6bb763255b136f1d5103dd2e72ce6aebf38f06d3 100644 --- a/tensorflow/lite/kernels/relu1.cc +++ b/tensorflow/lite/kernels/ceil.cc @@ -12,48 +12,48 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/context.h" + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { -namespace custom { -namespace relu1 { +namespace builtin { +namespace ceil { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); - TfLiteTensor* output = GetOutput(context, node, 0); output->type = input->type; - return context->ResizeTensor(context, output, - TfLiteIntArrayCopy(input->dims)); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); + return context->ResizeTensor(context, output, output_size); } -// This is derived from lite/kernels/activations.cc. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); - const int elements = NumElements(input); - const float* in = input->data.f; - const float* in_end = in + elements; - float* out = output->data.f; - for (; in < in_end; ++in, ++out) { - *out = std::min(std::max(0.f, *in), 1.f); - } + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + optimized_ops::Ceil(GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; } +} // namespace ceil -} // namespace relu1 - -TfLiteRegistration* Register_RELU_1() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - relu1::Prepare, relu1::Eval}; +TfLiteRegistration* Register_CEIL() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, ceil::Prepare, ceil::Eval}; return &r; } -} // namespace custom +} // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/ceil_test.cc b/tensorflow/lite/kernels/ceil_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e120105082751a732bb8812944c318ad9e5ecff5 --- /dev/null +++ b/tensorflow/lite/kernels/ceil_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class CeilOpModel : public SingleOpModel { + public: + CeilOpModel(std::initializer_list input_shape, TensorType input_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_CEIL, BuiltinOptions_NONE, 0); + BuildInterpreter({ + input_shape, + }); + } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(CeilOpTest, SingleDim) { + CeilOpModel model({2}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), {8.5, 0.0}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({9, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(CeilOpTest, MultiDims) { + CeilOpModel model({2, 1, 1, 5}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), { + 0.0001, + 8.0001, + 0.9999, + 9.9999, + 0.5, + -0.0001, + -8.0001, + -0.9999, + -9.9999, + -0.5, + }); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 9, 1, 10, 1, 0, -8, 0, -9, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/comparisons.cc b/tensorflow/lite/kernels/comparisons.cc index a914449ae552e37249f2cecb5c88f3b49e83f133..e49348a5462d7efcaffad888b8714cadbfb64a10 100644 --- a/tensorflow/lite/kernels/comparisons.cc +++ b/tensorflow/lite/kernels/comparisons.cc @@ -59,11 +59,12 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { // TODO(ruic): optimize macros below to using template functions. #define TF_LITE_QUANTIZE_COMPARISON(opname) \ + template \ void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \ const TfLiteTensor* input1, \ const TfLiteTensor* input2, TfLiteTensor* output, \ bool requires_broadcast) { \ - if (input1->type == kTfLiteUInt8) { \ + if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) { \ auto input1_offset = -input1->params.zero_point; \ auto input2_offset = -input2->params.zero_point; \ const int left_shift = 8; \ @@ -87,14 +88,16 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { op_params.input2_shift = input2_shift; \ if (requires_broadcast) { \ reference_ops::Broadcast4DSlow##opname##WithScaling( \ - op_params, GetTensorShape(input1), GetTensorData(input1), \ - GetTensorShape(input2), GetTensorData(input2), \ - GetTensorShape(output), GetTensorData(output)); \ + op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)); \ } else { \ reference_ops::opname##WithScaling( \ - op_params, GetTensorShape(input1), GetTensorData(input1), \ - GetTensorShape(input2), GetTensorData(input2), \ - GetTensorShape(output), GetTensorData(output)); \ + op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)); \ } \ } \ } @@ -126,6 +129,9 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); switch (input1->type) { + case kTfLiteBool: + TF_LITE_COMPARISON(bool, Equal, requires_broadcast); + break; case kTfLiteFloat32: TF_LITE_COMPARISON(float, Equal, requires_broadcast); break; @@ -136,13 +142,17 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); break; case kTfLiteUInt8: - EvalQuantizedEqual(context, node, input1, input2, output, - requires_broadcast); + EvalQuantizedEqual(context, node, input1, input2, output, + requires_broadcast); + break; + case kTfLiteInt8: + EvalQuantizedEqual(context, node, input1, input2, output, + requires_broadcast); break; default: - context->ReportError(context, - "Does not support type %d, requires float|int|uint8", - input1->type); + context->ReportError( + context, "Does not support type %d, requires bool|float|int|uint8", + input1->type); return kTfLiteError; } return kTfLiteOk; @@ -155,6 +165,9 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); switch (input1->type) { + case kTfLiteBool: + TF_LITE_COMPARISON(bool, NotEqual, requires_broadcast); + break; case kTfLiteFloat32: TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); break; @@ -165,13 +178,17 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); break; case kTfLiteUInt8: - EvalQuantizedNotEqual(context, node, input1, input2, output, - requires_broadcast); + EvalQuantizedNotEqual(context, node, input1, input2, output, + requires_broadcast); + break; + case kTfLiteInt8: + EvalQuantizedNotEqual(context, node, input1, input2, output, + requires_broadcast); break; default: - context->ReportError(context, - "Does not support type %d, requires float|int|uint8", - input1->type); + context->ReportError( + context, "Does not support type %d, requires bool|float|int|uint8", + input1->type); return kTfLiteError; } return kTfLiteOk; @@ -193,8 +210,12 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); break; case kTfLiteUInt8: - EvalQuantizedGreater(context, node, input1, input2, output, - requires_broadcast); + EvalQuantizedGreater(context, node, input1, input2, output, + requires_broadcast); + break; + case kTfLiteInt8: + EvalQuantizedGreater(context, node, input1, input2, output, + requires_broadcast); break; default: context->ReportError(context, @@ -221,8 +242,12 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); break; case kTfLiteUInt8: - EvalQuantizedGreaterEqual(context, node, input1, input2, output, - requires_broadcast); + EvalQuantizedGreaterEqual(context, node, input1, input2, output, + requires_broadcast); + break; + case kTfLiteInt8: + EvalQuantizedGreaterEqual(context, node, input1, input2, output, + requires_broadcast); break; default: context->ReportError(context, @@ -249,8 +274,12 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); break; case kTfLiteUInt8: - EvalQuantizedLess(context, node, input1, input2, output, - requires_broadcast); + EvalQuantizedLess(context, node, input1, input2, output, + requires_broadcast); + break; + case kTfLiteInt8: + EvalQuantizedLess(context, node, input1, input2, output, + requires_broadcast); break; default: context->ReportError(context, @@ -277,8 +306,12 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); break; case kTfLiteUInt8: - EvalQuantizedLessEqual(context, node, input1, input2, output, - requires_broadcast); + EvalQuantizedLessEqual(context, node, input1, input2, output, + requires_broadcast); + break; + case kTfLiteInt8: + EvalQuantizedLessEqual(context, node, input1, input2, output, + requires_broadcast); break; default: context->ReportError(context, diff --git a/tensorflow/lite/kernels/comparisons_test.cc b/tensorflow/lite/kernels/comparisons_test.cc index ab10c959a4d6b234cb6ae0810174e8f1c48898d1..3f950a322059f6e22dd95740606098c32c0bd310 100644 --- a/tensorflow/lite/kernels/comparisons_test.cc +++ b/tensorflow/lite/kernels/comparisons_test.cc @@ -92,6 +92,17 @@ class ComparisonOpModel : public SingleOpModel { } }; +TEST(ComparisonsTest, EqualBool) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_BOOL, + BuiltinOperator_EQUAL); + model.PopulateTensor(model.input1(), {true, false, true, false}); + model.PopulateTensor(model.input2(), {true, true, false, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + TEST(ComparisonsTest, EqualFloat) { ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, BuiltinOperator_EQUAL); @@ -137,6 +148,17 @@ TEST(ComparisonsTest, EqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } +TEST(ComparisonsTest, NotEqualBool) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_BOOL, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {true, false, true, false}); + model.PopulateTensor(model.input2(), {true, true, false, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + TEST(ComparisonsTest, NotEqualFloat) { ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, BuiltinOperator_NOT_EQUAL); @@ -363,7 +385,7 @@ TEST(ComparisonsTest, LessEqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -TEST(QuantizedComparisonsTest, EqualQuantized) { +TEST(QuantizedComparisonsTest, EqualUInt8Quantized) { const float kMin = -1.f; const float kMax = 128.f; ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, @@ -376,7 +398,20 @@ TEST(QuantizedComparisonsTest, EqualQuantized) { EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false)); } -TEST(QuantizedComparisonsTest, NotEqualQuantized) { +TEST(QuantizedComparisonsTest, EqualInt8Quantized) { + const float kMin = -127.f; + const float kMax = 127.f; + ComparisonOpModel model({TensorType_INT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_INT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, -9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {-1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); +} + +TEST(QuantizedComparisonsTest, NotEqualUInt8Quantized) { const float kMin = -1.f; const float kMax = 128.f; ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, @@ -389,6 +424,19 @@ TEST(QuantizedComparisonsTest, NotEqualQuantized) { EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); } +TEST(QuantizedComparisonsTest, NotEqualInt8Quantized) { + const float kMin = -127.f; + const float kMax = 127.f; + ComparisonOpModel model({TensorType_INT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_INT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, -9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); +} + TEST(ComparisonsTest, GreaterQuantized) { const float kMin = -1.f; const float kMax = 128.f; @@ -470,7 +518,7 @@ TEST(ComparisonsTest, QuantizedEqualWithBroadcast) { } } -TEST(ComparisonsTest, QuantizedNotEqualWithBroadcast) { +TEST(ComparisonsTest, QuantizedUInt8NotEqualWithBroadcast) { const float kMin = -1.f; const float kMax = 128.f; std::vector> test_shapes = { @@ -488,7 +536,25 @@ TEST(ComparisonsTest, QuantizedNotEqualWithBroadcast) { } } -TEST(ComparisonsTest, QuantizedGreaterWithBroadcast) { +TEST(ComparisonsTest, QuantizedInt8NotEqualWithBroadcast) { + const float kMin = -127.f; + const float kMax = 127.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_INT8, test_shapes[i], kMin, kMax}, + {TensorType_INT8, {}, kMin, kMax}, TensorType_INT8, + BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {-20, 2, 7, -8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, true, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedUInt8GreaterWithBroadcast) { const float kMin = -1.f; const float kMax = 128.f; std::vector> test_shapes = { @@ -506,7 +572,25 @@ TEST(ComparisonsTest, QuantizedGreaterWithBroadcast) { } } -TEST(ComparisonsTest, QuantizedGreaterEqualWithBroadcast) { +TEST(ComparisonsTest, QuantizedInt8GreaterWithBroadcast) { + const float kMin = -127.f; + const float kMax = 127.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_INT8, test_shapes[i], kMin, kMax}, + {TensorType_INT8, {}, kMin, kMax}, TensorType_INT8, + BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {20, -2, -71, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, false, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedUInt8GreaterEqualWithBroadcast) { const float kMin = -1.f; const float kMax = 128.f; std::vector> test_shapes = { @@ -524,7 +608,25 @@ TEST(ComparisonsTest, QuantizedGreaterEqualWithBroadcast) { } } -TEST(ComparisonsTest, QuantizedLessWithBroadcast) { +TEST(ComparisonsTest, QuantizedInt8GreaterEqualWithBroadcast) { + const float kMin = -127.f; + const float kMax = 127.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_INT8, test_shapes[i], kMin, kMax}, + {TensorType_INT8, {}, kMin, kMax}, TensorType_INT8, + BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, -2, -71, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedUInt8LessWithBroadcast) { const float kMin = -1.f; const float kMax = 128.f; std::vector> test_shapes = { @@ -542,7 +644,25 @@ TEST(ComparisonsTest, QuantizedLessWithBroadcast) { } } -TEST(ComparisonsTest, QuantizedLessEqualWithBroadcast) { +TEST(ComparisonsTest, QuantizedInt8LessWithBroadcast) { + const float kMin = -127.f; + const float kMax = 127.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_INT8, test_shapes[i], kMin, kMax}, + {TensorType_INT8, {}, kMin, kMax}, TensorType_INT8, + BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {20, -2, -71, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedUInt8LessEqualWithBroadcast) { const float kMin = -1.f; const float kMax = 128.f; std::vector> test_shapes = { @@ -560,6 +680,24 @@ TEST(ComparisonsTest, QuantizedLessEqualWithBroadcast) { } } +TEST(ComparisonsTest, QuantizedInt8LessEqualWithBroadcast) { + const float kMin = -127.f; + const float kMax = 127.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_INT8, test_shapes[i], kMin, kMax}, + {TensorType_INT8, {}, kMin, kMax}, TensorType_INT8, + BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, -2, -71, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, true, false, false)) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index a8dd160c8dbb42ba2c2363af55b30eb0b79f86af..76d906fa6deb0f7c1a3d67301f83ab4e02929dab 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -58,8 +58,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || - input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || - input_type == kTfLiteInt64); + input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || + input_type == kTfLiteInt32 || input_type == kTfLiteInt64); // Output dimensions will match input dimensions, except 'axis', which // will be the sum of inputs @@ -85,6 +85,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TF_LITE_ENSURE_EQ(context, output->type, input_type); + if (input_type == kTfLiteInt8) { + // Make sure there is no re-scaling needed for Int8 quantized kernel. This + // is a restriction we introduced to Int8 kernels. + VectorOfTensors all_inputs(*context, *node->inputs); + for (int i = 0; i < node->inputs->size; ++i) { + TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; + TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale); + TF_LITE_ENSURE_EQ(context, t->params.zero_point, + output->params.zero_point); + } + } + return context->ResizeTensor(context, output, output_size); } @@ -148,6 +160,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); } break; + case kTfLiteInt8: { + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int8_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, int8_t); + } + } break; case kTfLiteInt64: if (kernel_type == kReference) { TF_LITE_CONCATENATION(reference_ops, int64_t); diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index 422380a03eaf9073958d4984eb2234890d555780..dab77d612dc5ab328849892fbee1115bc5324f44 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -78,13 +78,18 @@ class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { .Union()); BuildInterpreter(all_input_shapes); } + template void SetInput(int index, std::initializer_list data) { - QuantizeAndPopulate(index, data); + QuantizeAndPopulate(index, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -165,25 +170,47 @@ TEST(ConcatenationOpTest, FourInputs) { })); } -TEST(ConcatenationOpTest, FourInputsQuantized) { +TEST(ConcatenationOpTest, FourInputsQuantizedUint8) { QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, /*axis=*/2, /*num_inputs=*/4); - m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); - m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); - m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); - m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); m0.Invoke(); - EXPECT_THAT(m0.GetDequantizedOutput(), + EXPECT_THAT(m0.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // }))); - EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ - 137, 157, 138, 158, 139, 159, 140, 160, // - 167, 197, 168, 198, 169, 199, 170, 200, // - })); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +TEST(ConcatenationOpTest, FourInputsQuantizedInt8) { + QuantizedConcatenationOpModel m0({TensorType_INT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1, 3, 1.1, 3.1, 1.2, 3.2, 1.3, 3.3, // + 4, 7, 4.1, 7.1, 4.2, 7.2, 4.3, 7.3 // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 9, 29, 10, 30, 11, 31, 12, 32, // + 39, 69, 40, 70, 41, 71, 42, 72, // + })); } TEST(ConcatenationOpTest, FourInputsQuantizedMixedRange) { @@ -194,20 +221,21 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRange) { /*axis=*/2, /*num_inputs=*/4, {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}); - m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); - m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); - m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); - m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); m0.Invoke(); - EXPECT_THAT(m0.GetDequantizedOutput(), + EXPECT_THAT(m0.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // }))); - EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ - 137, 157, 138, 158, 139, 159, 140, 160, // - 167, 197, 168, 198, 169, 199, 170, 200, // - })); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); } TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) { @@ -218,22 +246,23 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) { /*axis=*/2, /*num_inputs=*/4, {TensorType_UINT8, {2, 1, 2}, -1., 1.}); - m0.SetInput(0, {1.0f, -3.0f, -4.0f, -7.0f}); - m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); - m0.SetInput(2, {1.2f, -3.2f, -4.2f, 7.2f}); - m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.SetInput(0, {1.0f, -3.0f, -4.0f, -7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, -3.2f, -4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); m0.Invoke(); - EXPECT_THAT(m0.GetDequantizedOutput(), + EXPECT_THAT(m0.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, // -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, // }, 4e-3))); - EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ - 255, 0, 255, 255, 255, 0, 255, 255, // - 0, 0, 255, 255, 0, 255, 255, 255, // - })); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 255, 0, 255, 255, 255, 0, 255, 255, // + 0, 0, 255, 255, 0, 255, 255, 255, // + })); } } // namespace diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index 05368aa5edda8073d3ee5dbf0352468c7f4f6133..75e75fae6b09e6f3653719e3cf69c8cc1b4956f6 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -68,6 +69,11 @@ struct OpData { // be represented as a fixed point multiplier plus a left shift. int32_t output_multiplier; int output_shift; + + // Per channel output multiplier and shift. + std::vector per_channel_output_multiplier; + std::vector per_channel_output_shift; + // The range of the fused activation layer. For example for kNone and // uint8_t these would be 0 and 255. int32_t output_activation_min; @@ -226,8 +232,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check types. (We assume that UINT8 refers to quantized tensors) TfLiteType input_type = input->type; - TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || + input_type == kTfLiteUInt8 || + input_type == kTfLiteInt8); TF_LITE_ENSURE_EQ(context, output->type, input_type); TfLiteTensor* bias = nullptr; @@ -238,7 +245,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (has_bias) { bias = &context->tensors[node->inputs->data[2]]; - if (input_type == kTfLiteUInt8) { + if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); } else { @@ -296,18 +303,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, has_bias); // Note that full fixed-point inference requires that all tensors have their - // parameters set. This is usually done during quantized training. + // parameters set. This is usually done during quantized training or + // calibration. if (input_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + const int number_channel = affine_quantization->scale->size; + data->per_channel_output_multiplier.resize(number_channel); + data->per_channel_output_shift.resize(number_channel); + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data())); } TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); @@ -481,6 +495,29 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, } } +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + ConvParams op_params; + op_params.input_offset = input->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + + reference_integer_ops::ConvPerChannel( + op_params, data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + template void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, OpData* data, TfLiteTensor* input, @@ -665,6 +702,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { EvalQuantized(context, node, params, data, input, filter, bias, im2col, hwcn_weights, output); break; + case kTfLiteInt8: + EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, + output); + break; default: context->ReportError(context, "Type %d not currently supported.", input->type); diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index d0350b2fa7f7bad804d4b1348f4d389cb102f68e..7c562a530abf53a174e7f55053d01e6ae2fa212f 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -58,9 +58,35 @@ class BaseConvolutionOpModel : public SingleOpModel { // This is a quantized version. The scale of 'bias' depends on the scales // of input and filter. Supposedly this is correctly set during quantized // training. - auto bias_scale = GetScale(input_) * GetScale(filter_); - TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; - bias_ = AddInput(bias); + if (filter.per_channel_quantization) { + // per channel quantization. + std::vector bias_scale( + filter.per_channel_quantization_scales.size()); + std::vector bias_zero_points( + filter.per_channel_quantization_scales.size()); + for (int i = 0; i < filter.per_channel_quantization_scales.size(); + ++i) { + bias_scale[i] = + input.scale * filter.per_channel_quantization_scales[i]; + bias_zero_points[i] = 0; + } + TensorData bias{TensorType_INT32, + {bias_size}, + /*min=*/0, + /*max=*/0, + /*scale=*/0, + /*zero_point=*/0, + true, + /*per_channel_scale=*/bias_scale, + /*per_channel_zero_point=*/bias_zero_points, + /*channel_index==*/0}; + bias_ = AddInput(bias); + } else { + // per tensor quantization. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } } output_ = AddOutput(output); @@ -338,6 +364,18 @@ TEST_P(ConvolutionOpTest, HandCalculatedFloat32) { // | 187 | 234 | 261 | 121 | EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121})); + + // Add an additional test for the multi-threaded case, ensuring stability + // under different thread counts. + if (GetParam() == "MultithreadedOptimized") { + for (int i = 1; i < 4; ++i) { + m.SetNumThreads(i); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({105, 150, 183, 95, 235, 312, 357, 178, 187, + 234, 261, 121})); + } + } } TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { @@ -1069,6 +1107,76 @@ TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) { 0.0474))); } +class PerChannelQuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + PerChannelSymmetricQuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + PerChannelQuantizeBias(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } +}; + +TEST_P(ConvolutionOpTest, SimpleTest) { + PerChannelQuantizedConvolutionOpModel m( + GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + {2, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel=*/true, + /*per_channel_scales=*/{1, 2}, + /*per_channel_zeros=*/{0, 0}, + /*channel_index=*/0}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, + /*stride_width=*/1, /*stride_height=*/1); + m.SetInput({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + m.SetFilter( + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + { + 1, 2, // out channel = 0, y = 0, x = 0 + 3, 4, // out channel = 0, y = 0, x = 1 + 3, 4, // out channel = 0, y = 1, x = 0 + 5, 6, // out channel = 0, y = 1, x = 1 + 7, 8, // out channel = 1, y = 0, x = 0 + 5, 6, // out channel = 1, y = 0, x = 1 + 3, 4, // out channel = 1, y = 1, x = 0 + 1, 2, // out channel = 1, y = 1, x = 1 + }); + m.SetBias({3, -2}); + + // Invoke and verify output. + // output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel] + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({28.5, 64, -59.5, -46}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({56, 127, -120, -93})); +} + INSTANTIATE_TEST_SUITE_P( ConvolutionOpTest, ConvolutionOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); diff --git a/tensorflow/lite/kernels/depthwise_conv.cc b/tensorflow/lite/kernels/depthwise_conv.cc index 3f4ae5087b267a62d4d4237a8f5f534ff346a493..a349b2790531a674be1faa40d928677a9144e265 100644 --- a/tensorflow/lite/kernels/depthwise_conv.cc +++ b/tensorflow/lite/kernels/depthwise_conv.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -58,6 +59,10 @@ struct OpData { // uint8_t these would be 0 and 255. int32_t output_activation_min; int32_t output_activation_max; + + // Per channel output multiplier and shift. + std::vector per_channel_output_multiplier; + std::vector per_channel_output_shift; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -99,14 +104,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { SizeOfDimension(filter, 3)); const TfLiteType data_type = input->type; - TF_LITE_ENSURE(context, - data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE(context, data_type == kTfLiteFloat32 || + data_type == kTfLiteUInt8 || + data_type == kTfLiteInt8); TF_LITE_ENSURE_EQ(context, output->type, data_type); TF_LITE_ENSURE_EQ(context, filter->type, data_type); if (hasBias) { bias = GetInput(context, node, kBiasTensor); - if (data_type == kTfLiteUInt8) { + if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); } else { @@ -150,17 +156,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { filter_width, out_width); // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. + // parameters set. This is usually done during quantized training or + // calibration. if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + const int number_channel = affine_quantization->scale->size; + data->per_channel_output_multiplier.resize(number_channel); + data->per_channel_output_shift.resize(number_channel); + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data())); } TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); @@ -250,6 +264,33 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorData(output)); } +void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + DepthwiseParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.input_offset = input->params.zero_point; + op_params.weights_offset = 0; + op_params.output_offset = output->params.zero_point; + + reference_integer_ops::DepthwiseConvPerChannel( + op_params, data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = @@ -273,6 +314,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { EvalQuantized(context, node, params, data, input, filter, bias, output); break; + case kTfLiteInt8: { + EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, + output); + break; + } default: context->ReportError(context, "Type %d not currently supported.", input->type); diff --git a/tensorflow/lite/kernels/depthwise_conv_test.cc b/tensorflow/lite/kernels/depthwise_conv_test.cc index 75aed4cc4a96e76f35499d3c26cf0fc25f463160..5dc513262b2d71aaab458a815aacba473d788859 100644 --- a/tensorflow/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/kernels/depthwise_conv_test.cc @@ -56,9 +56,35 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel { // This is a quantized version. The scale of 'bias' depends on the scales // of input and filter. Supposedly this is correctly set during quantized // training. - auto bias_scale = GetScale(input_) * GetScale(filter_); - TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; - bias_ = AddInput(bias); + if (filter.per_channel_quantization) { + // per channel quantization. + std::vector bias_scale( + filter.per_channel_quantization_scales.size()); + std::vector bias_zero_points( + filter.per_channel_quantization_scales.size()); + for (int i = 0; i < filter.per_channel_quantization_scales.size(); + ++i) { + bias_scale[i] = + input.scale * filter.per_channel_quantization_scales[i]; + bias_zero_points[i] = 0; + } + TensorData bias{TensorType_INT32, + {bias_size}, + /*min=*/0, + /*max=*/0, + /*scale=*/0, + /*zero_point=*/0, + true, + /*per_channel_scale=*/bias_scale, + /*per_channel_zero_point=*/bias_zero_points, + /*channel_index==*/0}; + bias_ = AddInput(bias); + } else { + // per tensor quantization. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } } output_ = AddOutput(output); @@ -437,6 +463,76 @@ TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) { ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1})); } +class PerChannelQuantizedDepthwiseConvolutionOpModel + : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + PerChannelSymmetricQuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + PerChannelQuantizeBias(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } +}; + +TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTest) { + PerChannelQuantizedDepthwiseConvolutionOpModel m( + GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, + {TensorType_INT8, + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + {1, 2, 2, 4}, + 0, + 0, + 0, + 0, + /*per_channel=*/true, + /*per_channel_scales=*/{1, 2, 3, 4}, + /*per_channel_zeros=*/{0, 0, 0, 0}, + /*channel_index=*/3}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID); + m.SetInput({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + m.SetFilter( + /*filter data*/ + { + // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel] + // depth multiplier = 2 + 1, 2, 3, 4, // y = 0, x = 0 + 3, 4, 5, 6, // y = 0, x = 1 + 7, 8, 5, 6, // y = 1, x = 0 + 3, 4, 1, 2, // y = 1, x = 1 + }); + m.SetBias({3, -2, 4, 6}); + + // Invoke and verify output. + // output has dimension [1 * 1 * 2 * 4] as [batch, y, x, output_channel] + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({40.5, 48, 27, 40, 0.5, -4, -24, -36}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({80, 95, 53, 79, 0, -9, -49, -73})); +} + INSTANTIATE_TEST_SUITE_P( DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); diff --git a/tensorflow/lite/kernels/dequantize_test.cc b/tensorflow/lite/kernels/dequantize_test.cc index be7caa31892a9dbb41eef2f88479c9f0051e2339..77254335fbde0ff4246af00291ccfba9ec8b0acf 100644 --- a/tensorflow/lite/kernels/dequantize_test.cc +++ b/tensorflow/lite/kernels/dequantize_test.cc @@ -30,13 +30,7 @@ class DequantizeOpModel : public SingleOpModel { public: DequantizeOpModel(TensorType type, std::initializer_list shape, float scale, int32_t zero_point) { - TensorData input_tensor_data; - input_tensor_data.type = type; - input_tensor_data.shape = shape; - input_tensor_data.min = 0; - input_tensor_data.max = 0; - input_tensor_data.scale = scale; - input_tensor_data.zero_point = zero_point; + const TensorData input_tensor_data = {type, shape, 0, 0, scale, zero_point}; input_ = AddInput(input_tensor_data); output_ = AddOutput({TensorType_FLOAT32, shape}); SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, diff --git a/tensorflow/lite/kernels/eigen_support.cc b/tensorflow/lite/kernels/eigen_support.cc index e2a2c4aac9456dfae2e26d75d903c300e382b1d0..6543cc7162d66eca2cd6da6a120f4feab19531a3 100644 --- a/tensorflow/lite/kernels/eigen_support.cc +++ b/tensorflow/lite/kernels/eigen_support.cc @@ -24,8 +24,12 @@ namespace tflite { namespace eigen_support { namespace { +// For legacy reasons, we use 4 threads by default unless the thread count is +// explicitly specified by the context. +const int kDefaultNumThreadpoolThreads = 4; + #ifndef EIGEN_DONT_ALIGN -// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on +// Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on // hardware architecture and build configurations. // If the static assertion fails, try to increase `kDefaultTensorAlignment` to // in `arena_planner.h` to 32 or 64. @@ -63,9 +67,45 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { std::unique_ptr pool_; }; +// Utility class for lazily creating an Eigen thread pool/device only when used. +class LazyEigenThreadPoolHolder { + public: + explicit LazyEigenThreadPoolHolder(int num_threads) { + SetNumThreads(num_threads); + } + + // Gets the ThreadPoolDevice, creating if necessary. + const Eigen::ThreadPoolDevice* GetThreadPoolDevice() { + if (!device_) { + thread_pool_wrapper_.reset(new EigenThreadPoolWrapper( + new Eigen::ThreadPool(target_num_threads_))); + device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(), + target_num_threads_)); + } + return device_.get(); + } + + // Updates the thread count, invalidating the ThreadPoolDevice if necessary. + void SetNumThreads(int num_threads) { + const int target_num_threads = + num_threads != -1 ? num_threads : kDefaultNumThreadpoolThreads; + if (target_num_threads_ != target_num_threads) { + target_num_threads_ = target_num_threads; + // As the device references the thread pool wrapper, destroy it first. + device_.reset(); + thread_pool_wrapper_.reset(); + } + } + + private: + int target_num_threads_ = kDefaultNumThreadpoolThreads; + // Both device_ and thread_pool_wrapper_ are lazily created. + std::unique_ptr device_; + std::unique_ptr thread_pool_wrapper_; +}; + struct RefCountedEigenContext : public TfLiteExternalContext { - std::unique_ptr thread_pool_wrapper; - std::unique_ptr device; + std::unique_ptr thread_pool_holder; int num_references = 0; }; @@ -74,24 +114,12 @@ RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { context->GetExternalContext(context, kTfLiteEigenContext)); } -void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) { - int num_threads = 4; - if (context->recommended_num_threads != -1) { - num_threads = context->recommended_num_threads; - } - ptr->device.reset(); // destroy before we invalidate the thread pool - ptr->thread_pool_wrapper.reset( - new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads))); - ptr->device.reset( - new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads)); -} - TfLiteStatus Refresh(TfLiteContext* context) { SetEigenNbThreads(context->recommended_num_threads); auto* ptr = GetEigenContext(context); if (ptr != nullptr) { - InitDevice(context, ptr); + ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads); } return kTfLiteOk; @@ -108,8 +136,9 @@ void IncrementUsageCounter(TfLiteContext* context) { ptr = new RefCountedEigenContext; ptr->type = kTfLiteEigenContext; ptr->Refresh = Refresh; + ptr->thread_pool_holder.reset( + new LazyEigenThreadPoolHolder(context->recommended_num_threads)); ptr->num_references = 0; - InitDevice(context, ptr); context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; @@ -134,7 +163,7 @@ const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { TF_LITE_FATAL( "Call to GetFromContext() not preceded by IncrementUsageCounter()"); } - return ptr->device.get(); + return ptr->thread_pool_holder->GetThreadPoolDevice(); } } // namespace eigen_support diff --git a/tensorflow/lite/kernels/eigen_support.h b/tensorflow/lite/kernels/eigen_support.h index c24ae6896a7e9783ddd32bc510881ccc1a5d27bf..7e052e1f83cc9ddabac84d318d1639f478b6fb01 100644 --- a/tensorflow/lite/kernels/eigen_support.h +++ b/tensorflow/lite/kernels/eigen_support.h @@ -32,6 +32,11 @@ void IncrementUsageCounter(TfLiteContext* context); // usages all temporary Eigen objects will be deleted. void DecrementUsageCounter(TfLiteContext* context); +// Fetch the ThreadPoolDevice associated with the provided context. +// +// Note: The caller must ensure that |IncrementUsageCounter()| has already been +// called. Moreover, it is *not* safe to cache the returned device; it may be +// invalidated if the context thread count changes. const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice( TfLiteContext* context); diff --git a/tensorflow/lite/kernels/elementwise.cc b/tensorflow/lite/kernels/elementwise.cc index a79388b900eb89b56a4d18f887dbe52e84fb123f..1cc188ae5f7bfe91bee48c60b692d9dca2b7cf0e 100644 --- a/tensorflow/lite/kernels/elementwise.cc +++ b/tensorflow/lite/kernels/elementwise.cc @@ -83,6 +83,10 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { return EvalNumeric(context, node, std::sin); } +TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { + return EvalNumeric(context, node, std::cos); +} + TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { return EvalNumeric(context, node, std::log); } @@ -122,6 +126,14 @@ TfLiteRegistration* Register_SIN() { return &r; } +TfLiteRegistration* Register_COS() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::CosEval}; + return &r; +} + TfLiteRegistration* Register_LOG() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, diff --git a/tensorflow/lite/kernels/elementwise_test.cc b/tensorflow/lite/kernels/elementwise_test.cc index 7d24320081257925508b2aa53503c1cf71d0e913..89f2a506f0cc00df021d8b5113174833df7e33cb 100644 --- a/tensorflow/lite/kernels/elementwise_test.cc +++ b/tensorflow/lite/kernels/elementwise_test.cc @@ -65,6 +65,15 @@ TEST(ElementWise, Sin) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Cos) { + ElementWiseOpFloatModel m(BuiltinOperator_COS, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, -1, -1, 0.54030}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + TEST(ElementWise, Log) { ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index dfc9550ed600ac08407c4f07b6ad7d0be26bfe10..55cde983abccdd4c20c9c1cd24d44883d5b883cd 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -40,9 +41,8 @@ namespace fully_connected { // This file has four implementations of FullyConnected enum KernelType { kReference, - kGenericOptimized, // Neon-free - kNeonOptimized, - kPie, // Used by the PIE team + kGenericOptimized, + kLegacyPie, // Legacy path used by the PIE team and related clients. }; struct OpData { @@ -212,7 +212,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8); - TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); + if (bias) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); + } TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); int total_input_size = 1; @@ -286,6 +288,27 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, macro_name(target_namespace, kRelu6); \ } +namespace { +void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output, + gemmlowp::GemmContext* gemm_context) { + FullyConnectedParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = -filter->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + reference_integer_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), gemm_context); +} +} // namespace + template TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, @@ -314,11 +337,20 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorShape(output), GetTensorData(output), \ gemm_context); \ } - if (kernel_type == kReference) { + // Only the Pie path supports quantized models and float inputs/outputs. + if (input->type == kTfLiteFloat32) { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); + return EvalHybrid(context, node, params, data, input, filter, bias, + input_quantized, scaling_factors, output); + } else if (kernel_type == kReference) { switch (output->type) { case kTfLiteUInt8: TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); break; + case kTfLiteInt8: + FullyConnectedInt8(data, input, filter, bias, output, gemm_context); + break; case kTfLiteInt16: TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); break; @@ -328,17 +360,14 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, "Quantized FullyConnected expects output data type uint8 or int16"); return kTfLiteError; } - } else if (kernel_type == kPie && input->type == kTfLiteFloat32) { - // Pie currently only supports quantized models and float inputs/outputs. - TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); - return EvalHybrid(context, node, params, data, input, filter, bias, - input_quantized, scaling_factors, output); } else { switch (output->type) { case kTfLiteUInt8: TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); break; + case kTfLiteInt8: + FullyConnectedInt8(data, input, filter, bias, output, gemm_context); + break; case kTfLiteInt16: TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); break; @@ -418,7 +447,7 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, } if (kernel_type == kReference) { TF_LITE_FULLY_CONNECTED(reference_ops); - } else if (kernel_type == kPie) { + } else if (kernel_type == kLegacyPie) { return EvalPie(context, node, params, data, input, filter, bias, output); } else { TF_LITE_FULLY_CONNECTED(optimized_ops); @@ -488,13 +517,6 @@ TfLiteRegistration* Register_FULLY_CONNECTED_REF() { return &r; } -TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() { - static TfLiteRegistration r = { - fully_connected::Init, fully_connected::Free, fully_connected::Prepare, - fully_connected::Eval}; - return &r; -} - TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() { static TfLiteRegistration r = { fully_connected::Init, fully_connected::Free, fully_connected::Prepare, @@ -502,24 +524,16 @@ TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() { return &r; } +// Legacy path for PIE clients. TfLiteRegistration* Register_FULLY_CONNECTED_PIE() { - static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, - fully_connected::Prepare, - fully_connected::Eval}; + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; return &r; } TfLiteRegistration* Register_FULLY_CONNECTED() { - // TODO(ahentz): We don't have a dedicated quantized version of the PIE - // kernel. For now, the quantized version just defer to the corresponding - // optimized MINI kernel. At some point we will allow different libraries to - // be built with different kernels, but for now we have to pick one here. - return Register_FULLY_CONNECTED_PIE(); -#ifdef USE_NEON - return Register_FULLY_CONNECTED_NEON_OPT(); -#else return Register_FULLY_CONNECTED_GENERIC_OPT(); -#endif } } // namespace builtin diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 03f4ea71430f5d578288d913e8ba1d0222467882..7a69b999f1f2207dad865616f33f07e71da8b54f 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -33,7 +33,6 @@ namespace ops { namespace builtin { TfLiteRegistration* Register_FULLY_CONNECTED_REF(); -TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT(); TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT(); TfLiteRegistration* Register_FULLY_CONNECTED_PIE(); @@ -137,6 +136,7 @@ class BaseFullyConnectedOpModel : public SingleOpModel { BaseFullyConnectedOpModel( TfLiteRegistration* registration, int units, int batches, const TensorData& input, const TensorData& output = {TensorType_FLOAT32}, + bool bias_tensor_optional = false, ActivationFunctionType activation_func = ActivationFunctionType_RELU, FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT) @@ -151,7 +151,9 @@ class BaseFullyConnectedOpModel : public SingleOpModel { weights_ = AddInput({input.type, {units_, input_size_}, input.min, input.max}); - if (input.type == TensorType_FLOAT32) { + if (bias_tensor_optional) { + bias_ = AddNullInput(); + } else if (input.type == TensorType_FLOAT32) { bias_ = AddInput({TensorType_FLOAT32, {units_}}); } else { // This is a quantized version. The scale of 'bias' depends on the scales @@ -173,7 +175,9 @@ class BaseFullyConnectedOpModel : public SingleOpModel { .Union()); resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, registration); - BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + BuildInterpreter( + {GetShape(input_), GetShape(weights_), + (bias_ == kOptionalTensor) ? std::vector() : GetShape(bias_)}); } int input_size() { return input_size_; } @@ -216,9 +220,12 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { void SetBias(const std::vector& data) { QuantizeAndPopulate(bias_, data); } + template void SetWeights(const std::vector& data) { - QuantizeAndPopulate(weights_, data); + QuantizeAndPopulate(weights_, data); } + + template void ShuffleAndSetWeights(const std::vector& data, int input_depth, int output_depth) { std::vector shuffled_data(data.size()); @@ -237,15 +244,17 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { } TfLiteTensor* t = interpreter_->tensor(weights_); auto quantized_data = - Quantize(shuffled_data, t->params.scale, t->params.zero_point); - for (uint8_t& q : quantized_data) { + Quantize(shuffled_data, t->params.scale, t->params.zero_point); + for (T& q : quantized_data) { q ^= 0x80; } PopulateTensor(weights_, 0, quantized_data.data(), quantized_data.data() + quantized_data.size()); } + + template void SetInput(const std::vector& data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } template @@ -320,7 +329,6 @@ class HybridFullyConnectedOpModel : public SingleOpModel { const auto kKernelMap = new std::map({ {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()}, - {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()}, {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()}, {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()}, }); @@ -334,7 +342,6 @@ class FloatFullyConnectedOpTest : public SingleOpTest { const auto kKernelMapNoPie = new std::map({ {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()}, - {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()}, {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()}, }); @@ -345,15 +352,18 @@ class QuantizedFullyConnectedOpTest : public SingleOpTest { } }; -const auto kKernelMapPie = new std::map({ +const auto kKernelMapHybrid = new std::map({ {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()}, + // Only Pie supports the hybrid path, so the optimized kernel should fall + // back to the Pie path in such cases. + {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()}, }); // Hybrid mode is used by the Pie quantized kernel. class HybridFullyConnectedOpTest : public SingleOpTest { protected: const std::map& GetKernelMap() override { - return *kKernelMapPie; + return *kKernelMapHybrid; } }; @@ -397,21 +407,42 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest2) { EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9)); } -TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { +TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) { + // The optimized kernel assumes that the bias is specified. + FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(), + /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}, + /*output=*/{TensorType_FLOAT32}, + /*bias_tensor_optional=*/true); + m.SetWeights({ + 2, 4, // u = 0 + }); + + m.SetInput({ + 1, 2, // b = 0 + 2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8)); +} + +TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, /*output=*/{TensorType_UINT8, {}, -127, 128}); // input_product_scale < output_scale was not true. - m.SetWeights({ + m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }); m.SetBias({1, 2, 3}); - m.SetInput({ + m.SetInput({ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 }); @@ -427,22 +458,48 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { ElementsAre(151, 152, 153, 185, 186, 187)); } +TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) { + QuantizedFullyConnectedOpModel m( + ops::builtin::Register_FULLY_CONNECTED_REF(), /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_INT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}))); + EXPECT_THAT(m.GetOutput(), ElementsAre(23, 24, 25, 57, 58, 59)); +} + TEST_P(QuantizedFullyConnectedOpTest, - SimpleTestQuantizedOutputMultiplierGreaterThan1) { + SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8) { // real_multiplier = 2. QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, /*input=*/{TensorType_UINT8, {2, 10}, -127, 128}, /*output=*/{TensorType_UINT8, {}, -63.5, 64}); - m.SetWeights({ + m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }); m.SetBias({1, 2, 3}); - m.SetInput({ + m.SetInput({ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 }); @@ -458,6 +515,36 @@ TEST_P(QuantizedFullyConnectedOpTest, ElementsAre(175, 177, 179, 243, 245, 247)); } +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedOutputMultiplierGreaterThan1Int8) { + // real_multiplier = 2. + QuantizedFullyConnectedOpModel m( + ops::builtin::Register_FULLY_CONNECTED_REF(), /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_INT8, {2, 10}, -127, 128}, + /*output=*/{TensorType_INT8, {}, -63.5, 64}); + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(47, 49, 51, 115, 117, 119)); +} + void SimpleTestQuantizedInt16OutputCase( TfLiteRegistration* registration, int input_depth, int output_depth, int batches, FullyConnectedOptionsWeightsFormat weights_format) { @@ -477,6 +564,7 @@ void SimpleTestQuantizedInt16OutputCase( /*input=*/ {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax}, /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax}, + /*bias_tensor_optional=*/false, /*activation_func=*/ActivationFunctionType_NONE, weights_format); std::mt19937 random_engine; @@ -492,7 +580,7 @@ void SimpleTestQuantizedInt16OutputCase( // and set the (possibly shuffled) weights. switch (weights_format) { case FullyConnectedOptionsWeightsFormat_DEFAULT: - m.SetWeights(weights_data); + m.SetWeights(weights_data); break; case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: // The shuffled path currently supports only a restrictive subset of @@ -500,7 +588,7 @@ void SimpleTestQuantizedInt16OutputCase( CHECK_EQ(input_depth % 16, 0); CHECK_EQ(output_depth % 4, 0); CHECK(batches == 1 || batches == 4); - m.ShuffleAndSetWeights(weights_data, input_depth, output_depth); + m.ShuffleAndSetWeights(weights_data, input_depth, output_depth); break; default: LOG(FATAL) << "Unhandled weights format"; @@ -522,7 +610,7 @@ void SimpleTestQuantizedInt16OutputCase( } m.SetBias(bias_data); - m.SetInput(input_data); + m.SetInput(input_data); m.Invoke(); @@ -664,21 +752,21 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { })); } -TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) { +TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantizedUint8) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, /*output=*/{TensorType_UINT8, {}, -127, 128}); // input_product_scale < output_scale was not true. - m.SetWeights({ + m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 }); m.SetBias({1, 2, 3}); - m.SetInput({ + m.SetInput({ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 }); @@ -695,21 +783,21 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) { } TEST_P(QuantizedFullyConnectedOpTest, - SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1) { + SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8) { // real_multiplier = 2. QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128}, /*output=*/{TensorType_UINT8, {}, -63.5, 64}); - m.SetWeights({ + m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 }); m.SetBias({1, 2, 3}); - m.SetInput({ + m.SetInput({ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 }); diff --git a/tensorflow/lite/kernels/gather.cc b/tensorflow/lite/kernels/gather.cc index f205daae1343cb0abecc95e7d1b280c10f55d897..54d05adbcf161a2af88bea4a0de1eec06e70c09a 100644 --- a/tensorflow/lite/kernels/gather.cc +++ b/tensorflow/lite/kernels/gather.cc @@ -57,6 +57,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: case kTfLiteUInt8: + case kTfLiteInt8: case kTfLiteInt64: case kTfLiteInt32: break; @@ -135,6 +136,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return Gather(*params, input, positions, output); case kTfLiteUInt8: return Gather(*params, input, positions, output); + case kTfLiteInt8: + return Gather(*params, input, positions, output); case kTfLiteInt32: return Gather(*params, input, positions, output); case kTfLiteInt64: @@ -153,6 +156,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return Gather(*params, input, positions, output); case kTfLiteUInt8: return Gather(*params, input, positions, output); + case kTfLiteInt8: + return Gather(*params, input, positions, output); case kTfLiteInt32: return Gather(*params, input, positions, output); case kTfLiteInt64: diff --git a/tensorflow/lite/kernels/gather_nd.cc b/tensorflow/lite/kernels/gather_nd.cc new file mode 100644 index 0000000000000000000000000000000000000000..20e98652ee57ec7b6b86a20cbc474b4b9c29b2aa --- /dev/null +++ b/tensorflow/lite/kernels/gather_nd.cc @@ -0,0 +1,154 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace gather_nd { +constexpr int kParams = 0; +constexpr int kIndices = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* params = GetInput(context, node, kParams); + const TfLiteTensor* indices = GetInput(context, node, kIndices); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (params->type) { + case kTfLiteFloat32: + case kTfLiteUInt8: + case kTfLiteInt8: + case kTfLiteInt64: + case kTfLiteInt32: + break; + default: + context->ReportError( + context, "Params of type '%s' are not supported by gather_nd.", + TfLiteTypeGetName(params->type)); + return kTfLiteError; + } + switch (indices->type) { + case kTfLiteInt64: + case kTfLiteInt32: + break; + default: + context->ReportError( + context, "Indices of type '%s' are not supported by gather_nd.", + TfLiteTypeGetName(indices->type)); + return kTfLiteError; + } + + const int params_rank = NumDimensions(params); + const int indices_rank = NumDimensions(indices); + const int indices_nd = SizeOfDimension(indices, indices_rank - 1); + if (params_rank < 1) { + context->ReportError(context, "Params must be at least a vector."); + return kTfLiteError; + } + if (indices_rank < 1) { + context->ReportError(context, "Indices must be at least a vector."); + return kTfLiteError; + } + if (indices_nd > params_rank) { + context->ReportError( + context, "Index innermost dimension length must be <= params rank."); + return kTfLiteError; + } + + // Assign to output the input type. + output->type = params->type; + + // The result shape is + // indices.shape[:-1] + params.shape[indices.shape[-1]:] + const int output_rank = indices_rank + params_rank - indices_nd - 1; + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + int output_index = 0; + for (int i = 0; i < indices_rank - 1; ++i) { + output_shape->data[output_index++] = indices->dims->data[i]; + } + for (int i = indices_nd; i < params_rank; ++i) { + output_shape->data[output_index++] = params->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +template +TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices, + TfLiteTensor* output) { + reference_ops::GatherNd( + GetTensorShape(params), GetTensorData(params), + GetTensorShape(indices), GetTensorData(indices), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; +} + +template +TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params, + const TfLiteTensor* indices, TfLiteTensor* output) { + switch (params->type) { + case kTfLiteFloat32: + return GatherNd(params, indices, output); + case kTfLiteUInt8: + return GatherNd(params, indices, output); + case kTfLiteInt8: + return GatherNd(params, indices, output); + case kTfLiteInt32: + return GatherNd(params, indices, output); + case kTfLiteInt64: + return GatherNd(params, indices, output); + default: + context->ReportError(context, + "Params type '%s' are not supported by gather_nd.", + TfLiteTypeGetName(params->type)); + return kTfLiteError; + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* params = GetInput(context, node, kParams); + const TfLiteTensor* indices = GetInput(context, node, kIndices); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (indices->type) { + case kTfLiteInt32: + return EvalGatherNd(context, params, indices, output); + case kTfLiteInt64: + return EvalGatherNd(context, params, indices, output); + default: + context->ReportError( + context, "Indices of type '%s' are not supported by gather_nd.", + TfLiteTypeGetName(indices->type)); + return kTfLiteError; + } +} +} // namespace gather_nd + +TfLiteRegistration* Register_GATHER_ND() { + static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr, + gather_nd::Prepare, gather_nd::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/gather_nd_test.cc b/tensorflow/lite/kernels/gather_nd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5e93efb8ff468f9e1cd6d2cd8c4343c0fe62e79 --- /dev/null +++ b/tensorflow/lite/kernels/gather_nd_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class GatherNdOpModel : public SingleOpModel { + public: + GatherNdOpModel(const TensorData& params, const TensorData& indices) { + params_ = AddInput(params); + indices_ = AddInput(indices); + output_ = AddOutput(params.type); + SetBuiltinOp(BuiltinOperator_GATHER_ND, BuiltinOptions_GatherNdOptions, + CreateGatherNdOptions(builder_).Union()); + BuildInterpreter({GetShape(params_), GetShape(indices_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(params_, data); + } + + template + void SetPositions(std::initializer_list data) { + PopulateTensor(indices_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int params_; + int indices_; + int output_; +}; + +TEST(GatherNdOpTest, ElementIndexingIntoMatrix) { + GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2, 2}}); + m.SetInput({1.1, 1.2, 2.1, 2.2}); + m.SetPositions({0, 0, 1, 1}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.1, 2.2})); +} + +TEST(GatherNdOpTest, SliceIndexingIntoMatrix) { + GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2, 1}}); + m.SetInput({1.1, 1.2, 2.1, 2.2}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.1, 2.2, 1.1, 1.2})); +} + +TEST(GatherNdOpTest, BatchedIndexingIntoMatrix1) { + GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, + {TensorType_INT32, {2, 1, 1}}); + m.SetInput({1.1, 1.2, 2.1, 2.2}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.1, 2.2, 1.1, 1.2})); +} + +TEST(GatherNdOpTest, BatchedIndexingIntoMatrix2) { + GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, + {TensorType_INT32, {2, 1, 2}}); + m.SetInput({1.1, 1.2, 2.1, 2.2}); + m.SetPositions({0, 0, 1, 1}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.1, 2.2})); +} + +TEST(GatherNdOpTest, DuplicateIndexingIntoMatrix) { + GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2, 2}}); + m.SetInput({1.1, 1.2, 2.1, 2.2}); + m.SetPositions({0, 0, 0, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.1, 1.1})); +} + +TEST(GatherNdOpTest, ElementIndexingIntoRank3Tensor) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {1, 2, 3}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 0, 1, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.2, -4.1})); +} + +TEST(GatherNdOpTest, SliceIndexingIntoRank3Tensor) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 1}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 2}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, 5.1, -5.2, 5.3, + 6.1, -6.2, 6.3})); +} + +TEST(GatherNdOpTest, BatchedIndexingIntoRank3Tensor1) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 1, 3}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 0, 1, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.2, -4.1})); +} + +TEST(GatherNdOpTest, BatchedIndexingIntoRank3Tensor2) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 1, 1}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({3.1, 3.2, -3.3, -4.1, -4.2, 4.3, 1.1, -1.2, 1.3, + -2.1, 2.2, 2.3})); +} + +TEST(GatherNdOpTest, BatchedIndexingIntoRank3Tensor3) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 2, 2}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 1, 1, 0, 0, 0, 2, 1}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-2.1, 2.2, 2.3, 3.1, 3.2, -3.3, 1.1, -1.2, 1.3, + 6.1, -6.2, 6.3})); +} + +TEST(GatherNdOpTest, BatchedIndexingIntoRank3Tensor4) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 2, 3}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 0, 1, 1, 0, 1, 1, 1, 2, 2, 1, 2}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.2, 3.2, 4.3, 6.3})); +} + +TEST(GatherNdOpTest, DuplicateIndexingIntoRank3Tensor) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 2}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 1, 0, 1}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-2.1, 2.2, 2.3, -2.1, 2.2, 2.3})); +} + +TEST(GatherNdOpTest, Float32Int32) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT32, {2, 2}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-2.1, 2.2, 2.3, 3.1, 3.2, -3.3})); +} + +TEST(GatherNdOpTest, Float32Int64) { + GatherNdOpModel m({TensorType_FLOAT32, {3, 2, 3}}, + {TensorType_INT64, {2, 2}}); + m.SetInput({1.1, -1.2, 1.3, -2.1, 2.2, 2.3, // + 3.1, 3.2, -3.3, -4.1, -4.2, 4.3, // + 5.1, -5.2, 5.3, 6.1, -6.2, 6.3}); + m.SetPositions({0LL, 1LL, 1LL, 0LL}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-2.1, 2.2, 2.3, 3.1, 3.2, -3.3})); +} + +TEST(GatherNdOpTest, Int32Int32) { + GatherNdOpModel m({TensorType_INT32, {3, 2, 3}}, {TensorType_INT32, {2, 2}}); + m.SetInput({1, -1, 1, -2, 2, 2, // + 3, 3, -3, -4, -4, 4, // + 5, -5, 5, 6, -6, 6}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-2, 2, 2, 3, 3, -3})); +} + +TEST(GatherNdOpTest, Int32Int64) { + GatherNdOpModel m({TensorType_INT32, {3, 2, 3}}, {TensorType_INT64, {2, 2}}); + m.SetInput({1, -1, 1, -2, 2, 2, // + 3, 3, -3, -4, -4, 4, // + 5, -5, 5, 6, -6, 6}); + m.SetPositions({0LL, 1LL, 1LL, 0LL}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-2, 2, 2, 3, 3, -3})); +} + +TEST(GatherNdOpTest, Uint8Int32) { + GatherNdOpModel m({TensorType_UINT8, {3, 2, 3}}, {TensorType_INT32, {2, 2}}); + m.SetInput({1, 1, 1, 2, 2, 2, // + 3, 3, 3, 4, 4, 4, // + 5, 5, 5, 6, 6, 6}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 2, 2, 3, 3, 3})); +} + +TEST(GatherNdOpTest, Uint8Int64) { + GatherNdOpModel m({TensorType_UINT8, {3, 2, 3}}, {TensorType_INT64, {2, 2}}); + m.SetInput({1, 1, 1, 2, 2, 2, // + 3, 3, 3, 4, 4, 4, // + 5, 5, 5, 6, 6, 6}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 2, 2, 3, 3, 3})); +} + +TEST(GatherNdOpTest, Int8Int32) { + GatherNdOpModel m({TensorType_INT8, {3, 2, 3}}, {TensorType_INT32, {2, 2}}); + m.SetInput({1, -1, 1, -2, 2, 2, // + 3, 3, -3, -4, -4, 4, // + 5, -5, 5, 6, -6, 6}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-2, 2, 2, 3, 3, -3})); +} + +TEST(GatherNdOpTest, Int8Int64) { + GatherNdOpModel m({TensorType_INT8, {3, 2, 3}}, {TensorType_INT64, {2, 2}}); + m.SetInput({1, -1, 1, -2, 2, 2, // + 3, 3, -3, -4, -4, 4, // + 5, -5, 5, 6, -6, 6}); + m.SetPositions({0LL, 1LL, 1LL, 0LL}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-2, 2, 2, 3, 3, -3})); +} + +TEST(GatherNdOpTest, Int64Int32) { + GatherNdOpModel m({TensorType_INT64, {3, 2, 3}}, {TensorType_INT32, {2, 2}}); + m.SetInput({1LL, -1LL, 1LL, -2LL, 2LL, 2LL, // + 3LL, 3LL, -3LL, -4LL, -4LL, 4LL, // + 5LL, -5LL, 5LL, 6LL, -6LL, 6LL}); + m.SetPositions({0, 1, 1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-2LL, 2LL, 2LL, 3LL, 3LL, -3LL})); +} + +TEST(GatherNdOpTest, Int64Int64) { + GatherNdOpModel m({TensorType_INT64, {3, 2, 3}}, {TensorType_INT64, {2, 2}}); + m.SetInput({1LL, -1LL, 1LL, -2LL, 2LL, 2LL, // + 3LL, 3LL, -3LL, -4LL, -4LL, 4LL, // + 5LL, -5LL, 5LL, 6LL, -6LL, 6LL}); + m.SetPositions({0LL, 1LL, 1LL, 0LL}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-2LL, 2LL, 2LL, 3LL, 3LL, -3LL})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/gather_test.cc b/tensorflow/lite/kernels/gather_test.cc index 7b5f84348903a3cc436f1bd6cf32b3175b2f5815..8fc6bd173da831d63ee9eea364b7f352ea679af3 100644 --- a/tensorflow/lite/kernels/gather_test.cc +++ b/tensorflow/lite/kernels/gather_test.cc @@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) { EXPECT_TRUE(m.GetOutputShape().empty()); } +TEST(GatherOpTest, Test1DInput1DIndex) { + GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {1}}); + m.SetInput({1.0, 3.0, 5.0}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.0}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); +} + TEST(GatherOpTest, Test2DIndexWith2DResult) { GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {1, 2}}); m.SetInput({1.0, 2.0, 3.0}); @@ -137,6 +146,29 @@ TEST(FloatGatherOpTest, Axis1) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3})); } +TEST(FloatGatherOpTest, Axis10DIndex) { + const int axis = 1; + GatherOpModel m({TensorType_FLOAT32, {1, 3, 2}}, {TensorType_INT32, {}}, + axis); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 4}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); +} + +TEST(FloatGatherOpTest, Axis1Slice) { + const int axis = 1; + GatherOpModel m({TensorType_FLOAT32, {1, 4, 2}}, {TensorType_INT32, {2}}, + axis); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.SetPositions({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({7, 8, 3, 4}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); +} + TEST(FloatGatherOpTest, LastAxis) { const int axis = -1; GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {2}}, @@ -149,6 +181,17 @@ TEST(FloatGatherOpTest, LastAxis) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); } +TEST(FloatGatherOpTest, LastAxis0DIndex) { + const int axis = -1; + GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {}}, + axis); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPositions({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 6}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); +} + TEST(TypesGatherOpTest, Float32Int32) { GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2}}); m.SetInput({13.3, -13.4, -1.4, 1.5}); @@ -205,6 +248,24 @@ TEST(TypesGatherOpTest, Uint8Int64) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({14, 15, 133, 134})); } +TEST(TypesGatherOpTest, Int8Int32) { + GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT32, {2}}); + m.SetInput({-13, -120, 14, 15}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({14, 15, -13, -120})); +} + +TEST(TypesGatherOpTest, Int8Int64) { + GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT64, {2}}); + m.SetInput({-13, -120, 14, 15}); + m.SetPositions({1LL, 0LL}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({14, 15, -13, -120})); +} + TEST(TypesGatherOpTest, Int64Int32) { GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT32, {2}}); m.SetInput({-(1LL << 34), 134LL, 14LL, 15LL}); diff --git a/tensorflow/lite/kernels/if.cc b/tensorflow/lite/kernels/if.cc new file mode 100644 index 0000000000000000000000000000000000000000..1bd394e980073b73674ca972d28fafe04f7b8adf --- /dev/null +++ b/tensorflow/lite/kernels/if.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace if_kernel { + +struct OpData { + int then_subgraph_index; + int else_subgraph_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + op_data->then_subgraph_index = m["then_subgraph_index"].AsInt32(); + op_data->else_subgraph_index = m["else_subgraph_index"].AsInt32(); + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const OpData* op_data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE(context, node->inputs->size > 0); + + // The first input is the condition. + const TfLiteTensor* cond = GetInput(context, node, 0); + // Currently only bool is supported. + // TODO(ycling): Support other types since TensorFlow also support + // non-bool types as condition. + TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool); + TF_LITE_ENSURE_EQ(context, NumElements(cond), 1); + + // The first input of the node is the condition. The rest of inputs are + // passed to the branch subgraphs. Therefore, the number of subgraph inputs + // will be the number of node inputs - 1. + int num_inputs = node->inputs->size - 1; + int num_outputs = node->outputs->size; + + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); + TF_LITE_ENSURE(context, op_data->then_subgraph_index < subgraphs->size()); + TF_LITE_ENSURE(context, op_data->else_subgraph_index < subgraphs->size()); + + Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get(); + Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get(); + + for (auto* subgraph : {then_subgraph, else_subgraph}) { + TF_LITE_ENSURE_EQ(context, num_inputs, subgraph->inputs().size()); + TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size()); + } + + bool has_dynamic_output_tensors = false; + for (auto* subgraph : {then_subgraph, else_subgraph}) { + for (int i = 0; i < num_inputs; ++i) { + // The first input of the node is the condition. The indices of the inputs + // passed to the subgraphs are offset by 1. + const TfLiteTensor* input = GetInput(context, node, i + 1); + std::vector dims(input->dims->data, + input->dims->data + input->dims->size); + subgraph->ResizeInputTensor(i, dims); + TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]); + TF_LITE_ENSURE_EQ(context, input->type, subgraph_input->type); + } + // Note: The `Prepare` function is responsible to run `AllocateTensors` on + // both subgraphs. It's intentionally not to break out of the loop when + // finding a dynamic output tensor. + TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors()); + has_dynamic_output_tensors |= subgraph->HasDynamicTensors(); + } + + if (!has_dynamic_output_tensors) { + for (int i = 0; i < num_outputs; ++i) { + TfLiteTensor* then_output = + then_subgraph->tensor(then_subgraph->outputs()[i]); + TfLiteTensor* else_output = + else_subgraph->tensor(else_subgraph->outputs()[i]); + // If the 2 subgraphs have static but different output shapes, the output + // tensors of the IF op have dynamic sizes. + if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) { + has_dynamic_output_tensors = true; + break; + } + } + } + + for (int i = 0; i < num_outputs; ++i) { + TfLiteTensor* output = GetOutput(context, node, i); + if (has_dynamic_output_tensors) { + SetTensorToDynamic(output); + } else { + // When there's no dynamic output tensors, the 2 subgraph has exactly + // the same static sized outputs. + TfLiteTensor* then_output = + then_subgraph->tensor(then_subgraph->outputs()[i]); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* cond = GetInput(context, node, 0); + bool cond_value = cond->data.b[0]; + + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); + + // Currently we copy the input / output between the subgraphs. This isn't + // optimized yet. + // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. + int active_branch_subgraph_index = + cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index; + Subgraph& active_branch_subgraph = + *(*subgraphs)[active_branch_subgraph_index]; + for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) { + const TfLiteTensor* input = GetInput(context, node, i + 1); + TfLiteTensor* subgraph_input = + active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]); + TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes); + memcpy(subgraph_input->data.raw, input->data.raw, input->bytes); + } + + // Note: It's guaranteed that the subgraphs' `AllocateTensors` are called + // in `Prepare`, so we don't need to do it here again. + TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke()); + + for (int tensor_index : active_branch_subgraph.outputs()) { + active_branch_subgraph.EnsureTensorDataIsReadable(tensor_index); + } + + bool has_dynamic_output_tensors = false; + for (int i = 0; i < node->outputs->size; ++i) { + TfLiteTensor* output = GetOutput(context, node, i); + if (IsDynamicTensor(output)) { + has_dynamic_output_tensors = true; + break; + } + } + + if (has_dynamic_output_tensors) { + for (int i = 0; i < node->outputs->size; ++i) { + TfLiteTensor* output = GetOutput(context, node, i); + TfLiteTensor* subgraph_output = + active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + } + } + + for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) { + const TfLiteTensor* subgraph_output = + active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); + TfLiteTensor* output = GetOutput(context, node, i); + TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes); + memcpy(output->data.raw, subgraph_output->data.raw, output->bytes); + } + return kTfLiteOk; +} + +} // namespace if_kernel + +TfLiteRegistration* Register_IF() { + static TfLiteRegistration r = {if_kernel::Init, if_kernel::Free, + if_kernel::Prepare, if_kernel::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/if_test.cc b/tensorflow/lite/kernels/if_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f90db131b0bc335b54f4f8c24fa5d8dd02862f4 --- /dev/null +++ b/tensorflow/lite/kernels/if_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/subgraph_test_util.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { + +using subgraph_test_util::CheckIntTensor; +using subgraph_test_util::ControlFlowOpTest; +using subgraph_test_util::FillIntTensor; + +namespace { + +// A simple test that performs `ADD` if condition is true, and `MUL` otherwise. +// The computation is: `cond ? a + b : a * b`. +class SimpleIfTest : public ControlFlowOpTest { + protected: + void SetUp() override { + interpreter_->AddSubgraphs(2); + builder_->BuildAddSubgraph(interpreter_->subgraph(1)); + builder_->BuildMulSubgraph(interpreter_->subgraph(2)); + builder_->BuildIfSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2}); + } +}; + +TEST_F(SimpleIfTest, TestIfTrue) { + interpreter_->typed_input_tensor(0)[0] = true; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output, {1, 2}, {6, 9}); +} + +TEST_F(SimpleIfTest, TestIfFalse) { + interpreter_->typed_input_tensor(0)[0] = false; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output, {1, 2}, {5, 14}); +} + +// Test IF op using subgraphs with dynamically sized outputs. +// The computation is: `cond ? a + b : pad(a, b)`. +class DynamicSubgraphIfTest : public ControlFlowOpTest { + protected: + void SetUp() override { + interpreter_->AddSubgraphs(2); + builder_->BuildAddSubgraph(interpreter_->subgraph(1)); + builder_->BuildPadSubgraph(interpreter_->subgraph(2)); + builder_->BuildIfSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2}); + } +}; + +TEST_F(DynamicSubgraphIfTest, TestIfTrue) { + interpreter_->typed_input_tensor(0)[0] = true; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + // Even if the true branch has a static type output, the output of the + // if op is dynamic because the other branch has dynamic output. + EXPECT_TRUE(IsDynamicTensor(output)); + CheckIntTensor(output, {1, 2}, {6, 9}); +} + +TEST_F(DynamicSubgraphIfTest, TestIfFalse) { + interpreter_->typed_input_tensor(0)[0] = false; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + // The false branch has dynamic output. + EXPECT_TRUE(IsDynamicTensor(output)); + CheckIntTensor(output, {5}, {0, 5, 7, 0, 0}); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index b734b2d6cc30bb84eaa424ffed71747136f57c4c..4a18ee3c09720f5e2b0b1f0ddc906ade97973adc 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "transitive_hdrs") load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") @@ -46,7 +46,6 @@ cc_library( ], deps = [ "//tensorflow/lite/kernels:op_macros", - "@com_google_absl//absl/base:core_headers", ], ) @@ -60,10 +59,16 @@ cc_library( ], deps = [ "//tensorflow/lite/kernels:op_macros", - "@com_google_absl//absl/base:core_headers", ], ) +config_setting( + name = "aarch64", + values = { + "cpu": "aarch64", + }, +) + config_setting( name = "arm", values = { @@ -246,6 +251,7 @@ cc_library( ":optimized_base", ":tensor", ":types", + "//tensorflow/core/kernels:eigen_spatial_convolutions-inl", "//tensorflow/lite/c:c_api_internal", "//third_party/eigen3", ], @@ -254,9 +260,6 @@ cc_library( cc_test( name = "tensor_test", srcs = ["tensor_test.cc"], - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":tensor", "@com_google_googletest//:gtest", @@ -286,9 +289,6 @@ cc_library( cc_test( name = "quantization_util_test", srcs = ["quantization_util_test.cc"], - tags = [ - "tflite_not_portable_ios", # TODO(b/117786830) - ], deps = [ ":quantization_util", "@com_google_googletest//:gtest", @@ -314,9 +314,19 @@ cc_library( "reference/depthwiseconv_float.h", "reference/depthwiseconv_uint8.h", "reference/fully_connected.h", + "reference/integer_ops/add.h", + "reference/integer_ops/conv.h", + "reference/integer_ops/depthwise_conv.h", "reference/integer_ops/dequantize.h", + "reference/integer_ops/fully_connected.h", + "reference/integer_ops/l2normalization.h", + "reference/integer_ops/log_softmax.h", + "reference/integer_ops/logistic.h", + "reference/integer_ops/mean.h", + "reference/integer_ops/mul.h", "reference/integer_ops/pooling.h", "reference/integer_ops/softmax.h", + "reference/integer_ops/tanh.h", "reference/reference_ops.h", "reference/softmax.h", ], @@ -413,6 +423,7 @@ cc_library( "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:activation_functor", "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels/internal:types", ], ) @@ -497,6 +508,9 @@ cc_library( "//tensorflow/lite/kernels:op_macros", "@gemmlowp", ] + select({ + ":aarch64": [ + ":neon_tensor_utils", + ], ":arm": [ ":neon_tensor_utils", ], @@ -552,13 +566,10 @@ cc_library( }), deps = [ ":types", - "//tensorflow/lite:string", ], ) -# TODO(b/122597976): Eliminate TF dependency from lite/kernels:test_util, -# in turn eliminating the need to use tf_cc_test for any dependent tests. -tf_cc_test( +cc_test( name = "tensor_utils_test", srcs = ["tensor_utils_test.cc"], linkopts = select({ @@ -568,9 +579,6 @@ tf_cc_test( "//conditions:default": [], }), linkstatic = 1, - tags = [ - "tflite_not_portable_ios", - ], deps = [ ":tensor_utils", "//tensorflow/lite/c:c_api_internal", @@ -593,26 +601,25 @@ cc_test( cc_test( name = "depthwiseconv_quantized_test", - srcs = ["depthwiseconv_quantized_test.cc"], - shard_count = 2, - tags = [ - "tflite_not_portable_ios", + srcs = [ + "depthwiseconv_quantized_test.cc", + "optimized/depthwiseconv_uint8_transitional.h", ], + shard_count = 2, deps = [ ":optimized_base", ":reference_base", ":test_util", ":types", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@gemmlowp", ], ) cc_test( name = "resize_bilinear_test", srcs = ["resize_bilinear_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":optimized_base", ":reference_base", @@ -625,9 +632,6 @@ cc_test( cc_test( name = "resize_nearest_neighbor_test", srcs = ["resize_nearest_neighbor_test.cc"], - tags = [ - "tflite_not_portable", - ], deps = [ ":optimized_base", ":reference_base", @@ -660,7 +664,7 @@ cc_test( srcs = [ "logsoftmax_quantized_test.cc", ], - shard_count = 3, + shard_count = 4, tags = [ # TODO(b/122242739): Reenable after fixing the flakiness? "nomac", @@ -679,6 +683,10 @@ cc_test( cc_test( name = "log_quantized_test", srcs = ["log_quantized_test.cc"], + linkopts = select({ + "//tensorflow:windows": [], + "//conditions:default": ["-lm"], + }), deps = [ ":optimized_base", ":reference_base", @@ -714,4 +722,78 @@ cc_test( exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"]) +filegroup( + name = "optimized_op_headers", + srcs = glob([ + "optimized/*.h", + ]), + visibility = ["//tensorflow/lite:__subpackages__"], +) + +filegroup( + name = "reference_op_headers", + srcs = glob([ + "reference/*.h", + ]), + visibility = ["//tensorflow/lite:__subpackages__"], +) + +filegroup( + name = "headers", + srcs = glob([ + "*.h", + ]), + visibility = ["//tensorflow/lite:__subpackages__"], +) + +transitive_hdrs( + name = "nnapi_external_headers", + visibility = ["//tensorflow/lite:__subpackages__"], + deps = [ + "//third_party/eigen3", + "@gemmlowp", + ], +) + +# --------------------------------------------------------- +# The public target "install_nnapi_extra_headers" is only +# used for external targets that requires exporting optmized +# and reference op headers. + +genrule( + name = "install_nnapi_extra_headers", + srcs = [ + ":nnapi_external_headers", + ":headers", + ":optimized_op_headers", + ":reference_op_headers", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(SRCS); do + d="$${f%/*}" + d="$${d#bazel-out*genfiles/}" + d="$${d#*external/eigen_archive/}" + + if [[ $${d} == *local_config_* ]]; then + continue + fi + + if [[ $${d} == external* ]]; then + extname="$${d#*external/}" + extname="$${extname%%/*}" + if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then + continue + fi + fi + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """, + tags = ["manual"], + visibility = ["//visibility:private"], +) + tflite_portable_test_suite() diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index bc30ac91220906588f204d6ff21c275faa2b6c25..2b8226c4977f489876f9073a4cd16c49f0673fa1 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -131,6 +131,223 @@ int CountLeadingZeros(T integer_input) { #endif } +// TODO(b/77858996): Add these to gemmlowp. +template +IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + sum))); +} + +template +gemmlowp::FixedPoint SaturatingAddNonGemmlowp( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingAddNonGemmlowp(a.raw(), b.raw())); +} + +template +IntegerType SaturatingSub(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t diff = a32 - b32; + return static_cast( + std::min(static_cast(32767), + std::max(static_cast(-32768), diff))); +} + +template <> +inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t diff = a64 - b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + diff))); +} + +template +gemmlowp::FixedPoint SaturatingSub( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingSub(a.raw(), b.raw())); +} +// End section to be moved to gemmlowp. + +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +// Although currently the name of this function says that it cannot handle +// values less than 1, in practice it can handle as low as 1/x_max, where +// x_max is the largest representable input. In other words, the output range +// is symmetric. +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); + // assert(__builtin_clz(0u) <= std::numeric_limits::digits); + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + inline int32 GetReciprocal(int32 x, int x_integer_digits, int* num_bits_over_unit) { int headroom_plus_one = CountLeadingZeros(static_cast(x)); @@ -148,6 +365,55 @@ inline int32 GetReciprocal(int32 x, int x_integer_digits, return shifted_scale.raw(); } +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift, + int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } + // Convert right shift (right is positive) to left shift. + *output_shift *= reverse_shift; +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING // BROADCASTING. // diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index 3682499d494cc4e63712b6c57d80482899b2185d..3e48d95a082ca285874c2dad01aaf3b845db9e88 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -19,40 +19,97 @@ limitations under the License. #include #include #include +#include #include #include +#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/test_util.h" #include "tensorflow/lite/kernels/internal/types.h" #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "absl/strings/substitute.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" namespace tflite { namespace { -enum class ForceKernelInvocation { - // Run all tests against kUseStandardEntry even if also testing another - // kernel, since we need to be sure that the main DepthwiseConv() function in - // optimized_ops.h dispatches to a correctly-executing kernel. - kNone = 0, // The "default" option: use the normal DepthwiseConv - // kernel (entry) function. - kUseGenericKernel, - kUseNeon3x3, // 3x3 kernel that uses NEON when available. - kUseNeon3x3DotProduct, // 3x3 kernel that uses dot-product enabled NEON when - // available. +using optimized_ops::depthwise_conv::DotProduct3x3KernelType; +using ::testing::Bool; +using ::testing::Values; + +#if defined(__aarch64__) +static constexpr bool kLooseIntrinsicsTolerance = false; +#else +static constexpr bool kLooseIntrinsicsTolerance = true; +#endif + +// Currently, this is used in place of a Boolean "is symmetric?". +enum class ParamsSpecialization { + kNone = 0, + kSymmetric, // Symmetric quantization: zero represented by 128. +}; + +static constexpr int kSymmetricZeroPoint = 128; + +// Extend coverage distribution in a specific aspect, either explicitly chosen +// or randomly chosen as in a mixture distribution. +enum class CoverageExtension { + kNone = 0, + kLargeHeights = 1, + kLargeWidths = 2, + kNumOptions +}; + +// The TestParam structure below is the preferred parameterization of tests. A +// tuple version is defined in order to support value-parameterized tests. +typedef std::tuple + TestParamTuple; + +struct TestParam { + TestParam() = default; + + explicit TestParam(TestParamTuple param_tuple) + : forced_invocation(::testing::get<0>(param_tuple)), + tests_to_run(::testing::get<1>(param_tuple)), + test_stride(::testing::get<2>(param_tuple)), + test_pad(::testing::get<3>(param_tuple)), + test_depth_multiplier(::testing::get<4>(param_tuple)), + output_rounding(::testing::get<5>(param_tuple)), + loose_tolerance(::testing::get<6>(param_tuple)) {} + + static std::string TestNameSuffix( + const ::testing::TestParamInfo& info) { + const TestParam param(info.param); + return absl::Substitute("invocation_$0_stride_$1_pad_$2_depth_mult_$3", + static_cast(param.forced_invocation), + param.test_stride, param.test_pad, + param.test_depth_multiplier); + } + + DepthwiseConvImplementation forced_invocation = + DepthwiseConvImplementation::kNone; + int tests_to_run = 0; + bool test_stride = false; + bool test_pad = false; + bool test_depth_multiplier = false; + DepthwiseConvOutputRounding output_rounding = + DepthwiseConvOutputRounding::kNone; + bool loose_tolerance = false; }; inline void DispatchDepthwiseConv( - ForceKernelInvocation forced_invocation, const DepthwiseParams& params, + const TestParam& test_param, const DepthwiseParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, uint8* output_data) { - switch (forced_invocation) { - case ForceKernelInvocation::kUseNeon3x3: { + switch (test_param.forced_invocation) { + case DepthwiseConvImplementation::kUseNeon3x3: { // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on // Jetson TX-2. This compiler does not support the offsetof() macro. #if defined(__aarch64__) && !defined(GOOGLE_L4T) @@ -67,20 +124,20 @@ inline void DispatchDepthwiseConv( // Check that parameter combination is supported. const bool basic_3x3_kernel_supported = - optimized_ops::Fast3x3FilterKernelSupported( + optimized_ops::depthwise_conv::Fast3x3FilterKernelSupported( input_shape, filter_shape, stride_width, stride_height, dilation_width_factor, dilation_height_factor, pad_width, pad_height, depth_multiplier, output_shape, output_shift); ASSERT_TRUE(basic_3x3_kernel_supported) << "pad_width = " << params.padding_values.width << " pad_height = " << params.padding_values.height - << " input_width = " << input_shape.Dims(1) - << " input_height = " << input_shape.Dims(2) - << " output_width = " << output_shape.Dims(1) - << " output_height = " << output_shape.Dims(2); + << " input_width = " << input_shape.Dims(2) + << " input_height = " << input_shape.Dims(1) + << " output_width = " << output_shape.Dims(2) + << " output_height = " << output_shape.Dims(1); // Call kernel optimized for depthwise convolutions using 3x3 filters. - optimized_ops::DepthwiseConv3x3Filter( + optimized_ops::depthwise_conv::DepthwiseConv3x3Filter( params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data); return; @@ -88,64 +145,148 @@ inline void DispatchDepthwiseConv( break; #endif } - case ForceKernelInvocation::kUseNeon3x3DotProduct: { -// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on -// Jetson TX-2. This compiler does not support the offsetof() macro. -#if defined(__ARM_FEATURE_DOTPROD) && defined(__aarch64__) && \ - !defined(GOOGLE_L4T) - using optimized_ops::DotProduct3x3KernelType; + case DepthwiseConvImplementation::kUseNeon3x3DotProduct: { +#if defined(__ARM_FEATURE_DOTPROD) && defined(__aarch64__) DotProduct3x3KernelType kernel_type = - optimized_ops::CategorizeDotProductKernel(params); - switch (kernel_type) { - case DotProduct3x3KernelType::kPlain: - // TODO(b/118430534): Implement optimized kernel. - optimized_ops::DepthwiseConv3x3Filter( - params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); - return; - case DotProduct3x3KernelType::kWithDepthMultiplication: - // TODO(b/118430338): Implement optimized kernel. - optimized_ops::DepthwiseConvGeneral( - params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); - return; - case DotProduct3x3KernelType::kWithPad0Stride2: - // TODO(b/118430338): Implement optimized kernel. - optimized_ops::DepthwiseConv3x3Filter( - params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); - return; - case DotProduct3x3KernelType::kWithPad1Stride1: - // TODO(b/118430338): Implement optimized kernel. - optimized_ops::DepthwiseConvGeneral( - params, input_shape, input_data, filter_shape, filter_data, - bias_shape, bias_data, output_shape, output_data); - return; - case DotProduct3x3KernelType::kNone: - default: - break; - } + optimized_ops::depthwise_conv::CategorizeDotProductKernel( + input_shape, filter_shape, params); + + ASSERT_TRUE( + kernel_type == DotProduct3x3KernelType::kPlain || + kernel_type == DotProduct3x3KernelType::kStride2 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride1 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride2) + << "Kernel type = " << static_cast(kernel_type); + + optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3< + DepthwiseConvImplementation::kUseNeon3x3DotProduct>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; #endif break; } - case ForceKernelInvocation::kUseGenericKernel: { - optimized_ops::DepthwiseConvGeneral(params, input_shape, input_data, - filter_shape, filter_data, bias_shape, - bias_data, output_shape, output_data); + case DepthwiseConvImplementation::kUseCModel3x3DotProduct: { + DotProduct3x3KernelType kernel_type = + optimized_ops::depthwise_conv::CategorizeDotProductKernel( + input_shape, filter_shape, params); + + ASSERT_TRUE( + kernel_type == DotProduct3x3KernelType::kPlain || + kernel_type == DotProduct3x3KernelType::kStride2 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride1 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride2) + << "Kernel type = " << static_cast(kernel_type) + << " depth_multiplier = " << params.depth_multiplier + << " pad_width = " << params.padding_values.width + << " pad_height = " << params.padding_values.height + << " stride_width = " << params.stride_width + << " stride_height = " << params.stride_height + << " input_width = " << input_shape.Dims(2) + << " input_height = " << input_shape.Dims(1) + << " output_width = " << output_shape.Dims(2) + << " output_height = " << output_shape.Dims(1) + << " depth = " << input_shape.Dims(3) + << " buffer need = " << input_shape.Dims(3) * input_shape.Dims(2) * 6 + << " input_offset = " << params.input_offset; + + optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3< + DepthwiseConvImplementation::kUseCModel3x3DotProduct>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); return; } - case ForceKernelInvocation::kNone: + case DepthwiseConvImplementation::kUseUnwound3x3DotProduct: { + DotProduct3x3KernelType kernel_type = + optimized_ops::depthwise_conv::CategorizeDotProductKernel( + input_shape, filter_shape, params); + ASSERT_TRUE( + kernel_type == DotProduct3x3KernelType::kPlain || + kernel_type == DotProduct3x3KernelType::kStride2 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride1 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride2); + optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3< + DepthwiseConvImplementation::kUseUnwound3x3DotProduct>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; + } + case DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct: { +#if defined(USE_NEON) + DotProduct3x3KernelType kernel_type = + optimized_ops::depthwise_conv::CategorizeDotProductKernel( + input_shape, filter_shape, params); + + ASSERT_TRUE( + kernel_type == DotProduct3x3KernelType::kPlain || + kernel_type == DotProduct3x3KernelType::kStride2 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride1 || + kernel_type == + DotProduct3x3KernelType::kWithDepthMultiplicationStride2); + optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3< + DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; +#else + break; +#endif + } + case DepthwiseConvImplementation::kUseGenericKernel: { + optimized_ops::depthwise_conv::DepthwiseConvGeneral( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; + } + case DepthwiseConvImplementation::kNone: + default: + break; + } + + EXPECT_EQ(test_param.forced_invocation, DepthwiseConvImplementation::kNone) + << "TODO(b/118426582) requested kernel was not invoked / available yet: " + << " forced_invocation = " + << static_cast(test_param.forced_invocation) + << " depth_multiplier = " << params.depth_multiplier + << " pad_width = " << params.padding_values.width + << " pad_height = " << params.padding_values.height + << " stride_width = " << params.stride_width + << " stride_height = " << params.stride_height + << " input_width = " << input_shape.Dims(2) + << " input_height = " << input_shape.Dims(1) + << " output_width = " << output_shape.Dims(2) + << " output_height = " << output_shape.Dims(1) + << " depth = " << input_shape.Dims(3) + << " buffer need = " << input_shape.Dims(3) * input_shape.Dims(2) * 6 + << " input_offset = " << params.input_offset; + switch (test_param.output_rounding) { + case DepthwiseConvOutputRounding::kAwayFromZero: + optimized_ops::DepthwiseConvWithRounding< + DepthwiseConvOutputRounding::kAwayFromZero>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; + case DepthwiseConvOutputRounding::kUpward: + optimized_ops::DepthwiseConvWithRounding< + DepthwiseConvOutputRounding::kUpward>( + params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); + return; default: break; } - optimized_ops::DepthwiseConv(params, input_shape, input_data, filter_shape, - filter_data, bias_shape, bias_data, output_shape, - output_data); } // Runs the DepthwiseConv and compares against the reference implementation. int TestOneDepthwiseConvWithGivenOutputShift( - ForceKernelInvocation forced_invocation, const std::uint8_t* input_data, + const TestParam& test_param, const std::uint8_t* input_data, const RuntimeShape& input_shape, std::int32_t input_offset, const std::uint8_t* filter_data, const RuntimeShape& filter_shape, std::int32_t filter_offset, const std::int32_t* bias_data, @@ -174,10 +315,31 @@ int TestOneDepthwiseConvWithGivenOutputShift( op_params.output_offset = output_offset; op_params.output_multiplier = output_multiplier; op_params.output_shift = -output_shift; - reference_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape, - filter_data, bias_shape, bias_data, output_shape, - reference_output_data.data()); - DispatchDepthwiseConv(forced_invocation, op_params, input_shape, input_data, + switch (test_param.output_rounding) { + case DepthwiseConvOutputRounding::kUpward: + reference_ops::depthwise_conv::DepthwiseConvBasicKernel< + DepthwiseConvOutputRounding::kAwayFromZero>::Run(op_params, + input_shape, + input_data, + filter_shape, + filter_data, + bias_shape, + bias_data, + output_shape, + reference_output_data + .data()); + break; + case DepthwiseConvOutputRounding::kAwayFromZero: + reference_ops::DepthwiseConv( + op_params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, reference_output_data.data()); + break; + case DepthwiseConvOutputRounding::kNone: + default: + EXPECT_NE(test_param.output_rounding, DepthwiseConvOutputRounding::kNone); + break; + } + DispatchDepthwiseConv(test_param, op_params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data.data()); int saturated_min = 0; @@ -201,15 +363,46 @@ int TestOneDepthwiseConvWithGivenOutputShift( const float mean_diff = static_cast(sum_diff) / output_buffer_size; const float mean_abs_diff = static_cast(sum_abs_diff) / output_buffer_size; + + int diff_mean_tolerance = 1; + int diff_median_tolerance = 0; + // The tolerance that we apply to means is tight, but we allow for a rounding + // difference in one pixel, and loosen by another 1% for float comparison. + float mean_tolerance = std::max(2e-5f, 1.01f * 3.f / output_buffer_size * + std::sqrt(1.f * depth_multiplier)); + if (test_param.loose_tolerance) { + mean_tolerance = 500.f; + diff_mean_tolerance = 256; + diff_median_tolerance = 225; + } + // Normally we should require bit-for-bit exact results. Unfortunately a bug // in the Intel arm_neon_sse.h translation header that we use for x86 tests - // causes 1-bit inaccuracy in - // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized - // DepthwiseConv ops. So we have to live with a few off-by-one errors for now, - // yet still ensure that no more than a small minority of values are wrong. - EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f && - std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && - std::abs(max_diff) <= 1); + // causes 1-bit inaccuracy in the vqrdmulh_n_s32 intrinsic, which causes + // off-by-1 errors in quantized DepthwiseConv ops. So we have to live with a + // few off-by-one errors for now, yet still ensure that no more than a small + // minority of values are wrong. + EXPECT_LT(std::abs(mean_diff), mean_tolerance); + EXPECT_LT(mean_abs_diff, mean_tolerance); + EXPECT_LE(std::abs(median_diff), diff_median_tolerance); + EXPECT_LE(std::abs(min_diff), diff_mean_tolerance); + EXPECT_LE(std::abs(max_diff), diff_mean_tolerance); + EXPECT_TRUE(std::abs(mean_diff) < mean_tolerance && + mean_abs_diff < mean_tolerance && + std::abs(median_diff) <= diff_median_tolerance && + std::abs(min_diff) <= diff_mean_tolerance && + std::abs(max_diff) <= diff_mean_tolerance) + << "pad_width = " << op_params.padding_values.width + << " pad_height = " << op_params.padding_values.height + << " input_width = " << input_shape.Dims(2) + << " input_height = " << input_shape.Dims(1) + << " output_width = " << output_shape.Dims(2) + << " output_height = " << output_shape.Dims(1) + << " depth = " << input_shape.Dims(3) + << " output_offset = " << op_params.output_offset + << " output_multiplier = " << op_params.output_multiplier + << " output_shift = " << op_params.output_shift; + if (saturated_min > 2 * saturated_max) { return -1; } @@ -221,13 +414,12 @@ int TestOneDepthwiseConvWithGivenOutputShift( // The point of this function is that we can't practically know which // output_shift value to pass to test DepthwiseConv. It's not easy to guess (we -// could do some -// statistics for large size, but they would be fragile at smaller sizes), and -// guessing wrong would mean that all the values get saturated so the test -// becomes -// vacuous. So we just bisect our way to reasonable output_shift values. +// could do some statistics for large size, but they would be fragile at smaller +// sizes), and guessing wrong would mean that all the values get saturated so +// the test becomes vacuous. So we just bisect our way to reasonable +// output_shift values. void TestOneDepthwiseConvBisectOutputShift( - ForceKernelInvocation forced_invocation, const std::uint8_t* input_data, + const TestParam& test_param, const std::uint8_t* input_data, const RuntimeShape& input_shape, std::int32_t input_offset, const std::uint8_t* filter_data, const RuntimeShape& filter_shape, std::int32_t filter_offset, const std::int32_t* bias_data, @@ -242,7 +434,7 @@ void TestOneDepthwiseConvBisectOutputShift( int output_shift_bisect_midpoint = (output_activation_bisect_start + output_activation_bisect_end) / 2; int bisect_result = TestOneDepthwiseConvWithGivenOutputShift( - forced_invocation, input_data, input_shape, input_offset, filter_data, + test_param, input_data, input_shape, input_offset, filter_data, filter_shape, filter_offset, bias_data, bias_shape, stride, padding_type, pad_width, pad_height, depth_multiplier, output_offset, output_multiplier, output_shift_bisect_midpoint, output_activation_min, @@ -269,7 +461,7 @@ void TestOneDepthwiseConvBisectOutputShift( ? output_activation_bisect_end : output_shift_bisect_midpoint; TestOneDepthwiseConvBisectOutputShift( - forced_invocation, input_data, input_shape, input_offset, filter_data, + test_param, input_data, input_shape, input_offset, filter_data, filter_shape, filter_offset, bias_data, bias_shape, stride, padding_type, pad_width, pad_height, depth_multiplier, output_offset, output_multiplier, new_output_activation_bisect_start, new_output_activation_bisect_end, @@ -277,7 +469,7 @@ void TestOneDepthwiseConvBisectOutputShift( } void TestOneDepthwiseConv( - ForceKernelInvocation forced_invocation, const std::uint8_t* input_data, + const TestParam& test_param, const std::uint8_t* input_data, const RuntimeShape& input_shape, std::int32_t input_offset, const std::uint8_t* filter_data, const RuntimeShape& filter_shape, std::int32_t filter_offset, const std::int32_t* bias_data, @@ -287,13 +479,14 @@ void TestOneDepthwiseConv( std::int32_t output_activation_min, std::int32_t output_activation_max, const RuntimeShape& output_shape) { TestOneDepthwiseConvBisectOutputShift( - forced_invocation, input_data, input_shape, input_offset, filter_data, + test_param, input_data, input_shape, input_offset, filter_data, filter_shape, filter_offset, bias_data, bias_shape, stride, padding_type, pad_width, pad_height, depth_multiplier, output_offset, output_multiplier, 0, 32, output_activation_min, output_activation_max, output_shape); } -bool TryTestDepthwiseConv(ForceKernelInvocation forced_invocation, int batch, +bool TryTestDepthwiseConv(const TestParam& test_param, + ParamsSpecialization params_specialization, int batch, int input_depth, int input_width, int input_height, int filter_width, int filter_height, int depth_multiplier, int stride, @@ -318,9 +511,12 @@ bool TryTestDepthwiseConv(ForceKernelInvocation forced_invocation, int batch, } const std::int32_t output_multiplier = UniformRandomInt(1 << 29, std::numeric_limits::max()); - const std::int32_t input_offset = UniformRandomInt(-256, 0); - const std::int32_t filter_offset = UniformRandomInt(-256, 0); - const std::int32_t output_offset = UniformRandomInt(-256, 0); + std::int32_t filter_offset = -kSymmetricZeroPoint; + if (params_specialization != ParamsSpecialization::kSymmetric) { + filter_offset = UniformRandomInt(-255, 0); + } + const std::int32_t input_offset = UniformRandomInt(-255, 0); + const std::int32_t output_offset = UniformRandomInt(0, 255); RuntimeShape input_shape_inference( {batch, input_height, input_width, input_depth}); RuntimeShape output_shape_inference; @@ -343,7 +539,7 @@ bool TryTestDepthwiseConv(ForceKernelInvocation forced_invocation, int batch, FillRandom(&filter_data); FillRandom(&bias_data, -10000, 10000); TestOneDepthwiseConv( - forced_invocation, input_data.data(), input_shape_inference, input_offset, + test_param, input_data.data(), input_shape_inference, input_offset, filter_data.data(), filter_shape_inference, filter_offset, bias_data.data(), bias_shape_inference, stride, padding_type, pad_width, pad_height, depth_multiplier, output_offset, output_multiplier, @@ -355,7 +551,8 @@ bool TryTestDepthwiseConv(ForceKernelInvocation forced_invocation, int batch, // be legal. If they're not legal, it returns false. If they're legal, // it runs the DepthwiseConv test and returns true. This allows the caller // to loop until a test has been run. -bool TryTestOneDepthwiseConv(ForceKernelInvocation forced_invocation) { +bool TryTestOneDepthwiseConv(const TestParam& test_param, + ParamsSpecialization params_specialization) { // We have to pick a lot of positive values, where we are particularly // interested in small values because they are most likely to be special // cases in optimized implementations, and secondarily because they allow @@ -375,13 +572,14 @@ bool TryTestOneDepthwiseConv(ForceKernelInvocation forced_invocation) { UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; return TryTestDepthwiseConv( - forced_invocation, batch, input_depth, input_width, input_height, - filter_width, filter_height, depth_multiplier, stride, + test_param, params_specialization, batch, input_depth, input_width, + input_height, filter_width, filter_height, depth_multiplier, stride, dilation_width_factor, dilation_height_factor, padding_type); } // Tests parameters for the 3x3 filter kernel. -bool TryTestOneDepthwiseConv3x3Filter(ForceKernelInvocation forced_invocation) { +bool TryTestOneDepthwiseConv3x3Filter( + const TestParam& test_param, ParamsSpecialization params_specialization) { const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50); int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); @@ -397,7 +595,7 @@ bool TryTestOneDepthwiseConv3x3Filter(ForceKernelInvocation forced_invocation) { UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; // Adjust for, or reject, special cases. - if (forced_invocation != ForceKernelInvocation::kNone) { + if (test_param.forced_invocation != DepthwiseConvImplementation::kNone) { // With stride == 2 and SAME, padding width and height are the left and top // padding amounts. When there is an even input dimension, padding + 1 is // required on the right / bottom. This is not handled by these kernels, so @@ -416,59 +614,77 @@ bool TryTestOneDepthwiseConv3x3Filter(ForceKernelInvocation forced_invocation) { } return TryTestDepthwiseConv( - forced_invocation, batch, input_depth, input_width, input_height, - filter_width, filter_height, depth_multiplier, stride, + test_param, params_specialization, batch, input_depth, input_width, + input_height, filter_width, filter_height, depth_multiplier, stride, dilation_width_factor, dilation_height_factor, padding_type); } // Tests with parameters suited to dot-product-NEON 3x3 filter kernels. -bool TryTestOneNeonDot3x3(ForceKernelInvocation forced_invocation, - bool test_stride, bool test_pad, - bool test_depth_multiplier) { +bool TryTestOneNeonDot3x3(const TestParam& test_param, + ParamsSpecialization params_specialization) { + const CoverageExtension coverage_extension = static_cast( + UniformRandomInt(0, static_cast(CoverageExtension::kNumOptions))); + const int batch = 1; - const int input_depth = test_depth_multiplier + const int input_depth = test_param.test_depth_multiplier ? 1 - : 8 * ExponentialRandomPositiveInt(0.9f, 10, 50); - const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); - const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + : 8 * ExponentialRandomPositiveInt(0.9f, 3, 50); + const int input_width = coverage_extension == CoverageExtension::kLargeWidths + ? ExponentialRandomPositiveInt(0.9f, 50, 200) + : ExponentialRandomPositiveInt(0.9f, 20, 60); + const int input_height = + coverage_extension == CoverageExtension::kLargeHeights + ? ExponentialRandomPositiveInt(0.9f, 50, 200) + : ExponentialRandomPositiveInt(0.9f, 20, 60); const int filter_width = 3; const int filter_height = 3; const int depth_multiplier = - test_depth_multiplier ? 8 * ExponentialRandomPositiveInt(0.8f, 1, 6) : 1; - const int stride = test_stride ? 2 : 1; + test_param.test_depth_multiplier + ? 8 * ExponentialRandomPositiveInt(0.2f, 1, 9) + : 1; + const int stride = test_param.test_stride ? 2 : 1; // We don't support dilations in the 3x3 filter. const int dilation_width_factor = 1; const int dilation_height_factor = 1; - const auto padding_type = test_pad ? PaddingType::kSame : PaddingType::kValid; + const auto padding_type = + test_param.test_pad ? PaddingType::kSame : PaddingType::kValid; return TryTestDepthwiseConv( - forced_invocation, batch, input_depth, input_width, input_height, - filter_width, filter_height, depth_multiplier, stride, + test_param, params_specialization, batch, input_depth, input_width, + input_height, filter_width, filter_height, depth_multiplier, stride, dilation_width_factor, dilation_height_factor, padding_type); } -void TestOneDepthwiseConv(ForceKernelInvocation forced_invocation) { - while (!TryTestOneDepthwiseConv(forced_invocation)) { +void TestOneDepthwiseConv(DepthwiseConvImplementation forced_invocation, + DepthwiseConvOutputRounding output_rounding) { + TestParam test_param; + test_param.forced_invocation = forced_invocation; + test_param.output_rounding = output_rounding; + while (!TryTestOneDepthwiseConv(test_param, ParamsSpecialization::kNone)) { } } -void TestOneDepthwiseConv3x3Filter(ForceKernelInvocation forced_invocation) { - while (!TryTestOneDepthwiseConv3x3Filter(forced_invocation)) { +void TestOneDepthwiseConv3x3Filter( + DepthwiseConvImplementation forced_invocation, + DepthwiseConvOutputRounding output_rounding) { + TestParam test_param; + test_param.forced_invocation = forced_invocation; + test_param.output_rounding = output_rounding; + while (!TryTestOneDepthwiseConv3x3Filter(test_param, + ParamsSpecialization::kNone)) { } } -void TestOneNeonDot3x3(ForceKernelInvocation forced_invocation, - bool test_stride, bool test_pad, - bool test_depth_multiplier) { - while (!TryTestOneNeonDot3x3(forced_invocation, test_stride, test_pad, - test_depth_multiplier)) { +void TestOneNeonDot3x3(const TestParam& test_param) { + while (!TryTestOneNeonDot3x3(test_param, ParamsSpecialization::kSymmetric)) { } } TEST(TestDepthwiseConv, TestDepthwiseConv) { const int kTestsToRun = 10 * 1000; for (int i = 0; i < kTestsToRun; i++) { - TestOneDepthwiseConv(ForceKernelInvocation::kNone); + TestOneDepthwiseConv(DepthwiseConvImplementation::kNone, + DepthwiseConvOutputRounding::kAwayFromZero); } } @@ -476,69 +692,141 @@ TEST(TestDepthwiseConv, TestDepthwiseConv) { TEST(TestDepthwiseConv, TestGenericKernel) { const int kTestsToRun = 10 * 1000; for (int i = 0; i < kTestsToRun; i++) { - TestOneDepthwiseConv(ForceKernelInvocation::kUseGenericKernel); + TestOneDepthwiseConv(DepthwiseConvImplementation::kUseGenericKernel, + DepthwiseConvOutputRounding::kAwayFromZero); } } +#if defined(__aarch64__) && !defined(GOOGLE_L4T) TEST(TestDepthwiseConv, TestKernel3x3Filter) { const int kTestsToRun = 1000; for (int i = 0; i < kTestsToRun; i++) { - TestOneDepthwiseConv3x3Filter(ForceKernelInvocation::kNone); + TestOneDepthwiseConv3x3Filter(DepthwiseConvImplementation::kUseNeon3x3, + DepthwiseConvOutputRounding::kAwayFromZero); } } +#endif -// While the 3x3 coverage test is primarily targeted at specialized kernels, we -// also run it against the generic kernel, optionally with fewer invocations. +// While 3x3 coverage tests are primarily targeted at specialized kernels, we +// also run it against the generic kernel. TEST(TestDepthwiseConv, TestGenericKernel3x3Filter) { - const int kTestsToRun = 1000; + const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { - TestOneDepthwiseConv3x3Filter(ForceKernelInvocation::kUseGenericKernel); + TestOneDepthwiseConv3x3Filter( + DepthwiseConvImplementation::kUseGenericKernel, + DepthwiseConvOutputRounding::kAwayFromZero); } } +#if defined(__aarch64__) && !defined(GOOGLE_L4T) TEST(TestDepthwiseConv, TestNeon3x3Filter) { const int kTestsToRun = 3 * 1000; for (int i = 0; i < kTestsToRun; i++) { - TestOneDepthwiseConv3x3Filter(ForceKernelInvocation::kUseNeon3x3); + TestOneDepthwiseConv3x3Filter(DepthwiseConvImplementation::kUseNeon3x3, + DepthwiseConvOutputRounding::kAwayFromZero); } } +#endif -// No stride, no depth multiplier, no pad. -TEST(TestDepthwiseConv, TestNeonDot3x3Plain) { - const int kTestsToRun = 3 * 1000; - for (int i = 0; i < kTestsToRun; i++) { - TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct, - /*test_stride=*/false, /*test_pad=*/false, - /*test_depth_multiplier=*/false); - } -} +class DepthwiseConvTest : public ::testing::TestWithParam {}; -TEST(TestDepthwiseConv, TestNeonDot3x3DepthMultiplier) { - const int kTestsToRun = 3 * 1000; - for (int i = 0; i < kTestsToRun; i++) { - TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct, - /*test_stride=*/false, /*test_pad=*/false, - /*test_depth_multiplier=*/true); +TEST_P(DepthwiseConvTest, NeonDot3x3) { + const TestParam param(GetParam()); + for (int i = 0; i < param.tests_to_run; i++) { + TestOneNeonDot3x3(param); } } -TEST(TestDepthwiseConv, TestNeonDot3x3Stride2) { - const int kTestsToRun = 3 * 1000; - for (int i = 0; i < kTestsToRun; i++) { - TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct, - /*test_stride=*/true, /*test_pad=*/false, - /*test_depth_multiplier=*/false); - } -} +#if defined(__aarch64__) && !defined(GOOGLE_L4T) +INSTANTIATE_TEST_SUITE_P( + Neon3x3Kernel, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation::kUseNeon3x3), // forced_invocation + Values(1000), // tests_to_run + Bool(), // test_stride + Values(false), // test_pad + Values(false), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kAwayFromZero), // output_rounding + Values(false) // loose_tolerance + ), + TestParam::TestNameSuffix); +#endif -TEST(TestDepthwiseConv, TestNeonDot3x3Pad1) { - const int kTestsToRun = 3 * 1000; - for (int i = 0; i < kTestsToRun; i++) { - TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct, - /*test_stride=*/false, /*test_pad=*/true, - /*test_depth_multiplier=*/false); - } -} +// While 3x3 coverage tests are primarily targeted at specialized kernels, we +// also run it against the generic kernel. +INSTANTIATE_TEST_SUITE_P( + GenericKernel, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation:: + kUseGenericKernel), // forced_invocation + Values(100), // tests_to_run + Bool(), // test_stride + Bool(), // test_pad + Bool(), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kAwayFromZero), // output_rounding + Values(false) // loose_tolerance + ), + TestParam::TestNameSuffix); + +INSTANTIATE_TEST_SUITE_P( + CModel, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation:: + kUseCModel3x3DotProduct), // forced_invocation + Values(1000), // tests_to_run + Bool(), // test_stride + Bool(), // test_pad + Bool(), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kUpward), // output_rounding + Values(false) // loose_tolerance + ), + TestParam::TestNameSuffix); + +INSTANTIATE_TEST_SUITE_P( + Unwound, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation:: + kUseUnwound3x3DotProduct), // forced_invocation + Values(1000), // tests_to_run + Bool(), // test_stride + Bool(), // test_pad + Bool(), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kUpward), // output_rounding + Values(false) // loose_tolerance + ), + TestParam::TestNameSuffix); + +#if defined(USE_NEON) +INSTANTIATE_TEST_SUITE_P( + Intrinsics, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation:: + kUseIntrinsics3x3DotProduct), // forced_invocation + Values(1000), // tests_to_run + Bool(), // test_stride + Bool(), // test_pad + Bool(), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kUpward), // output_rounding + Values(kLooseIntrinsicsTolerance) // loose_tolerance + ), + TestParam::TestNameSuffix); +#endif + +#if defined(__ARM_FEATURE_DOTPROD) && defined(__aarch64__) +INSTANTIATE_TEST_SUITE_P( + NeonAsm, DepthwiseConvTest, + testing::Combine( + Values(DepthwiseConvImplementation:: + kUseNeon3x3DotProduct), // forced_invocation + Values(1000), // tests_to_run + Bool(), // test_stride + Bool(), // test_pad + Bool(), // test_depth_multiplier + Values(DepthwiseConvOutputRounding::kUpward), // output_rounding + Values(false) // loose_tolerance + ), + TestParam::TestNameSuffix); +#endif } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/log_quantized_test.cc b/tensorflow/lite/kernels/internal/log_quantized_test.cc index 8c39350ab1dd8996799e6539755f040399974106..c31c8e307751bcf1030e121eec23ac6cb217f461 100644 --- a/tensorflow/lite/kernels/internal/log_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/log_quantized_test.cc @@ -121,8 +121,7 @@ void RunSingleTest(const std::vector& test_input, const string& check_label, int tolerance) { const int n = test_input.size(); std::vector float_gen_output(n, 0); - std::vector reference_output(n, 0); - std::vector optimized_output(n, 0); + std::vector quantized_output(n, 0); // Workaround the stupid things that intelligent humans do. // Consequence of __builtin_clz(0u) may equal 31 instead of 32. @@ -132,45 +131,21 @@ void RunSingleTest(const std::vector& test_input, } for (int i = 0; i < n; ++i) { - reference_output[i] = - tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl< - OutputIntegerBits, InputIntegerBits>( - gemmlowp::FixedPoint::FromRaw( - fudged_input[i])) - .raw(); - optimized_output[i] = - tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl< - OutputIntegerBits, InputIntegerBits>( + quantized_output[i] = + tflite::log_x_for_x_greater_than_or_equal_to_1_impl( gemmlowp::FixedPoint::FromRaw( fudged_input[i])) .raw(); float_gen_output[i] = LogPositiveValuesViaFloat( fudged_input[i], InputIntegerBits, OutputIntegerBits); } - // Note that first check is intolerant. - { - std::ostringstream label; - label << check_label << " / optimized vs reference / InputIntegerBits=" - << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; - CheckOutputData( - optimized_output, reference_output, test_input, label.str(), - InputIntegerBits, OutputIntegerBits, 0); - } { std::ostringstream label; label << check_label << " / reference vs float-gen / InputIntegerBits=" << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; - CheckOutputData( - reference_output, float_gen_output, test_input, label.str(), - InputIntegerBits, OutputIntegerBits, tolerance); - } - { - std::ostringstream label; - label << check_label << " optimized vs float-gen / InputIntegerBits=" - << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; - CheckOutputData( - optimized_output, float_gen_output, test_input, label.str(), - InputIntegerBits, OutputIntegerBits, tolerance); + CheckOutputData(quantized_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); } } diff --git a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc index 945300dad1653257db69c3440f6db0589e0c1a7b..d0d2654d4123e5025d000a796907f675ca29b05c 100644 --- a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ limitations under the License. #include #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/test_util.h" #include "tensorflow/lite/string.h" @@ -61,7 +63,42 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, } } -void CheckOutputData(const uint8* test_output, const uint8* reference_output, +// Same as above except for the following change: +// - input and output data type +// - Dequnatize function +// - clamping values +void RunLogSoftmaxFloatReference(const int8* input_data, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + int8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float LogSoftmax. + DequantizationParams dq_params; + dq_params.zero_point = input_offset; + dq_params.scale = input_scale; + reference_integer_ops::Dequantize(dq_params, shape_common, input_data, + shape_common, + reference_dequant_data.data()); + SoftmaxParams sm_params; + optimized_ops::LogSoftmax(sm_params, shape_common, + reference_dequant_data.data(), shape_common, + reference_output_float_data.data()); + // Work with quantized scaling for LogSoftmax, under which 255 represents 0, + // and -16 gets nudged up to 0. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::max( + -128, static_cast( + 127 + std::round(16.0f * reference_output_float_data[i]))); + } +} + +template +void CheckOutputData(const T* test_output, const T* reference_output, const RuntimeShape& shape_common, const string& check_label, bool be_exacting) { const int buffer_size = shape_common.FlatSize(); @@ -144,15 +181,58 @@ void RunOneLogSoftmaxTest(const uint8* input_data, reference_ops::LogSoftmax(params, shape_common, input_data, shape_common, reference_quant_logsoftmax_output.data()); - CheckOutputData(optimized_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), shape_common, - "Optimized vs float reference", false); - CheckOutputData(optimized_logsoftmax_output.data(), - reference_quant_logsoftmax_output.data(), shape_common, - "Optimized vs quant reference", true); - CheckOutputData(reference_quant_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), shape_common, - "Quant reference vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), + shape_common, "Optimized vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_quant_logsoftmax_output.data(), + shape_common, "Optimized vs quant reference", true); + CheckOutputData(reference_quant_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), + shape_common, "Quant reference vs float reference", + false); +} + +// Runs the LogSoftmax and compares against the float reference implementation +// and the int8 quantized reference implementation. +void RunOneLogSoftmaxTest(const int8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); + std::vector quantized_logsoftmax_reference_implementation(buffer_size); + std::vector float_logsoftmax_optimized_implementation(buffer_size); + + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, + input_scale, stride, beta, + float_logsoftmax_optimized_implementation.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + int32 reverse_scaling_divisor; + int reverse_scaling_right_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, + &input_beta_left_shift, &reverse_scaling_divisor, + &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + const int outer_size = + shape_common.Dims(0) * shape_common.Dims(1) * shape_common.Dims(2); + const int inner_size = shape_common.Dims(3); + reference_integer_ops::LogSoftmax( + input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, outer_size, inner_size, input_data, + quantized_logsoftmax_reference_implementation.data()); + + CheckOutputData(quantized_logsoftmax_reference_implementation.data(), + float_logsoftmax_optimized_implementation.data(), + shape_common, "Quant reference vs float reference", + false); } // This function picks some random LogSoftmax params, which are checked for @@ -161,6 +241,7 @@ void RunOneLogSoftmaxTest(const uint8* input_data, // to loop until a test has been run. // // Currently we do not reject for any reason. +template bool TryOneUniformLogSoftmax() { // We pick mostly positive values, on the whole emphasizing smaller values and // therefore faster tests. We test a wider range of depths. In the case of @@ -178,7 +259,7 @@ bool TryOneUniformLogSoftmax() { RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); @@ -224,15 +305,23 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { return true; } -TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) { +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Tests) { + const int kTestsToRun = 100; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformLogSoftmax()) { + } + } +} + +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxUint8Int8Tests) { const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { - while (!TryOneUniformLogSoftmax()) { + while (!TryOneUniformLogSoftmax()) { } } } -TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { +TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxUint8Tests) { const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { while (!TryOneSkyscraperLogSoftmax(false)) { @@ -240,7 +329,7 @@ TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { } } -TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) { +TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxUint8Tests) { const int kTestsToRun = 100; for (int i = 0; i < kTestsToRun; i++) { while (!TryOneSkyscraperLogSoftmax(true)) { diff --git a/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc b/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc index 9748da39862edd7565fdb2bcce2ce92b9d767429..4f22517866eba964e5c51406022c377951c22252 100644 --- a/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc +++ b/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc @@ -197,7 +197,7 @@ void MfccMelFilterbank::Compute(const std::vector &input, } double MfccMelFilterbank::FreqToMel(double freq) const { - return 1127.0 * log(1.0 + (freq / 700.0)); + return 1127.0 * log1p(freq / 700.0); } } // namespace internal diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index d3dca799a7cca4a3048cd2d19477ba2b57fbcdac..d1a9d65aae825796ac390bd4682874ec23990a0c 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -19,10 +19,12 @@ limitations under the License. #include "public/gemmlowp.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { +namespace depthwise_conv { // Implementation of quantized DepthwiseConv @@ -1945,7 +1947,10 @@ inline void DepthwiseConvGeneral( } } -inline void DepthwiseConv( +} // namespace depthwise_conv + +template +inline void DepthwiseConvWithRounding( const DepthwiseParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, @@ -1979,20 +1984,34 @@ inline void DepthwiseConv( // Call kernel optimized for depthwise convolutions using 3x3 filters if // parameters are supported. - if (Fast3x3FilterKernelSupported( + if (depthwise_conv::Fast3x3FilterKernelSupported( input_shape, filter_shape, stride_width, stride_height, dilation_width_factor, dilation_height_factor, pad_width, pad_height, depth_multiplier, output_shape, output_shift)) { - DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape, - filter_data, bias_shape, bias_data, output_shape, - output_data); + gemmlowp::ScopedProfilingLabel specialized_label("DepthwiseConv/8bit/3x3"); + depthwise_conv::DepthwiseConv3x3Filter( + params, input_shape, input_data, filter_shape, filter_data, bias_shape, + bias_data, output_shape, output_data); return; } #endif - DepthwiseConvGeneral(params, input_shape, input_data, filter_shape, - filter_data, bias_shape, bias_data, output_shape, - output_data); + gemmlowp::ScopedProfilingLabel specialized_label( + "DepthwiseConv/8bit/General"); + depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data, + filter_shape, filter_data, bias_shape, + bias_data, output_shape, output_data); +} + +inline void DepthwiseConv( + const DepthwiseParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data) { + return DepthwiseConvWithRounding( + params, input_shape, input_data, filter_shape, filter_data, bias_shape, + bias_data, output_shape, output_data); } } // namespace optimized_ops diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 5859bcaed4ac2b991ca22e7d9c17d34d3267a120..50de905db1754db1c98b3dd721a427134dfaea3b 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,58 +15,263 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ +#include + #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { +namespace depthwise_conv { + +constexpr int kDepthwiseConvScratchWorkspaceSize = 10 * 10 * 64; +constexpr int kDepthwiseConvAdjustedBiasLimit = 64; +// In cases such as depth multiplication, we want to be able to load data from +// the workspace that is beyond the valid range. Macro-block sizes are adjusted +// to allow for this. +constexpr int kWorkspaceExtension = 16; + +#ifdef USE_NEON +// Lane operations are for clarity and convenience. We want to load and store +// 4 8-bit lanes together. So these are treated much like 32-bit loads and +// 32-bit stores. Stores require 32-bit alignment. + +#define vst1_lane_8x4(dst, reg, lane_num) \ + TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ + vst1_lane_u32(reinterpret_cast(dst), reg, lane_num) +#define vst1q_lane_8x4(dst, reg, lane_num) \ + TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ + vst1q_lane_u32(reinterpret_cast(dst), reg, lane_num) + +#define vld1q_lane_s8x8(src, reg, lane_num) \ + vld1q_lane_u64(reinterpret_cast(src), reg, lane_num) +#define vld1_lane_8x4(src, reg, lane_num) \ + vld1_lane_s32(reinterpret_cast(src), reg, lane_num) +#define vld1q_lane_8x4(src, reg, lane_num) \ + vld1q_lane_s32(reinterpret_cast(src), reg, lane_num) +#define vld1q_dup_s8x4(src) vld1q_dup_s32(reinterpret_cast(src)) + +#ifndef __aarch64__ +inline int8x16_t vqtbl4q_s8(int8x16x4_t a, uint8x16_t b) { + const uint8x16_t mask = vtstq_u8(b, vdupq_n_u8(8)); + + // Delete bit 3 from the indices. + const uint8x16_t high_bits = vshrq_n_u8(b, 4); + uint8x16_t deleted_bit_3 = b; + deleted_bit_3 = vsliq_n_u8(deleted_bit_3, high_bits, 3); + + int8x8x4_t repacked_data; + + // Calculate for lower indices. + repacked_data.val[0] = vget_low_u8(a.val[0]); + repacked_data.val[1] = vget_low_u8(a.val[1]); + repacked_data.val[2] = vget_low_u8(a.val[2]); + repacked_data.val[3] = vget_low_u8(a.val[3]); + const int8x16_t output_for_lower = + vcombine_u8(vtbl4_s8(repacked_data, vget_low_u8(deleted_bit_3)), + vtbl4_s8(repacked_data, vget_high_u8(deleted_bit_3))); + + // Calculate for high indices. + repacked_data.val[0] = vget_high_u8(a.val[0]); + repacked_data.val[1] = vget_high_u8(a.val[1]); + repacked_data.val[2] = vget_high_u8(a.val[2]); + repacked_data.val[3] = vget_high_u8(a.val[3]); + const int8x16_t output_for_higher = + vcombine_u8(vtbl4_s8(repacked_data, vget_low_u8(deleted_bit_3)), + vtbl4_s8(repacked_data, vget_high_u8(deleted_bit_3))); + + // Merge. + int8x16_t output = mask; + output = vbslq_u8(output, output_for_higher, output_for_lower); + return output; +} +#endif // !__aarch64__ + +// Convenience-compatibility functions. +// Compatibility: Intrinsics reflect a mixture of older and newer ARM +// instructions. This actually results in ZIP1 / ZIP2 asm instructions, but +// one intrinsic is provided. Also older instructions operated in place, +// and it seems more defensive to assume that some versions of intrinsics +// might reflect this +// Convenience: Callers in these kernels want both ZIP1 and ZIP2, and we do not +// want the calling code to get cluttered with unpacking int8x16x2_t. +inline void vzipq_s8_in_place(int8x16_t* a, int8x16_t* b) { + int8x16x2_t r8x16; + r8x16 = vzipq_s8(*a, *b); + *a = r8x16.val[0]; + *b = r8x16.val[1]; +} + +inline void vzipq_s8x2_in_place(int8x16_t* a, int8x16_t* b) { + int16x8x2_t r16x8; + r16x8 = vzipq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b)); + *a = vreinterpretq_s8_s16(r16x8.val[0]); + *b = vreinterpretq_s8_s16(r16x8.val[1]); +} + +// Similar rationale to the zip-in_place functions, but callers only actually +// need the TRN1 asm instruction result. +inline void vtrn1_s8x2_in_place(int8x16_t* a, int8x16_t* b) { + int16x8x2_t r16x8; + r16x8 = vtrnq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b)); + *a = vreinterpretq_s8_s16(r16x8.val[0]); +} + +inline void biregister_rotate_8(int8x16_t* left, int8x16_t* right) { + *left = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*left), 8)); + *left = vreinterpretq_s8_u32(vsliq_n_u32(vreinterpretq_u32_s8(*left), + vreinterpretq_u32_s8(*right), 24)); + *right = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*right), 8)); +} + +#ifndef __aarch64__ +inline int32x4_t vpaddq_s32(int32x4_t a, int8x16_t b) { + int32x4x2_t deinterleaved = vuzpq_s32(a, b); + return vqaddq_s32(deinterleaved.val[0], deinterleaved.val[1]); +} +#endif // !__aarch64__ + +#ifdef __ARM_FEATURE_DOTPROD +// The vdotq_lane_s32 takes int8x8t for the rhs parameter, whereas the actual +// instruction selects from between 4 32-bit (4x8-bit packed) sub-registers, an +// unusual interpretation of "lane". +inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs, + int8x16_t rhs, const int lane) { + switch (lane) { + case 0: + return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_low_s8(rhs)), 0); + case 1: + return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_low_s8(rhs)), 1); + case 2: + return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_high_s8(rhs)), + 0); + case 3: + default: + return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_high_s8(rhs)), + 1); + } +} + +#else + +inline int32x4_t vdotq_s32(int32x4_t acc, int8x16_t lhs, int8x16_t rhs) { + int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), vget_low_s8(rhs))); + int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), vget_high_s8(rhs))); + int32x4_t sum = vpaddq_s32(sum0, sum1); + return vaddq_s32(acc, sum); +} + +inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs, + int8x16_t rhs, int lane) { + int8x8_t lane_rhs; + if (lane == 0) { + lane_rhs = vreinterpret_s8_s32( + vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 0)); + } else if (lane == 1) { + lane_rhs = vreinterpret_s8_s32( + vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 1)); + } else if (lane == 2) { + lane_rhs = vreinterpret_s8_s32( + vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 0)); + } else { + lane_rhs = vreinterpret_s8_s32( + vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 1)); + } + int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), lane_rhs)); + int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), lane_rhs)); + int32x4_t sum = vpaddq_s32(sum0, sum1); + return vaddq_s32(acc, sum); +} + +#endif // !DOTPROD +#endif // ARM NEON + +template +struct DivideByPOT {}; + +template <> +struct DivideByPOT { + template + static inline IntegerType Run(IntegerType x, int exponent) { + return RoundingDivideByPOT(x, exponent); + } +}; + +#ifdef USE_NEON +template <> +struct DivideByPOT { + template + static inline IntegerType Run(IntegerType x, int exponent) { + return vqrshlq_s32(x, vdupq_n_s32(static_cast(-exponent))); + } +}; +#endif // ARM NEON // See CategorizeDotProductKernel for definitive taxonomy. enum class DotProduct3x3KernelType { kNone = 0, // Parameter combination is not supported for dot product kernels. kPlain, - kWithDepthMultiplication, - kWithPad0Stride2, - kWithPad1Stride1, + kWithDepthMultiplicationStride1, + kWithDepthMultiplicationStride2, + kStride2, }; inline DotProduct3x3KernelType CategorizeDotProductKernel( + const RuntimeShape& input_shape, const RuntimeShape& filter_shape, const DepthwiseParams& params) { - const int padding = params.padding_values.width; + constexpr int kSymmetricZeroPoint = 128; + const int padding = + std::max(params.padding_values.width, params.padding_values.height); const int stride = params.stride_width; - if (padding != params.padding_values.height || - stride != params.stride_height) { + const int32 input_depth = input_shape.Dims(3); + const int32 depth_multiplier = params.depth_multiplier; + const int32 filter_height = filter_shape.Dims(1); + const int32 filter_width = filter_shape.Dims(2); + + bool supported = + params.weights_offset == -kSymmetricZeroPoint && + stride == params.stride_height && stride <= 2 && padding <= 1 && + filter_width == 3 && filter_height == 3 && params.output_shift <= 0 && + params.dilation_width_factor == 1 && params.dilation_height_factor == 1 && + (((input_depth % 8) == 0 && depth_multiplier == 1) || + (input_depth == 1 && depth_multiplier > 1)); + + if (!supported) { return DotProduct3x3KernelType::kNone; } if (params.depth_multiplier == 1) { - if (padding == 0 && stride == 1) { + if (stride == 1) { return DotProduct3x3KernelType::kPlain; - } else if (padding == 0 && stride == 2) { - return DotProduct3x3KernelType::kWithPad0Stride2; - } else if (padding == 1 && stride == 1) { - return DotProduct3x3KernelType::kWithPad1Stride1; + } else if (stride == 2) { + return DotProduct3x3KernelType::kStride2; } else { return DotProduct3x3KernelType::kNone; } } else { - if (padding == 0 && stride == 1) { - return DotProduct3x3KernelType::kWithDepthMultiplication; + if (stride == 1) { + return DotProduct3x3KernelType::kWithDepthMultiplicationStride1; + } else if (stride == 2) { + return DotProduct3x3KernelType::kWithDepthMultiplicationStride2; } else { return DotProduct3x3KernelType::kNone; } } } +#ifdef USE_NEON + +#define STR(s) STR_UNEXPANDED(s) +#define STR_UNEXPANDED(s) #s + // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on // Jetson TX-2. This compiler does not support the offsetof() macro. #if defined(__aarch64__) && !defined(GOOGLE_L4T) #include -#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 - // Encapsulates constant parameters used in DepthwiseConv. // 64-bit is used for types that will be added to 64-bit addresses in asm. struct DepthwiseConvParams { @@ -90,9 +295,6 @@ struct DepthwiseConvParams { int32 output_height; }; -#define STR(s) STR_UNEXPANDED(s) -#define STR_UNEXPANDED(s) #s - // Represents the number of bytes offset from the start of the // DepthwiseConvParams struct. This is used in the asm to load parameters. // Keep these values in sync with the static_asserts below. @@ -167,7 +369,49 @@ static_assert(offsetof(DepthwiseConvParams, output_width) == static_assert(offsetof(DepthwiseConvParams, output_height) == OFFSET_OUTPUT_HEIGHT, ""); +#endif // __aarch64__ +#endif // ARM NEON + +// Encapsulates constant parameters used in DepthwiseConv using dot-product ops. +// 64-bit is used for types that will be added to 64-bit addresses in asm. +// +// This structure is specifically designed for use in asm. +struct DepthwiseConvDotProdParams { + int64_t input_depth; + int64_t output_depth; + int32 workspace_height_stride; + int32 input_width_overall_micro_repeats; + int32 input_width_micro_repeats; + int32 depth_micro_repeats; + int32 inbound_block_height; + int32 residual_width; + int32 input_height_stride; + int32 stride; + int32 output_width_overall_micro_repeats; + int32 output_width_micro_repeats; + int32 output_residual_width; + int32 output_height_stride; + int32 bias_increment; + int32 padding_left; + int32 padding_right; + int32 padding_top; + int32 padding_bottom; + int32 height_macro_count; + int32 width_macro_count; + int32 outbound_block_height; + int32 workspace_width_micro_repeats; + int32 input_offset; + int32 output_offset; + int32 output_multiplier; + int32 output_shift; + int32 quantized_activation_min; + int32 quantized_activation_max; + int32 four_over_stride; +}; + +#ifdef USE_NEON +#if defined(__aarch64__) && !defined(GOOGLE_L4T) template struct DepthwiseConvWindow {}; @@ -229,7 +473,8 @@ struct DepthwiseConvWindow<8, 1, 1> { // Set "constant" registers. These registers may be replaced with temp // values from time to time when there are not enough NEON registers. // We use x9--x15 general purpose registers as they are caller-saved - // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + // temporary registers (see + // http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" "cmp %w[output_window_height], #2\n" @@ -2963,8 +3208,6 @@ struct DepthwiseConvPartial { #undef OFFSET_INPUT_HEIGHT #undef OFFSET_OUTPUT_WIDTH #undef OFFSET_OUTPUT_HEIGHT -#undef STR -#undef STR_UNEXPANDED // Copies a subset of the input designated by |input_ptr| into |output_ptr| // with the specified output dimensions. Supports output depths of 64 only as @@ -3045,9 +3288,9 @@ struct DepthwiseConvMultiRow { TFLITE_DCHECK( shuffle_params.input_width == get_shuffle_input_size(kStrideWidth, shuffle_params.output_width)); - TFLITE_DCHECK(64 * shuffle_params.input_width * - shuffle_params.input_height <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE); + TFLITE_DCHECK_LE( + 64 * shuffle_params.input_width * shuffle_params.input_height, + kDepthwiseConvScratchWorkspaceSize); int32 out_x = start_x; @@ -3375,7 +3618,7 @@ inline void DepthwiseConv3x3Filter( // allocated on the stack. Eventually we will want to move it to the heap // and have it allocated outside of this function, like the im2col_array // used in gemmlowp. - uint8 shuffle_workspace[DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE]; + uint8 shuffle_workspace[kDepthwiseConvScratchWorkspaceSize]; for (int32 b = 0; b < batches; ++b) { const uint8* input_ptr = input_data + b * input_batch_size; @@ -3454,9 +3697,3840 @@ inline void DepthwiseConv3x3Filter( } } } - #endif // __aarch64__ +#endif + +// Permute filter data, and adjust bias data to account for symmetric input +// offset. Details are provided in the implementation of the +// kUseCModel3x3DotProduct version. +// +// See the comments preceding DepthwiseConvDotProduct3x3() for further notes. +template +struct ProcessPerDepth { + // Routine is contained in a static Run() method. No default template version + // is supplied, so that all implementations are deliberate choices of template + // specialization. + // + // Note that the signature of the Run() method will be designed for the asm + // implementation rather than conforming to style. +}; + +// Copy a macro block of data from the input buffer into the workspace, +// permuting data within each micro block. +// +// (a) Copy a macro block of data, padding as required along the width and +// height. +// (b) Transpose the data within each micro block. +// +// See the comments preceding DepthwiseConvDotProduct3x3() for further notes. +template +struct PackMacroBlock { + // Routine is contained in a static Run() method. No default template version + // is supplied, so that all implementations are deliberate choices of template + // specialization. + // + // Note that the signature of the Run() method will be designed for the asm + // implementation rather than conforming to style. +}; + +// Apply filter to macro block of input data and store results. Details are +// provided in the implementation of the kUseCModel3x3DotProduct version. +// +// Parameters for repeats and residual sizes are in terms of outputs. +// +// See the comments preceding DepthwiseConvDotProduct3x3() for further notes. +template +struct KernelMacroBlock { + // Routine is contained in a static Run() method. No default template version + // is supplied, so that all implementations are deliberate choices of template + // specialization. + // + // Note that the signature of the Run() method will be designed for the asm + // implementation rather than conforming to style. +}; + +#if defined(USE_NEON) && defined(__aarch64__) +// Experiments suggest that a modest performance improvement is seen, at least +// on 855 chipset big cores, with cache hints. +inline void PreloadInputBlock( + const uint8* input_block_data, + const DepthwiseConvDotProdParams* function_params) { + // Preload. + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + { + const int total_width = 4 * input_width_micro_repeats + residual_width; + const uint8* row_ptr = input_block_data; + for (int k_height = 0; k_height < block_height; ++k_height) { + const uint8* ptr = row_ptr; + for (int j = 0; j < total_width; ++j) { + // Input data is loaded once. + asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += input_depth; + } + row_ptr += input_height_stride; + } + } +} + +template <> +struct ProcessPerDepth { + static void ProcessPerDepthIntrinsics( + const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + const int depth = function_params->output_depth; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int bias_increment = function_params->bias_increment; + + constexpr int kSymmetricZeroPoint = 128; + constexpr uint8 kSignBit = 0x80; + const int32 input_offset = function_params->input_offset; + TFLITE_DCHECK_GE(input_offset, -255); + TFLITE_DCHECK_LE(input_offset, 0); + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int8x16_t ones_vector = vdupq_n_s8(1); + + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + + // Register pairs for each height. + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + const uint8* filter_block = filter_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Filter data is provided as filter_block[3][3][depth/8][2][4]. + // height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4. + // filter_bank[3][2][4][4]; Sub-block, height 3, depth 4, width 4. + + // Load zero-point into effective position of zero-padding of filter + // (register B, upper part). + filter_reg_0_b = vdupq_n_u8(kSignBit); + filter_reg_1_b = vdupq_n_u8(kSignBit); + filter_reg_2_b = vdupq_n_u8(kSignBit); + + const uint8* filter_block_ptr = filter_block; + filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 0); + filter_block_ptr += depth; + filter_reg_0_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_b, 0); + filter_block_ptr += depth; + filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 1); + filter_block_ptr += depth; + filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 0); + filter_block_ptr += depth; + filter_reg_1_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_b, 0); + filter_block_ptr += depth; + filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 1); + filter_block_ptr += depth; + filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 0); + filter_block_ptr += depth; + filter_reg_2_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_b, 0); + filter_block_ptr += depth; + filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 1); + + filter_reg_0_a = veorq_s8(filter_reg_0_a, sign_bit); + filter_reg_0_b = veorq_s8(filter_reg_0_b, sign_bit); + filter_reg_1_a = veorq_s8(filter_reg_1_a, sign_bit); + filter_reg_1_b = veorq_s8(filter_reg_1_b, sign_bit); + filter_reg_2_a = veorq_s8(filter_reg_2_a, sign_bit); + filter_reg_2_b = veorq_s8(filter_reg_2_b, sign_bit); + + vzipq_s8_in_place(&filter_reg_0_a, &filter_reg_0_b); + vzipq_s8_in_place(&filter_reg_1_a, &filter_reg_1_b); + vzipq_s8_in_place(&filter_reg_2_a, &filter_reg_2_b); + vzipq_s8x2_in_place(&filter_reg_0_a, &filter_reg_0_b); + vzipq_s8x2_in_place(&filter_reg_1_a, &filter_reg_1_b); + vzipq_s8x2_in_place(&filter_reg_2_a, &filter_reg_2_b); + + vst1q_s8(shuffled_filter_data, filter_reg_0_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_0_b); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_1_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_1_b); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_2_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_2_b); + shuffled_filter_data += 16; + + int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data); + bias_data += bias_increment; + int32x4_t adjusted_bias_data_b = vld1q_s32(bias_data); + bias_data += bias_increment; + // For instance, if input_offset == 128, no adjustment is needed. + + int32x4_t filter_sum_a = vdupq_n_s32(0); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_0_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_1_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_2_a, ones_vector); + int32x4_t filter_sum_b = vdupq_n_s32(0); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_0_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_1_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_2_b, ones_vector); + + adjusted_bias_data_a = vmlaq_n_s32(adjusted_bias_data_a, filter_sum_a, + input_offset_difference); + adjusted_bias_data_b = vmlaq_n_s32(adjusted_bias_data_b, filter_sum_b, + input_offset_difference); + + vst1q_s32(adjusted_bias_data, adjusted_bias_data_a); + adjusted_bias_data += 4; + vst1q_s32(adjusted_bias_data, adjusted_bias_data_b); + adjusted_bias_data += 4; + + filter_block += 8; + } + } + + static inline void Run(const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + ProcessPerDepthIntrinsics(filter_data, bias_data, shuffled_filter_data, + adjusted_bias_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); + TFLITE_DCHECK_EQ(function_params->padding_top, 0); + TFLITE_DCHECK_EQ(function_params->padding_left, 0); + TFLITE_DCHECK_EQ(function_params->padding_right, 0); + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + static const uint8 perm_data[64] = { + 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, // + 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, + 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, + 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63}; + + TFLITE_DCHECK_GE(depth_micro_repeats, 0); + constexpr uint8 kSignBit = 0x80; + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg_a; + int8x16_t work_reg_b; + const int8x16_t perm_data_0 = vld1q_u8(perm_data); + const int8x16_t perm_data_1 = vld1q_u8(perm_data + 16); + const int8x16_t perm_data_2 = vld1q_u8(perm_data + 32); + const int8x16_t perm_data_3 = vld1q_u8(perm_data + 48); + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + // Work through one slice, by row, at a time. + int8* scratch_data_0 = scratch_block_data; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const uint8* input_data_0 = input_block_data; + const uint8* input_data_1 = input_block_data + input_depth; + const uint8* input_data_2 = input_block_data + 2 * input_depth; + const uint8* input_data_3 = input_block_data + 3 * input_depth; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < input_width_micro_repeats; ++j_width) { + int i_depth = 0; + for (; i_depth < depth_micro_repeats - 1; i_depth += 2) { + int8x16x4_t input_data; + input_data.val[0] = vld1q_u8(input_data_0); + input_data.val[1] = vld1q_u8(input_data_1); + input_data.val[2] = vld1q_u8(input_data_2); + input_data.val[3] = vld1q_u8(input_data_3); + input_data_1 += 16; + input_data_0 += 16; + + int8x16_t tmp_0 = vqtbl4q_s8(input_data, perm_data_0); + int8x16_t tmp_1 = vqtbl4q_s8(input_data, perm_data_1); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_2 += 16; + input_data_3 += 16; + + tmp_0 = vqtbl4q_s8(input_data, perm_data_2); + tmp_1 = vqtbl4q_s8(input_data, perm_data_3); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + } + for (; i_depth < depth_micro_repeats; ++i_depth) { + int8x16x4_t input_data; + input_data.val[0] = + vld1q_lane_s8x8(input_data_0, input_data.val[0], 0); + input_data.val[1] = + vld1q_lane_s8x8(input_data_1, input_data.val[1], 0); + input_data.val[2] = + vld1q_lane_s8x8(input_data_2, input_data.val[2], 0); + input_data.val[3] = + vld1q_lane_s8x8(input_data_3, input_data.val[3], 0); + input_data_1 += 8; + input_data_0 += 8; + + int8x16_t tmp_0 = vqtbl4q_s8(input_data, perm_data_0); + int8x16_t tmp_1 = vqtbl4q_s8(input_data, perm_data_1); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + if (width_overall_micro_repeats > input_width_micro_repeats) { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, + input_width_micro_repeats + 1); + TFLITE_DCHECK_GT(residual_width, 0); + TFLITE_DCHECK_LT(residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(kSignBit); + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + work_reg_b = vdupq_n_u8(kSignBit); + if (residual_width > 1) { + work_reg_b = + vld1q_lane_s8x8(input_data_0 + input_depth, work_reg_b, 0); + if (residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + scratch_data_0 += height_advance; + input_block_data += input_height_stride; + } + TFLITE_DCHECK_EQ( + scratch_data_0, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + constexpr uint8 kSignBit = 0x80; + + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + TFLITE_DCHECK_GT(depth_micro_repeats, 0); + constexpr int kSymmetricZeroPoint = 128; + + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg_a; + int8x16_t work_reg_b; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + // Work through one slice, by row, at a time. + int8* scratch_data_0 = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_0, -input_offset_difference, workspace_height_stride); + scratch_data_0 += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + const uint8* input_data_0 = input_block_data; + const uint8* input_data_1 = input_block_data + input_depth; + const uint8* input_data_2 = input_block_data + 2 * input_depth; + const uint8* input_data_3 = input_block_data + 3 * input_depth; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) { + // Figure out division of work (available input vs zero-ed). + int adjusted_residual_width = + j_width == (input_width_micro_repeats) ? residual_width : 4; + + if (trailing_width_padding && + j_width == (width_overall_micro_repeats - 1)) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding && j_width == 0) { + start_width = 1; + } + if (start_width == 0) { + if (adjusted_residual_width == 4) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vld1q_lane_s8x8(input_data_2, work_reg_a, 1); + work_reg_b = vld1q_lane_s8x8(input_data_3, work_reg_b, 1); + work_reg_b = vld1q_lane_s8x8(input_data_1, work_reg_b, 0); + input_data_1 += 8; + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + input_data_0 += 8; + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + scratch_data_0 += 16; + vst1q_s8(scratch_data_0, work_reg_b); + + scratch_data_0 += depth_advance - 16; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } else { + TFLITE_DCHECK_LT(adjusted_residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + work_reg_b = vdupq_n_u8(-input_offset); + if (adjusted_residual_width > 0) { + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + if (adjusted_residual_width > 1) { + work_reg_b = vld1q_lane_s8x8(input_data_0 + input_depth, + work_reg_b, 0); + if (adjusted_residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + } else { + if (adjusted_residual_width == 4) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + work_reg_a = vld1q_lane_s8x8(input_data_2, work_reg_a, 1); + work_reg_b = vld1q_lane_s8x8(input_data_3, work_reg_b, 1); + work_reg_b = vld1q_lane_s8x8(input_data_1, work_reg_b, 0); + input_data_1 += 8; + // Skip loading first column. + input_data_0 += 8; + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + scratch_data_0 += 16; + vst1q_s8(scratch_data_0, work_reg_b); + + scratch_data_0 += depth_advance - 16; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } else { + TFLITE_DCHECK_LT(adjusted_residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + // Skip loading first column. + work_reg_b = vdupq_n_u8(-input_offset); + if (adjusted_residual_width > 1) { + work_reg_b = + vld1q_lane_s8x8(input_data_0 + input_depth, work_reg_b, 0); + if (adjusted_residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + } + } + scratch_data_0 += height_advance; + input_block_data += input_height_stride; + } + + if (trailing_height_padding) { + memset(scratch_data_0, -input_offset_difference, workspace_height_stride); + scratch_data_0 += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_0, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(height_block_number, width_block_number, + input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + constexpr int kSymmetricZeroPoint = 128; + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + if (trailing_width_padding) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding) { + start_width = 1; + input_block_data += 1; + } + + const int copy_size = (width_overall_micro_repeats - 1) * 4 + + adjusted_residual_width - start_width; + // Adjusted so that later conditionals are simplified. + const int copy_size_adjusted = + trailing_width_padding ? copy_size + 1 : copy_size; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + int scratch_data_offset = 0; + int input_block_offset = 0; + + constexpr uint8 kSignBit = 0x80; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg; + int8x8_t half_work_reg; + int8x8_t padding_mask; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + const uint8x16_t padding_reg = vdupq_n_u8(-input_offset); + padding_mask = vdup_n_s8(-1); + half_work_reg = vdup_n_s8(0); + + if (copy_size >= 16) { + const int copy_remaining = (copy_size + start_width) & 0x7; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + work_reg = vld1q_u8(input_block_data + input_block_offset); + work_reg = vextq_s8(padding_reg, work_reg, 15); + work_reg = veorq_s8(work_reg, sign_bit); + vst1q_s8(scratch_data, work_reg); + copy_done += 15; + } + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + work_reg = + vld1q_u8(input_block_data + input_block_offset + copy_done); + work_reg = veorq_s8(work_reg, sign_bit); + TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0); + vst1q_s8(scratch_data + start_width + copy_done, work_reg); + } + + if (copy_done + 8 <= copy_size) { + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_done); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + copy_done += 8; + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_size - 8); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + } + + // Trailing guard. + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + vst1_s8(scratch_data + start_width + copy_done + 8, half_work_reg); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + const int copy_remaining = (copy_size + start_width) & 0x3; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + half_work_reg = vld1_lane_8x4(input_block_data + input_block_offset, + half_work_reg, 0); + half_work_reg = vext_s8(vget_low_s8(padding_reg), half_work_reg, 7); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + vst1_lane_8x4(scratch_data, half_work_reg, 0); + copy_done += 3; + } + + // Main copy loop. + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + // Important! Most compilation configurations will compile and run + // without the reinterpret_cast. Sanitizers may fail silently on + // lane-loading, with a obscure bug or mis-feature probably in + // unhygienic macro expansion. + half_work_reg = + vld1_lane_8x4(input_block_data + input_block_offset + copy_done, + half_work_reg, 0); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, + 0); + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = vld1_lane_8x4( + input_block_data + input_block_offset + copy_size - 4, + half_work_reg, 0); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining))); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, + 0); + copy_done += 4; + } + // Trailing guard. + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 4, half_work_reg, + 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 8, half_work_reg, + 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 12, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (width_overall_micro_repeats == 2) { + // Special case of 1 + 3 + 1, padding + copy + padding. + // This is rarely executed in practice. + TFLITE_DCHECK_EQ(copy_size, 3); + TFLITE_DCHECK_EQ(start_width, 1); + TFLITE_DCHECK(leading_width_padding); + TFLITE_DCHECK(trailing_width_padding); + // ASM should use MOVI 64-bit set. + padding_mask = vcreate_u64(~0xffffff00L); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + half_work_reg = vld1_lane_s8(reinterpret_cast( + input_block_data + input_block_offset), + half_work_reg, 1); + half_work_reg = + vld1_lane_s8(reinterpret_cast(input_block_data + + input_block_offset + 1), + half_work_reg, 2); + half_work_reg = + vld1_lane_s8(reinterpret_cast(input_block_data + + input_block_offset + 2), + half_work_reg, 3); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 8, 0); + vst1_s8(scratch_data_base + scratch_data_offset, half_work_reg); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + const int copy_remaining = (copy_size + start_width) & 0x3; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + if (leading_width_padding) { + padding_mask = vset_lane_u8(255, padding_mask, 0); + } + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int i = 0; i < copy_size; ++i) { + half_work_reg = vshl_n_u64(half_work_reg, 8); + half_work_reg = vld1_lane_s8( + reinterpret_cast( + input_block_data + input_block_offset + copy_size - 1 - i), + half_work_reg, 0); + } + if (leading_width_padding) { + half_work_reg = vshl_n_s64(half_work_reg, 8); + } + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset, half_work_reg, + 0); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + if (trailing_height_padding) { + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(height_block_number, width_block_number, + input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockNeon( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + TFLITE_DCHECK_EQ(function_params->padding_left, 0); + TFLITE_DCHECK_EQ(function_params->padding_right, 0); + TFLITE_DCHECK_EQ(function_params->padding_top, 0); + TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + const int copy_block_height = block_height; + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + const int copy_size = + (width_overall_micro_repeats - 1) * 4 + adjusted_residual_width; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + int scratch_data_offset = 0; + int input_block_offset = 0; + + constexpr uint8 kSignBit = 0x80; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg; + int8x8_t half_work_reg; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + half_work_reg = vdup_n_s8(0); + + if (copy_size >= 16) { + const int copy_remaining = copy_size & 0x7; + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + work_reg = + vld1q_u8(input_block_data + input_block_offset + copy_done); + work_reg = veorq_s8(work_reg, sign_bit); + TFLITE_DCHECK_EQ(copy_done % 16, 0); + vst1q_s8(scratch_data + copy_done, work_reg); + } + + if (copy_done + 8 <= copy_size) { + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_done); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 8, 0); + vst1_s8(scratch_data + copy_done, half_work_reg); + copy_done += 8; + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_size - 8); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 8, 0); + vst1_s8(scratch_data + copy_done, half_work_reg); + copy_done += 8; + } + + // Trailing guard. + vst1_s8(scratch_data + copy_done, half_work_reg); + vst1_s8(scratch_data + copy_done + 8, half_work_reg); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + const int copy_remaining = copy_size & 0x3; + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // Main copy loop. + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + // Important! Most compilation configurations will compile and run + // without the reinterpret_cast. Sanitizers may fail silently on + // lane-loading, with a obscure bug or mis-feature probably in + // unhygienic macro expansion. + half_work_reg = + vld1_lane_8x4(input_block_data + input_block_offset + copy_done, + half_work_reg, 0); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 4, 0); + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = vld1_lane_8x4( + input_block_data + input_block_offset + copy_size - 4, + half_work_reg, 0); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining))); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 4, 0); + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + copy_done += 4; + } + // Trailing guard. + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 4, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 8, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 12, half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int i = 0; i < copy_size; ++i) { + half_work_reg = vshl_n_u64(half_work_reg, 8); + half_work_reg = vld1_lane_s8( + reinterpret_cast( + input_block_data + input_block_offset + copy_size - 1 - i), + half_work_reg, 0); + } + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset, half_work_reg, + 0); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + PreloadInputBlock(input_block_data, function_params); + PackMacroBlockNeon(height_block_number, width_block_number, + input_block_data, scratch_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + const int8* input_data_depthwise = scratch_block_data; + uint8* output_data_depthwise = output_block_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + int8x16_t filter_reg_0_a_shifted; + int8x16_t filter_reg_1_a_shifted; + int8x16_t filter_reg_2_a_shifted; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + + if (block_height == 4) { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int8* next_input_data = input_data_base; + uint8* output_data = output_data_base; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(next_input_data); + int8x16_t left_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + int8x16_t left_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + int8x16_t left_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + int8x16_t left_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + + int32x4_t acc0; + int32x4_t acc1; + int32x4_t acc2; + int32x4_t acc3; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + + for (int i_width = 0; i_width < output_width_micro_repeats; + ++i_width) { + next_input_data += width_micro_stride; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += depth; + } + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + int8x16_t right_bank_5_reg; + // Logic: (i_width == output_width_micro_repeats) && + // ((residual_width - 1) * stride_val < 2) + const bool no_right_block = + i_width == output_width_micro_repeats && residual_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + right_bank_3_reg = vdupq_n_s8(0); + right_bank_4_reg = vdupq_n_s8(0); + right_bank_5_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + right_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + right_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + right_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a_shifted, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a_shifted, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a_shifted, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a_shifted, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a_shifted, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a_shifted, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + left_bank_3_reg = vrev32q_u16(left_bank_3_reg); + left_bank_4_reg = vrev32q_u16(left_bank_4_reg); + left_bank_5_reg = vrev32q_u16(left_bank_5_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg); + vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg); + vtrn1_s8x2_in_place(&left_bank_5_reg, &right_bank_5_reg); + + output_data += depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a_shifted, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a_shifted, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a_shifted, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a_shifted, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a_shifted, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a_shifted, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + left_bank_3_reg = right_bank_3_reg; + left_bank_4_reg = right_bank_4_reg; + left_bank_5_reg = right_bank_5_reg; + + output_data += depth; + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + } + } + + if (residual_width > 0) { + next_input_data += width_micro_stride; + const int output_width = residual_width; + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + int8x16_t right_bank_5_reg; + // Logic: (output_width - 1) * stride_val < 2. + const bool no_right_block = output_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + right_bank_3_reg = vdupq_n_s8(0); + right_bank_4_reg = vdupq_n_s8(0); + right_bank_5_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + right_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + right_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + right_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + } + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + biregister_rotate_8(&left_bank_0_reg, &right_bank_0_reg); + biregister_rotate_8(&left_bank_1_reg, &right_bank_1_reg); + biregister_rotate_8(&left_bank_2_reg, &right_bank_2_reg); + biregister_rotate_8(&left_bank_3_reg, &right_bank_3_reg); + biregister_rotate_8(&left_bank_4_reg, &right_bank_4_reg); + biregister_rotate_8(&left_bank_5_reg, &right_bank_5_reg); + + output_data += depth; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + } + } + input_data_base += 4 * workspace_height_stride; + output_data_base += 4 * output_height_stride; + + // Move to next sub-block: advance to second set of filters, to new + // bias. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + } + } else { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* next_input_data = input_data_base; + uint8* output_data = output_data_base; + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(next_input_data); + int8x16_t left_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + next_input_data += width_micro_stride; + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 4; + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + // Logic: (output_width - 1) * stride_val < 2. + const bool no_right_block = output_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + } + // Load next sub-micro block of data. + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + int32x4_t acc = adjusted_bias_data; + acc = vdotq_s32(acc, filter_reg_0_a, left_bank_0_reg); + acc = vdotq_s32(acc, filter_reg_1_a, left_bank_1_reg); + acc = vdotq_s32(acc, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + acc = DivideByPOT::Run( + acc, -output_shift); + // Add the output offset. + // Note that we need to fill the top half with vcombine, but can + // drop the instruction in ASM code. + int16x8_t acc_s16_0_0 = + vcombine_s16(vqmovn_s32(acc), vqmovn_s32(acc)); + acc_s16_0_0 = vqaddq_s16(acc_s16_0_0, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8_0_0 = vqmovun_s16(acc_s16_0_0); + acc_u8_0_0 = + vmax_u8(acc_u8_0_0, vget_low_u8(output_activation_min_vec)); + acc_u8_0_0 = + vmin_u8(acc_u8_0_0, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_0, 0); + + biregister_rotate_8(&left_bank_0_reg, &right_bank_0_reg); + biregister_rotate_8(&left_bank_1_reg, &right_bank_1_reg); + biregister_rotate_8(&left_bank_2_reg, &right_bank_2_reg); + + output_data += depth; + } + } + input_data_base += workspace_height_stride; + output_data_base += output_height_stride; + } + + // Move to next sub-block: advance to second set of filters. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + } + } + input_data_depthwise += depth_micro_stride; + output_data_depthwise += 8; + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int workspace_width_micro_repeats = + function_params->workspace_width_micro_repeats; + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + // This version only does min/max on 64 bits. + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x8_t output_activation_min_vec = + vdup_n_u8(static_cast(output_activation_min)); + const uint8x8_t output_activation_max_vec = + vdup_n_u8(static_cast(output_activation_max)); + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + TFLITE_DCHECK_EQ(stride_val, 2); + TFLITE_DCHECK_LE(block_height, 2); + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + const int8* filter_block = + filter_workspace + shuffled_filter_increment * j_depth; + + if (block_height == 2) { + for (int s = 0; s < 2; ++s) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_2_a; + + filter_reg_0_a = vld1q_s8(filter_block + s * 16); + filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); + filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); + + const int8* scratch_data = + scratch_block_data + depth_micro_stride * j_depth; + uint8* output_data = output_block_data + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(input_data_0); + int8x16_t left_bank_1_reg = + vld1q_s8(input_data_0 + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(input_data_0 + 2 * workspace_height_stride); + int8x16_t left_bank_3_reg = + vld1q_s8(input_data_0 + 3 * workspace_height_stride); + int8x16_t left_bank_4_reg = + vld1q_s8(input_data_0 + 4 * workspace_height_stride); + + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + + if (!no_right_block) { + // Load next sub-micro block of data. + right_bank_0_reg = vld1q_s8(input_data + width_micro_stride); + right_bank_1_reg = vld1q_s8(input_data + width_micro_stride + + workspace_height_stride); + right_bank_2_reg = vld1q_s8(input_data + width_micro_stride + + 2 * workspace_height_stride); + right_bank_3_reg = vld1q_s8(input_data + width_micro_stride + + 3 * workspace_height_stride); + right_bank_4_reg = vld1q_s8(input_data + width_micro_stride + + 4 * workspace_height_stride); + } + + uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base, acc_u8, 0); + vst1_lane_8x4(output_data_base + output_height_stride, acc_u8, 1); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + left_bank_3_reg = vrev32q_u16(left_bank_3_reg); + left_bank_4_reg = vrev32q_u16(left_bank_4_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg); + vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg); + } + + if (output_width > 1) { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base + depth, acc_u8, 0); + vst1_lane_8x4(output_data_base + depth + output_height_stride, + acc_u8, 1); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + left_bank_3_reg = right_bank_3_reg; + left_bank_4_reg = right_bank_4_reg; + } + } + bias_data += bias_increment; + } + } else { + for (int s = 0; s < 2; ++s) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_2_a; + + filter_reg_0_a = vld1q_s8(filter_block + s * 16); + filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); + filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); + + const int8* scratch_data = + scratch_block_data + depth_micro_stride * j_depth; + uint8* output_data = output_block_data + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(input_data_0); + int8x16_t left_bank_1_reg = + vld1q_s8(input_data_0 + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(input_data_0 + 2 * workspace_height_stride); + + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + + int32x4_t acc0; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + + if (!no_right_block) { + // Load next sub-micro block of data. + right_bank_0_reg = vld1q_s8(input_data + width_micro_stride); + right_bank_1_reg = vld1q_s8(input_data + width_micro_stride + + workspace_height_stride); + right_bank_2_reg = vld1q_s8(input_data + width_micro_stride + + 2 * workspace_height_stride); + } + + uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc0)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base, acc_u8, 0); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + } + + if (output_width > 1) { + acc0 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc0)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base + depth, acc_u8, 0); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + } + } + bias_data += bias_increment; + } + } + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + TFLITE_DCHECK_EQ(function_params->stride, 1); + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + + TFLITE_DCHECK_EQ(bias_increment, 4); + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + uint8* output_data_depthwise = output_block_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + int8x16_t filter_reg_0_a_shifted; + int8x16_t filter_reg_1_a_shifted; + int8x16_t filter_reg_2_a_shifted; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + + if (block_height == 4) { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int8* next_input_data = scratch_block_data; + uint8* output_data = output_data_base; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + int8x16_t input_bank_c_reg; // left 4, right 4, left 5, right 5. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(next_input_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + next_input_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 2); + input_bank_c_reg = vld1q_dup_s8x4( + next_input_data + + 4 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 2); + + int32x4_t acc0; + int32x4_t acc1; + int32x4_t acc2; + int32x4_t acc3; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, 2); + + for (int i_width = 0; i_width < output_width_micro_repeats; + ++i_width) { + next_input_data += 4; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += output_depth; + } + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 3); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 4 * workspace_height_stride, + input_bank_c_reg, 1); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 3); + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a_shifted, + input_bank_a_reg, 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a_shifted, + input_bank_a_reg, 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a_shifted, + input_bank_a_reg, 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a_shifted, + input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a_shifted, + input_bank_b_reg, 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a_shifted, + input_bank_c_reg, 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += output_depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a_shifted, + input_bank_a_reg, 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a_shifted, + input_bank_a_reg, 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a_shifted, + input_bank_a_reg, 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a_shifted, + input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a_shifted, + input_bank_b_reg, 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a_shifted, + input_bank_c_reg, 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + } + } + + if (residual_width > 0) { + next_input_data += 4; + const int output_width = residual_width; + + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 3); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 4 * workspace_height_stride, + input_bank_c_reg, 1); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 3); + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 8); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 8); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 8); + + output_data += output_depth; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + } + } + // scratch_block_data += 4 * workspace_height_stride; + output_data_base += 4 * output_height_stride; + + // Move to next sub-block: advance to second set of filters, to new + // bias. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + } + } else { + // Block height < 4. + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* next_input_data = + scratch_block_data + k_height * workspace_height_stride; + uint8* output_data = output_data_base; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(next_input_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + next_input_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + next_input_data += 4; + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 4; + + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = + vld1q_lane_8x4(next_input_data + workspace_height_stride, + input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + int32x4_t acc = adjusted_bias_data; + acc = vdotq_four_lane_s32(acc, filter_reg_0_a, input_bank_a_reg, + 0); + acc = vdotq_four_lane_s32(acc, filter_reg_1_a, input_bank_a_reg, + 2); + acc = vdotq_four_lane_s32(acc, filter_reg_2_a, input_bank_b_reg, + 0); + + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + acc = DivideByPOT::Run( + acc, -output_shift); + // Add the output offset. + // Note that we need to fill the top half with vcombine, but can + // drop the instruction in ASM code. + int16x8_t acc_s16_0_0 = + vcombine_s16(vqmovn_s32(acc), vqmovn_s32(acc)); + acc_s16_0_0 = vqaddq_s16(acc_s16_0_0, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8_0_0 = vqmovun_s16(acc_s16_0_0); + acc_u8_0_0 = + vmax_u8(acc_u8_0_0, vget_low_u8(output_activation_min_vec)); + acc_u8_0_0 = + vmin_u8(acc_u8_0_0, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_0, 0); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 8); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 8); + + output_data += output_depth; + } + } + output_data_base += output_height_stride; + } + + // Move to next sub-block: advance to second set of filters. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + } + } + output_data_depthwise += 8; + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock { + static inline void KernelMacroBlockNeon( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + TFLITE_DCHECK_GE(depth_micro_repeats, 1); + TFLITE_DCHECK_EQ(bias_increment, 4); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + for (int j_depth = 0; j_depth < (depth_micro_repeats * 1 + 0); ++j_depth) { + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + TFLITE_DCHECK_EQ(bias_increment, 4); + const int32x4_t adjusted_bias_data_s_0 = vld1q_s32(bias_data); + bias_data += bias_increment; + const int32x4_t adjusted_bias_data_s_1 = vld1q_s32(bias_data); + bias_data += bias_increment; + + if (block_height == 2) { + const int8* scratch_data = scratch_block_data; + uint8* output_data = output_block_data + 8 * j_depth; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + int8x16_t input_bank_c_reg; // left 4, right 4, xxx, xxx. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(scratch_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + scratch_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + scratch_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_b_reg = vld1q_lane_8x4( + scratch_data + 3 * workspace_height_stride, input_bank_b_reg, 2); + input_bank_c_reg = vld1q_dup_s8x4( + scratch_data + + 4 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + + TFLITE_DCHECK_LE(output_width, 2); + TFLITE_DCHECK_GE(output_width, 1); + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 + 4 * i_width; + + // Load next sub-micro block of data. + input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 2 * workspace_height_stride, input_bank_b_reg, 1); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 3 * workspace_height_stride, input_bank_b_reg, 3); + input_bank_c_reg = vld1q_lane_8x4( + input_data + 4 * workspace_height_stride, input_bank_c_reg, 1); + + int16x8_t acc_s16_0_1; + uint8x8_t acc_u8_0_1; + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data_s_0; + acc1 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + output_height_stride, acc_u8_0_1, 1); + + acc0 = adjusted_bias_data_s_1; + acc1 = adjusted_bias_data_s_1; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_b, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_b, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data + 4, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + 4 + output_height_stride, acc_u8_0_1, + 1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + if (output_width == 2) { + acc0 = adjusted_bias_data_s_0; + acc1 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + output_height_stride, acc_u8_0_1, 1); + + acc0 = adjusted_bias_data_s_1; + acc1 = adjusted_bias_data_s_1; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_b, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_b, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data + 4, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + 4 + output_height_stride, acc_u8_0_1, + 1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + } + } else { + TFLITE_DCHECK_EQ(block_height, 1); + // Work through one slice, by row, at a time. + const int8* scratch_data = scratch_block_data; + uint8* output_data = output_block_data + 8 * j_depth; + + // + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, xxx, xxx. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(scratch_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + scratch_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + scratch_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 2; + + TFLITE_DCHECK_LE(output_width, 2); + TFLITE_DCHECK_GE(output_width, 1); + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 + 4 * i_width; + + // Load next sub-micro block of data. + input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 2 * workspace_height_stride, input_bank_b_reg, 1); + + int16x8_t acc_s16_0_1; + uint8x8_t acc_u8_0_1; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + + // Second sub-block accumulation. + acc1 = adjusted_bias_data_s_1; + + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_a_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_a_reg, 2); + + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + // This stores the results for both sub-blocks together. + vst1_u8(output_data, acc_u8_0_1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + + output_data += output_depth; + } + if (output_width == 2) { + acc0 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + + // Second sub-block accumulation. + acc1 = adjusted_bias_data_s_1; + + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_a_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_a_reg, 2); + + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + // This stores the results for both sub-blocks together. + vst1_u8(output_data, acc_u8_0_1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + + output_data += output_depth; + } + } + } + } + } + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockNeon(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +#endif // USE_NEON && __aarch64__ + +// Top-level implementation function for 3x3 depthwise convolution using NEON +// dot-product instructions. +// +// MACRO & MICRO BLOCKS +// +// The task is divided into macro blocks. Data is copied first into a macro +// block in a workspace. This has two purposes: (a) bringing data into +// cache, and (b) permuting data so that it can be used much more easily in +// a dot-product filter. +// +// When there is no depth multiplication: +// +// The permutations required for dot-products are local, within 4 data points +// down the depth and 4 across the width. We want to pull in input data at least +// 8-bytes at a time, down the depth, and so we divide the macro blocks into +// 1x4x8 (height, width, depth) and further divide the micro blocks into +// sub-blocks with shape (1x4x4). +// +// Each macro-block is constructed from micro-blocks that are internally +// rearranged during loading into the macro-block workspace. +// +// In other words, the micro-block shape is +// {1, 1, 4, 8} +// Each macro block is typically shape +// {1, height_block_size, 4 * workspace_width_micro_repeats, 64} +// and workspace_width_micro_repeats is chosen so it fits into the workspace. +// +// However, if depth < 64, we decrease the macro block depth, enabling us to +// increase the macro-block width. +// +// When there is depth multiplication: +// +// We require input-depth = 1 and exploit that instead. Note that output data +// is still full-depth, *as is the filter and bias data after certain +// adjustments*, and so the filter stage in this case still proceeds in terms of +// sub-blocks. +// +// The Magic of these numbers: +// 4 is the number of input elements used in each dot-product. +// 8 is the number of inputs we load at a time into a register. +// 64 is min amount of data to be loaded in a stretch (when possible). +// +// FILTER DATA PREPARATION +// +// Filter data needs to be permuted in a fashion like that of input data, and +// this is done in a preprocessing stage. In addition, this stage extends the +// filter in the direction of width from 3 to 4. The extra filter taps are set +// to zero so that input data does not have to be zeroed before applying +// dot-products. +// +// OVERALL COUNTS: HANDLING TRAILING ITERATION +// +// Often it is necessary to handle the last iteration in a loop differently, +// generally because the final item is shorter. The logic to detect the +// special case can be a bit expensive. We use a scheme in which there are +// two counts, in a pattern like xxx_yyy_repeats and +// xxx_overall_yyy_repeats. The first gives the count of "normal" +// iterations. The loop iterates over the second count, and the induction +// variable is checked to see if it reaches xxx_yyy_repeats. If there is no +// special trailing iteration, xxx_yyy_repeats = xxx_overall_yyy_repeats, +// and the special code is not executed. +// +// Example: +// Suppose that we characterize a size s as +// f(s) -> (block-4-repetitions, remainder, overall_repetitions): +// f(11) -> (2, 3, 3) +// f(12) -> (3, 0, 3) +// f(13) -> (3, 1, 4) +// +// POINTING OUTSIDE OF INPUT ARRAY. +// +// When there is padding, the input data pointer passed to the fill routines +// points outside of the input array and into a kind-of virtual padded +// margin. It turns out that this simplifies the code and removes +// conditional statements. It is hard to explain why without comparing two +// versions of the code. In summary, this way the adjustment into the margin +// can be made unconditionally, and the correction back into the input array +// is done where there is a conditional already. +// +// OVERLAP +// +// Since this is *depthwise* conv, neither the batch nor the depth have overlap. +// The height and depth overlap by (filter_size - 1). Thus some data is used +// twice on the borders of macro blocks. +// +template +inline void DepthwiseConvDotProduct3x3( + const DepthwiseParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data) { + // Check kernel restrictions. + constexpr int filter_size = 3; + constexpr int kMaxStride = 2; + constexpr int kMaxPadding = 1; + constexpr int kSymmetricZeroPoint = 128; + TFLITE_DCHECK_EQ(params.weights_offset, -kSymmetricZeroPoint); + TFLITE_DCHECK_LE(params.stride_width, kMaxStride); + TFLITE_DCHECK_EQ(params.stride_height, params.stride_width); + TFLITE_DCHECK_EQ(params.dilation_width_factor, 1); + TFLITE_DCHECK_EQ(params.dilation_height_factor, 1); + TFLITE_DCHECK_LE(params.padding_values.width, kMaxPadding); + TFLITE_DCHECK_LE(params.padding_values.height, kMaxPadding); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + + // Key kernel parameters (along with padding handled later). + const int stride = params.stride_width; + const int depth_multiplier = params.depth_multiplier; + const bool has_depth_multiplication = depth_multiplier > 1; + + // Extract task dimensions. + const int input_depth = input_shape.Dims(3); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + TFLITE_DCHECK(!has_depth_multiplication || input_depth == 1); + TFLITE_DCHECK(has_depth_multiplication || input_depth == output_depth); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + TFLITE_DCHECK_EQ(input_depth * depth_multiplier, output_depth); + TFLITE_DCHECK_EQ(MatchingDim(filter_shape, 1, filter_shape, 2), filter_size); + + // Return now if nothing to do. + if (output_width == 0 || output_height == 0) { + return; + } + + // Kernel parameter structure: set basic fields. + // + // In asm it is easier to pass a structure than more than, say, 8 parameters. + DepthwiseConvDotProdParams function_params; + function_params.input_depth = input_depth; + function_params.output_depth = output_depth; + function_params.input_offset = params.input_offset; + function_params.output_offset = params.output_offset; + function_params.output_multiplier = params.output_multiplier; + function_params.output_shift = params.output_shift; + function_params.quantized_activation_min = params.quantized_activation_min; + function_params.quantized_activation_max = params.quantized_activation_max; + function_params.stride = stride; + + // Handle inbound bias data. + // + // Note that this data is adjusted in a per-depth process before the main + // filters. The adjustment accounts for a non-symmetric input offset. + // + // Kernel subroutines need to be able to operate consistently on an bias + // array. Where there is no bias, we provide one filled with zeros. + constexpr int kMinBiasLoad = 8; + int32 zero_bias_data[kMinBiasLoad]; + int32 bias_increment; + if (bias_data) { + bias_increment = 4; + } else { + memset(zero_bias_data, 0, sizeof(zero_bias_data)); + bias_data = &zero_bias_data[0]; + bias_increment = 0; + } + function_params.bias_increment = bias_increment; + TFLITE_DCHECK_LE(2 * function_params.bias_increment, kMinBiasLoad); + + // Process padding. + // + // Whether "correct" or not, this matches ComputeConvSizes. When there is + // stride > 1 there can be padding on the bottom or top, and therefore + // we need to consider padding. This is true even if one or other of the + // padding_values is 0. + const int padded_width = (output_width - 1) * stride + filter_size; + { + const int padding_left = params.padding_values.width; + // Right padding would be -1 if discarding input because of stride. + const int padding_right = + std::max(padded_width - input_width - padding_left, 0); + const int padding_top = params.padding_values.height; + const int padded_height = (output_height - 1) * stride + filter_size; + const int padding_bottom = + std::max(padded_height - input_height - padding_top, 0); + + function_params.padding_left = padding_left; + function_params.padding_right = padding_right; + function_params.padding_top = padding_top; + function_params.padding_bottom = padding_bottom; + + TFLITE_DCHECK_LE(padding_left, padding_right); + TFLITE_DCHECK_LE(padding_top, padding_bottom); + } + // When stride == 1 left or top padding may only be non-zero. + // This is when padding is specified but not needed on a trailing dimension. + // When stride == 2 right or bottom padding may only be non-zero. + // This is a result of the details of the padding calculations. + const bool padding_required = + function_params.padding_left > 0 || function_params.padding_top > 0 || + function_params.padding_right > 0 || function_params.padding_bottom > 0; + + // Choose parameter-specific kernel subroutines. + // + // The main part of the kernel has two stages. First, a temporary workspace is + // filled with padded and permuted data. Second, the filter is applied to the + // workspace data to generate output. + // + // The workspace fill stage handles padding so that the filter stage does not + // need to account for it. The workspace fill stage does not need to + // understand striding, and implicitly handles striding through the parameters + // that it is given. + using pack_macro_block_func_t = decltype( + &PackMacroBlock::Run); + using kernel_macro_block_func_t = decltype( + &KernelMacroBlock::Run); + pack_macro_block_func_t pack_macro_block_func; + kernel_macro_block_func_t kernel_macro_block_func; + { + if (has_depth_multiplication) { + if (padding_required) { + pack_macro_block_func = + PackMacroBlock::Run; + } else { + pack_macro_block_func = + PackMacroBlock::Run; + } + if (stride == 1) { + kernel_macro_block_func = + KernelMacroBlock::Run; + } else { + kernel_macro_block_func = + KernelMacroBlock::Run; + } + } else { + if (padding_required) { + pack_macro_block_func = + PackMacroBlock::Run; + } else { + pack_macro_block_func = + PackMacroBlock::Run; + } + if (stride == 1) { + kernel_macro_block_func = KernelMacroBlock< + implementation, DepthwiseConvDepthMultiplication::kNoMultiplication, + /*stride=*/1>::Run; + } else { + kernel_macro_block_func = KernelMacroBlock< + implementation, DepthwiseConvDepthMultiplication::kNoMultiplication, + /*stride=*/2>::Run; + } + } + } + + // Stride-only variables. + // + // stride == 1 ? 4 : 2: + const int output_height_per_macro = 6 - 2 * stride; + // output_height_per_macro * stride: + constexpr int input_height_per_macro = 4; + // Number of rows per micro block (= rows per macro block) is + // (output_height_per_macro - 1) * stride + 1 + (filter_size - 1) + // = stride == 1 ? 3 + filter_size : 2 + filter_size: + const int height_block_size = 4 + filter_size - stride; + const int input_height_overlap = filter_size - stride; + // stride == 1 ? 4 : 2: + function_params.four_over_stride = output_height_per_macro; + + TFLITE_DCHECK_EQ(stride * function_params.four_over_stride, 4); + TFLITE_DCHECK_EQ(height_block_size, + input_height_per_macro + input_height_overlap); + + // Create workspaces. + // + // Filter workspace is for shuffle: only first depth/8 is used. + // indexed as [depth/8][sub-block][height][depth][width]. + TFLITE_DCHECK_EQ(kDepthwiseConvAdjustedBiasLimit % 8, 0); + int8 macroblock_workspace[kDepthwiseConvScratchWorkspaceSize]; + int32 adjusted_bias_data[kDepthwiseConvAdjustedBiasLimit]; + int8 filter_workspace[kDepthwiseConvAdjustedBiasLimit >> 3][3][2][4][4]; + + // Output depth characterization. + // + const int depth_macro_count = output_depth / 64; + const int depth_overall_macro_count = (output_depth + 63) / 64; + // Number of micro blocks down the depth in a final incomplete macro block. + const int depth_trailing_micro_repeats = output_depth / 8 % 8; + // The output_depth may not have a remainder: it must be a multiple of 8. + TFLITE_DCHECK_EQ(output_depth, + 64 * depth_macro_count + 8 * depth_trailing_micro_repeats); + + // Characterize the first macro block depth, the largest. + // + // We base treatment of the width on the trailing macro block if there are + // no full blocks, in order to do more work together (that is, increase + // workspace_width_micro_repeats when largest_macro_depth < 64). + const int largest_macro_depth = + has_depth_multiplication + ? 1 + : (depth_macro_count > 0 ? 64 : 8 * depth_trailing_micro_repeats); + + // Characterize width, consumption of input and generation of output. + // + // In the case of depth multiplication, we ensure that some of the workspace + // at the end remains unused. This enables the filter routines to load the + // "next" data, of at least 16 bytes, even when at the end of the workspace. + // It is relatively expensive to detect the end micro block. It is also very + // difficult to test for (to trigger) erroneous reads (past end of array) in + // the depth multplication case. + int workspace_width_micro_repeats = + (has_depth_multiplication + ? kDepthwiseConvScratchWorkspaceSize - kWorkspaceExtension + : kDepthwiseConvScratchWorkspaceSize) / + (4 * largest_macro_depth * height_block_size); + // When there is no depth multiplication, the workspace depth is a multiple of + // 8, which ensures that workspace rows are 16-byte aligned. (Actually 32, + // because of the micro width of 4.) This is not necessarily the case under + // depth multiplication, so we adjust now to impose this restriction. + if (has_depth_multiplication) { + workspace_width_micro_repeats = (workspace_width_micro_repeats / 4) * 4; + } + TFLITE_DCHECK_EQ((workspace_width_micro_repeats * largest_macro_depth) % 4, + 0); + // Discount 1 of the micro-block repeats in each macro block to account for + // overlap. + const int consumed_width_per_macro_block = + 4 * (workspace_width_micro_repeats - 1); + const int output_width_per_macro_block = + function_params.four_over_stride * (workspace_width_micro_repeats - 1); + TFLITE_DCHECK_GT(workspace_width_micro_repeats, 1); + TFLITE_DCHECK_EQ(output_width_per_macro_block * stride, + consumed_width_per_macro_block); + + // Width repetitions and residuals. + // + // Use of the workspace is characterized primarily in terms of *padded input*. + // Striding only matters in a few places. + // + // Simplifications: We require that there always be at least one full + // micro-block across the width. Since the maximum padding is 1, the trailing + // padding cannot span two micro blocks. + const int residual_micro_width = padded_width % 4; + // We base the count of macro blocks on the amount of padded input data each + // one consumes. + int width_overall_macro_count = (padded_width - residual_micro_width + + consumed_width_per_macro_block - 1) / + consumed_width_per_macro_block; + // Recall that we left a micro block at the end of each macro block for use as + // overlap. There is a special case in which we can use one fewer macro + // blocks, with the last one consuming extra input. (But not if the + // calculation thinks that we can use zero blocks.) + if (padded_width <= + ((width_overall_macro_count - 1) * consumed_width_per_macro_block + 4)) { + width_overall_macro_count -= 1; + } + width_overall_macro_count = std::max(width_overall_macro_count, 1); + // We always have to treat the final macro block along width as trailing, + // because even if it is full in terms of padded input, it will be incomplete + // in terms of output. + const int width_macro_count = width_overall_macro_count - 1; + // Micro blocks are traversed in terms of input in fill routines. + const int width_trailing_micro_repeats = + (padded_width - consumed_width_per_macro_block * width_macro_count) / 4; + const int width_overall_trailing_micro_repeats = + (padded_width - consumed_width_per_macro_block * width_macro_count + 3) / + 4; + // Micro blocks are traversed in terms of output in filtering routines. + const int residual_output_micro_width = + (output_width - 1) % function_params.four_over_stride + 1; + const int output_width_trailing_micro_repeats = + residual_micro_width > (filter_size - 1) + ? width_trailing_micro_repeats + : width_trailing_micro_repeats - 1; + // Check results. + TFLITE_DCHECK_GT(width_overall_trailing_micro_repeats, 0); + TFLITE_DCHECK_EQ(padded_width, + residual_micro_width + + consumed_width_per_macro_block * width_macro_count + + 4 * width_trailing_micro_repeats); + TFLITE_DCHECK_LE(width_overall_macro_count, width_macro_count + 1); + TFLITE_DCHECK_GE(width_overall_macro_count, width_macro_count); + + // Height repetitions and residuals. + // + const int height_macro_count = output_height / output_height_per_macro; + const int residual_output_height = output_height % output_height_per_macro; + const int height_overall_macro_count = + (output_height + output_height_per_macro - 1) / output_height_per_macro; + TFLITE_DCHECK_EQ( + output_height, + residual_output_height + output_height_per_macro * height_macro_count); + TFLITE_DCHECK_LE(height_overall_macro_count, height_macro_count + 1); + TFLITE_DCHECK_GE(height_overall_macro_count, height_macro_count); + + // Data strides. + // + const int input_height_stride = input_width * input_depth; + const int output_height_stride = output_width * output_depth; + const int input_batch_stride = input_height_stride * input_height; + const int output_batch_stride = output_height_stride * output_height; + const int input_depth_macro_stride = has_depth_multiplication ? 0 : 64; + const int input_width_macro_stride = + input_depth * consumed_width_per_macro_block; + const int output_width_macro_stride = + output_depth * output_width_per_macro_block; + + // Store parameters that do not vary across macro blocks. + // + function_params.workspace_width_micro_repeats = workspace_width_micro_repeats; + function_params.height_macro_count = height_overall_macro_count; + function_params.width_macro_count = width_overall_macro_count; + function_params.input_height_stride = input_height_stride; + function_params.output_height_stride = output_height_stride; + function_params.residual_width = residual_micro_width; + + // Main process. + // + // Most kernels are nested batch-height-width-depth. Here we proceed over + // macro blocks batch-width-depth-height. + // + // Example of handling of trailing iteration: when there is trailing depth, + // depth_overall_macro_count = depth_macro_count + 1, so we can adjust the + // dimensions for trailing macro blocks by looking for + // j_depth == depth_macro_count. + for (int b = 0; b < batches; ++b) { + for (int k_width = 0; k_width < width_overall_macro_count; ++k_width) { + // Figure out the work to be done for this macro block. If it trails in + // any dimension, the work in that dimension is adjusted. + // The work to be done across widths has 3 cases: + // (a) A full macro block, + // (b) Partial terminal macro block, with input and output ending in + // same micro block, and + // (c) Partial terminal macro block, with output corresponding to one + // fewer micro blocks, because filter extends across micro-block + // boundary. + if (k_width != width_macro_count) { + function_params.output_residual_width = 0; + function_params.input_width_micro_repeats = + workspace_width_micro_repeats; + function_params.input_width_overall_micro_repeats = + workspace_width_micro_repeats; + function_params.output_width_micro_repeats = + workspace_width_micro_repeats - 1; + } else { + function_params.output_residual_width = residual_output_micro_width; + function_params.input_width_micro_repeats = + width_trailing_micro_repeats; + function_params.input_width_overall_micro_repeats = + width_overall_trailing_micro_repeats; + function_params.output_width_micro_repeats = + output_width_trailing_micro_repeats; + } + function_params.output_width_overall_micro_repeats = + function_params.output_residual_width == 0 + ? function_params.output_width_micro_repeats + : function_params.output_width_micro_repeats + 1; + + for (int j_depth = 0; j_depth < depth_overall_macro_count; ++j_depth) { + const uint8* input_data_block = + input_data + b * input_batch_stride + + j_depth * input_depth_macro_stride + + k_width * input_width_macro_stride - + function_params.padding_left * input_depth - + function_params.padding_top * input_height_stride; + uint8* output_data_block = output_data + b * output_batch_stride + + j_depth * 64 + + k_width * output_width_macro_stride; + + // Process filter and bias data. + // + function_params.depth_micro_repeats = + j_depth == depth_macro_count ? depth_trailing_micro_repeats : 8; + ProcessPerDepth::Run( + filter_data + 64 * j_depth, + bias_data + 8 * 2 * bias_increment * j_depth, + filter_workspace[0][0][0][0], adjusted_bias_data, &function_params); + + // Under depth multiplication the workspace_height_stride does not have + // to depend on input_width_overall_micro_repeats, but this improves the + // compactness of workspace use. + const int workspace_height_stride = + has_depth_multiplication + ? 16 * ((function_params.input_width_overall_micro_repeats + + 3) >> + 2) + : 4 * function_params.input_width_overall_micro_repeats * 8 * + function_params.depth_micro_repeats; + TFLITE_DCHECK_EQ(workspace_height_stride % 16, 0); + function_params.workspace_height_stride = workspace_height_stride; + + // For the first macro block for output rows we fill in the first few + // rows. After this we will copy them (see below in loop.) + function_params.inbound_block_height = input_height_overlap; + pack_macro_block_func(-1, k_width, input_data_block, + macroblock_workspace, &function_params); + input_data_block += input_height_stride * input_height_overlap; + + for (int i_height = 0; i_height < height_overall_macro_count; + ++i_height) { + if (i_height != height_macro_count) { + function_params.inbound_block_height = input_height_per_macro; + function_params.outbound_block_height = output_height_per_macro; + } else { + function_params.inbound_block_height = + residual_output_height * stride; + function_params.outbound_block_height = residual_output_height; + } + TFLITE_DCHECK_LT(i_height * output_height_per_macro, output_height); + TFLITE_DCHECK_LT(i_height * input_height_per_macro, input_height); + TFLITE_DCHECK_LT(k_width * output_width_per_macro_block, + output_width); + TFLITE_DCHECK_LT(k_width * consumed_width_per_macro_block, + input_width); + + // Macro blocks overlap by input_height_overlap rows, so we copy + // those instead of filling in afresh. The first macro block across + // output rows was filled in outside of the loop (above). + if (i_height > 0) { + memcpy(macroblock_workspace, + macroblock_workspace + + input_height_per_macro * workspace_height_stride, + input_height_overlap * workspace_height_stride); + } + + pack_macro_block_func( + i_height, k_width, input_data_block, + macroblock_workspace + + input_height_overlap * workspace_height_stride, + &function_params); + + kernel_macro_block_func( + macroblock_workspace, filter_workspace[0][0][0][0], + adjusted_bias_data, output_data_block, &function_params); + + input_data_block += input_height_stride * input_height_per_macro; + output_data_block += output_height_stride * output_height_per_macro; + } + } + } + } +} + +#undef vst1_lane_8x4 +#undef vst1q_lane_8x4 +#undef vld1q_lane_s8x8 +#undef vld1_lane_8x4 +#undef vld1q_lane_8x4 +#undef vld1q_dup_s8x4 + +#undef STR +#undef STR_UNEXPANDED + +} // namespace depthwise_conv } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h new file mode 100644 index 0000000000000000000000000000000000000000..e7fafa0b1c89e1b83e351e8c1c19afd48d227e04 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h @@ -0,0 +1,5005 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_TRANSITIONAL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_TRANSITIONAL_H_ + +// This file provides kernel implementations that are not used in shipped +// inference code, but rather (a) show how model C++ code is designed and then +// transformed into asm code, and (b) aid with maintenance and later development +// of variations. Many projects (even including, say, the classic NAG libraries) +// develop highly optimized code, but do not maintain intermediate versions. +// Often the result is incomprehensible final-version code. + +#include + +#include "fixedpoint/fixedpoint.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { +namespace depthwise_conv { + +#ifdef USE_NEON +// Lane operations are for clarity and convenience. We want to load and store +// 4 8-bit lanes together. So these are treated much like 32-bit loads and +// 32-bit stores. Stores require 32-bit alignment. + +#define vst1_lane_8x4(dst, reg, lane_num) \ + TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ + vst1_lane_u32(reinterpret_cast(dst), reg, lane_num) +#define vst1q_lane_8x4(dst, reg, lane_num) \ + TFLITE_DCHECK_EQ(reinterpret_cast(dst) % 4, 0); \ + vst1q_lane_u32(reinterpret_cast(dst), reg, lane_num) + +#define vld1q_lane_s8x8(src, reg, lane_num) \ + vld1q_lane_u64(reinterpret_cast(src), reg, lane_num) +#define vld1_lane_8x4(src, reg, lane_num) \ + vld1_lane_s32(reinterpret_cast(src), reg, lane_num) +#define vld1q_lane_8x4(src, reg, lane_num) \ + vld1q_lane_s32(reinterpret_cast(src), reg, lane_num) +#define vld1q_dup_s8x4(src) vld1q_dup_s32(reinterpret_cast(src)) +#endif // USE_NEON + +template <> +struct ProcessPerDepth { + // Filter data is provided as filter_block[3][3][depth/8][2][4]: height 3, + // width 3, sub-block 0 or 1, depth 4. Filter data is written as + // filter_bank[3][2][4][4]; height 3, sub-block, depth 4, width 4. + // + // Note that this rearrangement is much like that performed on input data when + // filling the workspace, and optimized versions will be similar. + static inline void FillFilterBank(int depth, const uint8* filter_block, + int8 filter_bank[3][2][4][4]) { + constexpr int kSymmetricZeroPoint = 128; + // Load filter data in, 8-bytes down depth / sub-block at a time. + // + // loaded_filter has dimensions height 3, width 4, sub-block 0 or 1, + // depth 4. + uint8 loaded_filter[3][4][2][4]; + for (int y = 0; y < 3; ++y) { + for (int x = 0; x < 3; ++x) { + memcpy(loaded_filter[y][x][0], &filter_block[3 * y * depth + x * depth], + 8); + } + // Pad the filter with symmetric representation of 0, so that the values + // become 0 when the zero-poing is added below. Thus these filter taps are + // effectively disregarded in later filtering. + memset(loaded_filter[y][3][0], kSymmetricZeroPoint, 8); + } + for (int y = 0; y < 3; ++y) { + for (int z = 0; z < 4; ++z) { + for (int x = 0; x < 4; ++x) { + filter_bank[y][0][z][x] = + loaded_filter[y][x][0][z] - kSymmetricZeroPoint; + filter_bank[y][1][z][x] = + loaded_filter[y][x][1][z] - kSymmetricZeroPoint; + } + } + } + } + + // Adjust the bias (weights) data according to the input offset. + // + // The output calculation is + // out[h][w][d] = bias[d] + sum_ij (in[h+i][w+j][d] + in_offset) * + // (filter[i][j][d] + filter_offset) + // (where offsets are expressed as differences from 128). + // + // Since we cannot efficiently handle varying offsets / bias across the image, + // we insist on filter_offset = 0. + // + // This function calculates + // adjusted_bias[d] = bias[d] + sum_ij in_offset * filter[i][j][d] + // which accounts for input offset. If the bias is constant over the depth, + // the adjusted bias will vary. + static inline void AdjustBias(int32 input_offset, + const int8 filter_bank[3][2][4][4], + const int32* bias_data, + int32 adjusted_bias_block[2][4]) { + constexpr int kSymmetricZeroPoint = 128; + TFLITE_DCHECK_GE(input_offset, -255); + TFLITE_DCHECK_LE(input_offset, 0); + // For instance, if input_offset == 128, no adjustment is needed. + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + for (int s = 0; s < 2; ++s) { + for (int z = 0; z < 4; ++z) { + adjusted_bias_block[s][z] = bias_data[4 * s + z]; + for (int i = 0; i < 9; ++i) { + adjusted_bias_block[s][z] += + input_offset_difference * filter_bank[i % 3][s][z][i / 3]; + } + } + } + } + + static void Run(const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + const int depth = function_params->output_depth; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int bias_increment = function_params->bias_increment; + const int32 input_offset = function_params->input_offset; + + int8 filter_bank[3][2][4][4]; + int32 adjusted_bias_block[2][4]; + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + FillFilterBank(depth, filter_data + 8 * j_depth, filter_bank); + AdjustBias(input_offset, filter_bank, + bias_data + 2 * bias_increment * j_depth, adjusted_bias_block); + + memcpy(shuffled_filter_data, filter_bank[0][0][0], + shuffled_filter_increment); + shuffled_filter_data += shuffled_filter_increment; + memcpy(adjusted_bias_data, adjusted_bias_block[0], + 8 * sizeof(adjusted_bias_block[0][0])); + adjusted_bias_data += 8; + } + } +}; + +template <> +struct ProcessPerDepth { + static inline void Run(const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + const int depth = function_params->output_depth; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int bias_increment = function_params->bias_increment; + + // Simulate NEON-register transposition of subset of filter. + int8 filter_bank_a_0[4][4]; // Depth 4, width 4. + int8 filter_bank_a_1[4][4]; + int8 filter_bank_a_2[4][4]; + int8 filter_bank_b_0[4][4]; + int8 filter_bank_b_1[4][4]; + int8 filter_bank_b_2[4][4]; + + // Load filter data in, essentially dropping the [depth/8] dimension, which + // is equivalent to loading just the depth needed for one micro-block. + // + // loaded_filter has dimensions height 3, width 4, sub-block 0 or 1, + // depth 4. + uint8 loaded_filter_0[4][2][4]; + uint8 loaded_filter_1[4][2][4]; + uint8 loaded_filter_2[4][2][4]; + + constexpr int kSymmetricZeroPoint = 128; + const int32 input_offset = function_params->input_offset; + TFLITE_DCHECK_GE(input_offset, -255); + TFLITE_DCHECK_LE(input_offset, 0); + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + const uint8* filter_block = filter_data + 8 * j_depth; + + // Filter data is provided as filter_block[3][3][depth/8][2][4]. + // height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4. + // filter_bank[3][2][4][4]; Sub-block, height 3, depth 4, width 4. + for (int x = 0; x < 3; ++x) { + memcpy(loaded_filter_0[x][0], &filter_block[3 * 0 * depth + x * depth], + 8); + memcpy(loaded_filter_1[x][0], &filter_block[3 * 1 * depth + x * depth], + 8); + memcpy(loaded_filter_2[x][0], &filter_block[3 * 2 * depth + x * depth], + 8); + } + // Pad the filter with -filter_offset, so that the values become 0 when + // the filter_offset is later added, and so the filter tap is effectively + // disregarded. + memset(loaded_filter_0[3][0], kSymmetricZeroPoint, 8); + memset(loaded_filter_1[3][0], kSymmetricZeroPoint, 8); + memset(loaded_filter_2[3][0], kSymmetricZeroPoint, 8); + + for (int z = 0; z < 4; ++z) { + for (int x = 0; x < 4; ++x) { + filter_bank_a_0[z][x] = + loaded_filter_0[x][0][z] - kSymmetricZeroPoint; + filter_bank_b_0[z][x] = + loaded_filter_0[x][1][z] - kSymmetricZeroPoint; + filter_bank_a_1[z][x] = + loaded_filter_1[x][0][z] - kSymmetricZeroPoint; + filter_bank_b_1[z][x] = + loaded_filter_1[x][1][z] - kSymmetricZeroPoint; + filter_bank_a_2[z][x] = + loaded_filter_2[x][0][z] - kSymmetricZeroPoint; + filter_bank_b_2[z][x] = + loaded_filter_2[x][1][z] - kSymmetricZeroPoint; + } + } + + memcpy(shuffled_filter_data, filter_bank_a_0, 16); + shuffled_filter_data += 16; + memcpy(shuffled_filter_data, filter_bank_b_0, 16); + shuffled_filter_data += 16; + memcpy(shuffled_filter_data, filter_bank_a_1, 16); + shuffled_filter_data += 16; + memcpy(shuffled_filter_data, filter_bank_b_1, 16); + shuffled_filter_data += 16; + memcpy(shuffled_filter_data, filter_bank_a_2, 16); + shuffled_filter_data += 16; + memcpy(shuffled_filter_data, filter_bank_b_2, 16); + shuffled_filter_data += 16; + + int32 adjusted_bias_data_0[4]; + int32 adjusted_bias_data_1[4]; + // For instance, if input_offset == 128, no adjustment is needed. + for (int z = 0; z < 4; ++z) { + adjusted_bias_data_0[z] = bias_data[z]; + adjusted_bias_data_1[z] = bias_data[4 + z]; + for (int x = 0; x < 4; ++x) { + adjusted_bias_data_0[z] += + input_offset_difference * filter_bank_a_0[z][x]; + adjusted_bias_data_0[z] += + input_offset_difference * filter_bank_a_1[z][x]; + adjusted_bias_data_0[z] += + input_offset_difference * filter_bank_a_2[z][x]; + adjusted_bias_data_1[z] += + input_offset_difference * filter_bank_b_0[z][x]; + adjusted_bias_data_1[z] += + input_offset_difference * filter_bank_b_1[z][x]; + adjusted_bias_data_1[z] += + input_offset_difference * filter_bank_b_2[z][x]; + + adjusted_bias_data[z] = adjusted_bias_data_0[z]; + adjusted_bias_data[4 + z] = adjusted_bias_data_1[z]; + } + } + bias_data += 2 * bias_increment; + adjusted_bias_data += 8; + } + } +}; + +#ifdef USE_NEON +template <> +struct ProcessPerDepth< + DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct> { + static void ProcessPerDepthIntrinsics( + const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + const int depth = function_params->output_depth; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int bias_increment = function_params->bias_increment; + + constexpr int kSymmetricZeroPoint = 128; + constexpr uint8 kSignBit = 0x80; + const int32 input_offset = function_params->input_offset; + TFLITE_DCHECK_GE(input_offset, -255); + TFLITE_DCHECK_LE(input_offset, 0); + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int8x16_t ones_vector = vdupq_n_s8(1); + + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + + // Register pairs for each height. + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + const uint8* filter_block = filter_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Filter data is provided as filter_block[3][3][depth/8][2][4]. + // height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4. + // filter_bank[3][2][4][4]; Sub-block, height 3, depth 4, width 4. + + // Load zero-point into effective position of zero-padding of filter + // (register B, upper part). + filter_reg_0_b = vdupq_n_u8(kSignBit); + filter_reg_1_b = vdupq_n_u8(kSignBit); + filter_reg_2_b = vdupq_n_u8(kSignBit); + + const uint8* filter_block_ptr = filter_block; + filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 0); + filter_block_ptr += depth; + filter_reg_0_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_b, 0); + filter_block_ptr += depth; + filter_reg_0_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_0_a, 1); + filter_block_ptr += depth; + filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 0); + filter_block_ptr += depth; + filter_reg_1_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_b, 0); + filter_block_ptr += depth; + filter_reg_1_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_1_a, 1); + filter_block_ptr += depth; + filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 0); + filter_block_ptr += depth; + filter_reg_2_b = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_b, 0); + filter_block_ptr += depth; + filter_reg_2_a = vld1q_lane_s8x8(filter_block_ptr, filter_reg_2_a, 1); + + filter_reg_0_a = veorq_s8(filter_reg_0_a, sign_bit); + filter_reg_0_b = veorq_s8(filter_reg_0_b, sign_bit); + filter_reg_1_a = veorq_s8(filter_reg_1_a, sign_bit); + filter_reg_1_b = veorq_s8(filter_reg_1_b, sign_bit); + filter_reg_2_a = veorq_s8(filter_reg_2_a, sign_bit); + filter_reg_2_b = veorq_s8(filter_reg_2_b, sign_bit); + + vzipq_s8_in_place(&filter_reg_0_a, &filter_reg_0_b); + vzipq_s8_in_place(&filter_reg_1_a, &filter_reg_1_b); + vzipq_s8_in_place(&filter_reg_2_a, &filter_reg_2_b); + vzipq_s8x2_in_place(&filter_reg_0_a, &filter_reg_0_b); + vzipq_s8x2_in_place(&filter_reg_1_a, &filter_reg_1_b); + vzipq_s8x2_in_place(&filter_reg_2_a, &filter_reg_2_b); + + vst1q_s8(shuffled_filter_data, filter_reg_0_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_0_b); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_1_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_1_b); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_2_a); + shuffled_filter_data += 16; + vst1q_s8(shuffled_filter_data, filter_reg_2_b); + shuffled_filter_data += 16; + + int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data); + bias_data += bias_increment; + int32x4_t adjusted_bias_data_b = vld1q_s32(bias_data); + bias_data += bias_increment; + // For instance, if input_offset == 128, no adjustment is needed. + + int32x4_t filter_sum_a = vdupq_n_s32(0); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_0_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_1_a, ones_vector); + filter_sum_a = vdotq_s32(filter_sum_a, filter_reg_2_a, ones_vector); + int32x4_t filter_sum_b = vdupq_n_s32(0); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_0_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_1_b, ones_vector); + filter_sum_b = vdotq_s32(filter_sum_b, filter_reg_2_b, ones_vector); + + adjusted_bias_data_a = vmlaq_n_s32(adjusted_bias_data_a, filter_sum_a, + input_offset_difference); + adjusted_bias_data_b = vmlaq_n_s32(adjusted_bias_data_b, filter_sum_b, + input_offset_difference); + + vst1q_s32(adjusted_bias_data, adjusted_bias_data_a); + adjusted_bias_data += 4; + vst1q_s32(adjusted_bias_data, adjusted_bias_data_b); + adjusted_bias_data += 4; + + filter_block += 8; + } + } + + static inline void Run(const uint8* filter_data, const int32* bias_data, + int8* shuffled_filter_data, int32* adjusted_bias_data, + const DepthwiseConvDotProdParams* function_params) { + ProcessPerDepthIntrinsics(filter_data, bias_data, shuffled_filter_data, + adjusted_bias_data, function_params); + } +}; +#endif + +template +struct PackMacroBlock { + // A straight copy of a macro block of input data into a scratch buffer. + // + // Requirement: depth_micro_repeats > 0. + static inline void CopyMacroBlock( + int32 height_block_number, int32 width_block_number, + const DepthwiseConvDotProdParams& function_params, + const uint8* input_block_data, int8* scratch_block_data) { + TFLITE_DCHECK_LE(max_padding, 1); + + // Strides. + // The input depth and count of micro blocks provide the width strides. + const int input_height_stride = function_params.input_height_stride; + const int workspace_height_stride = function_params.workspace_height_stride; + const int input_depth = function_params.input_depth; + const int depth_micro_repeats = function_params.depth_micro_repeats; + TFLITE_DCHECK_GT(depth_micro_repeats, 0); + + // Remaining iteration and dimension parameters. + // + // If width_overall_micro_repeats = input_width_micro_repeats + 1, then the + // final micro block is incomplete. + const int width_overall_micro_repeats = + function_params.input_width_overall_micro_repeats; + int input_width_micro_repeats = function_params.input_width_micro_repeats; + const int residual_width = function_params.residual_width; + const int block_height = function_params.inbound_block_height; + + const int padding_left = function_params.padding_left; + const int padding_right = function_params.padding_right; + const int padding_top = function_params.padding_top; + const int padding_bottom = function_params.padding_bottom; + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params.width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params.height_macro_count - 1); + + // Modify the trailing case to reflect the input width. + int input_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + if (trailing_width_padding) { + input_residual_width -= 1; + input_width_micro_repeats = width_overall_micro_repeats - 1; + } + + constexpr int kSymmetricZeroPoint = 128; + const int32 input_offset_difference = + function_params.input_offset + kSymmetricZeroPoint; + + // We load data into a temporary buffer and then save, to match subsequent + // processing. This will make it easier to combine stages into one ASM + // routine. + int8 tmp_load[4][2][4]; + + int copy_block_height = block_height; + if (leading_height_padding) { + memset(scratch_block_data, -input_offset_difference, + workspace_height_stride); + scratch_block_data += workspace_height_stride; + input_block_data += input_height_stride; + copy_block_height -= 1; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + // The outer 3 loops go through all the micro blocks in a macro block. + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) { + // Figure out division of work (available input vs trailing padding). + int adjusted_residual_width = + j_width == input_width_micro_repeats ? input_residual_width : 4; + + int start_width = 0; + if (leading_width_padding && j_width == 0) { + start_width = 1; + memset(tmp_load[0][0], -input_offset_difference, 8); + } + if (adjusted_residual_width < 4) { + for (int x = adjusted_residual_width; x < 4; ++x) { + memset(tmp_load[x][0], -input_offset_difference, 8); + } + } + + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + // The inner 3 loops go through the sub-block, depth and width within + // each micro block. + + // Load, and apply symmetric offset. + int8* scratch_data = + scratch_block_data + k_height * workspace_height_stride + + j_width * 4 * 8 + i_depth * 4 * 8 * width_overall_micro_repeats; + const uint8* input_data = input_block_data + + k_height * input_height_stride + + j_width * 4 * input_depth + i_depth * 8; + // Full-size macro blocks are 2*4*4 = 32 bytes. + for (int x = start_width; x < adjusted_residual_width; ++x) { + for (int s = 0; s < 2; ++s) { + for (int d = 0; d < 4; ++d) { + tmp_load[x][s][d] = input_data[x * input_depth + 4 * s + d] - + kSymmetricZeroPoint; + } + } + } + + // Save results. + memcpy(&scratch_data[0], tmp_load[0][0], 8); + memcpy(&scratch_data[8], tmp_load[1][0], 8); + memcpy(&scratch_data[16], tmp_load[2][0], 8); + memcpy(&scratch_data[24], tmp_load[3][0], 8); + } + } + } + + if (trailing_height_padding) { + memset(scratch_block_data + copy_block_height * workspace_height_stride, + -input_offset_difference, workspace_height_stride); + } + } + + // Transpose 4x4 blocks within each sub-micro-block. + // + // Implemented somewhat like NEON register manipulation, so that we can see + // equivalence of the two approaches. + static inline void MicroTransposeBlocks( + const DepthwiseConvDotProdParams& function_params, + int8* scratch_block_data) { + const int workspace_height_stride = function_params.workspace_height_stride; + const int width_overall_micro_repeats = + function_params.input_width_overall_micro_repeats; + const int depth_micro_repeats = function_params.depth_micro_repeats; + const int block_height = function_params.inbound_block_height; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in the + // NEON code we are simulating. + int8 tmp_load[4][2][4]; // [width][sub-block][depth] + int8 tmp_transposed[4][2][4]; // [depth][sub-block][width] + int8 tmp_interleaved[2][4][4]; // [sub-block][depth][width] + + // The outer 3 loops go through all the micro blocks in a macro block. + for (int k_height = 0; k_height < block_height; ++k_height) { + for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) { + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + int8* scratch_data = + scratch_block_data + k_height * workspace_height_stride + + j_width * 4 * 8 + i_depth * 4 * 8 * width_overall_micro_repeats; + // A. Load data + memcpy(tmp_load[0][0], &scratch_data[0], 8); + memcpy(tmp_load[1][0], &scratch_data[8], 8); + memcpy(tmp_load[2][0], &scratch_data[16], 8); + memcpy(tmp_load[3][0], &scratch_data[24], 8); + + // B. Simulate between-register transposition. + for (int x = 0; x < 4; ++x) { + for (int y = 0; y < 4; ++y) { + tmp_transposed[x][0][y] = tmp_load[y][0][x]; + tmp_transposed[x][1][y] = tmp_load[y][1][x]; + } + } + + // C. Simulate between-register interleaving. + for (int x = 0; x < 4; ++x) { + for (int y = 0; y < 4; ++y) { + tmp_interleaved[0][x][y] = tmp_transposed[x][0][y]; + tmp_interleaved[1][x][y] = tmp_transposed[x][1][y]; + } + } + // D. Simulate mangled storage arrangement. + memcpy(&scratch_data[0], tmp_interleaved[0][0], 16); + memcpy(&scratch_data[16], tmp_interleaved[1][0], 16); + } + } + } + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + CopyMacroBlock(height_block_number, width_block_number, *function_params, + input_block_data, scratch_block_data); + MicroTransposeBlocks(*function_params, scratch_block_data); + } +}; + +template +struct PackMacroBlock { + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + // Currently support for padding is limited to 1 on any side. + TFLITE_DCHECK_LE(max_padding, 1); + + // Strides. + // The count of micro blocks (below) provides the width strides. + const int input_height_stride = function_params->input_height_stride; + const int workspace_height_stride = + function_params->workspace_height_stride; + + // Remaining iteration and dimension parameters. + // + // If width_overall_micro_repeats = input_width_micro_repeats + 1, then the + // final micro block is incomplete. + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int residual_width = function_params->residual_width; + const int block_height = function_params->inbound_block_height; + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + constexpr int kSymmetricZeroPoint = 128; + const int32 input_offset_difference = + function_params->input_offset + kSymmetricZeroPoint; + + int copy_block_height = block_height; + if (leading_height_padding) { + memset(scratch_block_data, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_block_data += workspace_height_stride; + input_block_data += input_height_stride; + copy_block_height -= 1; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + if (trailing_width_padding) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding) { + start_width = 1; + input_block_data += 1; + } + + const int copy_size = (width_overall_micro_repeats - 1) * 4 + + adjusted_residual_width - start_width; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + // When there is unit input depth, the micro-block iteration need only be + // through the height. The micro blocks are contiguous across the width. + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + const uint8* input_data = + input_block_data + k_height * input_height_stride; + int8* scratch_data = + scratch_block_data + k_height * workspace_height_stride; + + // Handle leading padding. This is overwritten if there is no padding. + scratch_data[0] = -input_offset_difference; + + memcpy(&scratch_data[start_width], input_data, copy_size); + for (int i = 0; i < copy_size; ++i) { + scratch_data[start_width + i] += -kSymmetricZeroPoint; + } + + // Handle trailing padding, and fill in remainder of micro block. + memset(&scratch_data[start_width + copy_size], -input_offset_difference, + 4 - adjusted_residual_width + kWorkspaceExtension); + } + + if (trailing_height_padding) { + memset(scratch_block_data + copy_block_height * workspace_height_stride, + -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + } + } +}; + +// Beginning of code section containing intermediate code transformation. +// +// This section is only compiled when kUseUnwound3x3DotProduct versions of +// templated functions are selected. +template <> +struct PackMacroBlock { + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + TFLITE_DCHECK_GE(depth_micro_repeats, 0); + constexpr int kSymmetricZeroPoint = 128; + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in the + // NEON code we are simulating. Note the blocks of 4x4 are still interleaved + // down the depth. + int8 tmp_load[4][2][4]; + int8 tmp_transposed[4][2][4]; + int8 tmp_interleaved[2][4][4]; + + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_block_data; + for (int k_height = 0; k_height < block_height; ++k_height) { + const uint8* input_data = input_block_data; + input_block_data += input_height_stride; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < input_width_micro_repeats; ++j_width) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + // A. Simulate register loading. + for (int x = 0; x < 4; ++x) { + for (int s = 0; s < 2; ++s) { + for (int d = 0; d < 4; ++d) { + tmp_load[x][s][d] = input_data[x * input_depth + 4 * s + d] - + kSymmetricZeroPoint; + } + } + } + // B. Simulate between-register transposition. + for (int x = 0; x < 4; ++x) { + for (int y = 0; y < 4; ++y) { + tmp_transposed[x][0][y] = tmp_load[y][0][x]; + tmp_transposed[x][1][y] = tmp_load[y][1][x]; + } + } + + // C and D are to be performed together as 4-byte stores in NEON code. + // C. Simulate between-register interleaving. + for (int x = 0; x < 4; ++x) { + for (int y = 0; y < 4; ++y) { + tmp_interleaved[0][x][y] = tmp_transposed[x][0][y]; + tmp_interleaved[1][x][y] = tmp_transposed[x][1][y]; + } + } + // D. Simulate mangled storage arrangement. + memcpy(&scratch_data[0], tmp_interleaved[0][0], 8); + memcpy(&scratch_data[8], tmp_interleaved[0][2], 8); + memcpy(&scratch_data[16], tmp_interleaved[1][0], 8); + memcpy(&scratch_data[24], tmp_interleaved[1][2], 8); + + scratch_data += depth_advance; + input_data += 8; + } + scratch_data += width_advance; + input_data += input_depth_skip; + } + if (width_overall_micro_repeats > input_width_micro_repeats) { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, + input_width_micro_repeats + 1); + TFLITE_DCHECK_GT(residual_width, 0); + // Figure out division of work (available input vs zero-ed). + const int adjusted_residual_width = residual_width; + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + // A. Simulate register loading. + for (int x = 0; x < adjusted_residual_width; ++x) { + for (int s = 0; s < 2; ++s) { + for (int d = 0; d < 4; ++d) { + tmp_load[x][s][d] = input_data[x * input_depth + 4 * s + d] - + kSymmetricZeroPoint; + } + } + } + for (int x = adjusted_residual_width; x < 4; ++x) { + for (int s = 0; s < 2; ++s) { + for (int d = 0; d < 4; ++d) { + tmp_load[x][s][d] = 0; + } + } + } + // B. Simulate between-register transposition. + for (int x = 0; x < 4; ++x) { + for (int y = 0; y < 4; ++y) { + tmp_transposed[x][0][y] = tmp_load[y][0][x]; + tmp_transposed[x][1][y] = tmp_load[y][1][x]; + } + } + + // C and D are to be performed together as 4-byte stores in NEON code. + // C. Simulate between-register interleaving. + for (int x = 0; x < 4; ++x) { + for (int y = 0; y < 4; ++y) { + tmp_interleaved[0][x][y] = tmp_transposed[x][0][y]; + tmp_interleaved[1][x][y] = tmp_transposed[x][1][y]; + } + } + // D. Simulate mangled storage arrangement. + memcpy(&scratch_data[0], tmp_interleaved[0][0], 8); + memcpy(&scratch_data[8], tmp_interleaved[0][2], 8); + memcpy(&scratch_data[16], tmp_interleaved[1][0], 8); + memcpy(&scratch_data[24], tmp_interleaved[1][2], 8); + + scratch_data += depth_advance; + input_data += 8; + } + scratch_data += width_advance; + input_data += input_depth_skip; + } + scratch_data += height_advance; + } + + TFLITE_DCHECK_EQ(scratch_data, scratch_block_data + + block_height * workspace_height_stride); + } +}; + +template <> +struct PackMacroBlock { + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + // Just use C model code for case of padding. Optimized versions merge the + // modifications therein to handle padding. + PackMacroBlock::Run(height_block_number, + width_block_number, input_block_data, + scratch_block_data, function_params); + } +}; + +template +struct PackMacroBlock { + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + constexpr int kSymmetricZeroPoint = 128; + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + if (trailing_width_padding) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding) { + start_width = 1; + input_block_data += 1; + } + + const int copy_size = (width_overall_micro_repeats - 1) * 4 + + adjusted_residual_width - start_width; + // Adjusted so that later conditionals are simplified. + const int copy_size_adjusted = + trailing_width_padding ? copy_size + 1 : copy_size; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + // This is used to simulate what should happen in registers. + int8 tmp_data[16]; + + int scratch_data_offset = 0; + int input_block_offset = 0; + + if (copy_size >= 16) { + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + memcpy(tmp_data + 1, input_block_data + input_block_offset, 15); + for (int i = 0; i < 16; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + tmp_data[0] = -input_offset_difference; + memcpy(scratch_data, tmp_data, 16); + copy_done += 15; + } + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + memcpy(tmp_data, input_block_data + input_block_offset + copy_done, + 16); + for (int i = 0; i < 16; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0); + memcpy(&scratch_data[start_width + copy_done], tmp_data, 16); + } + + const int copy_remaining = copy_size - copy_done; + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + memcpy(tmp_data, + input_block_data + input_block_offset + copy_done - + (16 - copy_remaining), + 16); + // Shift to select the part that we need. + for (int i = 0; i < copy_remaining; ++i) { + tmp_data[i] = tmp_data[(16 - copy_remaining) + i]; + } + for (int i = 0; i < 16; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + // Apply padding to remainder, some unnecessary but costless in regs. + for (int i = copy_remaining; i < 16; ++i) { + tmp_data[i] = -input_offset_difference; + } + const int final_repeats = + width_overall_micro_repeats - (start_width + copy_done) / 4; + for (int i = 0; i < final_repeats; ++i) { + memcpy(&scratch_data[start_width + copy_done], tmp_data + 4 * i, 4); + copy_done += 4; + } + } + memset(scratch_data + start_width + copy_done, -input_offset_difference, + kWorkspaceExtension); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + memcpy(tmp_data + 1, input_block_data + input_block_offset, 3); + for (int i = 0; i < 4; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + tmp_data[0] = -input_offset_difference; + memcpy(scratch_data, tmp_data, 4); + copy_done += 3; + } + + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + memcpy(tmp_data, input_block_data + input_block_offset + copy_done, + 4); + for (int i = 0; i < 4; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + // Perform as 4 int32 stores, because that is our alignment. + memcpy(&scratch_data[start_width + copy_done], tmp_data, 4); + } + + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + const int copy_remaining = copy_size - copy_done; + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + memcpy(tmp_data, + input_block_data + input_block_offset + copy_done - + (4 - copy_remaining), + 4); + // Shift to select the part that we need. + for (int i = 0; i < copy_remaining; ++i) { + tmp_data[i] = tmp_data[(4 - copy_remaining) + i]; + } + for (int i = 0; i < 4; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + // Apply padding to remainder, some unnecessary but costless in regs. + for (int i = copy_remaining; i < 4; ++i) { + tmp_data[i] = -input_offset_difference; + } + memcpy(&scratch_data[start_width + copy_done], tmp_data, 4); + copy_done += 4; + } + memset(scratch_data + start_width + copy_done, -input_offset_difference, + kWorkspaceExtension); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (width_overall_micro_repeats == 2) { + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Apply padding by quick fill of whole reg. + for (int i = 0; i < 8; ++i) { + tmp_data[i] = -input_offset; + } + for (int i = 0; i < copy_size; ++i) { + // Apply shift-left insert, tmp_data as both operands. + // The zero-index byte is left unchanged. + for (int i = 7; i > 0; --i) { + tmp_data[i] = tmp_data[i - 1]; + } + tmp_data[1] = + input_block_data[input_block_offset + (copy_size - 1 - i)]; + } + if (!leading_width_padding) { + // Remove leading padding, junking trailing byte, OK because max size + // is less than 8. + TFLITE_DCHECK_LT(copy_size_adjusted + start_width, 8); + for (int i = 0; i < 7; ++i) { + tmp_data[i] = tmp_data[i + 1]; + } + } + for (int i = 0; i < 8; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + memcpy(scratch_data_base + scratch_data_offset, tmp_data, 8); + memset(scratch_data_base + scratch_data_offset + 8, + -input_offset_difference, kWorkspaceExtension); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + // This path is basically the same as the preceding, 2-micro-block one, + // but here we simply store fewer bytes. + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Apply padding by quick fill of whole reg. + for (int i = 0; i < 8; ++i) { + tmp_data[i] = -input_offset; + } + for (int i = 0; i < copy_size; ++i) { + // Apply shift-left insert, tmp_data as both operands. + // The zero-index byte is left unchanged. + for (int i = 7; i > 0; --i) { + tmp_data[i] = tmp_data[i - 1]; + } + tmp_data[1] = + input_block_data[input_block_offset + (copy_size - 1 - i)]; + } + if (!leading_width_padding) { + // Remove leading padding, junking trailing byte, OK because max size + // is less than 8. + TFLITE_DCHECK_LT(copy_size_adjusted + start_width, 8); + for (int i = 0; i < 7; ++i) { + tmp_data[i] = tmp_data[i + 1]; + } + } + for (int i = 0; i < 8; ++i) { + tmp_data[i] += -kSymmetricZeroPoint; + } + memcpy(scratch_data_base + scratch_data_offset, tmp_data, 4); + memset(scratch_data_base + scratch_data_offset + 4, + -input_offset_difference, kWorkspaceExtension); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + if (trailing_height_padding) { + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } +}; +// The preceding section is only compiled when kUseUnwound3x3DotProduct versions +// of templated functions are selected. +// +// End of code section containing intermediate code transformation. + +#ifdef USE_NEON +template <> +struct PackMacroBlock { + static inline void PackMacroBlockIntrinsics( + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); + TFLITE_DCHECK_EQ(function_params->padding_top, 0); + TFLITE_DCHECK_EQ(function_params->padding_left, 0); + TFLITE_DCHECK_EQ(function_params->padding_right, 0); + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + static const uint8 perm_data[64] = { + 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, // + 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, + 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, + 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63}; + + TFLITE_DCHECK_GE(depth_micro_repeats, 0); + constexpr uint8 kSignBit = 0x80; + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg_a; + int8x16_t work_reg_b; + const int8x16_t perm_data_0 = vld1q_u8(perm_data); + const int8x16_t perm_data_1 = vld1q_u8(perm_data + 16); + const int8x16_t perm_data_2 = vld1q_u8(perm_data + 32); + const int8x16_t perm_data_3 = vld1q_u8(perm_data + 48); + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + // Work through one slice, by row, at a time. + int8* scratch_data_0 = scratch_block_data; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const uint8* input_data_0 = input_block_data; + const uint8* input_data_1 = input_block_data + input_depth; + const uint8* input_data_2 = input_block_data + 2 * input_depth; + const uint8* input_data_3 = input_block_data + 3 * input_depth; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < input_width_micro_repeats; ++j_width) { + int i_depth = 0; + for (; i_depth < depth_micro_repeats - 1; i_depth += 2) { + int8x16x4_t input_data; + input_data.val[0] = vld1q_u8(input_data_0); + input_data.val[1] = vld1q_u8(input_data_1); + input_data.val[2] = vld1q_u8(input_data_2); + input_data.val[3] = vld1q_u8(input_data_3); + input_data_1 += 16; + input_data_0 += 16; + + int8x16_t tmp_0 = vqtbl4q_s8(input_data, perm_data_0); + int8x16_t tmp_1 = vqtbl4q_s8(input_data, perm_data_1); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_2 += 16; + input_data_3 += 16; + + tmp_0 = vqtbl4q_s8(input_data, perm_data_2); + tmp_1 = vqtbl4q_s8(input_data, perm_data_3); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + } + for (; i_depth < depth_micro_repeats; ++i_depth) { + int8x16x4_t input_data; + input_data.val[0] = + vld1q_lane_s8x8(input_data_0, input_data.val[0], 0); + input_data.val[1] = + vld1q_lane_s8x8(input_data_1, input_data.val[1], 0); + input_data.val[2] = + vld1q_lane_s8x8(input_data_2, input_data.val[2], 0); + input_data.val[3] = + vld1q_lane_s8x8(input_data_3, input_data.val[3], 0); + input_data_1 += 8; + input_data_0 += 8; + + int8x16_t tmp_0 = vqtbl4q_s8(input_data, perm_data_0); + int8x16_t tmp_1 = vqtbl4q_s8(input_data, perm_data_1); + work_reg_a = veorq_s8(tmp_0, sign_bit); + work_reg_b = veorq_s8(tmp_1, sign_bit); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + if (width_overall_micro_repeats > input_width_micro_repeats) { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, + input_width_micro_repeats + 1); + TFLITE_DCHECK_GT(residual_width, 0); + TFLITE_DCHECK_LT(residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(kSignBit); + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + work_reg_b = vdupq_n_u8(kSignBit); + if (residual_width > 1) { + work_reg_b = + vld1q_lane_s8x8(input_data_0 + input_depth, work_reg_b, 0); + if (residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + scratch_data_0 += height_advance; + input_block_data += input_height_stride; + } + TFLITE_DCHECK_EQ( + scratch_data_0, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { +#ifdef __aarch64__ + PreloadInputBlock(input_block_data, function_params); +#endif + + PackMacroBlockIntrinsics(input_block_data, scratch_block_data, + function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockIntrinsics( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + constexpr uint8 kSignBit = 0x80; + + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + const int input_depth = function_params->input_depth; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + TFLITE_DCHECK_GT(depth_micro_repeats, 0); + constexpr int kSymmetricZeroPoint = 128; + + const int micro_block_size = 4 * 8; + const int depth_advance = width_overall_micro_repeats * micro_block_size; + const int width_advance = + micro_block_size * + (1 - depth_micro_repeats * width_overall_micro_repeats); + const int height_advance = workspace_height_stride - + width_overall_micro_repeats * micro_block_size; + const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats; + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg_a; + int8x16_t work_reg_b; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + + // Work through one slice, by row, at a time. + int8* scratch_data_0 = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_0, -input_offset_difference, workspace_height_stride); + scratch_data_0 += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + const uint8* input_data_0 = input_block_data; + const uint8* input_data_1 = input_block_data + input_depth; + const uint8* input_data_2 = input_block_data + 2 * input_depth; + const uint8* input_data_3 = input_block_data + 3 * input_depth; + + // Traverse the width one point at a time, but the depth in (micro) blocks + // of size 8. + // + // The depth and width margins, which are filled with "zeros", may be + // larger than is strictly needed to calculate output. This is because the + // conv calculation is performed across complete micro blocks. + for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) { + // Figure out division of work (available input vs zero-ed). + int adjusted_residual_width = + j_width == (input_width_micro_repeats) ? residual_width : 4; + + if (trailing_width_padding && + j_width == (width_overall_micro_repeats - 1)) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding && j_width == 0) { + start_width = 1; + } + if (start_width == 0) { + if (adjusted_residual_width == 4) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vld1q_lane_s8x8(input_data_2, work_reg_a, 1); + work_reg_b = vld1q_lane_s8x8(input_data_3, work_reg_b, 1); + work_reg_b = vld1q_lane_s8x8(input_data_1, work_reg_b, 0); + input_data_1 += 8; + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + input_data_0 += 8; + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + scratch_data_0 += 16; + vst1q_s8(scratch_data_0, work_reg_b); + + scratch_data_0 += depth_advance - 16; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } else { + TFLITE_DCHECK_LT(adjusted_residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + work_reg_b = vdupq_n_u8(-input_offset); + if (adjusted_residual_width > 0) { + work_reg_a = vld1q_lane_s8x8(input_data_0, work_reg_a, 0); + if (adjusted_residual_width > 1) { + work_reg_b = vld1q_lane_s8x8(input_data_0 + input_depth, + work_reg_b, 0); + if (adjusted_residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + } else { + if (adjusted_residual_width == 4) { + // Load, then zero. + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + work_reg_a = vld1q_lane_s8x8(input_data_2, work_reg_a, 1); + work_reg_b = vld1q_lane_s8x8(input_data_3, work_reg_b, 1); + work_reg_b = vld1q_lane_s8x8(input_data_1, work_reg_b, 0); + input_data_1 += 8; + // Skip loading first column. + input_data_0 += 8; + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + scratch_data_0 += 16; + vst1q_s8(scratch_data_0, work_reg_b); + + scratch_data_0 += depth_advance - 16; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } else { + TFLITE_DCHECK_LT(adjusted_residual_width, 4); + for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { + work_reg_a = vdupq_n_u8(-input_offset); + // Skip loading first column. + work_reg_b = vdupq_n_u8(-input_offset); + if (adjusted_residual_width > 1) { + work_reg_b = + vld1q_lane_s8x8(input_data_0 + input_depth, work_reg_b, 0); + if (adjusted_residual_width == 3) { + work_reg_a = vld1q_lane_s8x8(input_data_0 + 2 * input_depth, + work_reg_a, 1); + } + } + work_reg_a = veorq_s8(work_reg_a, sign_bit); + work_reg_b = veorq_s8(work_reg_b, sign_bit); + + vzipq_s8_in_place(&work_reg_a, &work_reg_b); + vzipq_s8x2_in_place(&work_reg_a, &work_reg_b); + + vst1q_s8(scratch_data_0, work_reg_a); + vst1q_s8(scratch_data_0 + 16, work_reg_b); + + scratch_data_0 += depth_advance; + input_data_0 += 8; + input_data_1 += 8; + input_data_2 += 8; + input_data_3 += 8; + } + scratch_data_0 += width_advance; + input_data_0 += input_depth_skip; + input_data_1 += input_depth_skip; + input_data_2 += input_depth_skip; + input_data_3 += input_depth_skip; + } + } + } + scratch_data_0 += height_advance; + input_block_data += input_height_stride; + } + + if (trailing_height_padding) { + memset(scratch_data_0, -input_offset_difference, workspace_height_stride); + scratch_data_0 += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_0, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { +#ifdef __aarch64__ + PreloadInputBlock(input_block_data, function_params); +#endif + + PackMacroBlockIntrinsics(height_block_number, width_block_number, + input_block_data, scratch_block_data, + function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockIntrinsics( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + const int padding_left = function_params->padding_left; + const int padding_right = function_params->padding_right; + const int padding_top = function_params->padding_top; + const int padding_bottom = function_params->padding_bottom; + + constexpr int kSymmetricZeroPoint = 128; + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + const bool leading_width_padding = + padding_left > 0 && width_block_number == 0; + const bool trailing_width_padding = + padding_right > 0 && + width_block_number == (function_params->width_macro_count - 1); + const bool leading_height_padding = + padding_top > 0 && height_block_number < 0; + const bool trailing_height_padding = + padding_bottom > 0 && + height_block_number == (function_params->height_macro_count - 1); + + const int32 input_offset = function_params->input_offset; + const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + int copy_block_height = block_height; + if (leading_height_padding) { + copy_block_height -= 1; + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + input_block_data += input_height_stride; + } + if (trailing_height_padding) { + copy_block_height -= 1; + } + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + if (trailing_width_padding) { + adjusted_residual_width -= 1; + } + int start_width = 0; + if (leading_width_padding) { + start_width = 1; + input_block_data += 1; + } + + const int copy_size = (width_overall_micro_repeats - 1) * 4 + + adjusted_residual_width - start_width; + // Adjusted so that later conditionals are simplified. + const int copy_size_adjusted = + trailing_width_padding ? copy_size + 1 : copy_size; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + int scratch_data_offset = 0; + int input_block_offset = 0; + + constexpr uint8 kSignBit = 0x80; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg; + int8x8_t half_work_reg; + int8x8_t padding_mask; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + const uint8x16_t padding_reg = vdupq_n_u8(-input_offset); + padding_mask = vdup_n_s8(-1); + half_work_reg = vdup_n_s8(0); + + if (copy_size >= 16) { + const int copy_remaining = (copy_size + start_width) & 0x7; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + work_reg = vld1q_u8(input_block_data + input_block_offset); + work_reg = vextq_s8(padding_reg, work_reg, 15); + work_reg = veorq_s8(work_reg, sign_bit); + vst1q_s8(scratch_data, work_reg); + copy_done += 15; + } + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + work_reg = + vld1q_u8(input_block_data + input_block_offset + copy_done); + work_reg = veorq_s8(work_reg, sign_bit); + TFLITE_DCHECK_EQ((start_width + copy_done) % 16, 0); + vst1q_s8(scratch_data + start_width + copy_done, work_reg); + } + + if (copy_done + 8 <= copy_size) { + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_done); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + copy_done += 8; + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_size - 8); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 8, 0); + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + } + + // Trailing guard. + vst1_s8(scratch_data + start_width + copy_done, half_work_reg); + vst1_s8(scratch_data + start_width + copy_done + 8, half_work_reg); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + const int copy_remaining = (copy_size + start_width) & 0x3; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // The surrounding condition ensures that we always need at least one + // iteration of the main copy loop. In the case of leading width + // padding, we unroll this specially. + if (leading_width_padding) { + half_work_reg = vld1_lane_8x4(input_block_data + input_block_offset, + half_work_reg, 0); + half_work_reg = vext_s8(vget_low_s8(padding_reg), half_work_reg, 7); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + vst1_lane_8x4(scratch_data, half_work_reg, 0); + copy_done += 3; + } + + // Main copy loop. + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + // Important! Most compilation configurations will compile and run + // without the reinterpret_cast. Sanitizers may fail silently on + // lane-loading, with a obscure bug or mis-feature probably in + // unhygienic macro expansion. + half_work_reg = + vld1_lane_8x4(input_block_data + input_block_offset + copy_done, + half_work_reg, 0); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, + 0); + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size_adjusted) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = vld1_lane_8x4( + input_block_data + input_block_offset + copy_size - 4, + half_work_reg, 0); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining))); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ((start_width + copy_done) % 4, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, + 0); + copy_done += 4; + } + // Trailing guard. + vst1_lane_8x4(scratch_data + start_width + copy_done, half_work_reg, 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 4, half_work_reg, + 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 8, half_work_reg, + 0); + vst1_lane_8x4(scratch_data + start_width + copy_done + 12, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (width_overall_micro_repeats == 2) { + // Special case of 1 + 3 + 1, padding + copy + padding. + // This is rarely executed in practice. + TFLITE_DCHECK_EQ(copy_size, 3); + TFLITE_DCHECK_EQ(start_width, 1); + TFLITE_DCHECK(leading_width_padding); + TFLITE_DCHECK(trailing_width_padding); + // ASM should use MOVI 64-bit set. + padding_mask = vcreate_u64(~0xffffff00L); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + half_work_reg = vld1_lane_s8(reinterpret_cast( + input_block_data + input_block_offset), + half_work_reg, 1); + half_work_reg = + vld1_lane_s8(reinterpret_cast(input_block_data + + input_block_offset + 1), + half_work_reg, 2); + half_work_reg = + vld1_lane_s8(reinterpret_cast(input_block_data + + input_block_offset + 2), + half_work_reg, 3); + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 8, 0); + vst1_s8(scratch_data_base + scratch_data_offset, half_work_reg); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + const int copy_remaining = (copy_size + start_width) & 0x3; + padding_mask = vshl_u64(padding_mask, vdup_n_s64(8 * copy_remaining)); + if (leading_width_padding) { + padding_mask = vset_lane_u8(255, padding_mask, 0); + } + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int i = 0; i < copy_size; ++i) { + half_work_reg = vshl_n_u64(half_work_reg, 8); + half_work_reg = vld1_lane_s8( + reinterpret_cast( + input_block_data + input_block_offset + copy_size - 1 - i), + half_work_reg, 0); + } + if (leading_width_padding) { + half_work_reg = vshl_n_s64(half_work_reg, 8); + } + half_work_reg = + vbsl_s8(padding_mask, vget_low_s8(padding_reg), half_work_reg); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset, half_work_reg, + 0); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + if (trailing_height_padding) { + memset(scratch_data_base, -input_offset_difference, + workspace_height_stride + kWorkspaceExtension); + scratch_data_base += workspace_height_stride; + } + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { +#ifdef __aarch64__ + PreloadInputBlock(input_block_data, function_params); +#endif + + PackMacroBlockIntrinsics(height_block_number, width_block_number, + input_block_data, scratch_block_data, + function_params); + } +}; + +template <> +struct PackMacroBlock { + static inline void PackMacroBlockIntrinsics( + int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int input_width_micro_repeats = + function_params->input_width_micro_repeats; + const int block_height = function_params->inbound_block_height; + const int residual_width = function_params->residual_width; + const int input_height_stride = function_params->input_height_stride; + + TFLITE_DCHECK_EQ(function_params->padding_left, 0); + TFLITE_DCHECK_EQ(function_params->padding_right, 0); + TFLITE_DCHECK_EQ(function_params->padding_top, 0); + TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); + + TFLITE_DCHECK_GE(workspace_height_stride, 4 * width_overall_micro_repeats); + + // Work through one slice, by row, at a time. + int8* scratch_data_base = scratch_block_data; + + const int copy_block_height = block_height; + + int adjusted_residual_width = + input_width_micro_repeats < width_overall_micro_repeats ? residual_width + : 4; + + const int copy_size = + (width_overall_micro_repeats - 1) * 4 + adjusted_residual_width; + + TFLITE_DCHECK_LE( + copy_size, + input_height_stride - width_block_number * input_width_micro_repeats); + // We may drop up to stride-1 of trailing input. + TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); + + int scratch_data_offset = 0; + int input_block_offset = 0; + + constexpr uint8 kSignBit = 0x80; + + // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON + // code. Note the blocks of 4x4 are still interleaved down the depth. + int8x16_t work_reg; + int8x8_t half_work_reg; + + // Effect subtraction of zero-point = 128 by XOR of sign bit. + const uint8x16_t sign_bit = vdupq_n_u8(kSignBit); + half_work_reg = vdup_n_s8(0); + + if (copy_size >= 16) { + const int copy_remaining = copy_size & 0x7; + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // Main copy loop. + for (; (copy_done + 16) <= copy_size; copy_done += 16) { + work_reg = + vld1q_u8(input_block_data + input_block_offset + copy_done); + work_reg = veorq_s8(work_reg, sign_bit); + TFLITE_DCHECK_EQ(copy_done % 16, 0); + vst1q_s8(scratch_data + copy_done, work_reg); + } + + if (copy_done + 8 <= copy_size) { + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_done); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 8, 0); + vst1_s8(scratch_data + copy_done, half_work_reg); + copy_done += 8; + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size) { + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = + vld1_u8(input_block_data + input_block_offset + copy_size - 8); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (8 - copy_remaining))); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 8, 0); + vst1_s8(scratch_data + copy_done, half_work_reg); + copy_done += 8; + } + + // Trailing guard. + vst1_s8(scratch_data + copy_done, half_work_reg); + vst1_s8(scratch_data + copy_done + 8, half_work_reg); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else if (copy_size >= 4) { + const int copy_remaining = copy_size & 0x3; + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + // Work through one slice, by row, at a time. + int8* scratch_data = scratch_data_base + scratch_data_offset; + + int copy_done = 0; + + // Main copy loop. + for (; (copy_done + 4) <= copy_size; copy_done += 4) { + // Important! Most compilation configurations will compile and run + // without the reinterpret_cast. Sanitizers may fail silently on + // lane-loading, with a obscure bug or mis-feature probably in + // unhygienic macro expansion. + half_work_reg = + vld1_lane_8x4(input_block_data + input_block_offset + copy_done, + half_work_reg, 0); + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 4, 0); + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + } + + TFLITE_DCHECK_EQ(copy_remaining, copy_size - copy_done); + // Total amount + // = copy_size - copy_done + 4 - adjusted_residual_width + // = width_overall_micro_repeats * 4 - start_width - copy_done. + // Undone micro blocks + // = width_overall_micro_repeats - (start_width + copy_done) / 4. + + // Conditional is (copy_remaining > 0 || trailing_width_padding). + if (copy_done < copy_size) { + TFLITE_DCHECK_LT(copy_remaining, 4); + // Employ overlapping-load strategy in order to load full register, + // but use only part. + // This has the advantage of resulting in zeros after shifting. + half_work_reg = vld1_lane_8x4( + input_block_data + input_block_offset + copy_size - 4, + half_work_reg, 0); + + half_work_reg = + vshl_u64(half_work_reg, vdup_n_s64(-8 * (4 - copy_remaining))); + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(copy_done % 4, 0); + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + copy_done += 4; + } + // Trailing guard. + vst1_lane_8x4(scratch_data + copy_done, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 4, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 8, half_work_reg, 0); + vst1_lane_8x4(scratch_data + copy_done + 12, half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } else { + TFLITE_DCHECK_EQ(width_overall_micro_repeats, 1); + + for (int k_height = 0; k_height < copy_block_height; ++k_height) { + for (int i = 0; i < copy_size; ++i) { + half_work_reg = vshl_n_u64(half_work_reg, 8); + half_work_reg = vld1_lane_s8( + reinterpret_cast( + input_block_data + input_block_offset + copy_size - 1 - i), + half_work_reg, 0); + } + + half_work_reg = veor_s8(half_work_reg, vget_low_s8(sign_bit)); + TFLITE_DCHECK_EQ(scratch_data_offset % 4, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset, half_work_reg, + 0); + + // Trailing guard. + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 4, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 8, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 12, + half_work_reg, 0); + vst1_lane_8x4(scratch_data_base + scratch_data_offset + 16, + half_work_reg, 0); + + scratch_data_offset += workspace_height_stride; + input_block_offset += input_height_stride; + } + } + + scratch_data_base += copy_block_height * workspace_height_stride; + + TFLITE_DCHECK_EQ( + scratch_data_base, + scratch_block_data + block_height * workspace_height_stride); + } + + static inline void Run(int32 height_block_number, int32 width_block_number, + const uint8* input_block_data, + int8* scratch_block_data, + const DepthwiseConvDotProdParams* function_params) { +#ifdef __aarch64__ + PreloadInputBlock(input_block_data, function_params); +#endif + + PackMacroBlockIntrinsics(height_block_number, width_block_number, + input_block_data, scratch_block_data, + function_params); + } +}; + +#endif // ARM NEON + +// Apply filter to macro block of input data and store results. +// +// Requirement: depth_micro_repeats > 0 || residual_depth > 0. +template +struct KernelMacroBlock { + // Construct a width-shifted combination of two input sub-blocks, effectively + // concatenating them. + // + // The filter is applied using sub-blocks. These are in the needed form for + // the first (width) offset. For subsequent offsets, the filter is applied to + // shifted and combined data. The concatentation and shifting herein is fairly + // straightforward, but in the optimized code is an area of creativity in + // design because NEON instructions do not directly support the required + // between-register permutation. + // + // In NEON optimized code, input data is grouped in 4-byte blocks. In order to + // move along the width for each output point calculation, data is shifted, in + // essence between two such blocks. + // + // selected_data has format height 3, depth 4, width 4. + // + // When the micro block is trailing (the last across the macro-block width), + // it would be illegal to load the right (next) block, and the no_right_block + // indicates this scenario. + static inline void ConcatenateInputSubBlocks(int offset, int sub_block, + int workspace_height_stride, + int width_micro_stride, + bool no_right_block, + const int8* input_block, + int8 selected_data[3][4][4]) { + TFLITE_DCHECK_GE(offset, 0); + TFLITE_DCHECK_LT(offset, 4); + + // The input banks have same format as selected_data. + int8 left_bank[3][4][4]; + int8 right_bank[3][4][4]; + + // Work through one slice, by row, at a time. + for (int k_height = 0; k_height < 3; ++k_height) { + // Simulate demangling of mangled storage arrangement. + const int8* left_input_block = + &input_block[k_height * workspace_height_stride + sub_block * 2 * 8]; + memcpy(left_bank[k_height][0], left_input_block, 16); + if (no_right_block) { + memset(right_bank[k_height][0], 0, 16); + } else { + const int8* right_input_block = + &input_block[k_height * workspace_height_stride + + sub_block * 2 * 8 + width_micro_stride]; + memcpy(right_bank[k_height][0], right_input_block, 16); + } + for (int depth_index = 0; depth_index < 4; ++depth_index) { + memcpy(selected_data[k_height][depth_index], + &left_bank[k_height][depth_index][offset], 4 - offset); + memcpy(&selected_data[k_height][depth_index][4 - offset], + right_bank[k_height][depth_index], offset); + } + } + } + + // Straight implementation of 3x3 filter within sub-micro block. + static inline void Calculate3x3FilterOutput( + const DepthwiseConvDotProdParams& params, int sub_block, + const int8 selected_data[3][4][4], const int8 filter_bank[3][2][4][4], + const int32* bias_data, uint8 output_values[4]) { + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + const int32 output_multiplier = params.output_multiplier; + const int32 output_shift = params.output_shift; + const int32 output_offset = params.output_offset; + for (int d = 0; d < 4; ++d) { + int32 acc = 0; + for (int y = 0; y < 3; ++y) { + for (int x = 0; x < 4; ++x) { + int32 input_val = selected_data[y][d][x]; + int32 filter_val = filter_bank[y][sub_block][d][x]; + acc += filter_val * input_val; + } + } + acc += bias_data[d]; + acc = reference_ops::depthwise_conv::DepthwiseConvRound< + DepthwiseConvOutputRounding::kUpward>(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_values[d] = static_cast(acc); + } + } + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + constexpr int bias_increment = 4; + TFLITE_DCHECK_EQ(function_params->bias_increment, bias_increment); + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + // Simulate NEON-register transposition of subset of filter. + int8 filter_bank[3][2][4][4]; // Height 3, sub-block, depth 4, width 4. + // Simulate NEON-register input data concatenation + sub-selection. + int8 sub_selected_input_data[3][4][4]; // Height 3, depth 4, width 4. + uint8 output_values[4]; // Depth 4. + + // The outer 3 loops go through all the micro blocks in a macro block, and + // separately treat the two sub-blocks within each micro block. + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + memcpy(filter_bank[0][0][0], + filter_workspace + j_depth * shuffled_filter_increment, + shuffled_filter_increment); + + for (int s = 0; s < 2; ++s) { + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* scratch_data = + scratch_block_data + + workspace_height_stride * k_height * stride_val + + depth_micro_stride * j_depth; + uint8* output_data = + output_block_data + output_height_stride * k_height + 8 * j_depth; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + const bool no_right_block = (output_width - 1) * stride_val < 2; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + scratch_data + width_micro_stride * i_width; + // Iterate over input width shifts within sub-micro blocks. + for (int x = 0; x < output_width; ++x) { + ConcatenateInputSubBlocks(x * stride_val, s, + workspace_height_stride, + width_micro_stride, no_right_block, + input_data, sub_selected_input_data); + Calculate3x3FilterOutput( + *function_params, s, sub_selected_input_data, filter_bank, + bias_data + (2 * j_depth + s) * bias_increment, + output_values); + for (int d = 0; d < 4; ++d) { + output_data[depth * (four_over_stride * i_width + x) + 4 * s + + d] = output_values[d]; + } + } + } + } + } + } + } +}; + +// Apply filter to macro block of input data and store results. +// +// Parameters for repeats and residual sizes are in terms of outputs. +// +// Requirement: depth_micro_repeats > 0 || residual_depth > 0. +template +struct KernelMacroBlock { + // Construct a width-shifted combination of two input sub-blocks, effectively + // concatenating them. + // + // The filter is applied using sub-blocks. These are in the needed form for + // the first (width) offset. For subsequent offsets, the filter is applied to + // shifted and combined data. The concatentation and shifting herein is fairly + // straightforward, but in the optimized code is an area of creativity in + // design because NEON instructions do not directly support the required + // between-register permutation. + // + // In NEON optimized code, input data is grouped in 4-byte blocks. In order to + // move along the width for each output point calculation, data is shifted, in + // essence between two such blocks. + // + // selected_data has format height 3, width 4. + // + // When the micro block is trailing (the last across the macro-block width), + // it would be illegal to load the right (next) block, and the no_right_block + // indicates this scenario. + static inline void ConcatenateInputSubBlocks(int offset, + int workspace_height_stride, + bool no_right_block, + const int8* input_block, + int8 selected_data[3][4]) { + TFLITE_DCHECK_GE(offset, 0); + TFLITE_DCHECK_LT(offset, 4); + if (no_right_block) { + for (int k_height = 0; k_height < 3; ++k_height) { + memcpy(selected_data[k_height], + &input_block[k_height * workspace_height_stride + offset], + 4 - offset); + } + } else { + for (int k_height = 0; k_height < 3; ++k_height) { + memcpy(selected_data[k_height], + &input_block[k_height * workspace_height_stride + offset], 4); + } + } + } + + // Straight implementation of 3x3 filter within sub-micro block. + static inline void Calculate3x3FilterOutput( + const DepthwiseConvDotProdParams& function_params, int sub_block, + const int8 selected_data[3][4], const int8 filter_bank[3][2][4][4], + const int32* bias_data, uint8 output_values[4]) { + const int32 output_activation_min = + function_params.quantized_activation_min; + const int32 output_activation_max = + function_params.quantized_activation_max; + const int32 output_multiplier = function_params.output_multiplier; + const int32 output_shift = function_params.output_shift; + const int32 output_offset = function_params.output_offset; + for (int d = 0; d < 4; ++d) { + int32 acc = 0; + for (int y = 0; y < 3; ++y) { + for (int x = 0; x < 4; ++x) { + int32 input_val = selected_data[y][x]; + int32 filter_val = filter_bank[y][sub_block][d][x]; + acc += filter_val * input_val; + } + } + acc += bias_data[d]; + acc = reference_ops::depthwise_conv::DepthwiseConvRound< + DepthwiseConvOutputRounding::kUpward>(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_values[d] = static_cast(acc); + } + } + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->output_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int workspace_width_micro_repeats = + function_params->workspace_width_micro_repeats; + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + constexpr int bias_increment = 4; + TFLITE_DCHECK_EQ(function_params->bias_increment, bias_increment); + + TFLITE_DCHECK(depth_micro_repeats > 0); + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + // Simulate NEON-register transposition of subset of filter. + int8 filter_bank[3][2][4][4]; // Height 3, sub-block, depth 4, width 4. + // Simulate NEON-register input data concatenation + sub-selection. + int8 sub_selected_input_data[3][4]; // Height 3, depth 4, width 4. + uint8 output_values[4]; // Depth 4. + + // The outer 3 loops go through all the micro blocks in a macro block, and + // separately treat the two sub-blocks within each micro block. + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + memcpy(filter_bank[0][0][0], + filter_workspace + j_depth * shuffled_filter_increment, + shuffled_filter_increment); + + for (int s = 0; s < 2; ++s) { + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* scratch_data = + scratch_block_data + + workspace_height_stride * k_height * stride_val; + uint8* output_data = + output_block_data + output_height_stride * k_height + 8 * j_depth; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 * i_width; + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + ConcatenateInputSubBlocks(x * stride_val, workspace_height_stride, + no_right_block, input_data, + sub_selected_input_data); + Calculate3x3FilterOutput( + *function_params, s, sub_selected_input_data, filter_bank, + bias_data + (2 * j_depth + s) * bias_increment, + output_values); + for (int d = 0; d < 4; ++d) { + output_data[depth * (four_over_stride * i_width + x) + 4 * s + + d] = output_values[d]; + } + } + } + } + } + } + } +}; + +// Beginning of code section containing intermediate code transformation. +// +// This section is only compiled when kUseUnwound3x3DotProduct versions of +// templated functions are selected. +template +struct KernelMacroBlock { + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + + // Simulate NEON-register transposition of subset of filter. + int8 filter_bank_a_0[4][4]; // Depth 4, width 4. + int8 filter_bank_a_1[4][4]; + int8 filter_bank_a_2[4][4]; + int8 filter_bank_b_0[4][4]; + int8 filter_bank_b_1[4][4]; + int8 filter_bank_b_2[4][4]; + // Simulate NEON-register input data concatenation + sub-selection. + // Also sub-block, height 3, depth 4, width 4. + uint8 output_values[4]; // Sub-block, depth 4. + // selected_data has format Depth 4, width 4. + int8 left_bank_0[4][4]; + int8 left_bank_1[4][4]; + int8 left_bank_2[4][4]; + int8 right_bank_0[4][4]; + int8 right_bank_1[4][4]; + int8 right_bank_2[4][4]; + memset(right_bank_0[0], 0, 16); + memset(right_bank_1[0], 0, 16); + memset(right_bank_2[0], 0, 16); + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + const int8* filter_block = + filter_workspace + shuffled_filter_increment * j_depth; + + memcpy(filter_bank_a_0, filter_block, 16); + memcpy(filter_bank_b_0, filter_block + 16, 16); + memcpy(filter_bank_a_1, filter_block + 32, 16); + memcpy(filter_bank_b_1, filter_block + 48, 16); + memcpy(filter_bank_a_2, filter_block + 64, 16); + memcpy(filter_bank_b_2, filter_block + 80, 16); + + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* scratch_data = + scratch_block_data + + workspace_height_stride * k_height * stride_val + + depth_micro_stride * j_depth; + uint8* output_data = + output_block_data + output_height_stride * k_height + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + // Load first sub-micro block of data into operational banks. + memcpy(left_bank_0[0], input_data_0, 16); + memcpy(left_bank_1[0], input_data_0 + workspace_height_stride, 16); + memcpy(left_bank_2[0], input_data_0 + 2 * workspace_height_stride, + 16); + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = (output_width - 1) * stride_val < 2; + + // Load next sub-micro block of data. + if (!no_right_block) { + memcpy(right_bank_0[0], input_data + width_micro_stride, 16); + memcpy(right_bank_1[0], + input_data + workspace_height_stride + width_micro_stride, + 16); + memcpy( + right_bank_2[0], + input_data + 2 * workspace_height_stride + width_micro_stride, + 16); + } + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + // Operate on depth of 4 in batches. + for (int d = 0; d < 4; ++d) { + int32 acc = 0; + for (int x = 0; x < 4; ++x) { + int32 input_val = left_bank_0[d][x]; + int32 filter_val = filter_bank_a_0[d][x]; + acc += filter_val * input_val; + } + for (int x = 0; x < 4; ++x) { + int32 input_val = left_bank_1[d][x]; + int32 filter_val = filter_bank_a_1[d][x]; + acc += filter_val * input_val; + } + for (int x = 0; x < 4; ++x) { + int32 input_val = left_bank_2[d][x]; + int32 filter_val = filter_bank_a_2[d][x]; + acc += filter_val * input_val; + } + acc += bias_data[d]; + acc = reference_ops::depthwise_conv::DepthwiseConvRound< + DepthwiseConvOutputRounding::kUpward>( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_values[d] = static_cast(acc); + } + + for (int d = 0; d < 4; ++d) { + output_data[depth * (four_over_stride * i_width + x) + 4 * s + + d] = output_values[d]; + } + + // Simulate shifting instructions. + if (stride_val == 1) { + for (int depth_index = 0; depth_index < 4; ++depth_index) { + for (int z = 0; z < 3; ++z) { + left_bank_0[depth_index][z] = + left_bank_0[depth_index][z + 1]; + left_bank_1[depth_index][z] = + left_bank_1[depth_index][z + 1]; + left_bank_2[depth_index][z] = + left_bank_2[depth_index][z + 1]; + } + left_bank_0[depth_index][3] = right_bank_0[depth_index][0]; + left_bank_1[depth_index][3] = right_bank_1[depth_index][0]; + left_bank_2[depth_index][3] = right_bank_2[depth_index][0]; + for (int z = 0; z < 3; ++z) { + right_bank_0[depth_index][z] = + right_bank_0[depth_index][z + 1]; + right_bank_1[depth_index][z] = + right_bank_1[depth_index][z + 1]; + right_bank_2[depth_index][z] = + right_bank_2[depth_index][z + 1]; + } + } + } else { + for (int depth_index = 0; depth_index < 4; ++depth_index) { + for (int z = 0; z < 2; ++z) { + left_bank_0[depth_index][z] = + left_bank_0[depth_index][z + 2]; + left_bank_1[depth_index][z] = + left_bank_1[depth_index][z + 2]; + left_bank_2[depth_index][z] = + left_bank_2[depth_index][z + 2]; + } + left_bank_0[depth_index][2] = right_bank_0[depth_index][0]; + left_bank_1[depth_index][2] = right_bank_1[depth_index][0]; + left_bank_2[depth_index][2] = right_bank_2[depth_index][0]; + left_bank_0[depth_index][3] = right_bank_0[depth_index][1]; + left_bank_1[depth_index][3] = right_bank_1[depth_index][1]; + left_bank_2[depth_index][3] = right_bank_2[depth_index][1]; + for (int z = 0; z < 2; ++z) { + right_bank_0[depth_index][z] = + right_bank_0[depth_index][z + 2]; + right_bank_1[depth_index][z] = + right_bank_1[depth_index][z + 2]; + right_bank_2[depth_index][z] = + right_bank_2[depth_index][z + 2]; + } + } + } + } + } + } + bias_data += bias_increment; + + // Move filter for second sub-block into operational filter. + for (int z = 0; z < 4; ++z) { + for (int x = 0; x < 4; ++x) { + filter_bank_a_0[z][x] = filter_bank_b_0[z][x]; + filter_bank_a_1[z][x] = filter_bank_b_1[z][x]; + filter_bank_a_2[z][x] = filter_bank_b_2[z][x]; + } + } + } + } + } +}; + +template +struct KernelMacroBlock { + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + + TFLITE_DCHECK(depth_micro_repeats > 0); + + TFLITE_DCHECK_EQ(bias_increment, 4); + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + // Simulate NEON-register transposition of subset of filter. + int8 filter_bank_a_0[4][4]; // Depth 4, width 4. + int8 filter_bank_a_1[4][4]; + int8 filter_bank_a_2[4][4]; + int8 filter_bank_b_0[4][4]; + int8 filter_bank_b_1[4][4]; + int8 filter_bank_b_2[4][4]; + // Simulate NEON-register input data concatenation + sub-selection. + // Also sub-block, height 3, depth 4, width 4. + + int8 input_bank_0[8]; + int8 input_bank_1[8]; + int8 input_bank_2[8]; + + TFLITE_DCHECK_GE(depth_micro_repeats, 1); + + uint8 output_values[2][4]; // Sub-block, depth 4. + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + memcpy(filter_bank_a_0, filter_workspace, 16); + memcpy(filter_bank_b_0, filter_workspace + 16, 16); + memcpy(filter_bank_a_1, filter_workspace + 32, 16); + memcpy(filter_bank_b_1, filter_workspace + 48, 16); + memcpy(filter_bank_a_2, filter_workspace + 64, 16); + memcpy(filter_bank_b_2, filter_workspace + 80, 16); + + // Work through one slice, by row, at a time. + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* scratch_data = + scratch_block_data + + workspace_height_stride * k_height * stride_val; + uint8* output_data = + output_block_data + output_height_stride * k_height + 8 * j_depth; + + memcpy(input_bank_0, scratch_data, 4); + memcpy(input_bank_1, scratch_data + workspace_height_stride, 4); + memcpy(input_bank_2, scratch_data + 2 * workspace_height_stride, 4); + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 * i_width; + + memcpy(input_bank_0 + 4, input_data + 4, 4); + memcpy(input_bank_1 + 4, input_data + workspace_height_stride + 4, 4); + memcpy(input_bank_2 + 4, input_data + 2 * workspace_height_stride + 4, + 4); + + // Iterate over input width shifts within 4x4 blocks. + for (int w = 0; w < output_width; ++w) { + constexpr int offset = + 0; // Shift input instead of offset in multiply-accumulate. + + { + const int s = 0; + for (int d = 0; d < 4; ++d) { + int32 acc = bias_data[s * 4 + d]; + for (int x = 0; x < 4; ++x) { + int32 input_val_0 = input_bank_0[offset + x]; + int32 filter_val_0 = filter_bank_a_0[d][x]; + acc += filter_val_0 * input_val_0; + int32 input_val_1 = input_bank_1[offset + x]; + int32 filter_val_1 = filter_bank_a_1[d][x]; + acc += filter_val_1 * input_val_1; + int32 input_val_2 = input_bank_2[offset + x]; + int32 filter_val_2 = filter_bank_a_2[d][x]; + acc += filter_val_2 * input_val_2; + } + acc = reference_ops::depthwise_conv::DepthwiseConvRound< + DepthwiseConvOutputRounding::kUpward>( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_values[s][d] = static_cast(acc); + + output_data[s * 4 + d] = output_values[s][d]; + } + } + { + const int s = 1; + for (int d = 0; d < 4; ++d) { + int32 acc = bias_data[s * 4 + d]; + for (int x = 0; x < 4; ++x) { + int32 input_val_0 = input_bank_0[offset + x]; + int32 filter_val_0 = filter_bank_b_0[d][x]; + acc += filter_val_0 * input_val_0; + int32 input_val_1 = input_bank_1[offset + x]; + int32 filter_val_1 = filter_bank_b_1[d][x]; + acc += filter_val_1 * input_val_1; + int32 input_val_2 = input_bank_2[offset + x]; + int32 filter_val_2 = filter_bank_b_2[d][x]; + acc += filter_val_2 * input_val_2; + } + acc = reference_ops::depthwise_conv::DepthwiseConvRound< + DepthwiseConvOutputRounding::kUpward>( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_values[s][d] = static_cast(acc); + + output_data[s * 4 + d] = output_values[s][d]; + } + } + + // Simulate register shifts. + for (int i = 0; i < (8 - stride_val); ++i) { + input_bank_0[i] = input_bank_0[i + stride_val]; + input_bank_1[i] = input_bank_1[i + stride_val]; + input_bank_2[i] = input_bank_2[i + stride_val]; + } + + output_data += output_depth; + } + } + } + bias_data += 2 * bias_increment; + filter_workspace += shuffled_filter_increment; + } + } +}; +// The preceding section is only compiled when kUseUnwound3x3DotProduct versions +// of templated functions are selected. +// +// End of code section containing intermediate code transformation. + +#ifdef USE_NEON +template <> +struct KernelMacroBlock< + DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct, + DepthwiseConvDepthMultiplication::kNoMultiplication, + /*stride=*/1> { + static inline void KernelMacroBlockIntrinsics( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + const int8* input_data_depthwise = scratch_block_data; + uint8* output_data_depthwise = output_block_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + int8x16_t filter_reg_0_a_shifted; + int8x16_t filter_reg_1_a_shifted; + int8x16_t filter_reg_2_a_shifted; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + + if (block_height == 4) { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int8* next_input_data = input_data_base; + uint8* output_data = output_data_base; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(next_input_data); + int8x16_t left_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + int8x16_t left_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + int8x16_t left_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + int8x16_t left_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + + int32x4_t acc0; + int32x4_t acc1; + int32x4_t acc2; + int32x4_t acc3; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + + for (int i_width = 0; i_width < output_width_micro_repeats; + ++i_width) { + next_input_data += width_micro_stride; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += depth; + } + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + int8x16_t right_bank_5_reg; + // Logic: (i_width == output_width_micro_repeats) && + // ((residual_width - 1) * stride_val < 2) + const bool no_right_block = + i_width == output_width_micro_repeats && residual_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + right_bank_3_reg = vdupq_n_s8(0); + right_bank_4_reg = vdupq_n_s8(0); + right_bank_5_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + right_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + right_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + right_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a_shifted, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a_shifted, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a_shifted, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a_shifted, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a_shifted, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a_shifted, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + left_bank_3_reg = vrev32q_u16(left_bank_3_reg); + left_bank_4_reg = vrev32q_u16(left_bank_4_reg); + left_bank_5_reg = vrev32q_u16(left_bank_5_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg); + vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg); + vtrn1_s8x2_in_place(&left_bank_5_reg, &right_bank_5_reg); + + output_data += depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a_shifted, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a_shifted, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a_shifted, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a_shifted, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a_shifted, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a_shifted, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a_shifted, left_bank_3_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a_shifted, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a_shifted, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + left_bank_3_reg = right_bank_3_reg; + left_bank_4_reg = right_bank_4_reg; + left_bank_5_reg = right_bank_5_reg; + + output_data += depth; + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + } + } + + if (residual_width > 0) { + next_input_data += width_micro_stride; + const int output_width = residual_width; + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + int8x16_t right_bank_5_reg; + // Logic: (output_width - 1) * stride_val < 2. + const bool no_right_block = output_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + right_bank_3_reg = vdupq_n_s8(0); + right_bank_4_reg = vdupq_n_s8(0); + right_bank_5_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + right_bank_3_reg = + vld1q_s8(next_input_data + 3 * workspace_height_stride); + right_bank_4_reg = + vld1q_s8(next_input_data + 4 * workspace_height_stride); + right_bank_5_reg = + vld1q_s8(next_input_data + 5 * workspace_height_stride); + } + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_1_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_1_a, left_bank_3_reg); + acc2 = vdotq_s32(acc2, filter_reg_2_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_1_a, left_bank_4_reg); + acc3 = vdotq_s32(acc3, filter_reg_2_a, left_bank_5_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + biregister_rotate_8(&left_bank_0_reg, &right_bank_0_reg); + biregister_rotate_8(&left_bank_1_reg, &right_bank_1_reg); + biregister_rotate_8(&left_bank_2_reg, &right_bank_2_reg); + biregister_rotate_8(&left_bank_3_reg, &right_bank_3_reg); + biregister_rotate_8(&left_bank_4_reg, &right_bank_4_reg); + biregister_rotate_8(&left_bank_5_reg, &right_bank_5_reg); + + output_data += depth; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_2_reg); + acc2 = vdotq_s32(acc2, filter_reg_0_a, left_bank_2_reg); + acc3 = vdotq_s32(acc3, filter_reg_0_a, left_bank_3_reg); + } + } + input_data_base += 4 * workspace_height_stride; + output_data_base += 4 * output_height_stride; + + // Move to next sub-block: advance to second set of filters, to new + // bias. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + } + } else { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* next_input_data = input_data_base; + uint8* output_data = output_data_base; + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(next_input_data); + int8x16_t left_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + next_input_data += width_micro_stride; + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 4; + + // Load next sub-micro block of data. + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + // Logic: (output_width - 1) * stride_val < 2. + const bool no_right_block = output_width < 3; + + if (no_right_block) { + // Only needed for santizer checks. + right_bank_0_reg = vdupq_n_s8(0); + right_bank_1_reg = vdupq_n_s8(0); + right_bank_2_reg = vdupq_n_s8(0); + } else { + right_bank_0_reg = vld1q_s8(next_input_data); + right_bank_1_reg = + vld1q_s8(next_input_data + workspace_height_stride); + right_bank_2_reg = + vld1q_s8(next_input_data + 2 * workspace_height_stride); + } + // Load next sub-micro block of data. + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + int32x4_t acc = adjusted_bias_data; + acc = vdotq_s32(acc, filter_reg_0_a, left_bank_0_reg); + acc = vdotq_s32(acc, filter_reg_1_a, left_bank_1_reg); + acc = vdotq_s32(acc, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + acc = DivideByPOT::Run( + acc, -output_shift); + // Add the output offset. + // Note that we need to fill the top half with vcombine, but can + // drop the instruction in ASM code. + int16x8_t acc_s16_0_0 = + vcombine_s16(vqmovn_s32(acc), vqmovn_s32(acc)); + acc_s16_0_0 = vqaddq_s16(acc_s16_0_0, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8_0_0 = vqmovun_s16(acc_s16_0_0); + acc_u8_0_0 = + vmax_u8(acc_u8_0_0, vget_low_u8(output_activation_min_vec)); + acc_u8_0_0 = + vmin_u8(acc_u8_0_0, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_0, 0); + + biregister_rotate_8(&left_bank_0_reg, &right_bank_0_reg); + biregister_rotate_8(&left_bank_1_reg, &right_bank_1_reg); + biregister_rotate_8(&left_bank_2_reg, &right_bank_2_reg); + + output_data += depth; + } + } + input_data_base += workspace_height_stride; + output_data_base += output_height_stride; + } + + // Move to next sub-block: advance to second set of filters. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + } + } + input_data_depthwise += depth_micro_stride; + output_data_depthwise += 8; + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock< + DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct, + DepthwiseConvDepthMultiplication::kNoMultiplication, + /*stride=*/2> { + static inline void KernelMacroBlockIntrinsics( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int input_width_overall_micro_repeats = + function_params->input_width_overall_micro_repeats; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int depth = function_params->input_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int workspace_width_micro_repeats = + function_params->workspace_width_micro_repeats; + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + const int width_micro_stride = 4 * 8; + const int depth_micro_stride = + width_micro_stride * input_width_overall_micro_repeats; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + // This version only does min/max on 64 bits. + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x8_t output_activation_min_vec = + vdup_n_u8(static_cast(output_activation_min)); + const uint8x8_t output_activation_max_vec = + vdup_n_u8(static_cast(output_activation_max)); + + constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; + + TFLITE_DCHECK_EQ(stride_val, 2); + TFLITE_DCHECK_LE(block_height, 2); + + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + const int8* filter_block = + filter_workspace + shuffled_filter_increment * j_depth; + + if (block_height == 2) { + for (int s = 0; s < 2; ++s) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_2_a; + + filter_reg_0_a = vld1q_s8(filter_block + s * 16); + filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); + filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); + + const int8* scratch_data = + scratch_block_data + depth_micro_stride * j_depth; + uint8* output_data = output_block_data + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(input_data_0); + int8x16_t left_bank_1_reg = + vld1q_s8(input_data_0 + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(input_data_0 + 2 * workspace_height_stride); + int8x16_t left_bank_3_reg = + vld1q_s8(input_data_0 + 3 * workspace_height_stride); + int8x16_t left_bank_4_reg = + vld1q_s8(input_data_0 + 4 * workspace_height_stride); + + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + int8x16_t right_bank_3_reg; + int8x16_t right_bank_4_reg; + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + + if (!no_right_block) { + // Load next sub-micro block of data. + right_bank_0_reg = vld1q_s8(input_data + width_micro_stride); + right_bank_1_reg = vld1q_s8(input_data + width_micro_stride + + workspace_height_stride); + right_bank_2_reg = vld1q_s8(input_data + width_micro_stride + + 2 * workspace_height_stride); + right_bank_3_reg = vld1q_s8(input_data + width_micro_stride + + 3 * workspace_height_stride); + right_bank_4_reg = vld1q_s8(input_data + width_micro_stride + + 4 * workspace_height_stride); + } + + uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base, acc_u8, 0); + vst1_lane_8x4(output_data_base + output_height_stride, acc_u8, 1); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + left_bank_3_reg = vrev32q_u16(left_bank_3_reg); + left_bank_4_reg = vrev32q_u16(left_bank_4_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg); + vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg); + } + + if (output_width > 1) { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg); + acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg); + acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base + depth, acc_u8, 0); + vst1_lane_8x4(output_data_base + depth + output_height_stride, + acc_u8, 1); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + left_bank_3_reg = right_bank_3_reg; + left_bank_4_reg = right_bank_4_reg; + } + } + bias_data += bias_increment; + } + } else { + for (int s = 0; s < 2; ++s) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_2_a; + + filter_reg_0_a = vld1q_s8(filter_block + s * 16); + filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); + filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); + + const int8* scratch_data = + scratch_block_data + depth_micro_stride * j_depth; + uint8* output_data = output_block_data + 8 * j_depth; + const int8* input_data_0 = scratch_data + s * 2 * 8; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + + // Load first sub-micro block of data into operational banks. + int8x16_t left_bank_0_reg = vld1q_s8(input_data_0); + int8x16_t left_bank_1_reg = + vld1q_s8(input_data_0 + workspace_height_stride); + int8x16_t left_bank_2_reg = + vld1q_s8(input_data_0 + 2 * workspace_height_stride); + + int8x16_t right_bank_0_reg; + int8x16_t right_bank_1_reg; + int8x16_t right_bank_2_reg; + + int32x4_t acc0; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = + input_data_0 + width_micro_stride * i_width; + const bool no_right_block = i_width == output_width_micro_repeats && + output_width_overall_micro_repeats == + workspace_width_micro_repeats; + + if (!no_right_block) { + // Load next sub-micro block of data. + right_bank_0_reg = vld1q_s8(input_data + width_micro_stride); + right_bank_1_reg = vld1q_s8(input_data + width_micro_stride + + workspace_height_stride); + right_bank_2_reg = vld1q_s8(input_data + width_micro_stride + + 2 * workspace_height_stride); + } + + uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc0)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base, acc_u8, 0); + + left_bank_0_reg = vrev32q_u16(left_bank_0_reg); + left_bank_1_reg = vrev32q_u16(left_bank_1_reg); + left_bank_2_reg = vrev32q_u16(left_bank_2_reg); + vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg); + vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg); + vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg); + } + + if (output_width > 1) { + acc0 = adjusted_bias_data; + + acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg); + acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg); + acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc0)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1); + acc_u8 = vmax_u8(acc_u8, output_activation_min_vec); + acc_u8 = vmin_u8(acc_u8, output_activation_max_vec); + + vst1_lane_8x4(output_data_base + depth, acc_u8, 0); + + left_bank_0_reg = right_bank_0_reg; + left_bank_1_reg = right_bank_1_reg; + left_bank_2_reg = right_bank_2_reg; + } + } + bias_data += bias_increment; + } + } + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock< + DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct, + DepthwiseConvDepthMultiplication::kUnitInputDepth, + /*stride=*/1> { + static inline void KernelMacroBlockIntrinsics( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + TFLITE_DCHECK_EQ(function_params->stride, 1); + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + TFLITE_DCHECK(depth_micro_repeats > 0); + + TFLITE_DCHECK_EQ(bias_increment, 4); + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + uint8* output_data_depthwise = output_block_data; + for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { + // Simulate NEON-register transposition of subset of filter. + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + int8x16_t filter_reg_0_a_shifted; + int8x16_t filter_reg_1_a_shifted; + int8x16_t filter_reg_2_a_shifted; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + + if (block_height == 4) { + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int8* next_input_data = scratch_block_data; + uint8* output_data = output_data_base; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + int8x16_t input_bank_c_reg; // left 4, right 4, left 5, right 5. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(next_input_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + next_input_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 2); + input_bank_c_reg = vld1q_dup_s8x4( + next_input_data + + 4 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 2); + + int32x4_t acc0; + int32x4_t acc1; + int32x4_t acc2; + int32x4_t acc3; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, 2); + + for (int i_width = 0; i_width < output_width_micro_repeats; + ++i_width) { + next_input_data += 4; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += output_depth; + } + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 3); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 4 * workspace_height_stride, + input_bank_c_reg, 1); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 3); + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a_shifted, + input_bank_a_reg, 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a_shifted, + input_bank_a_reg, 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a_shifted, + input_bank_a_reg, 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a_shifted, + input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a_shifted, + input_bank_b_reg, 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a_shifted, + input_bank_c_reg, 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + output_data += output_depth; + } + + { + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a_shifted, + input_bank_a_reg, 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a_shifted, + input_bank_a_reg, 2); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a_shifted, + input_bank_a_reg, 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a_shifted, + input_bank_b_reg, 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a_shifted, + input_bank_b_reg, 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a_shifted, + input_bank_b_reg, 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a_shifted, + input_bank_b_reg, 2); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a_shifted, + input_bank_c_reg, 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a_shifted, + input_bank_c_reg, 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + } + } + + if (residual_width > 0) { + next_input_data += 4; + const int output_width = residual_width; + + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 3 * workspace_height_stride, + input_bank_b_reg, 3); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 4 * workspace_height_stride, + input_bank_c_reg, 1); + input_bank_c_reg = + vld1q_lane_8x4(next_input_data + 5 * workspace_height_stride, + input_bank_c_reg, 3); + + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + acc0 = vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, + 0); + acc0 = vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_a_reg, + 2); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_1_a, input_bank_b_reg, + 2); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_2_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_1_a, input_bank_c_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_2_a, input_bank_c_reg, + 2); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + acc2 = vqrdmulhq_n_s32(acc2, output_multiplier); + acc2 = DivideByPOT::Run( + acc2, -output_shift); + acc3 = vqrdmulhq_n_s32(acc3, output_multiplier); + acc3 = DivideByPOT::Run( + acc3, -output_shift); + // Add the output offset. + int16x8_t acc_s16_0_1 = + vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + int16x8_t acc_s16_2_3 = + vcombine_s16(vqmovn_s32(acc2), vqmovn_s32(acc3)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + acc_s16_2_3 = vqaddq_s16(acc_s16_2_3, output_offset_vec); + // Apply the activation function. + uint8x16_t acc_u8_all = vcombine_u8(vqmovun_s16(acc_s16_0_1), + vqmovun_s16(acc_s16_2_3)); + acc_u8_all = vmaxq_u8(acc_u8_all, output_activation_min_vec); + acc_u8_all = vminq_u8(acc_u8_all, output_activation_max_vec); + + vst1q_lane_8x4(output_data, acc_u8_all, 0); + vst1q_lane_8x4(output_data + output_height_stride, acc_u8_all, 1); + vst1q_lane_8x4(output_data + 2 * output_height_stride, acc_u8_all, + 2); + vst1q_lane_8x4(output_data + 3 * output_height_stride, acc_u8_all, + 3); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 8); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 8); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 8); + + output_data += output_depth; + + acc0 = adjusted_bias_data; + acc1 = adjusted_bias_data; + acc2 = adjusted_bias_data; + acc3 = adjusted_bias_data; + + acc0 = vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, + 0); + acc1 = vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, + 0); + acc2 = vdotq_four_lane_s32(acc2, filter_reg_0_a, input_bank_b_reg, + 0); + acc3 = vdotq_four_lane_s32(acc3, filter_reg_0_a, input_bank_b_reg, + 2); + } + } + // scratch_block_data += 4 * workspace_height_stride; + output_data_base += 4 * output_height_stride; + + // Move to next sub-block: advance to second set of filters, to new + // bias. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + filter_reg_0_a_shifted = vshlq_n_u32(filter_reg_0_a, 8); + filter_reg_1_a_shifted = vshlq_n_u32(filter_reg_1_a, 8); + filter_reg_2_a_shifted = vshlq_n_u32(filter_reg_2_a, 8); + } + } else { + // Block height < 4. + for (int s = 0; s < 2; ++s) { + // Work through one slice, by row, at a time. + uint8* output_data_base = output_data_depthwise + 4 * s; + + const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); + TFLITE_DCHECK_EQ(bias_increment, 4); + bias_data += bias_increment; + + for (int k_height = 0; k_height < block_height; ++k_height) { + const int8* next_input_data = + scratch_block_data + k_height * workspace_height_stride; + uint8* output_data = output_data_base; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(next_input_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + next_input_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + next_input_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + next_input_data += 4; + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 4; + + // Load next sub-micro block of data. + input_bank_a_reg = + vld1q_lane_8x4(next_input_data, input_bank_a_reg, 1); + input_bank_a_reg = + vld1q_lane_8x4(next_input_data + workspace_height_stride, + input_bank_a_reg, 3); + input_bank_b_reg = + vld1q_lane_8x4(next_input_data + 2 * workspace_height_stride, + input_bank_b_reg, 1); + // Iterate over input width shifts within 4x4 blocks. + for (int x = 0; x < output_width; ++x) { + int32x4_t acc = adjusted_bias_data; + acc = vdotq_four_lane_s32(acc, filter_reg_0_a, input_bank_a_reg, + 0); + acc = vdotq_four_lane_s32(acc, filter_reg_1_a, input_bank_a_reg, + 2); + acc = vdotq_four_lane_s32(acc, filter_reg_2_a, input_bank_b_reg, + 0); + + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + acc = DivideByPOT::Run( + acc, -output_shift); + // Add the output offset. + // Note that we need to fill the top half with vcombine, but can + // drop the instruction in ASM code. + int16x8_t acc_s16_0_0 = + vcombine_s16(vqmovn_s32(acc), vqmovn_s32(acc)); + acc_s16_0_0 = vqaddq_s16(acc_s16_0_0, output_offset_vec); + // Apply the activation function. + uint8x8_t acc_u8_0_0 = vqmovun_s16(acc_s16_0_0); + acc_u8_0_0 = + vmax_u8(acc_u8_0_0, vget_low_u8(output_activation_min_vec)); + acc_u8_0_0 = + vmin_u8(acc_u8_0_0, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_0, 0); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 8); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 8); + + output_data += output_depth; + } + } + output_data_base += output_height_stride; + } + + // Move to next sub-block: advance to second set of filters. + filter_reg_0_a = filter_reg_0_b; + filter_reg_1_a = filter_reg_1_b; + filter_reg_2_a = filter_reg_2_b; + } + } + output_data_depthwise += 8; + } + } // NOLINT(readability/fn_size) Manually unrolled. + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +template <> +struct KernelMacroBlock< + DepthwiseConvImplementation::kUseIntrinsics3x3DotProduct, + DepthwiseConvDepthMultiplication::kUnitInputDepth, + /*stride=*/2> { + static inline void KernelMacroBlockIntrinsics( + const int8* scratch_block_data, const int8* filter_workspace, + const int32* bias_data, uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + const int workspace_height_stride = + function_params->workspace_height_stride; + const int output_width_micro_repeats = + function_params->output_width_micro_repeats; + const int depth_micro_repeats = function_params->depth_micro_repeats; + const int output_depth = function_params->output_depth; + const int stride_val = function_params->stride; + const int four_over_stride = function_params->four_over_stride; + + const int output_width_overall_micro_repeats = + function_params->output_width_overall_micro_repeats; + const int block_height = function_params->outbound_block_height; + const int residual_width = function_params->output_residual_width; + const int output_height_stride = function_params->output_height_stride; + const int bias_increment = function_params->bias_increment; + + const int32 output_activation_min = + function_params->quantized_activation_min; + const int32 output_activation_max = + function_params->quantized_activation_max; + const int32 output_multiplier = function_params->output_multiplier; + const int32 output_shift = function_params->output_shift; + const int32 output_offset = function_params->output_offset; + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LT(output_activation_min, 256); + TFLITE_DCHECK_GE(output_activation_max, 0); + TFLITE_DCHECK_LT(output_activation_max, 256); + TFLITE_DCHECK_GE(output_offset, -32878); + TFLITE_DCHECK_LT(output_offset, 32768); + + TFLITE_DCHECK_GE(depth_micro_repeats, 1); + TFLITE_DCHECK_EQ(bias_increment, 4); + + const int16x8_t output_offset_vec = + vdupq_n_s16(static_cast(output_offset)); + const uint8x16_t output_activation_min_vec = + vdupq_n_u8(static_cast(output_activation_min)); + const uint8x16_t output_activation_max_vec = + vdupq_n_u8(static_cast(output_activation_max)); + + for (int j_depth = 0; j_depth < (depth_micro_repeats * 1 + 0); ++j_depth) { + int8x16_t filter_reg_0_a; + int8x16_t filter_reg_0_b; + int8x16_t filter_reg_1_a; + int8x16_t filter_reg_1_b; + int8x16_t filter_reg_2_a; + int8x16_t filter_reg_2_b; + + filter_reg_0_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_0_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_1_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_a = vld1q_s8(filter_workspace); + filter_workspace += 16; + filter_reg_2_b = vld1q_s8(filter_workspace); + filter_workspace += 16; + + TFLITE_DCHECK_EQ(bias_increment, 4); + const int32x4_t adjusted_bias_data_s_0 = vld1q_s32(bias_data); + bias_data += bias_increment; + const int32x4_t adjusted_bias_data_s_1 = vld1q_s32(bias_data); + bias_data += bias_increment; + + if (block_height == 2) { + const int8* scratch_data = scratch_block_data; + uint8* output_data = output_block_data + 8 * j_depth; + + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, left 3, right 3. + int8x16_t input_bank_c_reg; // left 4, right 4, xxx, xxx. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(scratch_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + scratch_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + scratch_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + input_bank_b_reg = vld1q_lane_8x4( + scratch_data + 3 * workspace_height_stride, input_bank_b_reg, 2); + input_bank_c_reg = vld1q_dup_s8x4( + scratch_data + + 4 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = i_width == output_width_micro_repeats + ? residual_width + : four_over_stride; + + TFLITE_DCHECK_LE(output_width, 2); + TFLITE_DCHECK_GE(output_width, 1); + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 + 4 * i_width; + + // Load next sub-micro block of data. + input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 2 * workspace_height_stride, input_bank_b_reg, 1); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 3 * workspace_height_stride, input_bank_b_reg, 3); + input_bank_c_reg = vld1q_lane_8x4( + input_data + 4 * workspace_height_stride, input_bank_c_reg, 1); + + int16x8_t acc_s16_0_1; + uint8x8_t acc_u8_0_1; + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data_s_0; + acc1 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + output_height_stride, acc_u8_0_1, 1); + + acc0 = adjusted_bias_data_s_1; + acc1 = adjusted_bias_data_s_1; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_b, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_b, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data + 4, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + 4 + output_height_stride, acc_u8_0_1, + 1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + if (output_width == 2) { + acc0 = adjusted_bias_data_s_0; + acc1 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_a, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_a, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_a, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + output_height_stride, acc_u8_0_1, 1); + + acc0 = adjusted_bias_data_s_1; + acc1 = adjusted_bias_data_s_1; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_b, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_b, input_bank_a_reg, 2); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_b_reg, 2); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_c_reg, 0); + + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + vst1_lane_8x4(output_data + 4, acc_u8_0_1, 0); + vst1_lane_8x4(output_data + 4 + output_height_stride, acc_u8_0_1, + 1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + input_bank_c_reg = vshrq_n_u64(input_bank_c_reg, 16); + + output_data += output_depth; + } + } + } else { + TFLITE_DCHECK_EQ(block_height, 1); + // Work through one slice, by row, at a time. + const int8* scratch_data = scratch_block_data; + uint8* output_data = output_block_data + 8 * j_depth; + + // + int8x16_t input_bank_a_reg; // left 0, right 0, left 1, right 1. + int8x16_t input_bank_b_reg; // left 2, right 2, xxx, xxx. + + // Load first sub-micro block of data into operational banks. + input_bank_a_reg = + vld1q_dup_s8x4(scratch_data); // Load lane 0, avoiding + // uninitialized variable. + input_bank_a_reg = vld1q_lane_8x4( + scratch_data + workspace_height_stride, input_bank_a_reg, 2); + input_bank_b_reg = vld1q_dup_s8x4( + scratch_data + + 2 * workspace_height_stride); // Load lane 0, avoiding + // uninitialized variable. + + int32x4_t acc0; + int32x4_t acc1; + + for (int i_width = 0; i_width < output_width_overall_micro_repeats; + ++i_width) { + const int output_width = + i_width == output_width_micro_repeats ? residual_width : 2; + + TFLITE_DCHECK_LE(output_width, 2); + TFLITE_DCHECK_GE(output_width, 1); + TFLITE_DCHECK_LE(output_width * stride_val, 4); + const int8* input_data = scratch_data + 4 + 4 * i_width; + + // Load next sub-micro block of data. + input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); + input_bank_a_reg = vld1q_lane_8x4( + input_data + workspace_height_stride, input_bank_a_reg, 3); + input_bank_b_reg = vld1q_lane_8x4( + input_data + 2 * workspace_height_stride, input_bank_b_reg, 1); + + int16x8_t acc_s16_0_1; + uint8x8_t acc_u8_0_1; + + // Iterate over input width shifts within 4x4 blocks. + { + acc0 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + + // Second sub-block accumulation. + acc1 = adjusted_bias_data_s_1; + + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_a_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_a_reg, 2); + + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + // This stores the results for both sub-blocks together. + vst1_u8(output_data, acc_u8_0_1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + + output_data += output_depth; + } + if (output_width == 2) { + acc0 = adjusted_bias_data_s_0; + + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_2_a, input_bank_b_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_0_a, input_bank_a_reg, 0); + acc0 = + vdotq_four_lane_s32(acc0, filter_reg_1_a, input_bank_a_reg, 2); + + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc0 = DivideByPOT::Run( + acc0, -output_shift); + + // Second sub-block accumulation. + acc1 = adjusted_bias_data_s_1; + + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_2_b, input_bank_b_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_0_b, input_bank_a_reg, 0); + acc1 = + vdotq_four_lane_s32(acc1, filter_reg_1_b, input_bank_a_reg, 2); + + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + acc1 = DivideByPOT::Run( + acc1, -output_shift); + + // Add the output offset. + acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1)); + acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec); + // Apply the activation function. + acc_u8_0_1 = vqmovun_s16(acc_s16_0_1); + acc_u8_0_1 = + vmax_u8(acc_u8_0_1, vget_low_u8(output_activation_min_vec)); + acc_u8_0_1 = + vmin_u8(acc_u8_0_1, vget_low_u8(output_activation_max_vec)); + + // This stores the results for both sub-blocks together. + vst1_u8(output_data, acc_u8_0_1); + + input_bank_a_reg = vshrq_n_u64(input_bank_a_reg, 16); + input_bank_b_reg = vshrq_n_u64(input_bank_b_reg, 16); + + output_data += output_depth; + } + } + } + } + } + + static inline void Run(const int8* scratch_block_data, + const int8* filter_workspace, const int32* bias_data, + uint8* output_block_data, + const DepthwiseConvDotProdParams* function_params) { + KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, + output_block_data, function_params); + } +}; + +#undef vst1_lane_8x4 +#undef vst1q_lane_8x4 +#undef vld1q_lane_s8x8 +#undef vld1_lane_8x4 +#undef vld1q_lane_8x4 +#undef vld1q_dup_s8x4 + +#endif // USE_NEON + +} // namespace depthwise_conv +} // namespace optimized_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_TRANSITIONAL_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h index a7abf3e370c466d12be99c9d3dc5d35eca2caf77..01103d727fc0a390e82d56310c9d7614fd864b35 100644 --- a/tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h +++ b/tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. -// TODO(petewarden) - move this to a common location in Eigen itself. - #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 730d9b662a33eb3a2fe08fec887c9fb35671a116..ce5eb308e8744097c49987ce8a1e31f54dfff222 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -12,8 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include +#include +#include + +#include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/activation_functor.h" @@ -43,6 +49,49 @@ void* aligned_alloc(size_t alignment, size_t size, void** freeing_buffer) { : ((char*)*freeing_buffer + (alignment - offset)); // NOLINT } +// Use /proc/cpuinfo to test whether we have the right processor. +bool HasSdotInstruction() { + // TODO(strohman): Replace this with a proper API call once we are running + // on kernels that can tell us about this instruction: (b/119112014) + // Note that the C++ spec ensures that this variable will be initialized + // exactly once. + static bool has_sdot = []() -> bool { + char text[1024]; + int fd = open("/proc/cpuinfo", O_RDONLY); + if (fd < 0) { + return false; + } + + bool found = false; + int buffer = 0; + const char kSM8150[] = "Qualcomm Technologies, Inc SM8150"; + while (true) { + int count = read(fd, text + buffer, sizeof(text) - buffer); + if (count <= 0) { + break; + } + int text_end = buffer + count; + + if (memmem(text, text_end, kSM8150, sizeof(kSM8150) - 1) != nullptr) { + found = true; + break; + } + + // Keep up to some bytes of the previous buffer state so that we + // can find a string match even if it occurs on a buffer boundary. + buffer = text_end; + if (text_end > sizeof(kSM8150)) { + buffer = sizeof(kSM8150); + } + + memmove(text, text + text_end - buffer, buffer); + } + close(fd); + return found; + }(); + return has_sdot; +} + } // namespace void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, @@ -84,10 +133,266 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, } } +#ifdef __aarch64__ + +// We interleave vector data to make the dot product logic more efficient. +// Suppose that vectors is: +// a0 a1 a2 a3 a4 a5 ... +// b0 b1 b2 b3 b4 b5 ... +// c0 c1 c2 c3 c4 c5 ... +// d0 d1 d2 d3 d4 d5 ... +// e0 e1 e2 e3 e4 e5 ... +// This code interleaves them like this: +// a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3 a4 a5 a6 a7 b4 ... +// e0 e1 e2 e3 f0 f1 f2 f3 ... +// Once the data is interleaved, each 16-byte read from the vectors pointer +// contains 4 bytes from each of 4 vectors. +const int8_t* ShuffleVectors(const int8_t* vectors, const int n_batch, + const int m_cols, void** shuffled_vectors_free) { + const int kWeightsPerUint32 = 4; + + int8* shuffled_vectors = reinterpret_cast(aligned_alloc( + kWeightsPerUint32, n_batch * m_cols, shuffled_vectors_free)); + + for (int i = 0; i < n_batch; i += 4) { + int8* shuffled_vectors_ptr = shuffled_vectors + (i * m_cols); + const int8* unshuffled_vec0_ptr = + reinterpret_cast(vectors) + (i * m_cols); + const int8* unshuffled_vec1_ptr = + reinterpret_cast(vectors) + ((i + 1) * m_cols); + const int8* unshuffled_vec2_ptr = + reinterpret_cast(vectors) + ((i + 2) * m_cols); + const int8* unshuffled_vec3_ptr = + reinterpret_cast(vectors) + ((i + 3) * m_cols); + const int8* const end_vec0_ptr = unshuffled_vec1_ptr; + + while (unshuffled_vec0_ptr != end_vec0_ptr) { + asm volatile( + // This code path requires that (n_cols % 16) == 0 so we can safely + // read in 16-byte chunks from each row. + "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n" + "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n" + "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n" + "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n" + + "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n" + "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n" + "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n" + "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n" + + : [ unshuffled_vec0_ptr ] "+r"(unshuffled_vec0_ptr), + [ unshuffled_vec1_ptr ] "+r"(unshuffled_vec1_ptr), + [ unshuffled_vec2_ptr ] "+r"(unshuffled_vec2_ptr), + [ unshuffled_vec3_ptr ] "+r"(unshuffled_vec3_ptr), + [ shuffled_vectors_ptr ] "+r"(shuffled_vectors_ptr) + : + : "v0", "v1", "v2", "v3", "cc", "memory"); + } + } + + return reinterpret_cast(shuffled_vectors); +} + +static void DotprodMatrixBatchFourVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* vectors, const float* scaling_factors, int n_batch, + float* __restrict__ result) { + void* shuffled_vectors_free; + + const int8_t* shuffled_vectors = + ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free); + + for (int row = 0; row < m_rows; row += 2) { + for (int batch = 0; batch < n_batch; batch += 4) { + float* result_ptr = result + (batch * m_rows) + row; + const int8* mat_ptr0 = matrix + (row * m_cols); + const int8* mat_ptr1 = matrix + ((row + 1) * m_cols); + const int8* mat_ptr0_end = mat_ptr1; + const int8* vec_ptr = shuffled_vectors + (batch * m_cols); + const float* scaling_factors_ptr = scaling_factors + batch; + const uint64_t wide_rows = m_rows * sizeof(float); + + asm volatile( + // Zero out the accumulator registers. + "dup v0.4s, wzr\n" + "dup v1.4s, wzr\n" + "dup v2.4s, wzr\n" + "dup v3.4s, wzr\n" + + "1:\n" // batch_cols_loop + + // Read 16 more bytes from a pair of matrix rows. + "ld1 {v12.16b}, [%[mat_ptr0]], #16\n" + + // Read from input vectors 4 times; 64 bytes total. + // Each 16-byte register contains parts of 4 vectors; see the + // shuffle logic above. + + // From Benoit, places to look in the future: + // - Move load instructions further from sdot + // - Switch loop use-then-reload + // - Do partial unrolling to use register space better + "ld1 {v8.16b}, [%[vec_ptr]], #16\n" + ".word 0x4f8ce100 // sdot v0.4s, v8.16b, v12.4b[0]\n" + "ld1 {v9.16b}, [%[vec_ptr]], #16\n" + ".word 0x4face121 // sdot v1.4s, v9.16b, v12.4b[1]\n" + "ld1 {v10.16b}, [%[vec_ptr]], #16\n" + ".word 0x4f8ce940 // sdot v0.4s, v10.16b, v12.4b[2]\n" + "ld1 {v11.16b}, [%[vec_ptr]], #16\n" + ".word 0x4face961 // sdot v1.4s, v11.16b, v12.4b[3]\n" + + // Re-use those vectors for the next row as well. + "ld1 {v13.16b}, [%[mat_ptr1]], #16\n" + ".word 0x4f8de102 // sdot v2.4s, v8.16b, v13.4b[0]\n" + ".word 0x4fade123 // sdot v3.4s, v9.16b, v13.4b[1]\n" + ".word 0x4f8de942 // sdot v2.4s, v10.16b, v13.4b[2]\n" + ".word 0x4fade963 // sdot v3.4s, v11.16b, v13.4b[3]\n" + + // If we're not done with these rows, continue. + "cmp %[mat_ptr0], %[mat_ptr0_end]\n" + "bne 1b\n" // batch_cols_loop + + // Done with the rows, sum the results. + "add v0.4s, v0.4s, v1.4s\n" + "add v2.4s, v2.4s, v3.4s\n" + + // Convert the per-vector sums to floating point. + "scvtf v0.4s, v0.4s\n" + "scvtf v1.4s, v2.4s\n" + + // Fetch scale factors. + "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n" + + // Multiply scale factors times sums. + "fmul v0.4s, v4.4s, v0.4s\n" + "fmul v1.4s, v4.4s, v1.4s\n" + + // Load previous result values. + // The result position is: + // result[batch * m_rows + row] + // Here that is factored into: + // result_ptr = result + row + // *result_ptr = res[0] + // (uint8*)result_ptr += (m_rows * sizeof(float)) + // *result_ptr = res[1] + // ... + // Since we're reading two rows at a time, though, we read both + // result[batch * m_rows + row] + // and + // result[batch * m_rows + row + 1] + "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n" + "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n" + "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n" + "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n" + + // Go back to the starting position (subtract wide_rows * 4). + "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n" + + // Add previous result values. + "fadd v9.4s, v9.4s, v0.4s\n" + "fadd v10.4s, v10.4s, v1.4s\n" + + // Store results. + "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n" + "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n" + "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n" + "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n" + : [ mat_ptr0 ] "+r"(mat_ptr0), [ mat_ptr1 ] "+r"(mat_ptr1), + [ vec_ptr ] "+r"(vec_ptr), [ result_ptr ] "+r"(result_ptr) + : [ mat_ptr0_end ] "r"(mat_ptr0_end), + [ scaling_factors_ptr ] "r"(scaling_factors_ptr), + [ wide_rows ] "r"(wide_rows) + : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "cc", "memory"); + } + } + + free(shuffled_vectors_free); +} + +static void DotprodSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride) { + const uint8_t* ledger_ptr = ledger; + const int8* mat_ptr = matrix; + + for (int row = 0; row < m_rows; row++) { + int num_nonzero_chunks = *ledger_ptr; + ledger_ptr++; + const uint8* ledger_start = ledger_ptr; + const uint8* ledger_end = ledger_ptr + num_nonzero_chunks; + const int8* mat_start = mat_ptr; + + for (int batch = 0; batch < n_batch; batch++) { + const int8* vec_ptr = vectors + (batch * m_cols); + int64_t row_sum = 0; + + mat_ptr = mat_start; + ledger_ptr = ledger_start; + + if (ledger_ptr != ledger_end) { + asm volatile( + "dup v0.4s, wzr\n" + "dup v1.4s, wzr\n" + "dup v8.4s, wzr\n" + "mov x7, 0\n" + + "1:\n" // chunks_loop + + // Single matrix chunk, 16 bytes + "ld1 {v8.16b}, [%[mat_ptr]], #16\n" + + // Read the next ledger index and increment. + "ldrb w7, [%[ledger_ptr]], #1\n" + + // Read 16 bytes of vector data from (vec_ptr + (ledger_index * 16)) + "add x8, %[vec_ptr], x7, lsl #4\n" + "ld1 {v9.16b}, [x8]\n" + + // Dot product of matrix row and vector. + ".word 0x4e889520 // sdot v0.4s, v9.16b, v8.16b\n" + + "cmp %[ledger_ptr], %[ledger_end]\n" + "blt 1b\n" // chunks_loop + + // Sum the 4 vector components into a 32-bit value. + "addv s1, v0.4s\n" + // row_sum is 64-bit, so we copy 64 bits of v1 into it. + // We have to be careful to cast this value to 32 bits in order + // to interpret the sign bit properly. + "mov %[row_sum], v1.d[0]\n" + : [ row_sum ] "=r"(row_sum), [ ledger_ptr ] "+r"(ledger_ptr), + [ mat_ptr ] "+r"(mat_ptr), [ vec_ptr ] "+r"(vec_ptr) + : [ ledger_end ] "r"(ledger_end) + : "x0", "x1", "x7", "x8", "v0", "v1", "v8", "v9", "cc", "memory"); + } + result[(batch * m_rows + row) * result_stride] += + static_cast(row_sum) * scaling_factors[batch]; + } + } +} + +#endif // __aarch64__ + void NeonMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* scaling_factors, int n_batch, float* __restrict__ result, int result_stride) { +#ifdef __aarch64__ + if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 && + m_rows >= n_batch) { + if (n_batch % 4 == 0 && result_stride == 1) { + // Benchmarks suggest that it's always better to use the batch code + // when we can, even on small matrices. + DotprodMatrixBatchFourVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); + return; + } + } +#endif // __aarch64__ + const int kWeightsPerUint32 = 4; const int kWeightsPerNeonLane = 16; // Assuming *matrix is kWeightsPerUint32-byte aligned, @@ -202,6 +507,127 @@ void NeonMatrixBatchVectorMultiplyAccumulate( free(aligned_vec_free); } +void NeonSparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride) { + const int kBlockSize = 16; + const int kNeonLanesPerBlock = 4; + TFLITE_DCHECK_EQ( // NOLINT + m_cols % kBlockSize, 0); + + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + const uint8_t* ledger_ptr = ledger; + for (int r = 0; r < m_rows; r++) { + int num_nonzero_blocks = *ledger_ptr++; + if (num_nonzero_blocks > 0) { + float32x4_t acc_32x4 = vmovq_n_f32(0.0); + const float* vector_in_batch = vector + b * m_cols; + + for (int i = 0; i < num_nonzero_blocks; i++) { + const int block_start_index = *ledger_ptr++ * kBlockSize; + const float* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + + for (int c = 0; c < kNeonLanesPerBlock; c++) { + // Load 4 float values from the vector and matrix row. + float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr + + c * kFloatWeightsPerNeonLane); + float32x4_t matrix_f32x4 = + vld1q_f32(matrix_ptr + c * kFloatWeightsPerNeonLane); + // Multiply the vector and matrix row and add to accumulator. + acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4); + } + matrix_ptr += kBlockSize; + } + *result_in_batch += + (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); + } + result_in_batch += result_stride; + } + } +} + +void NeonSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride) { +#ifdef __aarch64__ + if (HasSdotInstruction() && m_cols % 16 == 0) { + DotprodSparseMatrixBatchVectorMultiplyAccumulate( + matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch, + result, result_stride); + return; + } +#endif // __aarch64__ + + const int kWeightsPerUint32 = 4; + const int kWeightsPerNeonLane = 16; + const int kBlockSize = kWeightsPerNeonLane; + TFLITE_DCHECK_EQ( // NOLINT + m_cols % kBlockSize, 0); + void* aligned_vec_free = nullptr; + int8_t* aligned_vec = + (int8_t*)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT + &aligned_vec_free); + + int batch, row; + for (batch = 0; batch < n_batch; ++batch) { + const float batch_scaling_factor = scaling_factors[batch]; + // Copy the vector data to an aligned vector. + memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols); + + const uint8_t* ledger_ptr = ledger; + const int8_t* row_ptr = matrix; + for (row = 0; row < m_rows; ++row, result += result_stride) { + // Initialize the dot product sum for the row to 0. + int32x4_t dotprod = vmovq_n_s32(0); + int num_nonzero_blocks = *ledger_ptr++; + if (num_nonzero_blocks > 0) { + // Prefetch the row to cache. + __builtin_prefetch(row_ptr, 0 /* prefetch for read */, + 3 /* temporal locality */); + for (int i = 0; i < num_nonzero_blocks; i++) { + const int col_index = *ledger_ptr++ * kBlockSize; + // Load 16 8-bit values from the row and vector, each, to operate on. + // Here the assumption is that each buffer is 4-byte aligned. + // Otherwise, performance may suffer significantly. + TFLITE_DCHECK_EQ( // NOLINT + (uintptr_t)(&row_ptr) & (kWeightsPerUint32 - 1), 0); + const int8x16_t s1_8x16 = + vld1q_s8((const int8_t*)(aligned_vec + col_index)); + const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr)); + // Multiply the low bits (i.e. the lower 8 8bit numbers in the + // registers). + int16x8_t prod_16x8 = + vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16)); + // Multiply the high bits (i.e. the lower 8 8bit numbers in the + // registers), and accumulate with the result of the low bits product. + // The assumption here is that overflow will not happen as we quantize + // our values to be in the range [-127, 127]. As such the sum of the 2 + // products is always strictly smaller than 15-bits (32767 in absolute + // value). + prod_16x8 = + vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16)); + + dotprod = vpadalq_s16(dotprod, prod_16x8); + row_ptr += kBlockSize; + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this row. + int64x2_t pairwiseAdded = vpaddlq_s32(dotprod); + int32 neon_sum = + vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1); + *result += neon_sum * batch_scaling_factor; + } + } // for row + } // for batch + free(aligned_vec_free); +} + void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, float* result) { // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index 903f4c80139cd326b354ef6292a393c75af11608..a86457dba745dbe94ce3e1dc718012545f258804 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -40,6 +40,24 @@ void MatrixBatchVectorMultiplyAccumulate( vectors, scaling_factors, n_batch, result, result_stride); } +void SparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const float* vector, int n_batch, float* result, + int result_stride) { + NeonSparseMatrixBatchVectorMultiplyAccumulate( + matrix, ledger, m_rows, m_cols, vector, n_batch, result, result_stride); +} + +void SparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride) { + NeonSparseMatrixBatchVectorMultiplyAccumulate(matrix, ledger, m_rows, m_cols, + vectors, scaling_factors, + n_batch, result, result_stride); +} + void VectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, float* result) { NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result); diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 7009323a9b2ec7d1e5310f89ef6977a376f536eb..7d8ab2c6029b8e3b7fdbc38c591cf02f24019c83 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -63,6 +64,7 @@ using reference_ops::ConcatenationWithScaling; using reference_ops::DepthConcatenation; using reference_ops::Dequantize; using reference_ops::Div; +using reference_ops::Elu; using reference_ops::FakeQuant; using reference_ops::Fill; using reference_ops::Gather; @@ -85,6 +87,7 @@ using reference_ops::Select; using reference_ops::SpaceToBatchND; using reference_ops::Split; using reference_ops::StridedSlice; +using reference_ops::Sub16; using reference_ops::Transpose; // TODO(b/80247582) Remove this constant. @@ -181,45 +184,6 @@ MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, return MatrixMap(data, rows, cols); } -// This is like the template-parameter version, except that the power-of-two is -// passed as a function parameter. The template version is to be preferred, -// since some target hardware optimizations depend on the range of the exponent. -template -IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { - if (exponent == 0) { - return x; - } - using ScalarIntegerType = - typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; - const IntegerType min = - gemmlowp::Dup(std::numeric_limits::min()); - const IntegerType max = - gemmlowp::Dup(std::numeric_limits::max()); - const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); - - const std::int32_t threshold = - ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); - const IntegerType positive_mask = - gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); - const IntegerType negative_mask = - gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); - - IntegerType result = gemmlowp::ShiftLeft(x, exponent); - result = gemmlowp::SelectUsingMask(positive_mask, max, result); - result = gemmlowp::SelectUsingMask(negative_mask, min, result); - return result; -} - -// This is like the template-parameter version, except that the power-of-two is -// passed as a function parameter. See raw-integer version for further comments. -template -gemmlowp::FixedPoint -SaturatingRoundingMultiplyByPOTParam( - gemmlowp::FixedPoint a, int exponent) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); -} - inline void AddBiasAndEvalActivationFunction(float output_activation_min, float output_activation_max, const RuntimeShape& bias_shape, @@ -810,7 +774,7 @@ inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& weights_shape, const float* weights_data, const RuntimeShape& bias_shape, - const float* bias_data, const RuntimeShape& output_shape, + const float* optional_bias_data, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("FullyConnected"); const float output_activation_min = params.float_activation_min; @@ -834,30 +798,36 @@ inline void FullyConnected( MapAsMatrixWithLastDimAsRows(output_data, output_shape); Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); - AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max, - bias_shape, bias_data, output_shape, - output_data); + + if (optional_bias_data != nullptr) { + AddBiasAndEvalActivationFunction( + output_activation_min, output_activation_max, bias_shape, + optional_bias_data, output_shape, output_data); + } else { + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + output_data[i], output_activation_min, output_activation_max); + } + } } #ifdef USE_NEON -inline void FullyConnectedAsGEMV( +inline void FullyConnectedAsGEMVWorkerImpl( const RuntimeShape& input_shape, const uint8* input_data, int32 input_offset, const RuntimeShape& filter_shape, const uint8* filter_data, int32 filter_offset, const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_activation_max, const RuntimeShape& output_shape, - uint8* output_data) { + uint8* output_data, int row_start, int row_end) { gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit"); TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); const int output_dim_count = output_shape.DimensionsCount(); - const int filter_dim_count = filter_shape.DimensionsCount(); TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1); const int input_size = FlatSizeSkipDim(input_shape, 0); - const int output_size = MatchingDim(filter_shape, filter_dim_count - 2, - output_shape, output_dim_count - 1); static constexpr int kPeel = 4; const bool shift_left = (output_shift > 0); for (int k = 0; k < input_size; k += 64) { @@ -866,81 +836,139 @@ inline void FullyConnectedAsGEMV( for (int k = 0; k < kPeel * input_size; k += 64) { optimized_ops_preload_l1_stream(filter_data + k); } - TFLITE_DCHECK(!(output_size % kPeel)); - const int32* bias_ptr = bias_data; - uint8* output_ptr = output_data; - for (int out = 0; out < output_size; out += kPeel) { - int32x4_t acc[kPeel]; - for (int k = 0; k < kPeel; k++) { - acc[k] = vdupq_n_s32(0); - } + + TFLITE_DCHECK_GE(row_end - row_start, kPeel); + + for (int out = row_start; out < row_end; out += kPeel) { + out = std::min(out, row_end - kPeel); + int32x4_t acc0 = vdupq_n_s32(0); + int32x4_t acc1 = acc0; + int32x4_t acc2 = acc0; + int32x4_t acc3 = acc0; const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); int in = 0; for (; in <= input_size - 16; in += 16) { const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); - uint8x16_t filter_val_u8[kPeel]; - for (int k = 0; k < kPeel; k++) { - const uint8* filter_ptr = filter_data + in + (out + k) * input_size; - filter_val_u8[k] = vld1q_u8(filter_ptr); - optimized_ops_preload_l1_stream(filter_ptr + 64); - } - int16x8_t input_val[2]; - const uint8x8_t low = vget_low_u8(input_val_u8); - const uint8x8_t high = vget_high_u8(input_val_u8); - input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low)); - input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high)); - input_val[0] = vaddq_s16(input_val[0], input_offset_vec); - input_val[1] = vaddq_s16(input_val[1], input_offset_vec); - int16x8_t filter_val[kPeel][2]; - for (int k = 0; k < kPeel; k++) { - const uint8x8_t low = vget_low_u8(filter_val_u8[k]); - const uint8x8_t high = vget_high_u8(filter_val_u8[k]); - filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low)); - filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high)); - filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec); - filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec); - } - for (int p = 0; p < 2; p++) { - for (int k = 0; k < kPeel; k++) { - acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]), - vget_low_s16(input_val[p])); - } - for (int k = 0; k < kPeel; k++) { - acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]), - vget_high_s16(input_val[p])); - } - } + const uint8* filter_ptr = filter_data + in + out * input_size; + uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr); + optimized_ops_preload_l1_stream(filter_ptr + 64); + filter_ptr += input_size; + uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr); + optimized_ops_preload_l1_stream(filter_ptr + 64); + filter_ptr += input_size; + uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr); + optimized_ops_preload_l1_stream(filter_ptr + 64); + filter_ptr += input_size; + uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr); + optimized_ops_preload_l1_stream(filter_ptr + 64); + int16x8_t input_val_0, input_val_1; + uint8x8_t low = vget_low_u8(input_val_u8); + uint8x8_t high = vget_high_u8(input_val_u8); + input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val_0 = vaddq_s16(input_val_0, input_offset_vec); + input_val_1 = vaddq_s16(input_val_1, input_offset_vec); + low = vget_low_u8(filter_val_u8_0); + high = vget_high_u8(filter_val_u8_0); + int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec); + filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec); + low = vget_low_u8(filter_val_u8_1); + high = vget_high_u8(filter_val_u8_1); + int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec); + filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec); + low = vget_low_u8(filter_val_u8_2); + high = vget_high_u8(filter_val_u8_2); + int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec); + filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec); + low = vget_low_u8(filter_val_u8_3); + high = vget_high_u8(filter_val_u8_3); + int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec); + filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec); + acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0), + vget_low_s16(input_val_0)); + acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0), + vget_low_s16(input_val_0)); + acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0), + vget_low_s16(input_val_0)); + acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0), + vget_low_s16(input_val_0)); + acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1), + vget_low_s16(input_val_1)); + acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1), + vget_low_s16(input_val_1)); + acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1), + vget_low_s16(input_val_1)); + acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1), + vget_low_s16(input_val_1)); + acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0), + vget_high_s16(input_val_0)); + acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0), + vget_high_s16(input_val_0)); + acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0), + vget_high_s16(input_val_0)); + acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0), + vget_high_s16(input_val_0)); + acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1), + vget_high_s16(input_val_1)); + acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1), + vget_high_s16(input_val_1)); + acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1), + vget_high_s16(input_val_1)); + acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1), + vget_high_s16(input_val_1)); } for (; in <= input_size - 8; in += 8) { const uint8x8_t input_val_u8 = vld1_u8(input_data + in); - uint8x8_t filter_val_u8[kPeel]; - for (int k = 0; k < kPeel; k++) { - const uint8* filter_ptr = filter_data + in + (out + k) * input_size; - filter_val_u8[k] = vld1_u8(filter_ptr); - } - int16x8_t input_val; - input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + const uint8* filter_ptr = filter_data + in + out * input_size; + uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr); + filter_ptr += input_size; + uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr); + filter_ptr += input_size; + uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr); + filter_ptr += input_size; + uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr); + int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); input_val = vaddq_s16(input_val, input_offset_vec); - int16x8_t filter_val[kPeel]; - for (int k = 0; k < kPeel; k++) { - filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k])); - filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec); - } - for (int k = 0; k < kPeel; k++) { - acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]), - vget_low_s16(input_val)); - } - for (int k = 0; k < kPeel; k++) { - acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]), - vget_high_s16(input_val)); - } + int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0)); + filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec); + int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1)); + filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec); + int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2)); + filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec); + int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3)); + filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec); + acc0 = + vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val)); + acc1 = + vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val)); + acc2 = + vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val)); + acc3 = + vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val)); + acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0), + vget_high_s16(input_val)); + acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1), + vget_high_s16(input_val)); + acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2), + vget_high_s16(input_val)); + acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3), + vget_high_s16(input_val)); } if (in < input_size) { - int32 buf[4 * kPeel]; - for (int k = 0; k < 4; k++) { - vst1q_s32(buf + 4 * k, acc[k]); - } + int32 buf[16]; + vst1q_s32(buf + 0, acc0); + vst1q_s32(buf + 4, acc1); + vst1q_s32(buf + 8, acc2); + vst1q_s32(buf + 12, acc3); for (; in < input_size; in++) { int lane = (in + 8 - input_size) % 4; const int32 input_val = input_data[in] + input_offset; @@ -950,26 +978,28 @@ inline void FullyConnectedAsGEMV( buf[lane + 4 * k] += filter_val * input_val; } } - for (int k = 0; k < 4; k++) { - acc[k] = vld1q_s32(buf + 4 * k); - } + acc0 = vld1q_s32(buf + 0); + acc1 = vld1q_s32(buf + 4); + acc2 = vld1q_s32(buf + 8); + acc3 = vld1q_s32(buf + 12); } // Horizontally reduce accumulators - int32x2_t pairwise_reduced_acc[kPeel]; - for (int k = 0; k < kPeel; k++) { - pairwise_reduced_acc[k] = - vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k])); - } - static_assert(kPeel == 4, "the code below currently assumes kPeel = 4"); + int32x2_t pairwise_reduced_acc_0 = + vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0)); + int32x2_t pairwise_reduced_acc_1 = + vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1)); + int32x2_t pairwise_reduced_acc_2 = + vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2)); + int32x2_t pairwise_reduced_acc_3 = + vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3)); const int32x2_t reduced_lo = - vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]); + vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); const int32x2_t reduced_hi = - vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]); + vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); // Add bias values. - int32x4_t bias_vec = vld1q_s32(bias_ptr); - bias_ptr += 4; + int32x4_t bias_vec = vld1q_s32(bias_data + out); reduced = vaddq_s32(reduced, bias_vec); if (shift_left) { const int32 multiplier_power_of_two = 1 << output_shift; @@ -992,11 +1022,116 @@ inline void FullyConnectedAsGEMV( // Apply the clamping from the activation function res8 = vmax_u8(res8, vdup_n_u8(output_activation_min)); res8 = vmin_u8(res8, vdup_n_u8(output_activation_max)); - // Store results to destination. Assumes 32bit alignment. - vst1_lane_u32(reinterpret_cast(output_ptr), - vreinterpret_u32_u8(res8), 0); - output_ptr += kPeel; + // Store results to destination. + vst1_lane_u8(output_data + out + 0, res8, 0); + vst1_lane_u8(output_data + out + 1, res8, 1); + vst1_lane_u8(output_data + out + 2, res8, 2); + vst1_lane_u8(output_data + out + 3, res8, 3); + } +} + +struct FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { + FullyConnectedAsGEMVWorkerTask(const RuntimeShape& input_shape, + const uint8* input_data, int32 input_offset, + const RuntimeShape& filter_shape, + const uint8* filter_data, int32 filter_offset, + const RuntimeShape& bias_shape, + const int32* bias_data, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, + const RuntimeShape& output_shape, + uint8* output_data, int row_start, int row_end) + : input_shape_(input_shape), + input_data_(input_data), + input_offset_(input_offset), + filter_shape_(filter_shape), + filter_data_(filter_data), + filter_offset_(filter_offset), + bias_shape_(bias_shape), + bias_data_(bias_data), + output_offset_(output_offset), + output_multiplier_(output_multiplier), + output_shift_(output_shift), + output_activation_min_(output_activation_min), + output_activation_max_(output_activation_max), + output_shape_(output_shape), + output_data_(output_data), + row_start_(row_start), + row_end_(row_end) {} + + void Run() override { + FullyConnectedAsGEMVWorkerImpl( + input_shape_, input_data_, input_offset_, filter_shape_, filter_data_, + filter_offset_, bias_shape_, bias_data_, output_offset_, + output_multiplier_, output_shift_, output_activation_min_, + output_activation_max_, output_shape_, output_data_, row_start_, + row_end_); } + + const RuntimeShape& input_shape_; + const uint8* input_data_; + int32 input_offset_; + const RuntimeShape& filter_shape_; + const uint8* filter_data_; + int32 filter_offset_; + const RuntimeShape& bias_shape_; + const int32* bias_data_; + int32 output_offset_; + int32 output_multiplier_; + int output_shift_; + int32 output_activation_min_; + int32 output_activation_max_; + const RuntimeShape& output_shape_; + uint8* output_data_; + gemmlowp::GemmContext* gemm_context_; + int row_start_; + int row_end_; +}; + +inline void FullyConnectedAsGEMV( + const RuntimeShape& input_shape, const uint8* input_data, + int32 input_offset, const RuntimeShape& filter_shape, + const uint8* filter_data, int32 filter_offset, + const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, const RuntimeShape& output_shape, + uint8* output_data, gemmlowp::GemmContext* gemm_context) { + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_rows = output_shape.Dims(output_dim_count - 1); + const int input_size = FlatSizeSkipDim(input_shape, 0); + static constexpr int kKernelRows = 4; + const int thread_count = gemmlowp::HowManyThreads( + gemm_context->max_num_threads(), output_rows, batches, input_size); + if (thread_count == 1) { + // Single-thread case: do the computation on the current thread, don't + // use a threadpool + FullyConnectedAsGEMVWorkerImpl( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_shape, output_data, 0, output_rows); + return; + } + + // Multi-threaded case: use the gemmlowp context's threadpool. + TFLITE_DCHECK_GT(thread_count, 1); + std::vector tasks(thread_count); + const int kRowsPerWorker = + gemmlowp::RoundUp(output_rows / thread_count); + int row_start = 0; + for (int i = 0; i < thread_count; ++i) { + int row_end = std::min(output_rows, row_start + kRowsPerWorker); + tasks[i] = new FullyConnectedAsGEMVWorkerTask( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_shape, output_data, row_start, row_end); + row_start = row_end; + } + TFLITE_DCHECK_EQ(row_start, output_rows); + gemm_context->workers_pool()->Execute(tasks); } #endif // USE_NEON @@ -1053,14 +1188,16 @@ inline void FullyConnected( const int filter_dim_count = filter_shape.DimensionsCount(); const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); #ifdef USE_NEON - const int output_size = MatchingDim(filter_shape, filter_dim_count - 2, - output_shape, output_dim_count - 1); - if (batches == 1 && !(output_size % 4)) { - return FullyConnectedAsGEMV( - input_shape, input_data, input_offset, filter_shape, filter_data, - filter_offset, bias_shape, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_shape, output_data); + if (batches == 1) { + const int output_size = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + if (output_size >= 4) { + return FullyConnectedAsGEMV( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_shape, output_data, gemm_context); + } } #endif // USE_NEON const int filter_rows = filter_shape.Dims(filter_dim_count - 2); @@ -1616,6 +1753,222 @@ inline void ShuffledFullyConnected( gemm_context->workers_pool()->Execute(tasks); } +inline void MeanImpl(const tflite::MeanParams& op_params, + const RuntimeShape& input_shape, const uint8_t* input_data, + int32 input_zero_point, float input_scale, + const RuntimeShape& output_shape, uint8_t* output_data, + int32 output_zero_point, float output_scale, + int start_depth, int end_depth) { + gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl"); + + // Current implementation only supports dimension equals 4 and simultaneous + // reduction over width and height. + const int output_batch = output_shape.Dims(0); + const int output_height = output_shape.Dims(2); + const int output_width = output_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const float num_elements_in_axis = input_width * input_height; + + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + const bool ordinary_mean = + (input_zero_point == output_zero_point && input_scale == output_scale); + float scale, bias; + if (!ordinary_mean) { + scale = input_scale / output_scale; + bias = -input_zero_point * scale + 0.5; + } + +#ifdef USE_NEON + const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis); + // This is only an approximation as NEON does not offer division instruction. + const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup); + const float32x4_t kRounding = vdupq_n_f32(0.5); + float32x4_t bias_dup; + float32x4_t output_zero_point_dup; + if (!ordinary_mean) { + bias_dup = vdupq_n_f32(bias); + output_zero_point_dup = vdupq_n_f32(output_zero_point); + } +#endif + + for (int out_b = 0; out_b < output_batch; ++out_b) { + int out_d = start_depth; +#ifdef USE_NEON + + for (; out_d < end_depth - 8; out_d += 8) { + float32x4_t temp_sum_1 = vdupq_n_f32(0); + float32x4_t temp_sum_2 = vdupq_n_f32(0); + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + const uint8_t* input_data_ptr = + input_data + Offset(input_shape, out_b, in_h, in_w, out_d); + uint8x8_t input_data_val = vld1_u8(input_data_ptr); + int16x8_t input_data_val_shift = + vreinterpretq_s16_u16(vmovl_u8(input_data_val)); + float32x4_t input_float_1 = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift))); + float32x4_t input_float_2 = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift))); + temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1); + temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2); + } + } + + float32x4_t mean_1 = vmulq_f32(temp_sum_1, num_elements_reverse); + float32x4_t mean_2 = vmulq_f32(temp_sum_2, num_elements_reverse); + + if (!ordinary_mean) { + // maq is not supported, break down into two ops. + mean_1 = vmulq_n_f32(mean_1, scale); + mean_1 = vaddq_f32(mean_1, bias_dup); + mean_2 = vmulq_n_f32(mean_2, scale); + mean_2 = vaddq_f32(mean_2, bias_dup); + } + + if (!ordinary_mean) { + mean_1 = vaddq_f32(mean_1, output_zero_point_dup); + mean_2 = vaddq_f32(mean_2, output_zero_point_dup); + } + + // Rounding. + mean_1 = vaddq_f32(mean_1, kRounding); + mean_2 = vaddq_f32(mean_2, kRounding); + uint32x4_t casted_mean_1 = vcvtq_u32_f32(mean_1); + uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1); + uint32x4_t casted_mean_2 = vcvtq_u32_f32(mean_2); + uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2); + uint16x8_t combined_mean = + vcombine_u16(narrow_range_mean_2, narrow_range_mean_1); + uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean); + uint8_t* output_data_ptr = + output_data + Offset(output_shape, out_b, 0, 0, out_d); + vst1_u8(output_data_ptr, narrowed_combined_mean); + } +#endif + + for (; out_d < end_depth; ++out_d) { + float temp_value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + temp_value += + input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; + } + } + + temp_value = temp_value / num_elements_in_axis; + if (ordinary_mean) { + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(round(temp_value)); + } else { + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(round(temp_value * scale + bias)) + + output_zero_point; + } + } + } +} + +struct MeanWorkerTask : public gemmlowp::Task { + MeanWorkerTask(const tflite::MeanParams& op_params, + const RuntimeShape& input_shape, const uint8_t* input_data, + int32 input_zero_point, float input_scale, + const RuntimeShape& output_shape, uint8_t* output_data, + int32 output_zero_point, float output_scale, int start_height, + int end_height) + : op_params_(op_params), + input_shape_(input_shape), + input_data_(input_data), + input_zero_point_(input_zero_point), + input_scale_(input_scale), + output_shape_(output_shape), + output_data_(output_data), + output_zero_point_(output_zero_point), + output_scale_(output_scale), + start_height_(start_height), + end_height_(end_height) {} + + void Run() override { + MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_, + input_scale_, output_shape_, output_data_, output_zero_point_, + output_scale_, start_height_, end_height_); + } + + private: + const tflite::MeanParams& op_params_; + const RuntimeShape& input_shape_; + const uint8_t* input_data_; + int32 input_zero_point_; + float input_scale_; + const RuntimeShape& output_shape_; + uint8_t* output_data_; + int32 output_zero_point_; + float output_scale_; + int start_height_; + int end_height_; + gemmlowp::GemmContext* gemm_context_; +}; + +inline void Mean(const tflite::MeanParams& op_params, + const RuntimeShape& unextended_input_shape, + const uint8_t* input_data, int32 input_zero_point, + float input_scale, const RuntimeShape& unextended_output_shape, + uint8_t* output_data, int32 output_zero_point, + float output_scale, gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8"); + + // Current implementation only supports dimension equals 4 and simultaneous + // reduction over width and height. + TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4); + TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + constexpr int kMinDepthPerThread = 8; + int thread_count = output_depth / kMinDepthPerThread; + thread_count = thread_count > 0 ? thread_count : 1; + const int capped_thread_count = + std::min(thread_count, gemm_context->max_num_threads()); + + if (thread_count == 1) { + MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale, + output_shape, output_data, output_zero_point, output_scale, 0, + output_depth); + } else { + // Instead parrallel for batch, we loop for the output_depth since batch + // is typical 1. + std::vector tasks(capped_thread_count); + int depth_start = 0; + for (int i = 0; i < capped_thread_count; ++i) { + // Try to distribute the tasks as even as possible. + int depth_end = depth_start + + (output_depth - depth_start) / (capped_thread_count - i); + tasks[i] = new MeanWorkerTask(op_params, input_shape, input_data, + input_zero_point, input_scale, output_shape, + output_data, output_zero_point, + output_scale, depth_start, depth_end); + depth_start = depth_end; + } + gemm_context->workers_pool()->Execute(tasks); + } +} + template inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w, int h, int b, int kheight, int kwidth, @@ -2083,6 +2436,21 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); + +#ifdef USE_NEON + if (gemm_input_cols == 1 && output_rows >= 4) { + RuntimeShape fc_filter_shape{ + filter_shape.Dims(0), + filter_shape.Dims(filter_shape.DimensionsCount() - 1)}; + + return FullyConnectedAsGEMV( + *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape, + filter_data, filter_offset, bias_shape, bias_data, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_shape, output_data, gemm_context); + } +#endif + gemmlowp::MatrixMap filter_matrix( filter_data, filter_rows, filter_cols); gemmlowp::MatrixMap input_matrix( @@ -2215,55 +2583,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params, } } -inline void GetInvSqrtQuantizedMultiplierExp(int32 input, - int32* output_inv_sqrt, - int* output_shift) { - *output_shift = 11; - while (input >= (1 << 29)) { - input /= 4; - ++*output_shift; - } - TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = - CountLeadingZeros(static_cast(input)) - 1; - const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; - const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; - *output_shift -= left_shift_bit_pairs; - input <<= 2 * left_shift_bit_pairs; - TFLITE_DCHECK_GE(input, (1 << 27)); - TFLITE_DCHECK_LT(input, (1 << 29)); - using gemmlowp::FixedPoint; - using gemmlowp::Rescale; - using gemmlowp::SaturatingRoundingMultiplyByPOT; - // Using 3 integer bits gives us enough room for the internal arithmetic in - // this Newton-Raphson iteration. - using F3 = FixedPoint; - using F0 = FixedPoint; - const F3 fixedpoint_input = F3::FromRaw(input >> 1); - const F3 fixedpoint_half_input = - SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); - const F3 fixedpoint_half_three = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); - // Newton-Raphson iteration - // Naive unoptimized starting guess: x = 1 - F3 x = F3::One(); - // Naive unoptimized number of iterations: 5 - for (int i = 0; i < 5; i++) { - const F3 x3 = Rescale<3>(x * x * x); - x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); - } - const F0 fixedpoint_half_sqrt_2 = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); - x = x * fixedpoint_half_sqrt_2; - *output_inv_sqrt = x.raw(); - if (*output_shift < 0) { - *output_inv_sqrt <<= -*output_shift; - *output_shift = 0; - } - // Convert right shift (right is positive) to left shift. - *output_shift *= kReverseShift; -} - inline void L2Normalization(const tflite::L2NormalizationParams& op_params, const RuntimeShape& input_shape, const uint8* input_data, @@ -2285,8 +2604,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params, } int32 inv_l2norm_multiplier; int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift, + &inv_l2norm_multiplier, &inv_l2norm_shift); for (int c = 0; c < depth; c++) { int32 diff = *input_data - input_zero_point; @@ -2678,7 +2997,7 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner // dimension. for (int i0 = 0; i0 < y0; ++i0) { - const uint8* input2_data_ptr; + const uint8* input2_data_ptr = nullptr; for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { @@ -2707,7 +3026,7 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, // for y4 == 1 and the loop over y3 is contained within the // AddScalarBroadcast function. for (int i0 = 0; i0 < y0; ++i0) { - const uint8* input2_data_ptr; + const uint8* input2_data_ptr = nullptr; for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { @@ -3064,7 +3383,7 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, int y4 = params.broadcast_shape[4]; if (y4 > 1) { for (int i0 = 0; i0 < y0; ++i0) { - const uint8* input2_data_ptr; + const uint8* input2_data_ptr = nullptr; for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { @@ -3081,7 +3400,7 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, } } else { for (int i0 = 0; i0 < y0; ++i0) { - const uint8* input2_data_ptr; + const uint8* input2_data_ptr = nullptr; for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { @@ -3709,6 +4028,14 @@ inline void AveragePool(const PoolParams& params, const uint8* input_data, const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); + + // Here, and in other pooling ops, in order to maintain locality of reference, + // to minimize some recalculations, and to load into NEON vector registers, we + // use an inner loop down the depth. Since depths can be large and hence we + // would need arbitrarily large temporary storage, we divide the work up into + // depth tranches just within the batch loop. + static constexpr int kPoolingAccTrancheSize = 256; + TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); @@ -3721,69 +4048,76 @@ inline void AveragePool(const PoolParams& params, const int output_width = output_shape.Dims(2); const int stride_height = params.stride_height; const int stride_width = params.stride_width; + + uint16 acc[kPoolingAccTrancheSize]; for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = - (out_x * stride_width) - params.padding_values.width; - const int in_y_origin = - (out_y * stride_height) - params.padding_values.height; - const int filter_x_start = std::max(0, -in_x_origin); - const int filter_x_end = - std::min(params.filter_width, input_width - in_x_origin); - const int filter_y_start = std::max(0, -in_y_origin); - const int filter_y_end = - std::min(params.filter_height, input_height - in_y_origin); - const int filter_count = - (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); - // 2560 is required by MobileNetV2 with depth multiplier 2. - static constexpr int kAccBufferMaxSize = 4096; - TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); - uint16 acc[kAccBufferMaxSize]; - memset(acc, 0, depth * sizeof(acc[0])); - const uint8* input_ptr = - input_data + - depth * (in_x_origin + - input_width * (in_y_origin + input_height * batch)); - for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = - input_ptr + depth * (fy * input_width + filter_x_start); - for (int fx = filter_x_start; fx < filter_x_end; fx++) { - int channel = 0; + // We proceed through the depth in tranches (see comment above). The + // depth_base is the depth at the beginning of the tranche. The + // tranche_depth is the depth dimension of the tranche. + for (int depth_base = 0; depth_base < depth; + depth_base += kPoolingAccTrancheSize) { + const int tranche_depth = + std::min(depth - depth_base, kPoolingAccTrancheSize); + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(params.filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(params.filter_height, input_height - in_y_origin); + const int filter_count = + (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); + memset(acc, 0, tranche_depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + depth_base + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + const uint8* input_channel_ptr = input_row_ptr; + int channel = 0; #ifdef USE_NEON - for (; channel <= depth - 16; channel += 16) { - uint16x8_t acc_reg[2]; - for (int i = 0; i < 2; i++) { - acc_reg[i] = vld1q_u16(acc + channel + 8 * i); + for (; channel <= tranche_depth - 16; channel += 16) { + uint16x8_t acc_reg[2]; + for (int i = 0; i < 2; i++) { + acc_reg[i] = vld1q_u16(acc + channel + 8 * i); + } + uint8x16_t input_reg = vld1q_u8(input_channel_ptr); + input_channel_ptr += 16; + acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg)); + acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg)); + for (int i = 0; i < 2; i++) { + vst1q_u16(acc + channel + 8 * i, acc_reg[i]); + } } - uint8x16_t input_reg = vld1q_u8(input_row_ptr); - input_row_ptr += 16; - acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg)); - acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg)); - for (int i = 0; i < 2; i++) { - vst1q_u16(acc + channel + 8 * i, acc_reg[i]); + for (; channel <= tranche_depth - 8; channel += 8) { + uint16x8_t acc_reg = vld1q_u16(acc + channel); + uint8x8_t input_reg = vld1_u8(input_channel_ptr); + input_channel_ptr += 8; + acc_reg = vaddw_u8(acc_reg, input_reg); + vst1q_u16(acc + channel, acc_reg); } - } - for (; channel <= depth - 8; channel += 8) { - uint16x8_t acc_reg = vld1q_u16(acc + channel); - uint8x8_t input_reg = vld1_u8(input_row_ptr); - input_row_ptr += 8; - acc_reg = vaddw_u8(acc_reg, input_reg); - vst1q_u16(acc + channel, acc_reg); - } #endif - for (; channel < depth; ++channel) { - acc[channel] += *input_row_ptr++; + for (; channel < tranche_depth; ++channel) { + acc[channel] += *input_channel_ptr++; + } + input_row_ptr += depth; } } - } - uint8* output_ptr = - output_data + Offset(output_shape, batch, out_y, out_x, 0); - int channel = 0; + uint8* output_ptr = output_data + Offset(output_shape, batch, out_y, + out_x, depth_base); + int channel = 0; #ifdef USE_NEON #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ if (filter_count == FILTER_COUNT) { \ - for (; channel <= depth - 8; channel += 8) { \ + for (; channel <= tranche_depth - 8; channel += 8) { \ uint16 buf[8]; \ for (int i = 0; i < 8; i++) { \ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ @@ -3794,25 +4128,26 @@ inline void AveragePool(const PoolParams& params, vst1_u8(output_ptr + channel, buf8); \ } \ } - AVGPOOL_DIVIDING_BY(9) - AVGPOOL_DIVIDING_BY(15) + AVGPOOL_DIVIDING_BY(9) + AVGPOOL_DIVIDING_BY(15) #undef AVGPOOL_DIVIDING_BY - for (; channel <= depth - 8; channel += 8) { - uint16 buf[8]; - for (int i = 0; i < 8; i++) { - buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; + for (; channel <= tranche_depth - 8; channel += 8) { + uint16 buf[8]; + for (int i = 0; i < 8; i++) { + buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; + } + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); + buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); + vst1_u8(output_ptr + channel, buf8); } - uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); - buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); - buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); - vst1_u8(output_ptr + channel, buf8); - } #endif - for (; channel < depth; ++channel) { - uint16 a = (acc[channel] + filter_count / 2) / filter_count; - a = std::max(a, params.quantized_activation_min); - a = std::min(a, params.quantized_activation_max); - output_ptr[channel] = static_cast(a); + for (; channel < tranche_depth; ++channel) { + uint16 a = (acc[channel] + filter_count / 2) / filter_count; + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); + output_ptr[channel] = static_cast(a); + } } } } @@ -3877,6 +4212,14 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); + + // Here, and in other pooling ops, in order to maintain locality of reference, + // to minimize some recalculations, and to load into NEON vector registers, we + // use an inner loop down the depth. Since depths can be large and hence we + // would need arbitrarily large temporary storage, we divide the work up into + // depth tranches just within the batch loop. + static constexpr int kPoolingAccTrancheSize = 256; + TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); @@ -3889,77 +4232,85 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, const int output_width = output_shape.Dims(2); const int stride_height = params.stride_height; const int stride_width = params.stride_width; + + uint8 acc[kPoolingAccTrancheSize]; for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = - (out_x * stride_width) - params.padding_values.width; - const int in_y_origin = - (out_y * stride_height) - params.padding_values.height; - const int filter_x_start = std::max(0, -in_x_origin); - const int filter_x_end = - std::min(params.filter_width, input_width - in_x_origin); - const int filter_y_start = std::max(0, -in_y_origin); - const int filter_y_end = - std::min(params.filter_height, input_height - in_y_origin); - // 2560 is required by MobileNetV2 with depth multiplier 2. - static constexpr int kAccBufferMaxSize = 4096; - TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); - uint8 acc[kAccBufferMaxSize]; - memset(acc, 0, depth * sizeof(acc[0])); - const uint8* input_ptr = - input_data + - depth * (in_x_origin + - input_width * (in_y_origin + input_height * batch)); - for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = - input_ptr + depth * (fy * input_width + filter_x_start); - for (int fx = filter_x_start; fx < filter_x_end; fx++) { - int channel = 0; + // We proceed through the depth in tranches (see comment above). The + // depth_base is the depth at the beginning of the tranche. The + // tranche_depth is the depth dimension of the tranche. + for (int depth_base = 0; depth_base < depth; + depth_base += kPoolingAccTrancheSize) { + const int tranche_depth = + std::min(depth - depth_base, kPoolingAccTrancheSize); + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(params.filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(params.filter_height, input_height - in_y_origin); + memset(acc, 0, tranche_depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + depth_base + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + const uint8* input_channel_ptr = input_row_ptr; + int channel = 0; #ifdef USE_NEON - for (; channel <= depth - 16; channel += 16) { - uint8x16_t acc_reg = vld1q_u8(acc + channel); - uint8x16_t input_reg = vld1q_u8(input_row_ptr); - input_row_ptr += 16; - acc_reg = vmaxq_u8(acc_reg, input_reg); - vst1q_u8(acc + channel, acc_reg); - } + for (; channel <= tranche_depth - 16; channel += 16) { + uint8x16_t acc_reg = vld1q_u8(acc + channel); + uint8x16_t input_reg = vld1q_u8(input_channel_ptr); + input_channel_ptr += 16; + acc_reg = vmaxq_u8(acc_reg, input_reg); + vst1q_u8(acc + channel, acc_reg); + } - for (; channel <= depth - 8; channel += 8) { - uint8x8_t acc_reg = vld1_u8(acc + channel); - uint8x8_t input_reg = vld1_u8(input_row_ptr); - input_row_ptr += 8; - acc_reg = vmax_u8(acc_reg, input_reg); - vst1_u8(acc + channel, acc_reg); - } + for (; channel <= tranche_depth - 8; channel += 8) { + uint8x8_t acc_reg = vld1_u8(acc + channel); + uint8x8_t input_reg = vld1_u8(input_channel_ptr); + input_channel_ptr += 8; + acc_reg = vmax_u8(acc_reg, input_reg); + vst1_u8(acc + channel, acc_reg); + } #endif - for (; channel < depth; ++channel) { - acc[channel] = std::max(acc[channel], *input_row_ptr++); + for (; channel < tranche_depth; ++channel) { + acc[channel] = std::max(acc[channel], *input_channel_ptr++); + } + input_row_ptr += depth; } } - } - uint8* output_ptr = - output_data + Offset(output_shape, batch, out_y, out_x, 0); - int channel = 0; + uint8* output_ptr = output_data + Offset(output_shape, batch, out_y, + out_x, depth_base); + int channel = 0; #ifdef USE_NEON - for (; channel <= depth - 16; channel += 16) { - uint8x16_t a = vld1q_u8(acc + channel); - a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max)); - a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min)); - vst1q_u8(output_ptr + channel, a); - } - for (; channel <= depth - 8; channel += 8) { - uint8x8_t a = vld1_u8(acc + channel); - a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max)); - a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min)); - vst1_u8(output_ptr + channel, a); - } + for (; channel <= tranche_depth - 16; channel += 16) { + uint8x16_t a = vld1q_u8(acc + channel); + a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min)); + vst1q_u8(output_ptr + channel, a); + } + for (; channel <= tranche_depth - 8; channel += 8) { + uint8x8_t a = vld1_u8(acc + channel); + a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max)); + a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min)); + vst1_u8(output_ptr + channel, a); + } #endif - for (; channel < depth; ++channel) { - uint8 a = acc[channel]; - a = std::max(a, params.quantized_activation_min); - a = std::min(a, params.quantized_activation_max); - output_ptr[channel] = static_cast(a); + for (; channel < tranche_depth; ++channel) { + uint8 a = acc[channel]; + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); + output_ptr[channel] = static_cast(a); + } } } } @@ -4345,119 +4696,6 @@ inline void LogSoftmax(const SoftmaxParams& params, } } -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1_impl( - gemmlowp::FixedPoint input_val) { - // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); - // assert(__builtin_clz(0u) <= std::numeric_limits::digits); - using FixedPoint0 = gemmlowp::FixedPoint; - // The reason for accumulating the result with an extra bit of headroom is - // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * - // recip_denom will otherwise introduce an error. - static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; - using FixedPointAccum = gemmlowp::FixedPoint; - - const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1488522236, std::log(2.0)); - const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); - const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1518500250, std::sqrt(0.5)); - const FixedPoint0 one_quarter = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); - - const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1057819769, - 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); - - const FixedPointAccum shifted_quarter = - gemmlowp::Rescale(one_quarter); - - // Reinterpret the input value as Q0.31, because we will figure out the - // required shift "ourselves" instead of using, say, Rescale. - FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); - // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); - FixedPoint0 r_a_tmp = - SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); - const int32 r_a_raw = - SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); - // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); - // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, - // InputIntegerBits - z_b_headroom - 0.25); - const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), - shifted_quarter); - - // z_b is treated like z_a, but premultiplying by sqrt(0.5). - FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; - const int32 r_b_raw = - SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); - const FixedPointAccum z_b_pow_2_adj = SaturatingSub( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), - shifted_quarter); - - const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); - const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( - std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); - - const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); - FixedPoint0 q = r - sqrt_sqrt_half; - q = q + q; - - const FixedPoint0 common_sq = q * q; - const FixedPoint0 num = q * r + q * common_sq * alpha_n; - const FixedPoint0 denom_minus_one_0 = - p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; - const FixedPoint0 recip_denom = - one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); - - const FixedPointAccum num_scaled = gemmlowp::Rescale(num); - return gemmlowp::Rescale(z_pow_2_adj * log_2 + - num_scaled * recip_denom); -} - -// Minimum output bits to accommodate log of maximum input range. It actually -// does not matter if one considers, say, [-64,64] or [-64,64). -// -// For example, run this through Octave: -// [0:127; ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] -constexpr int min_log_x_output_bits(int input_bits) { - return input_bits > 90 - ? 7 - : input_bits > 44 - ? 6 - : input_bits > 21 - ? 5 - : input_bits > 10 - ? 4 - : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; -} - -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1( - gemmlowp::FixedPoint input_val) { - static_assert( - OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), - "Output integer bits must be sufficent to accommodate logs of inputs."); - return log_x_for_x_greater_than_or_equal_to_1_impl( - input_val); -} - // Currently just a copy of the reference code. inline void LogSoftmax(const SoftmaxParams& params, const RuntimeShape& input_shape, const uint8* input_data, @@ -5052,6 +5290,14 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data, output_map.array() = Eigen::floor(input_map.array()); } +inline void Ceil(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + gemmlowp::ScopedProfilingLabel label("Ceil"); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); + output_map.array() = Eigen::ceil(input_map.array()); +} + #ifdef USE_NEON inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, float scale, float* output_ptr) { diff --git a/tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h index 8f52ef131dedf4d0270c0346b1094add57f52dfc..00b2d7e063254e2941fd3453f15dbaf2dbd4451e 100644 --- a/tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -54,6 +54,25 @@ void NeonMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ vectors, const float* scaling_factors, int n_batch, float* __restrict__ result, int result_stride); +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride); +void NeonSparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride); + +// Matrix multiplication for quantized values using symmetric quantization. +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride); +void NeonSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride); + // Cwise product of two vectors. void PortableVectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, diff --git a/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h index 002444b6810925910a651dd5c919a46ac8e5fb47..c38f37416dde30cf16a41d6cc6f08dc40f3dfe7d 100644 --- a/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -23,90 +23,173 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { + +// Used in tests and template parameters to control which version of depthwise +// convolution is called. Primarily for reference code, and specializations +// forced in tests. +enum class DepthwiseConvImplementation { + // Run all tests against kUseStandardEntry even if also testing another + // kernel, since we need to be sure that the main DepthwiseConv() function in + // optimized_ops.h dispatches to a correctly-executing kernel. + kNone = 0, // The "default" option: use the normal + // DepthwiseConv kernel (entry) function. + kUseGenericKernel, // Forced use of generic kernel. + kUseNeon3x3, // 3x3 kernel that uses NEON when available. + kUseNeon3x3DotProduct, // 3x3 kernel that uses dot-product enabled NEON + // when available. + kUseCModel3x3DotProduct, // 3x3 kernel, reference C model that is intended + // to match overall design NEON code. + kUseUnwound3x3DotProduct, // 3x3 kernel, reference C model with unwound loops + // and some arrays. + kUseIntrinsics3x3DotProduct, // 3x3 kernel using NEON intrinsics. +}; + +// Category of depthwise convolution output rounding. +enum class DepthwiseConvOutputRounding { + kNone = 0, // Invalid: specific method must be specified. + kAwayFromZero, // Original method: exact halves rounded away from zero. + kUpward, // Halves towards +infinity: adds 0.5 before truncate. + // This is where a future kNearestEven would be placed. +}; + +// Category of depthwise convolution depth multiplication. +enum class DepthwiseConvDepthMultiplication { + kNoMultiplication = 0, // Depth multiplier = 1. + kUnitInputDepth, // Input depth = 1, output depth = depth multiplier. +}; + namespace reference_ops { +namespace depthwise_conv { -inline void DepthwiseConv( - const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data) { - const int stride_width = params.stride_width; - const int stride_height = params.stride_height; - const int dilation_width_factor = params.dilation_width_factor; - const int dilation_height_factor = params.dilation_height_factor; - const int pad_width = params.padding_values.width; - const int pad_height = params.padding_values.height; - const int depth_multiplier = params.depth_multiplier; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; - const int output_shift = params.output_shift; - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingDim(input_shape, 0, output_shape, 0); - const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); - const int input_height = input_shape.Dims(1); - const int input_width = input_shape.Dims(2); - const int input_depth = input_shape.Dims(3); - const int filter_height = filter_shape.Dims(1); - const int filter_width = filter_shape.Dims(2); - const int output_height = output_shape.Dims(1); - const int output_width = output_shape.Dims(2); - TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); - TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); - - for (int b = 0; b < batches; ++b) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int ic = 0; ic < input_depth; ++ic) { - for (int m = 0; m < depth_multiplier; m++) { - const int oc = m + ic * depth_multiplier; - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; - int32 acc = 0; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // If the location is outside the bounds of the input image, - // use zero as a default value. - if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height)) { - int32 input_val = - input_data[Offset(input_shape, b, in_y, in_x, ic)]; - int32 filter_val = filter_data[Offset( - filter_shape, 0, filter_y, filter_x, oc)]; - acc += - (filter_val + filter_offset) * (input_val + input_offset); +template +inline int32 DepthwiseConvRound(int32 x, int32 quantized_multiplier, + int shift) { + TFLITE_DCHECK_NE(output_rounding, DepthwiseConvOutputRounding::kNone); + return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift); +} + +template <> +inline int32 DepthwiseConvRound( + int32 x, int32 quantized_multiplier, int shift) { + return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift); +} + +template <> +inline int32 DepthwiseConvRound( + int32 x, int32 quantized_multiplier, int shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + const int left_shift = shift > 0 ? shift : 0; + const int right_shift = shift > 0 ? 0 : -shift; + const int rounding_offset = right_shift > 0 ? 1 << (right_shift - 1) : 0; + return (SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier) + + rounding_offset) >> + right_shift; +} + +template +struct DepthwiseConvBasicKernel { + static inline void Run(const DepthwiseParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& filter_shape, + const uint8* filter_data, + const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int depth_multiplier = params.depth_multiplier; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = + in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = + input_data[Offset(input_shape, b, in_y, in_x, ic)]; + int32 filter_val = filter_data[Offset( + filter_shape, 0, filter_y, filter_x, oc)]; + acc += (filter_val + filter_offset) * + (input_val + input_offset); + } } } + if (bias_data) { + acc += bias_data[oc]; + } + acc = DepthwiseConvRound(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_shape, b, out_y, out_x, oc)] = + static_cast(acc); } - if (bias_data) { - acc += bias_data[oc]; - } - acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, - output_shift); - acc += output_offset; - acc = std::max(acc, output_activation_min); - acc = std::min(acc, output_activation_max); - output_data[Offset(output_shape, b, out_y, out_x, oc)] = - static_cast(acc); } } } } } +}; + +} // namespace depthwise_conv + +inline void DepthwiseConv( + const DepthwiseParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data) { + return depthwise_conv::DepthwiseConvBasicKernel< + DepthwiseConvOutputRounding::kAwayFromZero>::Run(params, input_shape, + input_data, filter_shape, + filter_data, bias_shape, + bias_data, output_shape, + output_data); } -} // end namespace reference_ops +} // namespace reference_ops } // end namespace tflite #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/add.h b/tensorflow/lite/kernels/internal/reference/integer_ops/add.h new file mode 100644 index 0000000000000000000000000000000000000000..a694ba2aaa993b0631958e0b338a7a62e154de75 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/add.h @@ -0,0 +1,144 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_ + +#include +#include "public/gemmlowp.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_integer_ops { + +// Element-wise add that can often be used for inner loop of broadcast add as +// well as the non-broadcast add. +inline void AddElementwise(int size, const ArithmeticParams& params, + const int8_t* input1_data, const int8_t* input2_data, + int8_t* output_data) { + const int32_t int8_max_value = std::numeric_limits::max(); + TFLITE_DCHECK_GE(params.input1_offset, -1 * int8_max_value); + TFLITE_DCHECK_GE(params.input2_offset, -1 * int8_max_value); + TFLITE_DCHECK_LE(params.input1_offset, int8_max_value); + TFLITE_DCHECK_LE(params.input2_offset, int8_max_value); + + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = input2_val * (1 << params.left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, params.input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, params.input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); + } +} + +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int8_t* input1_data, + const RuntimeShape& input2_shape, const int8_t* input2_data, + const RuntimeShape& output_shape, int8_t* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + const int32_t int8_max_value = std::numeric_limits::max(); + TFLITE_DCHECK_GE(params.input1_offset, -1 * int8_max_value); + TFLITE_DCHECK_GE(params.input2_offset, -1 * int8_max_value); + TFLITE_DCHECK_LE(params.input1_offset, int8_max_value); + TFLITE_DCHECK_LE(params.input2_offset, int8_max_value); + AddElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int8_t* input1_data, + const RuntimeShape& input2_shape, + const int8_t* input2_data, + const RuntimeShape& output_shape, + int8_t* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int8"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + const RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + const int32_t input1_val = + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + const int32_t input2_val = + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32_t shifted_input1_val = + input1_val * (1 << params.left_shift); + const int32_t shifted_input2_val = + input2_val * (1 << params.left_shift); + const int32_t scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, + params.input1_shift); + const int32_t scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, + params.input2_shift); + const int32_t raw_sum = scaled_input1_val + scaled_input2_val; + const int32_t raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32_t clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[Offset(extended_output_shape, b, y, x, c)] = + static_cast(clamped_output); + } + } + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_ADD_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h new file mode 100644 index 0000000000000000000000000000000000000000..4ad2a70b31c23cb9e316b88d7b53f6c065675e9f --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h @@ -0,0 +1,128 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +// Fixed-point per-channel-quantization convolution reference kernel. +inline void ConvPerChannel( + const ConvParams& params, const int32* output_multiplier, + const int32* output_shift, const RuntimeShape& input_shape, + const int8* input_data, const RuntimeShape& filter_shape, + const int8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int8* output_data) { + // Get parameters. + const int32 input_offset = params.input_offset; // r = s(q - Z) + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int32 output_offset = params.output_offset; + + // Set min and max value of the output. + const int32 output_activation_min = std::numeric_limits::min(); + const int32 output_activation_max = std::numeric_limits::max(); + + // Sanity check. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + + // Check dimensions of the tensors. + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; + // Zero padding by omitting the areas outside the image. + const bool is_point_inside_image = + (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height); + if (is_point_inside_image) { + int32 input_val = input_data[Offset(input_shape, batch, in_y, + in_x, in_channel)]; + int32 filter_val = + filter_data[Offset(filter_shape, out_channel, filter_y, + filter_x, in_channel)]; + // Accumulate with 32 bits accumulator. + // In the nudging process during model quantization, we force + // real value of 0.0 be represented by a quantized value. This + // guarantees that the input_offset is a int8, even though it + // is represented using int32. + // int32 += int8 * (int8 - int8) so the highest value we can + // get from each accumulation is [-127, 127] * ([-128, 127] - + // [-128, 127]), which is [-32512, 32512]. log2(32512) + // = 14.98, which means we can accumulate at least 2^16 + // multiplications without overflow. The accumulator is + // applied to a filter so the accumulation logic will hold as + // long as the filter size (filter_y * filter_x * in_channel) + // does not exceed 2^16, which is the case in all the models + // we have seen so far. + // TODO(jianlijianli): Add a check to make sure the + // accumulator depth is smaller than 2^16. + acc += filter_val * (input_val - input_offset); + } + } + } + } + + if (bias_data) { + acc += bias_data[out_channel]; + } + acc = MultiplyByQuantizedMultiplier( + acc, output_multiplier[out_channel], output_shift[out_channel]); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] = + static_cast(acc); + } + } + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_CONV_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..b6195c0da93a47c3034a50a167761ec8b5850c44 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { +inline void DepthwiseConvPerChannel( + const DepthwiseParams& params, const int32* output_multiplier, + const int32* output_shift, const RuntimeShape& input_shape, + const int8* input_data, const RuntimeShape& filter_shape, + const int8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int8* output_data) { + // Get parameters. + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int depth_multiplier = params.depth_multiplier; + const int32 input_offset = params.input_offset; + const int32 output_offset = params.output_offset; + + // Set min and max value of the output. + const int32 output_activation_min = std::numeric_limits::min(); + const int32 output_activation_max = std::numeric_limits::max(); + + // Check dimensions of the tensors. + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + for (int m = 0; m < depth_multiplier; ++m) { + const int output_channel = m + in_channel * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; + // Zero padding by omitting the areas outside the image. + const bool is_point_inside_image = + (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height); + if (is_point_inside_image) { + int32 input_val = input_data[Offset(input_shape, batch, in_y, + in_x, in_channel)]; + int32 filter_val = filter_data[Offset( + filter_shape, 0, filter_y, filter_x, output_channel)]; + // Accumulate with 32 bits accumulator. + // In the nudging process during model quantization, we force + // real value of 0.0 be represented by a quantized value. This + // guarentees that the input_offset is a int8, even though it + // is represented using int32. + // int32 += int8 * (int8 - int8) so the highest value we can + // get from each accumulation is [-127, 127] * ([-128, 127] - + // [-128, 127]), which is [-32512, 32512]. log2(32512) + // = 14.98, which means we can accumulate at least 2^16 + // multiplications without overflow. The accumulator is + // applied to a filter so the accumulation logic will hold as + // long as the filter size (filter_y * filter_x * in_channel) + // does not exceed 2^16, which is the case in all the models + // we have seen so far. + // TODO(jianlijianli): Add a check to make sure the + // accumulator depth is smaller than 2^16. + acc += filter_val * (input_val - input_offset); + } + } + } + if (bias_data) { + acc += bias_data[output_channel]; + } + acc = MultiplyByQuantizedMultiplier( + acc, output_multiplier[output_channel], + output_shift[output_channel]); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, + output_channel)] = static_cast(acc); + } + } + } + } + } +} +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h new file mode 100644 index 0000000000000000000000000000000000000000..36b349f4d49bb66d6f60ed92629945640e24cb9c --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h @@ -0,0 +1,69 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, void* gemm_context) { + (void)gemm_context; // only used in optimized code. + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = output_shape.Dims(0); + const int output_depth = output_shape.Dims(1); + TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32 acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32 input_val = input_data[b * accum_depth + d]; + int32 filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc); + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h b/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..8e7c7f317602d78a661500049eab736207aafcd9 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h @@ -0,0 +1,65 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void L2Normalization(int32_t input_zero_point, int32_t outer_size, + int32_t depth, const int8* input_data, + int8* output_data) { + static constexpr int8_t kMinInt8 = std::numeric_limits::min(); + static constexpr int8_t kMaxInt8 = std::numeric_limits::max(); + // The output scale must be in sync with Prepare(). + // Output is in 1/128 scale so the actual output range is nudged from [-1, 1] + // to [-1, 127/128]. + static constexpr int32_t kOutputScale = 7; + for (int outer_index = 0; outer_index < outer_size; ++outer_index) { + // int32 = (int8 - int8) ^ 2. + // ([-128, 127] - [-128, 127]) ^ 2 = [0, (2^8 - 1)^2] so the accumulator is + // safe from overflowing in at least 2^16 steps. + int32_t acc = 0; + for (int inner_index = 0; inner_index < depth; ++inner_index) { + int32_t input = + input_data[depth * outer_index + inner_index] - input_zero_point; + acc += input * input; + } + int32_t inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplierExp(acc, /*reverse_shift*/ -1, + &inv_l2norm_multiplier, &inv_l2norm_shift); + + for (int inner_index = 0; inner_index < depth; ++inner_index) { + int32_t input = + input_data[depth * outer_index + inner_index] - input_zero_point; + + // Rescale and downcast. Rescale is folded into the division. + int32_t output_in_q24 = MultiplyByQuantizedMultiplier( + input, inv_l2norm_multiplier, inv_l2norm_shift + kOutputScale); + output_in_q24 = + std::min(static_cast(kMaxInt8), + std::max(static_cast(kMinInt8), output_in_q24)); + output_data[depth * outer_index + inner_index] = + static_cast(output_in_q24); + } + } +} +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_L2NORMALIZATION_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h b/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..f22bb4f13803cf4e14c8b4fd18b9c301fab07359 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h @@ -0,0 +1,111 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift, + int32_t reverse_multiplier, int32_t reverse_shift, + int32_t diff_min, int32_t outer_size, int32_t depth, + const int8* input_data, int8* output_data) { + static constexpr int8_t kMinInt8 = std::numeric_limits::min(); + static constexpr int8_t kMaxInt8 = std::numeric_limits::max(); + static constexpr int32_t kMinInt32 = std::numeric_limits::min(); + + // [-16, 0] is mapped to [-128, 127] with 1/16 as scale and 127 as zero + // point. This nudges the output to [-255/16, 0]. + static constexpr int32_t kOutputZeroPoint = 127; + + // All IntegerBits must agree with Prepare function. + // Input is chosen as Q5.26 so exp(-1 * 2^5 * 2^-1) = exp(-16) is negligible. + static constexpr int kInputIntegerBits = 5; + static constexpr int kAccumulationIntegerBits = 12; + static constexpr int kOutputIntegerBits = 4; + using F5 = gemmlowp::FixedPoint; + using F12 = gemmlowp::FixedPoint; + + for (int outer_index = 0; outer_index < outer_size; ++outer_index) { + int8 max_in_row = kMinInt8; + for (int inner_index = 0; inner_index < depth; ++inner_index) { + max_in_row = + std::max(max_in_row, input_data[outer_index * depth + inner_index]); + } + + // Accumulator "sum_of_exps_in_q12" is safe from overflowing in 2^12 steps. + F12 sum_of_exps_in_q12 = F12::FromRaw(0); + for (int inner_index = 0; inner_index < depth; ++inner_index) { + int32_t input_diff = + static_cast(input_data[outer_index * depth + inner_index]) - + max_in_row; + if (input_diff >= diff_min) { + const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier( + input_diff, input_multiplier, input_shift); + sum_of_exps_in_q12 = + sum_of_exps_in_q12 + + gemmlowp::Rescale( + exp_on_negative_values(F5::FromRaw(input_diff_in_q5))); + } + } + + const int32_t log_sum_of_exps_in_q5 = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps_in_q12) + .raw(); + + // Potentially reduced the valid range. shifted_log_sum_of_exps_in_q5 is + // smallest representable in Q5.26 plus the log_sum_of_exps. + const int32_t shifted_log_sum_of_exps_in_q5 = + log_sum_of_exps_in_q5 + kMinInt32; + const int32_t adjusted_diff_min = std::max( + diff_min - 1, + MultiplyByQuantizedMultiplier(shifted_log_sum_of_exps_in_q5, + reverse_multiplier, -reverse_shift)); + + for (int inner_index = 0; inner_index < depth; ++inner_index) { + int32_t input_diff = + static_cast(input_data[outer_index * depth + inner_index]) - + max_in_row; + // Note use of > below instead of >= above. + if (input_diff > adjusted_diff_min) { + const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier( + input_diff, input_multiplier, input_shift); + + // Rescale and downcast. + int32_t output_in_q27 = + gemmlowp::RoundingDivideByPOT( + (input_diff_in_q5 - log_sum_of_exps_in_q5), + 31 - kInputIntegerBits - kOutputIntegerBits) + + kOutputZeroPoint; + + output_in_q27 = + std::max(std::min(output_in_q27, static_cast(kMaxInt8)), + static_cast(kMinInt8)); + output_data[outer_index * depth + inner_index] = + static_cast(output_in_q27); + } else { + output_data[outer_index * depth + inner_index] = kMinInt8; + } + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h b/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h new file mode 100644 index 0000000000000000000000000000000000000000..8277c3b3d565d845da4cc8931a4256c1005db77c --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_ + +#include +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void Logistic(int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int32_t input_left_shift, + int32_t input_size, const int8_t* input_data, + int8_t* output_data) { + // Integer bits must be in sync with Prepare() function. + static constexpr int32_t kInputIntegerBits = 4; + static constexpr int32_t kOutputIntegerBits = 8; + static constexpr int8_t kMinInt8 = std::numeric_limits::min(); + static constexpr int8_t kMaxInt8 = std::numeric_limits::max(); + static constexpr int32_t kOutputZeroPoint = -128; + + for (int i = 0; i < input_size; ++i) { + const int32_t input = + static_cast(input_data[i]) - input_zero_point; + if (input <= -input_range_radius) { + output_data[i] = kMinInt8; + } else if (input >= input_range_radius) { + output_data[i] = kMaxInt8; + } else { + const int32_t input_in_q4 = MultiplyByQuantizedMultiplier( + input, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + const int32_t output_in_q0 = + gemmlowp::logistic(FixedPoint4::FromRaw(input_in_q4)).raw(); + + // Rescale and downcast. + using gemmlowp::RoundingDivideByPOT; + int32_t output_in_q23 = + RoundingDivideByPOT(output_in_q0, 31 - kOutputIntegerBits); + output_in_q23 = std::min(std::max(output_in_q23 + kOutputZeroPoint, + static_cast(kMinInt8)), + static_cast(kMaxInt8)); + output_data[i] = static_cast(output_in_q23); + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOGISTIC_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h new file mode 100644 index 0000000000000000000000000000000000000000..72885d1b467ba57ec163876b1a1f271d879d36c7 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier, + int32_t shift, const RuntimeShape& unextended_input_shape, + const int8_t* input_data, int32 input_zero_point, + const RuntimeShape& unextended_output_shape, + int8_t* output_data, int32 output_zero_point) { + // Current implementation only supports dimension equals 4 and simultaneous + // reduction over width and height. + TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4); + TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + const int output_batch = output_shape.Dims(0); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int num_elements_in_axis = input_width * input_height; + + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + static constexpr int32_t kMinInt8 = std::numeric_limits::min(); + static constexpr int32_t kMaxInt8 = std::numeric_limits::max(); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + int32 acc = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)] - + input_zero_point; + } + } + acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift); + acc = acc > 0 ? (acc + num_elements_in_axis / 2) / num_elements_in_axis + : (acc - num_elements_in_axis / 2) / num_elements_in_axis; + acc += output_zero_point; + acc = std::min(std::max(acc, kMinInt8), kMaxInt8); + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(acc); + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h new file mode 100644 index 0000000000000000000000000000000000000000..5e33d089945a2907e489c51c117eec77b194ed7e --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h @@ -0,0 +1,130 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void MulElementwise(int size, const ArithmeticParams& params, + const int8_t* input1_data, const int8_t* input2_data, + int8_t* output_data) { + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); + } +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int8_t* input1_data, + const RuntimeShape& input2_shape, const int8_t* input2_data, + const RuntimeShape& output_shape, int8_t* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + gemmlowp::ScopedProfilingLabel label("Mul/8bit"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +// Mul with 16 bit inputs and int8_t outputs. +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int8_t* output_data) { + gemmlowp::ScopedProfilingLabel label("Mul/Int16Int8"); + int32 output_offset = params.output_offset; + int32 output_activation_min = params.quantized_activation_min; + int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 unclamped_result = + F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); + int16 rescaled_result = + gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8); + int16 clamped_result = + std::min(output_activation_max - output_offset, rescaled_result); + clamped_result = + std::max(output_activation_min - output_offset, clamped_result); + output_data[i] = output_offset + clamped_result; + } +} + +inline void BroadcastMul4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int8_t* input1_data, + const RuntimeShape& input2_shape, + const int8_t* input2_data, + const RuntimeShape& output_shape, + int8_t* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + // The input shapes are extended as part of NdArrayDesc initialization. + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + const RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + const int32 input1_val = + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + const int32 input2_val = + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, params.output_multiplier, + params.output_shift); + const int32 clamped_output = std::min( + params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[Offset(extended_output_shape, b, y, x, c)] = + static_cast(clamped_output); + } + } + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h index 22750bc91a856b360459fbf9b5ed0519e4ac6c88..2762bec8e6c3c8d69198456cbd16b04dc45ef2ab 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_POOLING_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_POOLING_H_ +#include #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { @@ -77,6 +78,63 @@ inline void AveragePool(const PoolParams& params, } } +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const int8* input_data, const RuntimeShape& output_shape, + int8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_GE(params.quantized_activation_min, + std::numeric_limits::min()); + TFLITE_DCHECK_LE(params.quantized_activation_max, + std::numeric_limits::max()); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(params.filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(params.filter_height, input_height - in_y_origin); + int8_t max = std::numeric_limits::lowest(); + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); + } + } + max = std::max(max, params.quantized_activation_min); + max = std::min(max, params.quantized_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + static_cast(max); + } + } + } + } +} + } // namespace reference_integer_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h b/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h index 3f6bf1cb73e40b2bc396a59f5b47cefaea071d02..892b38630a64e6c56b3216b74e9210b66fc8fd24 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h @@ -85,9 +85,8 @@ inline void Softmax(const SoftmaxParams& params, (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); const int32 shifted_output = unsat_output - 128; - output_data[i * depth + c] = static_cast( - std::max(std::min(shifted_output, static_cast(127)), - static_cast(-128))); + output_data[i * depth + c] = + static_cast(std::max(std::min(shifted_output, 127), -128)); } else { output_data[i * depth + c] = -128; diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h b/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h new file mode 100644 index 0000000000000000000000000000000000000000..081928bc88d9c59e15b5ed857daf4a144abe2ad7 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_ + +#include +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void Tanh(int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int32_t input_shift, + int32_t input_size, const int8_t* input_data, + int8_t* output_data) { + // Integer bits must be in sync with Prepare() function. + static constexpr int32_t kInputIntegerBits = 4; + static constexpr int32_t kOutputScale = 7; + static constexpr int8_t kMinInt8 = std::numeric_limits::min(); + static constexpr int8_t kMaxInt8 = std::numeric_limits::max(); + using F4 = gemmlowp::FixedPoint; + + for (int i = 0; i < input_size; ++i) { + const int32_t input = + static_cast(input_data[i]) - input_zero_point; + if (input <= -input_range_radius) { + output_data[i] = kMinInt8; + } else if (input >= input_range_radius) { + output_data[i] = kMaxInt8; + } else { + const int32_t input_in_q4 = + MultiplyByQuantizedMultiplier(input, input_multiplier, input_shift); + const int32_t output_in_q0 = + gemmlowp::tanh(F4::FromRaw(input_in_q4)).raw(); + + // Rescale and downcast. + using gemmlowp::RoundingDivideByPOT; + int32_t output_in_q24 = + RoundingDivideByPOT(output_in_q0, 31 - kOutputScale); + output_in_q24 = + std::min(std::max(output_in_q24, static_cast(kMinInt8)), + static_cast(kMaxInt8)); + output_data[i] = static_cast(output_in_q24); + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TANH_H_ diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index 380fc8f98ebbdd90bb68144a46903640734bff08..390bf08e30300625471f8fe0bfceac21fc43756d 100644 --- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -2033,7 +2034,16 @@ template void ArgMax(const T3* axis, const T1* input_data, const tflite::Dims<4>& input_dims, T2* output_data, const tflite::Dims<4>& output_dims) { - ArgMinMax(DimsToShape(input_dims), input_data, axis, DimsToShape(output_dims), + // Assumes the input always has 4 dimensions, and therefore, + // output always has three dimensions. + auto output_shape = RuntimeShape( + {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]}); + // Another way to interpret this is that output_dims.sizes[4] is always 1. + TFLITE_DCHECK_EQ(output_shape.FlatSize(), + DimsToShape(output_dims).FlatSize()); + // Legacy path only supported this. + TFLITE_DCHECK_EQ(axis[0], 3); + ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape, output_data, std::greater()); } diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 1acf0caad0db8481965fcba0bc1fafb41bd23f47..f5c4b78dc1429f45e477ecc9528e976aeda2ab1f 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -109,6 +110,73 @@ void PortableMatrixBatchVectorMultiplyAccumulate( } // for batch } +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride) { + const int kBlockSize = 16; + TFLITE_DCHECK_EQ( // NOLINT + m_cols % kBlockSize, 0); + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + const uint8_t* ledger_ptr = ledger; + for (int r = 0; r < m_rows; r++) { + float dot_prod = 0.0f; + int num_nonzero_blocks = *ledger_ptr++; + if (num_nonzero_blocks > 0) { + const float* vector_in_batch = vector + b * m_cols; + for (int i = 0; i < num_nonzero_blocks; i++) { + const int block_start_index = *ledger_ptr++ * kBlockSize; + const float* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++; + } + } + } + *result_in_batch += dot_prod; + result_in_batch += result_stride; + } + } +} + +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride) { + const int kBlockSize = 16; + TFLITE_DCHECK_EQ( // NOLINT + m_cols % kBlockSize, 0); + int batch, row; + for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { + const float batch_scaling_factor = scaling_factors[batch]; + // Get the address of the first row. + const int8_t* row_ptr = matrix; + const uint8_t* ledger_ptr = ledger; + for (row = 0; row < m_rows; ++row, result += result_stride) { + // Initialize the dot product sum for the row to 0. + int32_t dotprod = 0; +#if defined(__GNUC__) + // Prefetch the row to cache. + __builtin_prefetch(row_ptr, 0 /* prefetch for read */, + 3 /* temporal locality */); +#endif + int num_nonzero_blocks = *ledger_ptr++; + if (num_nonzero_blocks > 0) { + for (int i = 0; i < num_nonzero_blocks; i++) { + const int block_start_index = *ledger_ptr++ * kBlockSize; + const int8_t* vector_block_ptr = vectors + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dotprod += (*row_ptr++) * (*vector_block_ptr++); + } // for block + } + } + *result += (dotprod * batch_scaling_factor); + } // for row + } // for batch +} + void PortableVectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, float* result) { diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index a06ebc1600d4fe47cf054b4e157bc21a5f70ddfc..49b59da0bbaf7aec6ba1b66b499df8d5426f5951 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -48,6 +48,16 @@ void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ vectors, const float* scaling_factors, int n_batch, float* __restrict__ result, int result_stride); +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride); + +void PortableSparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride); + // Cwise product of two vectors. void PortableVectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, @@ -165,6 +175,23 @@ void MatrixBatchVectorMultiplyAccumulate( result_stride); } +void SparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride) { + PortableSparseMatrixBatchVectorMultiplyAccumulate( + matrix, ledger, m_rows, m_cols, vector, n_batch, result, result_stride); +} + +void SparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride) { + PortableSparseMatrixBatchVectorMultiplyAccumulate( + matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + result_stride); +} + void VectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, float* result) { PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result); diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index b6a8f3859f0aa36184304ee9d3af32d9f77d6f57..a028ab105807a4043b0b6f1d6f1d43720d5d4116 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -36,68 +36,6 @@ limitations under the License. namespace tflite { -// TODO(b/77858996): Add these to gemmlowp. -template -IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - return a; -} - -template <> -inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { - std::int64_t a64 = a; - std::int64_t b64 = b; - std::int64_t sum = a64 + b64; - return static_cast(std::min( - static_cast(std::numeric_limits::max()), - std::max( - static_cast(std::numeric_limits::min()), - sum))); -} - -template -gemmlowp::FixedPoint SaturatingAddNonGemmlowp( - gemmlowp::FixedPoint a, - gemmlowp::FixedPoint b) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingAddNonGemmlowp(a.raw(), b.raw())); -} - -template -IntegerType SaturatingSub(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - return a; -} - -template <> -inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { - std::int32_t a32 = a; - std::int32_t b32 = b; - std::int32_t diff = a32 - b32; - return static_cast(std::min(32767, std::max(-32768, diff))); -} - -template <> -inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { - std::int64_t a64 = a; - std::int64_t b64 = b; - std::int64_t diff = a64 - b64; - return static_cast(std::min( - static_cast(std::numeric_limits::max()), - std::max( - static_cast(std::numeric_limits::min()), - diff))); -} - -template -gemmlowp::FixedPoint SaturatingSub( - gemmlowp::FixedPoint a, - gemmlowp::FixedPoint b) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingSub(a.raw(), b.raw())); -} -// End section to be moved to gemmlowp. - namespace reference_ops { // Return true for broadcast case, false otherwise. @@ -192,59 +130,6 @@ inline bool ProcessBroadcastShapes(const RuntimeShape& shape0, return true; } -template -int CountLeadingZeros(T integer_input) { - static_assert(std::is_unsigned::value, - "Only unsigned integer types handled."); - if (integer_input == 0) { - return std::numeric_limits::digits; - } - const T one_in_leading_positive = static_cast(1) - << (std::numeric_limits::digits - 1); - int leading_zeros = 0; - while (integer_input < one_in_leading_positive) { - integer_input <<= 1; - ++leading_zeros; - } - return leading_zeros; -} - -template -IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { - if (exponent == 0) { - return x; - } - using ScalarIntegerType = - typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; - const IntegerType min = - gemmlowp::Dup(std::numeric_limits::min()); - const IntegerType max = - gemmlowp::Dup(std::numeric_limits::max()); - const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); - - const std::int32_t threshold = - ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); - const IntegerType positive_mask = - gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); - const IntegerType negative_mask = - gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); - - IntegerType result = gemmlowp::ShiftLeft(x, exponent); - result = gemmlowp::SelectUsingMask(positive_mask, max, result); - result = gemmlowp::SelectUsingMask(negative_mask, min, result); - return result; -} - -// If we want to leave IntegerBits fixed, then multiplication -// by a power of two has to be saturating/rounding, not exact anymore. -template -gemmlowp::FixedPoint -SaturatingRoundingMultiplyByPOTParam( - gemmlowp::FixedPoint a, int exponent) { - return gemmlowp::FixedPoint::FromRaw( - SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); -} - inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& filter_shape, const float* filter_data, const RuntimeShape& bias_shape, @@ -506,6 +391,15 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, } } +inline void Elu(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const float val = input_data[i]; + output_data[i] = val < 0.0 ? std::exp(val) - 1 : val; + } +} + inline void Relu(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -543,16 +437,17 @@ inline void Relu6(const RuntimeShape& input_shape, const float* input_data, } } +template inline void ReluX(const tflite::ActivationParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); - const uint8 max_value = params.quantized_activation_max; - const uint8 min_value = params.quantized_activation_min; + const T max_value = params.quantized_activation_max; + const T min_value = params.quantized_activation_min; for (int i = 0; i < flat_size; ++i) { - const uint8 val = input_data[i]; - const uint8 clamped = + const T val = input_data[i]; + const T clamped = val > max_value ? max_value : val < min_value ? min_value : val; output_data[i] = clamped; } @@ -594,55 +489,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params, } } -inline void GetInvSqrtQuantizedMultiplierExp(int32 input, - int32* output_inv_sqrt, - int* output_shift) { - *output_shift = 11; - while (input >= (1 << 29)) { - input /= 4; - ++*output_shift; - } - TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = - CountLeadingZeros(static_cast(input)) - 1; - const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; - const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; - *output_shift -= left_shift_bit_pairs; - input <<= 2 * left_shift_bit_pairs; - TFLITE_DCHECK_GE(input, (1 << 27)); - TFLITE_DCHECK_LT(input, (1 << 29)); - using gemmlowp::FixedPoint; - using gemmlowp::Rescale; - using gemmlowp::SaturatingRoundingMultiplyByPOT; - // Using 3 integer bits gives us enough room for the internal arithmetic in - // this Newton-Raphson iteration. - using F3 = FixedPoint; - using F0 = FixedPoint; - const F3 fixedpoint_input = F3::FromRaw(input >> 1); - const F3 fixedpoint_half_input = - SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); - const F3 fixedpoint_half_three = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); - // Newton-Raphson iteration - // Naive unoptimized starting guess: x = 1 - F3 x = F3::One(); - // Naive unoptimized number of iterations: 5 - for (int i = 0; i < 5; i++) { - const F3 x3 = Rescale<3>(x * x * x); - x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); - } - const F0 fixedpoint_half_sqrt_2 = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); - x = x * fixedpoint_half_sqrt_2; - *output_inv_sqrt = x.raw(); - if (*output_shift < 0) { - *output_inv_sqrt <<= -*output_shift; - *output_shift = 0; - } - // Convert right shift (right is positive) to left shift. - *output_shift *= kReverseShift; -} - inline void L2Normalization(const tflite::L2NormalizationParams& op_params, const RuntimeShape& input_shape, const uint8* input_data, @@ -662,9 +508,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params, } int32 inv_l2norm_multiplier; int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift, + &inv_l2norm_multiplier, &inv_l2norm_shift); for (int c = 0; c < depth; c++) { int32 diff = input_data[depth * i + c] - input_zero_point; int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( @@ -702,6 +547,22 @@ inline void Add(const ArithmeticParams& params, } } +// T is expected to be either float or int. +template +inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs, + T* const* input_data, T* output_data) { + // All inputs and output should have the same shape, this is checked during + // Prepare stage. + const size_t size = input_shape.FlatSize(); + for (int i = 0; i < size; ++i) { + T x = 0; + for (int j = 0; j < num_inputs; ++j) { + x += input_data[j][i]; + } + output_data[i] = x; + } +} + // Element-wise add that can often be used for inner loop of broadcast add as // well as the non-broadcast add. inline void AddElementwise(int size, const ArithmeticParams& params, @@ -1688,6 +1549,54 @@ inline void SubWithActivation(const ArithmeticParams& params, } } +inline void Sub16(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16_t* input1_data, + const RuntimeShape& input2_shape, const int16_t* input2_data, + const RuntimeShape& output_shape, int16_t* output_data) { + gemmlowp::ScopedProfilingLabel label("Sub/Int16"); + const int input1_shift = params.input1_shift; + const int flat_size = + MatchingFlatSize(output_shape, input1_shape, input2_shape); + const int16 output_activation_min = params.quantized_activation_min; + const int16 output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); + TFLITE_DCHECK_LE(input1_shift, 0); + TFLITE_DCHECK_LE(params.input2_shift, 0); + const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; + const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; + const int input_right_shift = + input1_shift == 0 ? -params.input2_shift : -input1_shift; + + if (input1_shift == 0) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + for (int i = 0; i < flat_size; ++i) { + F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); + F0 scaled_input = F0::FromRaw( + gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); + F0 result = SaturatingSub(input_ready_scaled, scaled_input); + const int16 raw_output = result.raw(); + const int16 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = clamped_output; + } + } else { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + for (int i = 0; i < flat_size; ++i) { + F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); + F0 scaled_input = F0::FromRaw( + gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); + F0 result = SaturatingSub(scaled_input, input_ready_scaled); + const int16 raw_output = result.raw(); + const int16 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = clamped_output; + } + } +} + template inline void Concatenation(const ConcatenationParams& params, const RuntimeShape* const* input_shapes, @@ -1837,11 +1746,17 @@ void Unpack(const UnpackParams& params, const RuntimeShape& input_shape, const int outputs_count = params.num_split; int outer_size = 1; - for (int i = 0; i < params.axis; i++) { + int axis = params.axis; + if (axis < 0) { + axis += dimensions; + } + TFLITE_DCHECK_GE(axis, 0); + TFLITE_DCHECK_LT(axis, dimensions); + for (int i = 0; i < axis; ++i) { outer_size *= input_shape.Dims(i); } int copy_size = 1; - for (int i = params.axis + 1; i < dimensions; i++) { + for (int i = axis + 1; i < dimensions; ++i) { copy_size *= input_shape.Dims(i); } TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size); @@ -2681,121 +2596,6 @@ inline void LogSoftmax(const SoftmaxParams& params, } } -// Although currently the name of this function says that it cannot handle -// values less than 1, in practice it can handle as low as 1/x_max, where -// x_max is the largest representable input. In other words, the output range -// is symmetric. -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1_impl( - gemmlowp::FixedPoint input_val) { - using FixedPoint0 = gemmlowp::FixedPoint; - // The reason for accumulating the result with an extra bit of headroom is - // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * - // recip_denom will otherwise introduce an error. - static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; - using FixedPointAccum = gemmlowp::FixedPoint; - - const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1488522236, std::log(2.0)); - const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); - const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1518500250, std::sqrt(0.5)); - const FixedPoint0 one_quarter = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); - - const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 1057819769, - 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); - const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( - FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); - - const FixedPointAccum shifted_quarter = - gemmlowp::Rescale(one_quarter); - - // Reinterpret the input value as Q0.31, because we will figure out the - // required shift "ourselves" instead of using, say, Rescale. - FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); - // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); - FixedPoint0 r_a_tmp = - SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); - const int32 r_a_raw = - SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); - // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); - // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, - // InputIntegerBits - z_b_headroom - 0.25); - const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), - shifted_quarter); - - // z_b is treated like z_a, but premultiplying by sqrt(0.5). - FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; - const int32 r_b_raw = - SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); - const FixedPointAccum z_b_pow_2_adj = SaturatingSub( - FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( - InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), - shifted_quarter); - - const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); - const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( - std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); - - const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); - FixedPoint0 q = r - sqrt_sqrt_half; - q = q + q; - - const FixedPoint0 common_sq = q * q; - const FixedPoint0 num = q * r + q * common_sq * alpha_n; - const FixedPoint0 denom_minus_one_0 = - p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; - const FixedPoint0 recip_denom = - one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); - - const FixedPointAccum num_scaled = gemmlowp::Rescale(num); - return gemmlowp::Rescale(z_pow_2_adj * log_2 + - num_scaled * recip_denom); -} - -// Minimum output bits to accommodate log of maximum input range. It actually -// does not matter if one considers, say, [-64,64] or [-64,64). -// -// For example, run this through Octave: -// [0:127; ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... -// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] -constexpr int min_log_x_output_bits(int input_bits) { - return input_bits > 90 - ? 7 - : input_bits > 44 - ? 6 - : input_bits > 21 - ? 5 - : input_bits > 10 - ? 4 - : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; -} - -template -inline gemmlowp::FixedPoint -log_x_for_x_greater_than_or_equal_to_1( - gemmlowp::FixedPoint input_val) { - static_assert( - OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), - "Output integer bits must be sufficent to accommodate logs of inputs."); - return log_x_for_x_greater_than_or_equal_to_1_impl( - input_val); -} - inline void LogSoftmax(const SoftmaxParams& params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, uint8* output_data) { @@ -3122,6 +2922,16 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data, } } +inline void Ceil(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + for (int i = 0; i < flat_size; i++) { + int offset = i; + output_data[offset] = std::ceil(input_data[offset]); + } +} + template inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, @@ -3159,6 +2969,43 @@ inline void Gather(const tflite::GatherParams& op_params, } } +template +inline void GatherNd(const RuntimeShape& params_shape, + const ParamsT* params_data, + const RuntimeShape& indices_shape, + const IndicesT* indices_data, + const RuntimeShape& output_shape, ParamsT* output_data) { + gemmlowp::ScopedProfilingLabel label("GatherNd"); + + int n_slices = 1; + int slice_size = 1; + const int indices_dims = indices_shape.DimensionsCount(); + const int indices_nd = indices_shape.Dims(indices_dims - 1); + const int params_dims = params_shape.DimensionsCount(); + for (int i = 0; i < indices_dims - 1; ++i) { + n_slices *= indices_shape.Dims(i); + } + for (int i = indices_nd; i < params_dims; ++i) { + slice_size *= params_shape.Dims(i); + } + + int remain_flat_size = params_shape.FlatSize(); + std::vector dims_to_count(indices_nd, 0); + for (int i = 0; i < indices_nd; ++i) { + dims_to_count[i] = remain_flat_size / params_shape.Dims(i); + remain_flat_size = dims_to_count[i]; + } + + for (int i = 0; i < n_slices; ++i) { + int from_pos = 0; + for (int j = 0; j < indices_nd; ++j) { + from_pos += indices_data[i * indices_nd + j] * dims_to_count[j]; + } + std::memcpy(output_data + i * slice_size, params_data + from_pos, + sizeof(ParamsT) * slice_size); + } +} + template inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape& unextended_input_shape, @@ -3458,6 +3305,16 @@ inline void PadImageStyle(const tflite::PadParams& op_params, output_data); } +template +inline void PadImageStyle(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, + const int8_t* input_data, const P* pad_value_ptr, + const RuntimeShape& output_shape, + int8_t* output_data) { + Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, + output_data); +} + template inline void PadImageStyle(const tflite::PadParams& op_params, const RuntimeShape& input_shape, @@ -3787,6 +3644,65 @@ inline void Mean(const tflite::MeanParams& op_params, } } +inline void Mean(const tflite::MeanParams& op_params, + const RuntimeShape& unextended_input_shape, + const uint8_t* input_data, int32 input_zero_point, + float input_scale, const RuntimeShape& unextended_output_shape, + uint8_t* output_data, int32 output_zero_point, + float output_scale) { + gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8"); + + // Current implementation only supports dimension equals 4 and simultaneous + // reduction over width and height. + TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4); + TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + const int output_batch = output_shape.Dims(0); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const float num_elements_in_axis = input_width * input_height; + + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + const bool ordinary_mean = + (input_zero_point == output_zero_point && input_scale == output_scale); + float scale, bias; + if (!ordinary_mean) { + scale = input_scale / output_scale; + bias = -input_zero_point * scale + 0.5; + } + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float temp_value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + temp_value += + input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; + } + } + temp_value = temp_value / num_elements_in_axis; + if (ordinary_mean) { + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(round(temp_value)); + } else { + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(round(temp_value * scale + bias)) + + output_zero_point; + } + } + } +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis for quantized values. @@ -3950,11 +3866,8 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, const T3* input2_data, const RuntimeShape& output_shape, T2* output_data, const Cmp& cmp) { gemmlowp::ScopedProfilingLabel label("ArgMinMax"); - // For ArgMax, the number of output dimensions = (number of input dimensions - - // 1). For the sake of simplicity, the output dimensions are equal to the - // input dimensions here. We enforce the constraint that the axis dimension - // must always be 1. - TFLITE_DCHECK_EQ(input1_shape.DimensionsCount(), + TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0); + TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1, output_shape.DimensionsCount()); int axis = input2_data[0]; @@ -3963,7 +3876,6 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, } const int axis_size = input1_shape.Dims(axis); - TFLITE_DCHECK_EQ(output_shape.Dims(axis), 1); int outer_size = 1; for (int i = 0; i < axis; ++i) { @@ -3974,7 +3886,7 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, int inner_size = 1; const int dims_count = input1_shape.DimensionsCount(); for (int i = axis + 1; i < dims_count; ++i) { - TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i)); + TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1)); inner_size *= input1_shape.Dims(i); } @@ -4419,6 +4331,34 @@ void RankOneSelect(const RuntimeShape& input_condition_shape, } } +template +void SelectTrueCoords(const RuntimeShape& input_condition_shape, + const D* input_condition_data, T* output_data) { + const size_t size = input_condition_shape.FlatSize(); + const size_t cond_rank = input_condition_shape.DimensionsCount(); + + std::vector dims_to_count(cond_rank, 0); + int cur_flat_size = size; + for (int i = 0; i < cond_rank; ++i) { + dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i); + cur_flat_size = dims_to_count[i]; + } + + int output_index = 0; + for (int i = 0; i < size; ++i) { + if (input_condition_data[i]) { + // Insert the coordinate of the current item (row major) into output. + int flat_index = i; + for (int j = 0; j < cond_rank; ++j) { + int coord_j = flat_index / dims_to_count[j]; + output_data[output_index * cond_rank + j] = coord_j; + flat_index %= dims_to_count[j]; + } + output_index++; + } + } +} + // For easy implementation, the indices is always a vector of size-4 vectors. template inline void SparseToDense(const std::vector>& indices, @@ -4714,6 +4654,112 @@ void Fill(const RuntimeShape& value_shape, const T* value_data, } } +template +void Reverse(int axis, const RuntimeShape& input_shape, + const Scalar* input_data, const RuntimeShape& output_shape, + Scalar* output_data) { + gemmlowp::ScopedProfilingLabel label("Reverse"); + + int outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= input_shape.Dims(i); + } + + int copy_size = 1; + for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) { + copy_size *= input_shape.Dims(i); + } + + const int dims_at_axis = input_shape.Dims(axis); + for (int i = 0; i < outer_size; ++i) { + for (int j = 0; j < dims_at_axis; ++j) { + const int start_pos = (i * dims_at_axis + j) * copy_size; + Scalar* output_ptr = output_data + start_pos; + int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size; + memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar)); + } + } +} + +template +void ReverseSequence(const TS* seq_lengths, const int seq_dim, + const int batch_dim, const RuntimeShape& input_shape, + const Scalar* input_data, const RuntimeShape& output_shape, + Scalar* output_data) { + gemmlowp::ScopedProfilingLabel label("ReverseSequence"); + + int outer_size = 1; + int outer_dim = std::min(batch_dim, seq_dim); + int medium_dim = std::max(batch_dim, seq_dim); + for (int i = 0; i < outer_dim; ++i) { + outer_size *= input_shape.Dims(i); + } + + int medium_size = 1; + for (int i = outer_dim + 1; i < medium_dim; ++i) { + medium_size *= input_shape.Dims(i); + } + + int copy_size = 1; + for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) { + copy_size *= input_shape.Dims(i); + } + + const int dims_at_outer_dim = input_shape.Dims(outer_dim); + const int dims_at_medium_dim = input_shape.Dims(medium_dim); + + Scalar* output_ptr; + if (batch_dim > seq_dim) { + for (int i = 0; i < outer_size; ++i) { + for (int j = 0; j < dims_at_outer_dim; ++j) { + const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size; + for (int p = 0; p < medium_size; ++p) { + for (int q = 0; q < dims_at_medium_dim; ++q) { + const int in_pos = + ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size; + const Scalar* in_ptr = input_data + in_pos; + int sl = seq_lengths[q] - 1; + if (j > sl) { + output_ptr = output_data + in_pos; + } else { + const int out_pos_base = + (i * dims_at_outer_dim + sl - j) * medium_size; + const int out_pos = + ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size; + output_ptr = output_data + out_pos; + } + memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar)); + } + } + } + } + } else if (batch_dim < seq_dim) { + for (int i = 0; i < outer_size; ++i) { + for (int j = 0; j < dims_at_outer_dim; ++j) { + const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size; + int sl = seq_lengths[j] - 1; + const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size; + for (int p = 0; p < medium_size; ++p) { + for (int q = 0; q < dims_at_medium_dim; ++q) { + const int in_pos = + ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size; + const Scalar* in_ptr = input_data + in_pos; + if (q > sl) { + output_ptr = output_data + in_pos; + } else { + const int out_pos = + ((out_pos_base + p) * dims_at_medium_dim + sl - q) * + copy_size; + output_ptr = output_data + out_pos; + } + memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar)); + } + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index 71ae69522f9a45745a9ed9eae211db3d048ba43d..4f18f283b6094c66fb89080115d359ffce776dd8 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -55,6 +55,21 @@ void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, int n_batch, float* result, int result_stride); +// Same as the function above, but the matrix is stored in block compressed +// sparse row format with block pattern 1x16 which consists of two arrays: +// 1. A matrix array stores non-zero blocks of the matrix in row major. +// 2. A ledger array stores nrows groups, one group per row. Each group starts +// with +// an integer representing the number of non-zero blocks for the +// corresponding row and follows with column indexes of the first element +// of each non-zero block. +// This function assumes that +// 1. m_cols is a multiple of 16 so that all blocks are full blocks. +// 2. m_cols < 254 * 16 so that block index can be represented by uint8. +void SparseMatrixBatchVectorMultiplyAccumulate( + const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, + const float* vector, int n_batch, float* result, int result_stride); + // Same as the function above, but for values quantized using symmetric // quantization (e.g. by calling SymmetricQuantizeFloats). // The passed scaling factors is a buffer of the quantization scaling factors @@ -67,6 +82,23 @@ void MatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ vectors, const float* scaling_factors, int n_batch, float* __restrict__ result, int result_stride); +// Same as the function above, but the matrix is stored in block compressed +// sparse row format with block pattern 1x16 which consists of two arrays: +// 1. A matrix array stores non-zero blocks of the matrix in row major. +// 2. A ledger array stores nrows groups, one group per row. Each group starts +// with +// an integer representing the number of non-zero blocks for the +// corresponding row followed by column index of the first element of +// each non-zero block. +// This function assumes that +// 1. m_cols is a multiple of 16 so that all blocks are full blocks. +// 2. m_cols < 254 * 16 so that block index can be represented by uint8. +void SparseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, + const int m_cols, const int8_t* __restrict__ vectors, + const float* scaling_factors, int n_batch, float* __restrict__ result, + int result_stride); + // Cwise product of two vectors. void VectorVectorCwiseProduct(const float* vector1, const float* vector2, int v_size, float* result); diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc index 29866d066406e58e06e6caa2e5b410460564c966..0918c8d27727408899f13ed866dfe737202f527b 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/test_util.h" +#ifdef DOTPROD_BENCHMARKS +#include "testing/base/public/benchmark.h" +#endif // DOTPROD_BENCHMARKS + namespace tflite { namespace tensor_utils { @@ -143,12 +147,228 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) { -1., 3., 7., 3., 23., 3.}))); } +struct MatrixVectorData { + // Contains dense parameters. + std::vector matrix; + + // Like matrix, but with about half of the parameters set to zero. + // Use this to create golden output for sparse matrix tests. + std::vector zeroed_matrix; + + // zeroed_matrix described in sparse form. + std::vector sparse_matrix; + std::vector ledger; + + std::vector vectors; + std::vector scale_factors; + std::vector results; + + int rows; + int cols; + int batch; +}; + +MatrixVectorData SetupMatrixVectorData(int rows, int cols, int batch, + bool negative = false) { + MatrixVectorData data; + data.rows = rows; + data.cols = cols; + data.batch = batch; + + for (int i = 0; i < rows * cols; i++) { + int sign = 1; + if ((i % 3) == 0 && negative) sign = -1; + data.matrix.push_back(sign * (i % 70)); + } + for (int i = 0; i < cols * batch; i++) { + int sign = 1; + if ((i % 5) == 0 && negative) sign = -1; + data.vectors.push_back(sign * (i % 50)); + } + data.scale_factors = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}; + data.results.resize(rows * batch, 0); + + data.zeroed_matrix = data.matrix; + + // Make a sparsification ledger. + for (int i = 0; i < rows; i++) { + int max_chunks = cols / 16; + int selected_chunks = (max_chunks / 2); + bool row_is_odd = (i % 2) > 0; + bool max_chunks_is_odd = (max_chunks % 2) > 0; + + data.ledger.push_back(selected_chunks); + if (max_chunks_is_odd && row_is_odd) { + selected_chunks++; + } + + // In odd rows, use odd chunk indexes. + // In even rows, use even chunk indexes. + for (int j = 0; j < max_chunks; j++) { + const int chunk_start = i * cols + (j * 16); + const int chunk_end = i * cols + (j * 16) + 16; + if ((j % 2) == (i % 2)) { + // Copy this chunk into the sparse matrix. + data.ledger.push_back(j); + for (int k = chunk_start; k < chunk_end; k++) { + data.sparse_matrix.push_back(data.matrix[k]); + } + } else { + // Zero this part out of zeroed_matrix. + for (int k = chunk_start; k < chunk_end; k++) { + data.zeroed_matrix[k] = 0; + } + } + } + } + return data; +} + +std::vector TestDotprodMatrixBatchVectorMultiply(int rows, int cols, + int batch, + bool negative = false) { + MatrixVectorData data = SetupMatrixVectorData(rows, cols, batch, negative); + + // All partial sums in this computation are small enough to fit in the + // mantissa of a float, and the scale factors are all integers, so we expect + // an exact result. + MatrixBatchVectorMultiplyAccumulate( + data.matrix.data(), rows, cols, data.vectors.data(), + data.scale_factors.data(), batch, &data.results[0], 1); + return data.results; +} + +std::vector TestSparseDotprodMatrixBatchVectorMultiply( + int rows, int cols, int batch, bool negative = false) { + MatrixVectorData data = SetupMatrixVectorData(rows, cols, batch, negative); + SparseMatrixBatchVectorMultiplyAccumulate( + data.sparse_matrix.data(), data.ledger.data(), rows, cols, + data.vectors.data(), data.scale_factors.data(), batch, &data.results[0], + 1); + return data.results; +} + +TEST(uKernels, DotprodMatrixBatchVectorMultiplyAccumulateTest) { + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 16, 1), + testing::ElementsAre(1240, 3160, 5080, 7000)); + + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 32, 2), + testing::ElementsAre(10416, 26288, 8490, 23312, 18276, 70756, + 37416, 60916)); + + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 32, 3), + testing::ElementsAre(10416, 26288, 8490, 23312, 18276, 70756, + 37416, 60916, 52080, 142704, 55878, 125712)); + + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(8, 1024, 3), + testing::ElementsAreArray( + {841094, 853168, 866642, 840286, 860760, 862754, + 843678, 872552, 1724476, 1769072, 1747588, 1738844, + 1758240, 1742916, 1761612, 1755808, 2506896, 2564262, + 2629188, 2515824, 2598390, 2569236, 2537352, 2645118})); + + const bool kNegative = true; + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(4, 64, 1, kNegative), + testing::ElementsAre(13696, 6904, 7764, 11806)); + ASSERT_THAT( + TestDotprodMatrixBatchVectorMultiply(4, 32, 2, kNegative), + testing::ElementsAre(3436, 3522, 1590, 6972, 2516, 20520, 456, 10628)); +} + +TEST(uKernels, DotprodMatrixBatchFourVectorMultiplyAccumulateDotprodTest) { + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 16, 4), + testing::ElementsAreArray( + {1240, 3160, 6320, 18352, 15240, 45576, 4200, 16232})); + ASSERT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 64, 4), + testing::ElementsAreArray({45794, 38948, 88536, 84252, 157626, + 165312, 209864, 246128})); + ASSERT_THAT( + TestDotprodMatrixBatchVectorMultiply(2, 64, 8), + testing::ElementsAreArray({45794, 38948, 88536, 84252, 157626, 165312, + 209864, 246128, 219700, 195550, 279684, 278928, + 413616, 445662, 374896, 365952})); + + ASSERT_THAT( + TestDotprodMatrixBatchVectorMultiply(4, 64, 8), + testing::ElementsAreArray( + {45794, 38948, 34622, 32816, 88536, 84252, 85008, 90804, + 157626, 165312, 180558, 203364, 209864, 246128, 236472, 208896, + 219700, 195550, 184000, 185050, 279684, 278928, 293292, 322776, + 413616, 445662, 495348, 513674, 374896, 365952, 321168, 296544})); + + ASSERT_THAT( + TestDotprodMatrixBatchVectorMultiply(16, 1024, 4), + testing::ElementsAreArray( + {841094, 853168, 866642, 840286, 860760, 862754, 843678, + 872552, 837586, 851270, 877414, 834188, 863062, 857846, + 841780, 879054, 1724476, 1769072, 1747588, 1738844, 1758240, + 1742916, 1761612, 1755808, 1737684, 1750780, 1747356, 1754152, + 1748348, 1753324, 1743320, 1754316, 2506896, 2564262, 2629188, + 2515824, 2598390, 2569236, 2537352, 2645118, 2508444, 2571480, + 2610576, 2510442, 2618208, 2566584, 2544570, 2614536, 3458904, + 3502688, 3474792, 3505976, 3499360, 3488264, 3485848, 3512832, + 3500616, 3482520, 3489624, 3469008, 3495992, 3524376, 3465680, + 3526264})); + + ASSERT_THAT( + TestDotprodMatrixBatchVectorMultiply(4, 128, 4), + testing::ElementsAreArray({87920, 80024, 92288, 103712, 228148, 224820, + 233812, 213124, 271284, 271788, 332772, 328236, + 419328, 431328, 411968, 417248})); + + ASSERT_THAT( + TestDotprodMatrixBatchVectorMultiply(4, 128, 8), + testing::ElementsAreArray( + {87920, 80024, 92288, 103712, 228148, 224820, 233812, 213124, + 271284, 271788, 332772, 328236, 419328, 431328, 411968, 417248, + 482680, 523840, 560800, 593560, 563940, 609924, 566868, 644772, + 743708, 857780, 818972, 823284, 708384, 695008, 730912, 872096})); + + const bool kNegative = true; + EXPECT_THAT(TestDotprodMatrixBatchVectorMultiply(1, 16, 1, kNegative), + testing::ElementsAre(450)); + EXPECT_THAT(TestDotprodMatrixBatchVectorMultiply(2, 64, 8, kNegative), + testing::ElementsAreArray({13696, 6904, 9952, 12368, 22848, 61632, + 40424, 46776, 57630, 38670, 62976, + 49824, 39032, 71988, 60128, 148992})); + + std::vector results = + TestDotprodMatrixBatchVectorMultiply(256, 1024, 8); + int64_t sum = 0; + for (int i = 0; i < results.size(); i++) { + sum += static_cast(results[i]); + } + EXPECT_EQ(7980076336, sum); +} + +TEST(uKernels, DotprodSparseMatrixBatchVectorMultiplyAccumulate) { + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 16, 1), + testing::ElementsAre(0)); + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 32, 1), + testing::ElementsAre(1240)); + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 1), + testing::ElementsAre(26544)); + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 2), + testing::ElementsAre(26544, 24344)); + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(4, 64, 4), + testing::ElementsAreArray( + {26544, 15866, 22140, 11408, 24344, 53248, 42704, 39900, + 48000, 94146, 101892, 81876, 87712, 105160, 148304, 75936})); + + const bool kNegative = true; + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(1, 64, 1, kNegative), + testing::ElementsAre(8764)); + EXPECT_THAT(TestSparseDotprodMatrixBatchVectorMultiply(2, 64, 2, kNegative), + testing::ElementsAre(8764, 5196, 7204, 11148)); +} + #ifdef __ANDROID__ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { // Note we use 29 columns as this exercises all the neon kernel: the // 16-block SIMD code, the 8-block postamble, and the leftover postamble. const int a_rows = 4, a_cols = 29; const int kWeightsPerUint32 = 4; + /* clang-format off */ const float a_float_data[] = { /* 1st row */ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, @@ -174,126 +394,18 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { SymmetricQuantizeFloats(a_float_data, a_rows * a_cols, a_int8_data, &a_min, &a_max, &scaling_factor_a); const int8_t expected_a_int8_data[] = { - /* 1st row */ - 5, - 10, - 15, - 20, - 25, - 30, - 35, - 40, - 44, - 45, - 50, - 54, - 59, - 64, - 68, - 73, - 77, - 82, - 86, - 91, - 95, - 100, - 104, - 109, - 113, - 118, - 122, - 127, - 0, - /* 2nd row */ - -5, - -10, - -15, - -20, - -25, - -30, - -35, - -40, - -44, - -45, - -50, - -54, - -59, - -64, - -68, - -73, - -77, - -82, - -86, - -91, - -95, - -100, - -104, - -109, - -113, - -118, - -122, - -127, - 0, - /* 3rd row */ - 5, - -10, - 15, - -20, - 25, - -30, - 35, - -40, - 44, - -45, - 50, - -54, - 59, - -64, - 68, - -73, - 77, - -82, - 86, - -91, - 95, - -100, - 104, - -109, - 113, - -118, - 122, - -127, - 0, - /* 4th row */ - -5, - 10, - -15, - 20, - -25, - 30, - -35, - 40, - -44, - 45, - -50, - 54, - -59, - 64, - -68, - 73, - -77, - 82, - -86, - 91, - -95, - 100, - -104, - 109, - -113, - 118, - -122, - 127, - 0, + /* 1st row */ + 5, 10, 15, 20, 25, 30, 35, 40, 44, 45, 50, 54, 59, 64, 68, 73, 77, 82, 86, + 91, 95, 100, 104, 109, 113, 118, 122, 127, 0, + /* 2nd row */ + -5, -10, -15, -20, -25, -30, -35, -40, -44, -45, -50, -54, -59, -64, -68, + -73, -77, -82, -86, -91, -95, -100, -104, -109, -113, -118, -122, -127, 0, + /* 3rd row */ + 5, -10, 15, -20, 25, -30, 35, -40, 44, -45, 50, -54, 59, -64, 68, -73, 77, + -82, 86, -91, 95, -100, 104, -109, 113, -118, 122, -127, 0, + /* 4th row */ + -5, 10, -15, 20, -25, 30, -35, 40, -44, 45, -50, 54, -59, 64, -68, 73, -77, + 82, -86, 91, -95, 100, -104, 109, -113, 118, -122, 127, 0, }; for (int i = 0; i < a_rows * a_cols; ++i) { EXPECT_EQ(expected_a_int8_data[i], a_int8_data[i]); @@ -301,66 +413,14 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { const int b_rows = 29, b_cols = 1, batches = 2; const float b_float_data[] = { - /* batch 1 */ - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - /* batch 2 */ - 2.5, - -2.1, - 3.0, - -1.3, - 1.3, - -1.1, - 2.0, - -1.7, - 1.9, - -1.5, - 0.5, - -0.7, - 0.8, - -0.3, - 2.8, - -2.8, - 1.1, - -2.3, - 1.9, - -1.9, - 2.1, - -0.5, - 2.4, - -0.1, - 1.0, - -2.5, - 0.7, - -1.9, - 0.2, + /* batch 1 */ + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, + /* batch 2 */ + 2.5, -2.1, 3.0, -1.3, 1.3, -1.1, 2.0, -1.7, 1.9, -1.5, 0.5, -0.7, 0.8, -0.3, + 2.8, -2.8, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, 1.0, -2.5, 0.7, -1.9, + 0.2, }; // Quantized values of B: @@ -374,67 +434,15 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { &scaling_factor_b[1]); const int8_t expected_b_int8_data[] = { - /* batch 1 */ - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - -127, - 127, - /* batch 2 */ - 106, - -89, - 127, - -55, - 55, - -47, - 85, - -72, - 80, - -64, - 21, - -30, - 34, - -13, - 119, - -119, - 47, - -97, - 80, - -80, - 89, - -21, - 102, - -4, - 42, - -106, - 30, - -80, - 8, + /* batch 1 */ + 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, + 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, + 127, + /* batch 2 */ + 106, -89, 127, -55, 55, -47, 85, -72, 80, -64, 21, -30, 34, -13, 119, -119, + 47, -97, 80, -80, 89, -21, 102, -4, 42, -106, 30, -80, 8, }; + /* clang-format on */ for (int i = 0; i < b_rows * b_cols * batches; ++i) { EXPECT_EQ(expected_b_int8_data[i], b_int8_data[i]); } @@ -468,6 +476,176 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { } #endif // __ANDROID__ +TEST(uKernels, SparseMatrixBatchVectorMultiplyAccumulateTest) { + const int kRow = 4; + const int kCol = 48; + const int kBatch = 2; + /* clang-format off */ + float matrix[kRow * kCol] = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, + 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0, 0, 0, 0, + /* 2nd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, + -25.25, -26.26, -27.27, -28.28, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0, + /* 3rd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, + -26.26, 27.27, -28.28, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37, + 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0, 0, 0, 0}; + + // BCSR format of the above matrix. + float matrix_values[] = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, 39.39, + 40.40, 41.41, 42.42, 43.43, 44.44, 0, 0, 0, 0, + /* 2nd row */ + -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, -25.25, + -26.26, -27.27, -28.28, 0, 0.0, 0.0, 0.0, + /* 3rd row */ + 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, -26.26, + 27.27, -28.28, 0, 0.0, 0.0, 0.0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, -33.33, 34.34, -35.35, 36.36, -37.37, 38.38, + -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0, 0, 0, 0}; + uint8_t ledger[] = { + 2, 0, 2, // 1st row + 1, 1, // 2nd row + 1, 1, // 3rd row + 2, 0, 2 // 4th row + }; + + float vector[kBatch * kCol] = { + /* 1st batch */ + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + /* 2nd batch */ + 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0, + -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, + 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, 1.0, -2.5, + 0.7, -1.9, 0.2, 0.0, 0.1, 0.2, + }; + /* clang-format on */ + + std::vector dense_output(kRow * kBatch, 0.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + dense_output.data(), /*result_stride=*/1); + + EXPECT_THAT(dense_output, ElementsAreArray(ArrayFloatNear( + {-13.69, 6.06001, 272.7, -608.03, -9.66602, + -10.201, 10.201, -713.897949}, + 1e-4))); + + std::vector sparse_output(kRow * kBatch, 0.0); + SparseMatrixBatchVectorMultiplyAccumulate( + matrix_values, ledger, kRow, kCol, vector, kBatch, sparse_output.data(), + /*result_stride=*/1); + + EXPECT_THAT(sparse_output, + ElementsAreArray(ArrayFloatNear(dense_output, 1e-4))); +} + +#ifdef __ANDROID__ +TEST(uKernels, + SparseMatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { + const int kRow = 4; + const int kCol = 48; + const int kBatch = 2; + /* clang-format off */ + const int8_t quantized_matrix[] = { + /* 1st row */ + 3, 6, 9, 13, 16, 19, 22, 25, 28, 29, 32, 35, 38, 40, 43, 46, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 95, 98, 101, 104, 107, 110, 113, 115, + 118, 121, 124, 127, 0, 0, 0, 0, + /* 2nd row */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -49, -52, -55, -58, -61, + -64, -66, -69, -72, -75, -78, -81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + /* 3rd row */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, -52, 55, -58, 61, -64, + 66, -69, 72, -75, 78, -81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + /* 4th row */ + -3, 6, -9, 13, -16, 19, -22, 25, -28, 29, -32, 35, -38, 40, -43, 46, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -95, 98, -101, 104, -107, 110, + -113, 115, -118, 121, -124, 127, 0, 0, 0, 0, + }; + const int8_t quantized_matrix_values[] = { + /* 1st row */ + 3, 6, 9, 13, 16, 19, 22, 25, 28, 29, 32, 35, 38, 40, 43, 46, 95, 98, 101, + 104, 107, 110, 113, 115, 118, 121, 124, 127, 0, 0, 0, 0, + /* 2nd row */ + -49, -52, -55, -58, -61, -64, -66, -69, -72, -75, -78, -81, 0, 0, 0, 0, + /* 3rd row */ + 49, -52, 55, -58, 61, -64, 66, -69, 72, -75, 78, -81, 0, 0, 0, 0, + /* 4th row */ + -3, 6, -9, 13, -16, 19, -22, 25, -28, 29, -32, 35, -38, 40, -43, 46, -95, + 98, -101, 104, -107, 110, -113, 115, -118, 121, -124, 127, 0, 0, 0, 0, + }; + uint8_t ledger[] = { + 2, 0, 2, // 1st row + 1, 1, // 2nd row + 1, 1, // 3rd row + 2, 0, 2 // 4th row + }; + + float matrix_scaling_factor = 0.349921; + + const int8_t quantized_vector[] = { + /* 1st batch */ + 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, + -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, + 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, -127, 127, + -127, 127, -127, 127, -127, 127, -127, 127, -127, + /* 2nd batch */ + 106, 0, -89, 0, 127, 0, -55, 0, 55, 0, -47, 0, 85, 0, -72, 0, 80, 0, + -64, 0, 21, 0, -30, 0, 34, 0, -13, 0, 119, 0, -119, 0, 47, -97, 80, -80, + 89, -21, 102, -4, 42, -106, 30, -80, 8, 1, 2, 3, + }; + float vector_scaling_factor[2] = {0.00787402, 0.023622}; + + /* clang-format on */ + float result_scaling_factor[2] = { + matrix_scaling_factor * vector_scaling_factor[0], + matrix_scaling_factor * vector_scaling_factor[1], + }; + std::vector dense_output(kRow * kBatch, 0.0); + MatrixBatchVectorMultiplyAccumulate(quantized_matrix, kRow, kCol, + quantized_vector, result_scaling_factor, + kBatch, dense_output.data(), + /*result_stride=*/1); + + EXPECT_THAT(dense_output, + ElementsAreArray(ArrayFloatNear( + {-13.646927, 6.298582, 272.938538, -607.813110, -6.637464, + -9.381721, 9.381721, -713.845642}))); + + std::vector sparse_output(kRow * kBatch, 0.0); + SparseMatrixBatchVectorMultiplyAccumulate( + quantized_matrix_values, ledger, kRow, kCol, quantized_vector, + result_scaling_factor, kBatch, sparse_output.data(), + /*result_stride=*/1); + + EXPECT_THAT(sparse_output, + ElementsAreArray(ArrayFloatNear( + {-13.646927, 6.298582, 272.938538, -607.813110, -6.637464, + -9.381721, 9.381721, -713.845642}))); +} +#endif // __ANDROID__ + TEST(uKernels, VectorVectorCwiseProductTest) { constexpr int kVectorSize = 10; static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, @@ -804,3 +982,109 @@ TEST(uKernels, MeanStddevNormalizationSmallValue) { } // namespace tensor_utils } // namespace tflite + +#ifdef DOTPROD_BENCHMARKS + +// Compile with --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" and +// --copt="-DDOTPROD_BENCHMARKS" +// Run with --benchmarks=all +void BM_DotprodBatchOneMultiply(benchmark::State& state) { + const int rows = state.range(0); + const int cols = state.range(1); + const int batch = state.range(2); + + tflite::tensor_utils::MatrixVectorData data = + tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch); + for (auto _ : state) { + for (int i = 0; i < batch; i++) { + tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( + data.matrix.data(), data.rows, data.cols, + data.vectors.data() + (data.cols * i), data.scale_factors.data(), 1, + &data.results[0], 1); + testing::DoNotOptimize(data.results[2]); + } + } +} +BENCHMARK(BM_DotprodBatchOneMultiply) + ->Args({16, 16, 1}) + ->Args({16, 16, 4}) + ->Args({32, 32, 1}) + ->Args({32, 32, 4}) + ->Args({64, 64, 1}) + ->Args({64, 64, 4}) + ->Args({128, 128, 1}) + ->Args({128, 128, 4}) + ->Args({992, 992, 1}) + ->Args({992, 992, 8}) + ->Args({1024, 1024, 1}) + ->Args({1024, 1024, 4}) + ->Args({1024, 1024, 8}) + ->Args({640, 2048, 1}) + ->Args({640, 2048, 4}) + ->Args({640, 2048, 8}) + ->Args({2048, 2048, 1}) + ->Args({2048, 2048, 8}); + +void BM_DotprodBatchFourMultiply(benchmark::State& state) { + const int rows = state.range(0); + const int cols = state.range(1); + const int batch = state.range(2); + + tflite::tensor_utils::MatrixVectorData data = + tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch); + for (auto _ : state) { + tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( + data.matrix.data(), data.rows, data.cols, data.vectors.data(), + data.scale_factors.data(), data.batch, &data.results[0], 1); + testing::DoNotOptimize(data.results[2]); + } +} +BENCHMARK(BM_DotprodBatchFourMultiply) + ->Args({16, 16, 4}) + ->Args({32, 32, 4}) + ->Args({64, 64, 4}) + ->Args({64, 256, 64}) + ->Args({64, 256, 256}) + ->Args({64, 256, 1024}) + ->Args({64, 256, 12544}) + ->Args({128, 128, 4}) + ->Args({640, 640, 4}) + ->Args({992, 992, 8}) + ->Args({1024, 1024, 4}) + ->Args({1024, 1024, 8}) + ->Args({1024, 1024, 256}) + ->Args({640, 2048, 4}) + ->Args({640, 2048, 8}) + ->Args({2048, 2048, 4}) + ->Args({2048, 2048, 8}); + +void BM_DotprodSparseMultiply(benchmark::State& state) { + const int rows = state.range(0); + const int cols = state.range(1); + const int batch = state.range(2); + + tflite::tensor_utils::MatrixVectorData data = + tflite::tensor_utils::SetupMatrixVectorData(rows, cols, batch); + for (auto _ : state) { + tflite::tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + data.sparse_matrix.data(), data.ledger.data(), data.rows, data.cols, + data.vectors.data(), data.scale_factors.data(), data.batch, + &data.results[0], 1); + testing::DoNotOptimize(data.results[2]); + } +} +BENCHMARK(BM_DotprodSparseMultiply) + ->Args({128, 128, 1}) + ->Args({128, 128, 4}) + ->Args({640, 640, 4}) + ->Args({992, 992, 8}) + ->Args({1024, 1024, 1}) + ->Args({1024, 1024, 4}) + ->Args({1024, 1024, 8}) + ->Args({640, 2048, 1}) + ->Args({640, 2048, 4}) + ->Args({640, 2048, 8}) + ->Args({2048, 2048, 1}) + ->Args({2048, 2048, 8}); + +#endif // DOTPROD_BENCHMARKS diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 57f4bfa9fa29ca39aa2506a08870ef6b2d61ab09..deb484b70f029d86e85495e4f6d8ad62efcf3ed4 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -18,10 +18,91 @@ limitations under the License. #include #include +#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/round.h" namespace tflite { +void GuardedQuantizeMultiplier(double effective_output_scale, + int32_t* significand, int* shift) { + QuantizeMultiplier(effective_output_scale, significand, shift); + // Additional guard to make sure RoundingDivideByPOT does not fail. + if (*shift < -31) { + // If shift is less than -31, RoundingDivideByPOT fails. This happens when + // min and max are close and small. For this particular case, both + // significand and shift are set to zero. + *significand = 0; + *shift = 0; + } +} + +TfLiteStatus PopulateConvolutionQuantizationParams( + TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, + const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift, + int32_t* output_activation_min, int32_t* output_activation_max, + int32_t* per_channel_multiplier, int* per_channel_shift) { + TF_LITE_ENSURE_EQ(context, input->quantization.type, + kTfLiteAffineQuantization); + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + // TODO(jianlijianli): Enable bias type check and bias scale == input scale + // * filter scale for each channel in affine quantization once bias + // quantization is properly populated. + // TF_LITE_ENSURE_EQ(context, bias->quantization.type, + // kTfLiteAffineQuantization); + + // Check data type. + const auto* affine_quantization = + reinterpret_cast(filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + const bool is_per_channel = affine_quantization->scale->size > 1; + if (is_per_channel) { + // Currently only Int8 is supported for per channel quantization. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ( + context, affine_quantization->scale->size, + filter->dims->data[affine_quantization->quantized_dimension]); + } + + // Populate multiplier and shift using affine quantization. + const int num_channels = affine_quantization->scale->size; + const float input_scale = input->params.scale; + const float output_scale = output->params.scale; + const float* filter_scales = affine_quantization->scale->data; + for (int i = 0; i < num_channels; ++i) { + const double filter_scale = static_cast(filter_scales[i]); + const double effective_output_scale = static_cast(input_scale) * + filter_scale / + static_cast(output_scale); + int32_t significand; + int shift; + GuardedQuantizeMultiplier(effective_output_scale, &significand, &shift); + per_channel_multiplier[i] = significand; + per_channel_shift[i] = shift; + } + + // Populate scalar quantization parameters. + // This check on legacy quantization parameters is kept only for backward + // compatibility. + if (input->type == kTfLiteUInt8) { + // Check bias scale == input scale * filter scale. + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + + // Populate quantization parameteters with multiplier and shift. + QuantizeMultiplier(real_multiplier, multiplier, &exponent); + *shift = -exponent; + CalculateActivationRangeUint8(activation, output, output_activation_min, + output_activation_max); + } + return kTfLiteOk; +} + TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, const TfLiteTensor* input, const TfLiteTensor* filter, @@ -81,6 +162,9 @@ TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, if (output->type == kTfLiteUInt8) { qmin = std::numeric_limits::min(); qmax = std::numeric_limits::max(); + } else if (output->type == kTfLiteInt8) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); } else if (output->type == kTfLiteInt16) { qmin = std::numeric_limits::min(); qmax = std::numeric_limits::max(); diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 4cfc885f8939481f1515b445dfc9e261a4e79ed9..423832c047c6f1cf3b8427f2b4eb8fec3d70d2cb 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -84,6 +84,18 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) { } } +// Check dimensionality match and populate OpData for Conv and DepthwiseConv. +TfLiteStatus PopulateConvolutionQuantizationParams( + TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, + const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift, + int32_t* output_activation_min, int32_t* output_activation_max, + int32_t* per_channel_multiplier, int* per_channel_shift); + +// QuantizedMultiplier with the guard that shift will not be smaller than -31. +void GuardedQuantizeMultiplier(double effective_output_scale, + int32_t* significand, int* shift); + // Calculates the multiplication factor for a quantized convolution (or // quantized depthwise convolution) involving the given tensors. Returns an // error if the scales of the tensors are not compatible. diff --git a/tensorflow/lite/kernels/kernel_util_test.cc b/tensorflow/lite/kernels/kernel_util_test.cc index 70eb18365891097686d579bde4a5457703e84aee..a31befbcd16a29cf4eb5d45602c3882138f54d09 100644 --- a/tensorflow/lite/kernels/kernel_util_test.cc +++ b/tensorflow/lite/kernels/kernel_util_test.cc @@ -28,6 +28,8 @@ class KernelUtilTest : public ::testing::Test { KernelUtilTest() { context_.ReportError = ReportError; + memset(&tensor1_, 0, sizeof(TfLiteTensor)); + memset(&tensor2_, 0, sizeof(TfLiteTensor)); tensor1_.dims = nullptr; tensor2_.dims = nullptr; tensor1_.allocation_type = kTfLiteMmapRo; @@ -142,6 +144,222 @@ TEST_F(KernelUtilTest, BroadcastShapeDifferentSizes) { TfLiteIntArrayFree(output); } +// TODO(jianlijianli): Add more test cases. +TEST_F(KernelUtilTest, CheckAndPopulate) { + // Create input. + TfLiteTensor input; + input.type = kTfLiteInt8; + input.allocation_type = kTfLiteArenaRw; + input.dims = TfLiteIntArrayCreate(1); + input.dims->data[0] = 2; + TfLiteQuantizationParams input_quant = {0.5, 5}; + input.params = input_quant; + input.quantization.type = kTfLiteAffineQuantization; + auto* input_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + input_params->scale = TfLiteFloatArrayCreate(1); + input_params->scale->data[0] = 0.5; + input_params->zero_point = TfLiteIntArrayCreate(1); + input_params->zero_point->data[0] = 5; + input.quantization.params = reinterpret_cast(input_params); + + // Create filter. + TfLiteTensor filter; + filter.type = kTfLiteInt8; + filter.allocation_type = kTfLiteArenaRw; + filter.dims = TfLiteIntArrayCreate(4); + filter.dims->data[0] = 3; + filter.dims->data[1] = 4; + filter.dims->data[2] = 5; + filter.dims->data[3] = 6; + TfLiteQuantizationParams filter_quant = {0.25, 0}; + filter.params = filter_quant; + filter.quantization.type = kTfLiteAffineQuantization; + auto* filter_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + filter_params->scale = TfLiteFloatArrayCreate(3); + filter_params->scale->data[0] = 0.25; + filter_params->scale->data[1] = 0.125; + filter_params->scale->data[2] = 0.25; + filter_params->zero_point = TfLiteIntArrayCreate(3); + filter_params->zero_point->data[0] = 0; + filter_params->zero_point->data[1] = 0; + filter_params->zero_point->data[2] = 0; + filter_params->quantized_dimension = 0; + filter.quantization.params = reinterpret_cast(filter_params); + + // Create bias. + TfLiteTensor bias; + bias.type = kTfLiteInt32; + bias.allocation_type = kTfLiteArenaRw; + bias.dims = TfLiteIntArrayCreate(4); + TfLiteQuantizationParams bias_quant = {0.125, 9}; + bias.params = bias_quant; + bias.quantization.type = kTfLiteAffineQuantization; + auto* bias_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + bias_params->scale = TfLiteFloatArrayCreate(3); + bias_params->scale->data[0] = 0.125; + bias_params->scale->data[1] = 0.0625; + bias_params->scale->data[2] = 0.125; + bias_params->zero_point = TfLiteIntArrayCreate(3); + bias_params->zero_point->data[0] = 11; + bias_params->zero_point->data[1] = 12; + bias_params->zero_point->data[2] = 15; + bias.quantization.params = reinterpret_cast(bias_params); + + // Create output. + TfLiteTensor output; + output.type = kTfLiteInt8; + output.allocation_type = kTfLiteArenaRw; + output.dims = nullptr; + TfLiteQuantizationParams output_quant = {0.5, -128}; + output.params = output_quant; + output.quantization.type = kTfLiteAffineQuantization; + auto* output_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + output_params->scale = TfLiteFloatArrayCreate(1); + output_params->scale->data[0] = 0.5; + output_params->zero_point = TfLiteIntArrayCreate(1); + output_params->zero_point->data[0] = -128; + output.quantization.params = reinterpret_cast(output_params); + + // Create call parameters. + TfLiteContext context; + int32_t multiplier; + int shift; + int32_t output_activation_min; + int32_t output_activation_max; + std::vector per_channel_multiplier(3); + std::vector per_channel_shift(3); + + // Call and verify results for per channel case. + EXPECT_EQ( + kTfLiteOk, + PopulateConvolutionQuantizationParams( + &context, &input, &filter, &bias, &output, kTfLiteActRelu, + &multiplier, &shift, &output_activation_min, &output_activation_max, + per_channel_multiplier.data(), per_channel_shift.data())); + EXPECT_THAT(per_channel_multiplier, + ::testing::ElementsAre(1073741824, 1073741824, 1073741824)); + EXPECT_THAT(per_channel_shift, ::testing::ElementsAre(-1, -2, -1)); + + // Release. + TfLiteTensorFree(&input); + TfLiteTensorFree(&filter); + TfLiteTensorFree(&bias); + TfLiteTensorFree(&output); +} + +TEST_F(KernelUtilTest, CheckAndPopulateZeroValue) { + // Create input. + TfLiteTensor input; + input.type = kTfLiteInt8; + input.allocation_type = kTfLiteArenaRw; + input.dims = TfLiteIntArrayCreate(1); + input.dims->data[0] = 2; + TfLiteQuantizationParams input_quant = {1, 5}; + input.params = input_quant; + input.quantization.type = kTfLiteAffineQuantization; + auto* input_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + input_params->scale = TfLiteFloatArrayCreate(1); + input_params->scale->data[0] = 1; + input_params->zero_point = TfLiteIntArrayCreate(1); + input_params->zero_point->data[0] = 5; + input.quantization.params = reinterpret_cast(input_params); + + // Create filter. + TfLiteTensor filter; + filter.type = kTfLiteInt8; + filter.allocation_type = kTfLiteArenaRw; + filter.dims = TfLiteIntArrayCreate(4); + filter.dims->data[0] = 3; + filter.dims->data[1] = 4; + filter.dims->data[2] = 5; + filter.dims->data[3] = 6; + TfLiteQuantizationParams filter_quant = {4.6566129e-10, 0}; + filter.params = filter_quant; + filter.quantization.type = kTfLiteAffineQuantization; + auto* filter_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + filter_params->scale = TfLiteFloatArrayCreate(3); + int32_t two_pow_neg_31 = 0x30000000; // 2^-31 so shift = -30. + int32_t two_pow_neg_32 = 0x2F800000; // 2^-32 so shift = -31. + int32_t two_pow_neg_33 = 0x2F000000; // 2^-33 so shift = -32. + filter_params->scale->data[0] = *reinterpret_cast(&two_pow_neg_31); + filter_params->scale->data[1] = *reinterpret_cast(&two_pow_neg_32); + filter_params->scale->data[2] = *reinterpret_cast(&two_pow_neg_33); + filter_params->zero_point = TfLiteIntArrayCreate(3); + filter_params->zero_point->data[0] = 0; + filter_params->zero_point->data[1] = 0; + filter_params->zero_point->data[2] = 0; + filter_params->quantized_dimension = 0; + filter.quantization.params = reinterpret_cast(filter_params); + + // Create bias. + TfLiteTensor bias; + bias.type = kTfLiteInt32; + bias.allocation_type = kTfLiteArenaRw; + bias.dims = TfLiteIntArrayCreate(4); + TfLiteQuantizationParams bias_quant = {4.6566129e-10, 9}; + bias.params = bias_quant; + bias.quantization.type = kTfLiteAffineQuantization; + auto* bias_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + bias_params->scale = TfLiteFloatArrayCreate(3); + bias_params->scale->data[0] = 4.6566129e-10; // 2^-31 + bias_params->scale->data[1] = 2.3283064e-10; // 2^-32 + bias_params->scale->data[2] = 1.1641532e-10; // 2^-33 + bias_params->zero_point = TfLiteIntArrayCreate(3); + bias_params->zero_point->data[0] = 11; + bias_params->zero_point->data[1] = 12; + bias_params->zero_point->data[2] = 15; + bias.quantization.params = reinterpret_cast(bias_params); + + // Create output. + TfLiteTensor output; + output.type = kTfLiteInt8; + output.allocation_type = kTfLiteArenaRw; + output.dims = nullptr; + TfLiteQuantizationParams output_quant = {1, -128}; + output.params = output_quant; + output.quantization.type = kTfLiteAffineQuantization; + auto* output_params = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + output_params->scale = TfLiteFloatArrayCreate(1); + output_params->scale->data[0] = 1; + output_params->zero_point = TfLiteIntArrayCreate(1); + output_params->zero_point->data[0] = -128; + output.quantization.params = reinterpret_cast(output_params); + + // Create call parameters. + TfLiteContext context; + int32_t multiplier; + int shift; + int32_t output_activation_min; + int32_t output_activation_max; + std::vector per_channel_multiplier(3); + std::vector per_channel_shift(3); + + // Call and verify results for per channel case. + EXPECT_EQ( + kTfLiteOk, + PopulateConvolutionQuantizationParams( + &context, &input, &filter, &bias, &output, kTfLiteActRelu, + &multiplier, &shift, &output_activation_min, &output_activation_max, + per_channel_multiplier.data(), per_channel_shift.data())); + EXPECT_THAT(per_channel_multiplier, + ::testing::ElementsAre(1073741824, 1073741824, 0)); + EXPECT_THAT(per_channel_shift, ::testing::ElementsAre(-30, -31, 0)); + + // Release. + TfLiteTensorFree(&input); + TfLiteTensorFree(&filter); + TfLiteTensorFree(&bias); + TfLiteTensorFree(&output); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/l2norm.cc b/tensorflow/lite/kernels/l2norm.cc index 19a4824e9398decec862bb7f5d20ac05b2652226..5eeda0858f4781bf9c47e3eab80c093bae1bcf70 100644 --- a/tensorflow/lite/kernels/l2norm.cc +++ b/tensorflow/lite/kernels/l2norm.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -45,13 +46,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumDimensions(input) <= 4); - TF_LITE_ENSURE( - context, output->type == kTfLiteFloat32 || output->type == kTfLiteUInt8); + TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 || + output->type == kTfLiteUInt8 || + output->type == kTfLiteInt8); TF_LITE_ENSURE_EQ(context, input->type, output->type); - if (output->type == kTfLiteUInt8) { + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.)); - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128); + if (output->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128); + } + if (output->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } } // TODO(ahentz): For some reason our implementations don't support @@ -97,6 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_L2NORM(optimized_ops); } #undef TF_LITE_L2NORM + } else if (output->type == kTfLiteInt8) { + const auto input_shape = GetTensorShape(input); + const auto output_shape = GetTensorShape(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + reference_integer_ops::L2Normalization(input->params.zero_point, outer_size, + depth, GetTensorData(input), + GetTensorData(output)); } else { context->ReportError(context, "Output type is %d, requires float.", output->type); diff --git a/tensorflow/lite/kernels/l2norm_test.cc b/tensorflow/lite/kernels/l2norm_test.cc index 50108a5a264c3624bbd9c230f50c65f5897480bb..4cd63155b95b7dfa412a77cf85d7eb33b4a68e23 100644 --- a/tensorflow/lite/kernels/l2norm_test.cc +++ b/tensorflow/lite/kernels/l2norm_test.cc @@ -55,9 +55,10 @@ class L2NormOpModel : public SingleOpModel { return ExtractVector(output_); } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } int input() const { return input_; } @@ -100,7 +101,20 @@ TEST(L2NormOpTest, SimpleUint8Test) { m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray({58, 166, 173, 205, 83, 134})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); +} + +TEST(L2NormOpTest, SimpleInt8Test) { + L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-70, 38, 45, 77, -45, 6})); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); } @@ -121,7 +135,32 @@ TEST(L2NormOpTest, MultipleBatchUint8Test) { 58, 166, 173, 205, 83, 134, // batch 2 58, 166, 173, 205, 83, 134, // batch 3 })); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + +TEST(L2NormOpTest, MultipleBatchInt8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + -70, 38, 45, 77, -45, 6, // batch 1 + -70, 38, 45, 77, -45, 6, // batch 2 + -70, 38, 45, 77, -45, 6, // batch 3 + })); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 diff --git a/tensorflow/lite/kernels/layer_norm_lstm.cc b/tensorflow/lite/kernels/layer_norm_lstm.cc deleted file mode 100644 index ce0c21dfcba770b72f144c272d7ab12b2e77e399..0000000000000000000000000000000000000000 --- a/tensorflow/lite/kernels/layer_norm_lstm.cc +++ /dev/null @@ -1,1324 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// DEPRECATED: Tensorflow Lite has implemented layer norm lstm as builtin Op and -// the implementation of layer norm lstm as custom Op in this file is -// deprecated. It is only kept for backward compatibility. -// -// Layer Normalization LSTM op that applies normalization by mean and standard -// deviation to the activation of the LSTM layers. Please see -// https://arxiv.org/abs/1607.06450 for details. -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/lite/context.h" -#include "tensorflow/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/lite/kernels/kernel_util.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace layer_norm_lstm { - -// Struct to hold Layer Norm LSTM option data. -struct OpData { - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - int scratch_tensor_index; -}; - -// Input Tensors of size {n_batch, n_input} -constexpr int kInputTensor = 0; - -// Input weight tensors of size: {n_cell, n_input} -constexpr int kInputToInputWeightsTensor = 1; // Optional -constexpr int kInputToForgetWeightsTensor = 2; -constexpr int kInputToCellWeightsTensor = 3; -constexpr int kInputToOutputWeightsTensor = 4; - -// Recurrent weight tensors of size {n_cell, n_output} -constexpr int kRecurrentToInputWeightsTensor = 5; // Optional -constexpr int kRecurrentToForgetWeightsTensor = 6; -constexpr int kRecurrentToCellWeightsTensor = 7; -constexpr int kRecurrentToOutputWeightsTensor = 8; - -// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. -constexpr int kCellToInputWeightsTensor = 9; // Optional -constexpr int kCellToForgetWeightsTensor = 10; // Optional -constexpr int kCellToOutputWeightsTensor = 11; // Optional - -// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. -constexpr int kInputLayerNormWeightsTensor = 12; // Optional -constexpr int kForgetLayerNormWeightsTensor = 13; -constexpr int kCellLayerNormWeightsTensor = 14; -constexpr int kOutputLayerNormWeightsTensor = 15; - -// Gates bias tensors of size {n_cell} -constexpr int kInputGateBiasTensor = 16; // Optional -constexpr int kForgetGateBiasTensor = 17; -constexpr int kCellGateBiasTensor = 18; -constexpr int kOutputGateBiasTensor = 19; - -// Projection weight tensor of size {n_output, n_cell} -constexpr int kProjectionWeightsTensor = 20; // Optional -// Projection bias tensor of size {n_output} -constexpr int kProjectionBiasTensor = 21; // Optional - -// State tensors. -constexpr int kInputActivationStateTensor = 22; -constexpr int kInputCellStateTensor = 23; - -// Output tensor. -constexpr int kOutputTensor = 0; - -// Total number of scratch tensors for hybrid Op. -constexpr int kTensorsToAdd = 7; - -// Small float to avoid divergence during calculation of deviation. -const float kLayerNormEpsilon = 1e-8; - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* data = new OpData; - - // Turn custom option data into flexbuffer map format. - const uint8_t* buffer_t = reinterpret_cast(buffer); - const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); - - // Get activation function, cell_clip and proj_clip from the flexbuffer. - // TODO(b/113824099): make activation more generic. - assert(m["fused_activation_function"].ToString() == "TANH"); - data->activation = kTfLiteActTanh; - data->cell_clip = m["cell_clip"].AsFloat(); - data->proj_clip = m["proj_clip"].AsFloat(); - - // Populate scratch_tensor_index. - context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd, - &data->scratch_tensor_index); - return data; -} - -// Check that input tensor dimensions matches with each other. -TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, - TfLiteNode* node, int n_input, - int n_output, int n_cell) { - const OpData* op_data = reinterpret_cast(node->user_data); - - // Making sure clipping parameters have valid values. - // == 0 means no clipping - // > 0 means clipping - TF_LITE_ENSURE(context, op_data->cell_clip >= 0); - TF_LITE_ENSURE(context, op_data->proj_clip >= 0); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const bool use_cifg = (input_to_input_weights == nullptr); - if (!use_cifg) { - TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); - TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); - } - - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); - TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); - - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); - TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - if (use_cifg) { - TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights, nullptr); - } else { - TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], - n_cell); - TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], - n_output); - } - - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], - n_cell); - TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], - n_output); - - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); - TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], - n_output); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - if (cell_to_input_weights) { - TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); - } - - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - if (cell_to_forget_weights) { - TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); - } - - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - if (cell_to_output_weights) { - TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); - } - - // Making sure the peephole weights are there all or none. - const bool peephole_weights_all_or_none = - ((cell_to_input_weights != nullptr || use_cifg) && - (cell_to_forget_weights != nullptr) && - (cell_to_output_weights != nullptr)) || - ((cell_to_input_weights == nullptr) && - (cell_to_forget_weights == nullptr) && - (cell_to_output_weights == nullptr)); - TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); - - // Making sure layer norm weights are not null and have the right dimension. - const TfLiteTensor* input_layer_norm_weights = - GetOptionalInputTensor(context, node, kInputLayerNormWeightsTensor); - if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_layer_norm_weights, nullptr); - } else { - TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr); - TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell); - } - - const TfLiteTensor* forget_layer_norm_weights = - GetInput(context, node, kForgetLayerNormWeightsTensor); - TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr); - TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell); - - const TfLiteTensor* cell_layer_norm_weights = - GetInput(context, node, kCellLayerNormWeightsTensor); - TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr); - TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell); - - const TfLiteTensor* output_layer_norm_weights = - GetInput(context, node, kOutputLayerNormWeightsTensor); - TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr); - TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1); - TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell); - - // Make sure the input gate bias is present only when not a CIFG-LSTM. - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); - } else { - TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); - } - - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); - - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); - - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - if (projection_weights != nullptr) { - TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); - TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); - } - - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - if (projection_bias != nullptr) { - TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); - TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); - } - - // Making sure the projection tensors are consistent: - // 1) If projection weight is not present, then projection bias should not be - // present. - // 2) If projection weight is present, then projection bias is optional. - const bool projection_tensors_consistent = - ((projection_weights != nullptr) || (projection_bias == nullptr)); - TF_LITE_ENSURE(context, projection_tensors_consistent == true); - - return kTfLiteOk; -} - -// Resize the output, state tensors based on the sizes of the input tensors. -// Allocate a temporary scratch tensor. Also check that the sizes of the input -// tensors match each other. -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - OpData* op_data = reinterpret_cast(node->user_data); - TF_LITE_ENSURE_EQ(context, node->inputs->size, 24); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Inferring batch size, number of outputs and number of cells from the - // input tensors. - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); - TF_LITE_ENSURE(context, input->dims->size > 1); - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - const int n_cell = input_to_output_weights->dims->data[0]; - TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); - - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], - n_cell); - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Check that input tensor dimensions matches with each other. - TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, - n_output, n_cell)); - - // Get the pointer to output, activation_state and cell_state tensors. - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - const TfLiteTensor* activation_state = - GetInput(context, node, kInputActivationStateTensor); - const TfLiteTensor* cell_state = - GetInput(context, node, kInputCellStateTensor); - - // Check the shape of input state tensors. - // These tensor may be 1D or 2D. It's fine as long as the total size is - // correct. - TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); - TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); - // Resize the output tensors. - TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); - output_size->data[0] = n_batch; - output_size->data[1] = n_output; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, output, output_size)); - - // The weights are of consistent type, so it suffices to check one. - const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && - input->type == kTfLiteFloat32); - - TfLiteIntArrayFree(node->temporaries); - if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(7); - } else { - node->temporaries = TfLiteIntArrayCreate(1); - } - node->temporaries->data[0] = op_data->scratch_tensor_index; - - // Create a scratch buffer tensor. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - scratch_buffer->type = input->type; - scratch_buffer->allocation_type = kTfLiteArenaRw; - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const bool use_cifg = (input_to_input_weights == nullptr); - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; - if (use_cifg) { - // Reserving space for Cell, Forget, Output gates - scratch_buffer_size->data[1] = n_cell * 3; - } else { - // Reserving space for Input, Cell, Forget, Output gates - scratch_buffer_size->data[1] = n_cell * 4; - } - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); - - if (is_hybrid_op) { - // Allocate temporary tensors to store quantized values of input, - // activation_state and cell_state tensors. - node->temporaries->data[1] = op_data->scratch_tensor_index + 1; - TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); - input_quantized->type = kTfLiteUInt8; - input_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { - TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, - input_quantized_size)); - } - node->temporaries->data[2] = op_data->scratch_tensor_index + 2; - TfLiteTensor* activation_state_quantized = - GetTemporary(context, node, /*index=*/2); - activation_state_quantized->type = kTfLiteUInt8; - activation_state_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(activation_state_quantized->dims, - activation_state->dims)) { - TfLiteIntArray* activation_state_quantized_size = - TfLiteIntArrayCopy(activation_state->dims); - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, activation_state_quantized, - activation_state_quantized_size)); - } - node->temporaries->data[3] = op_data->scratch_tensor_index + 3; - TfLiteTensor* cell_state_quantized = - GetTemporary(context, node, /*index=*/3); - cell_state_quantized->type = kTfLiteUInt8; - cell_state_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { - TfLiteIntArray* cell_state_quantized_size = - TfLiteIntArrayCopy(cell_state->dims); - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state_quantized, - cell_state_quantized_size)); - } - - // Allocate temporary tensors to store scaling factors and product scaling - // factors. The latter is a convenience storage which allows to quantize - // a vector once (which produces the scaling factors) and multiply it with - // different matrices (which requires multiplying the scaling factors with - // the scaling factor of the matrix). - node->temporaries->data[4] = op_data->scratch_tensor_index + 4; - TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; - int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); - } - node->temporaries->data[5] = op_data->scratch_tensor_index + 5; - TfLiteTensor* prod_scaling_factors = - GetTemporary(context, node, /*index=*/5); - prod_scaling_factors->type = kTfLiteFloat32; - prod_scaling_factors->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, - scaling_dims)) { - TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); - prod_scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, prod_scaling_factors, - prod_scaling_factors_size)); - } - - // Allocate a temporary tensor to store the recovered weights. Since - // this is used for diagonal matrices, only need to store n_cell values. - node->temporaries->data[6] = op_data->scratch_tensor_index + 6; - TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6); - recovered_weights->type = kTfLiteFloat32; - recovered_weights->allocation_type = kTfLiteArenaRw; - int recovered_dims[1] = {n_cell}; - if (!TfLiteIntArrayEqualsArray(recovered_weights->dims, 1, - recovered_dims)) { - TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1); - recovered_weights_size->data[0] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, recovered_weights, - recovered_weights_size)); - } - } - return kTfLiteOk; -} - -void LayerNormLstmStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, - const float* input_layer_norm_weight_ptr, - const float* forget_layer_norm_weight_ptr, - const float* cell_layer_norm_weight_ptr, - const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, float cell_clip, float proj_clip, - const TfLiteFusedActivation& activation, int n_batch, int n_cell, - int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr, - float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - - // Initialize scratch buffers with 0. - if (!use_cifg) { - tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); - } - tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); - tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); - tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch, - /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch, - /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::MeanStddevNormalization(input_gate_scratch, - input_gate_scratch, n_cell, n_batch, - kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr, - n_cell, input_gate_scratch, - n_batch, input_gate_scratch); - tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::MeanStddevNormalization(forget_gate_scratch, - forget_gate_scratch, n_cell, n_batch, - kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr, - n_cell, forget_gate_scratch, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, - n_batch, kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch); - tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip, - cell_state_ptr); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::MeanStddevNormalization(output_gate_scratch, - output_gate_scratch, n_cell, n_batch, - kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr, - n_cell, output_gate_scratch, - n_batch, output_gate_scratch); - tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, - output_ptr_batch, /*result_stride=*/1); - if (proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip, - output_ptr_batch); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); -} - -void LayerNormLstmStep( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, - const float* input_layer_norm_weight_ptr, - const float* forget_layer_norm_weight_ptr, - const float* cell_layer_norm_weight_ptr, - const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, - int n_batch, int n_cell, int n_input, int n_output, - float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - - // Initialize scratch buffers with 0. - if (!use_cifg) { - tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); - } - tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); - tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); - tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); - - if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, - &unused_min, &unused_max, &scaling_factors[b]); - } - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, - product_scaling_factors, n_batch, forget_gate_scratch, - /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, - product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, - product_scaling_factors, n_batch, output_gate_scratch, - /*result_stride=*/1); - } - - if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_output; - tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, - quantized_output_state_ptr + offset, - &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); - } - - // Save quantization and matmul computation for all zero input. - bool is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, - cell_to_input_weights_scale, - recovered_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_weights, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::MeanStddevNormalization(input_gate_scratch, - input_gate_scratch, n_cell, n_batch, - kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr, - n_cell, input_gate_scratch, - n_batch, input_gate_scratch); - tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, - cell_to_forget_weights_scale, - recovered_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_weights, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::MeanStddevNormalization(forget_gate_scratch, - forget_gate_scratch, n_cell, n_batch, - kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr, - n_cell, forget_gate_scratch, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, - n_batch, kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct( - cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch); - tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip, - cell_state_ptr); - } - - is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - // For each batch and cell: update the output gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, - cell_to_output_weights_scale, - recovered_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_weights, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::MeanStddevNormalization(output_gate_scratch, - output_gate_scratch, n_cell, n_batch, - kLayerNormEpsilon); - tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr, - n_cell, output_gate_scratch, - n_batch, output_gate_scratch); - tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_cell; - tensor_utils::SymmetricQuantizeFloats( - output_gate_scratch + offset, n_cell, - quantized_cell_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * projection_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, - product_scaling_factors, n_batch, output_ptr_batch, - /*result_stride=*/1); - } - if (proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip, - output_ptr_batch); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); -} - -// The LayerNormLSTM Op engine. -TfLiteStatus EvalFloat( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_layer_norm_weights, - const TfLiteTensor* forget_layer_norm_weights, - const TfLiteTensor* cell_layer_norm_weights, - const TfLiteTensor* output_layer_norm_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, - TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* input_to_input_weights_ptr = - (use_cifg) ? nullptr : input_to_input_weights->data.f; - const float* recurrent_to_input_weights_ptr = - (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; - const float* input_gate_bias_ptr = - (use_cifg) ? nullptr : input_gate_bias->data.f; - const float* cell_to_input_weights_ptr = - (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; - const float* cell_to_forget_weights_ptr = - (use_peephole) ? cell_to_forget_weights->data.f : nullptr; - const float* cell_to_output_weights_ptr = - (use_peephole) ? cell_to_output_weights->data.f : nullptr; - const float* projection_weights_ptr = - (projection_weights == nullptr) ? nullptr : projection_weights->data.f; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - const float* input_layer_norm_weight_ptr = - (input_layer_norm_weights == nullptr) ? nullptr - : input_layer_norm_weights->data.f; - - // Required tensors, pointers are non-null. - const float* input_ptr_batch = input->data.f; - const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; - const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; - const float* input_to_output_weights_ptr = input_to_output_weights->data.f; - const float* recurrent_to_forget_weights_ptr = - recurrent_to_forget_weights->data.f; - const float* recurrent_to_cell_weights_ptr = - recurrent_to_cell_weights->data.f; - const float* recurrent_to_output_weights_ptr = - recurrent_to_output_weights->data.f; - const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f; - const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f; - const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - float* output_ptr_batch = output->data.f; - - LayerNormLstmStep( - input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, - input_to_cell_weights_ptr, input_to_output_weights_ptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, - recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, - cell_to_input_weights_ptr, cell_to_forget_weights_ptr, - cell_to_output_weights_ptr, input_layer_norm_weight_ptr, - forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr, - output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell, - n_input, n_output, activation_state_ptr, cell_state_ptr, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, output_ptr_batch); - - return kTfLiteOk; -} - -TfLiteStatus EvalHybrid( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_layer_norm_weights, - const TfLiteTensor* forget_layer_norm_weights, - const TfLiteTensor* cell_layer_norm_weights, - const TfLiteTensor* output_layer_norm_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, - TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, - TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights, - TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - int8_t* input_to_input_weights_ptr = nullptr; - float input_to_input_weights_scale = 1.0f; - int8_t* recurrent_to_input_weights_ptr = nullptr; - float recurrent_to_input_weights_scale = 1.0f; - float* input_gate_bias_ptr = nullptr; - if (!use_cifg) { - input_to_input_weights_ptr = - reinterpret_cast(input_to_input_weights->data.uint8); - recurrent_to_input_weights_ptr = - reinterpret_cast(recurrent_to_input_weights->data.uint8); - input_gate_bias_ptr = input_gate_bias->data.f; - input_to_input_weights_scale = input_to_input_weights->params.scale; - recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; - } - - int8_t* cell_to_input_weights_ptr = nullptr; - int8_t* cell_to_forget_weights_ptr = nullptr; - int8_t* cell_to_output_weights_ptr = nullptr; - float cell_to_input_weights_scale = 1.0f; - float cell_to_forget_weights_scale = 1.0f; - float cell_to_output_weights_scale = 1.0f; - if (use_peephole) { - if (!use_cifg) { - cell_to_input_weights_ptr = - reinterpret_cast(cell_to_input_weights->data.uint8); - cell_to_input_weights_scale = cell_to_input_weights->params.scale; - } - cell_to_forget_weights_ptr = - reinterpret_cast(cell_to_forget_weights->data.uint8); - cell_to_output_weights_ptr = - reinterpret_cast(cell_to_output_weights->data.uint8); - cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; - cell_to_output_weights_scale = cell_to_output_weights->params.scale; - } - - const int8_t* projection_weights_ptr = - (projection_weights == nullptr) - ? nullptr - : reinterpret_cast(projection_weights->data.uint8); - const float projection_weights_scale = - (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - const float* input_layer_norm_weight_ptr = - (input_layer_norm_weights == nullptr) ? nullptr - : input_layer_norm_weights->data.f; - - // Required tensors, pointers are non-null. - const float* input_ptr_batch = input->data.f; - const int8_t* input_to_forget_weights_ptr = - reinterpret_cast(input_to_forget_weights->data.uint8); - const float input_to_forget_weights_scale = - input_to_forget_weights->params.scale; - const int8_t* input_to_cell_weights_ptr = - reinterpret_cast(input_to_cell_weights->data.uint8); - const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; - const int8_t* input_to_output_weights_ptr = - reinterpret_cast(input_to_output_weights->data.uint8); - const float input_to_output_weights_scale = - input_to_output_weights->params.scale; - const int8_t* recurrent_to_forget_weights_ptr = - reinterpret_cast(recurrent_to_forget_weights->data.uint8); - const float recurrent_to_forget_weights_scale = - recurrent_to_forget_weights->params.scale; - const int8_t* recurrent_to_cell_weights_ptr = - reinterpret_cast(recurrent_to_cell_weights->data.uint8); - const float recurrent_to_cell_weights_scale = - recurrent_to_cell_weights->params.scale; - const int8_t* recurrent_to_output_weights_ptr = - reinterpret_cast(recurrent_to_output_weights->data.uint8); - const float recurrent_to_output_weights_scale = - recurrent_to_output_weights->params.scale; - const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f; - const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f; - const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - float* output_ptr_batch = output->data.f; - - // Temporary storage for quantized values and scaling factors. - int8_t* quantized_input_ptr = - reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_activation_state_ptr = - reinterpret_cast(activation_state_quantized->data.uint8); - int8_t* quantized_cell_state_ptr = - reinterpret_cast(cell_state_quantized->data.uint8); - float* scaling_factors_ptr = scaling_factors->data.f; - float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; - float* recovered_weights_ptr = recovered_weights->data.f; - - LayerNormLstmStep( - input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr, - cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell, - n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_weights_ptr, quantized_input_ptr, - quantized_activation_state_ptr, quantized_cell_state_ptr, - activation_state_ptr, cell_state_ptr, output_ptr_batch); - - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const OpData* op_data = reinterpret_cast(node->user_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_layer_norm_weights = - GetOptionalInputTensor(context, node, kInputLayerNormWeightsTensor); - const TfLiteTensor* forget_layer_norm_weights = - GetInput(context, node, kForgetLayerNormWeightsTensor); - const TfLiteTensor* cell_layer_norm_weights = - GetInput(context, node, kCellLayerNormWeightsTensor); - const TfLiteTensor* output_layer_norm_weights = - GetInput(context, node, kOutputLayerNormWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - - TfLiteTensor* activation_state = - &context->tensors[node->inputs->data[kInputActivationStateTensor]]; - TfLiteTensor* cell_state = - &context->tensors[node->inputs->data[kInputCellStateTensor]]; - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - switch (input_to_output_weights->type) { - case kTfLiteFloat32: { - return EvalFloat(input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, - cell_to_input_weights, cell_to_forget_weights, - cell_to_output_weights, input_layer_norm_weights, - forget_layer_norm_weights, cell_layer_norm_weights, - output_layer_norm_weights, input_gate_bias, - forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, op_data->cell_clip, - op_data->proj_clip, op_data->activation, scratch_buffer, - activation_state, cell_state, output); - } - case kTfLiteUInt8: { - TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* activation_state_quantized = - GetTemporary(context, node, /*index=*/2); - TfLiteTensor* cell_state_quantized = - GetTemporary(context, node, /*index=*/3); - TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); - TfLiteTensor* prod_scaling_factors = - GetTemporary(context, node, /*index=*/5); - TfLiteTensor* recovered_weights = - GetTemporary(context, node, /*index=*/6); - return EvalHybrid( - input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, - cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, - input_layer_norm_weights, forget_layer_norm_weights, - cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias, - forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, op_data->cell_clip, op_data->proj_clip, - op_data->activation, scratch_buffer, scaling_factors, - prod_scaling_factors, recovered_weights, input_quantized, - activation_state_quantized, cell_state_quantized, activation_state, - cell_state, output); - } - default: - context->ReportError(context, "Type %d is not currently supported.", - input_to_output_weights->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); -} - -} // namespace layer_norm_lstm - -TfLiteRegistration* Register_LAYER_NORM_LSTM() { - static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free, - layer_norm_lstm::Prepare, - layer_norm_lstm::Eval}; - return &r; -} - -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/lite/kernels/layer_norm_lstm_test.cc deleted file mode 100644 index 5aed818f2407a96acb8893654971fc5bb91a81ed..0000000000000000000000000000000000000000 --- a/tensorflow/lite/kernels/layer_norm_lstm_test.cc +++ /dev/null @@ -1,885 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for TFLite Layer Norm LSTM op. - -#include -#include - -#include -#include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/model.h" - -namespace tflite { -namespace ops { -namespace custom { - -TfLiteRegistration* Register_LAYER_NORM_LSTM(); - -namespace { - -using ::testing::ElementsAreArray; - -class LayerNormLSTMOpModel : public SingleOpModel { - public: - LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, - bool use_cifg, bool use_peephole, - bool use_projection_weights, bool use_projection_bias, - float cell_clip, float proj_clip, - const std::vector>& input_shapes, - const TensorType& weight_type = TensorType_FLOAT32) - : n_batch_(n_batch), - n_input_(n_input), - n_cell_(n_cell), - n_output_(n_output) { - input_ = AddInput(TensorType_FLOAT32); - - if (use_cifg) { - input_to_input_weights_ = AddNullInput(); - } else { - input_to_input_weights_ = AddInput(weight_type); - } - - input_to_forget_weights_ = AddInput(weight_type); - input_to_cell_weights_ = AddInput(weight_type); - input_to_output_weights_ = AddInput(weight_type); - - if (use_cifg) { - recurrent_to_input_weights_ = AddNullInput(); - } else { - recurrent_to_input_weights_ = AddInput(weight_type); - } - - recurrent_to_forget_weights_ = AddInput(weight_type); - recurrent_to_cell_weights_ = AddInput(weight_type); - recurrent_to_output_weights_ = AddInput(weight_type); - - if (use_peephole) { - if (use_cifg) { - cell_to_input_weights_ = AddNullInput(); - } else { - cell_to_input_weights_ = AddInput(weight_type); - } - cell_to_forget_weights_ = AddInput(weight_type); - cell_to_output_weights_ = AddInput(weight_type); - } else { - cell_to_input_weights_ = AddNullInput(); - cell_to_forget_weights_ = AddNullInput(); - cell_to_output_weights_ = AddNullInput(); - } - - if (use_cifg) { - input_layer_norm_weights_ = AddNullInput(); - } else { - input_layer_norm_weights_ = AddInput(TensorType_FLOAT32); - } - forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32); - cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32); - output_layer_norm_weights_ = AddInput(TensorType_FLOAT32); - - if (use_cifg) { - input_gate_bias_ = AddNullInput(); - } else { - input_gate_bias_ = AddInput(TensorType_FLOAT32); - } - forget_gate_bias_ = AddInput(TensorType_FLOAT32); - cell_bias_ = AddInput(TensorType_FLOAT32); - output_gate_bias_ = AddInput(TensorType_FLOAT32); - - if (use_projection_weights) { - projection_weights_ = AddInput(weight_type); - if (use_projection_bias) { - projection_bias_ = AddInput(TensorType_FLOAT32); - } else { - projection_bias_ = AddNullInput(); - } - } else { - projection_weights_ = AddNullInput(); - projection_bias_ = AddNullInput(); - } - - // Adding the 2 state tensors. - output_state_ = - AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); - cell_state_ = - AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - - output_ = AddOutput(TensorType_FLOAT32); - - // Set up and pass in custom options using flexbuffer. - flexbuffers::Builder fbb; - fbb.Map([&]() { - fbb.Int("cell_clip", cell_clip); - fbb.Int("proj_clip", proj_clip); - fbb.String("fused_activation_function", "TANH"); - }); - fbb.Finish(); - SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM); - BuildInterpreter(input_shapes); - } - - void SetInputToInputWeights(const std::vector& f) { - PopulateTensor(input_to_input_weights_, f); - } - - void SetInputToForgetWeights(const std::vector& f) { - PopulateTensor(input_to_forget_weights_, f); - } - - void SetInputToCellWeights(const std::vector& f) { - PopulateTensor(input_to_cell_weights_, f); - } - - void SetInputToOutputWeights(const std::vector& f) { - PopulateTensor(input_to_output_weights_, f); - } - - void SetRecurrentToInputWeights(const std::vector& f) { - PopulateTensor(recurrent_to_input_weights_, f); - } - - void SetRecurrentToForgetWeights(const std::vector& f) { - PopulateTensor(recurrent_to_forget_weights_, f); - } - - void SetRecurrentToCellWeights(const std::vector& f) { - PopulateTensor(recurrent_to_cell_weights_, f); - } - - void SetRecurrentToOutputWeights(const std::vector& f) { - PopulateTensor(recurrent_to_output_weights_, f); - } - - void SetCellToInputWeights(const std::vector& f) { - PopulateTensor(cell_to_input_weights_, f); - } - - void SetCellToForgetWeights(const std::vector& f) { - PopulateTensor(cell_to_forget_weights_, f); - } - - void SetCellToOutputWeights(const std::vector& f) { - PopulateTensor(cell_to_output_weights_, f); - } - - void SetInputLayerNormWeights(const std::vector& f) { - PopulateTensor(input_layer_norm_weights_, f); - } - - void SetForgetLayerNormWeights(const std::vector& f) { - PopulateTensor(forget_layer_norm_weights_, f); - } - - void SetCellLayerNormWeights(const std::vector& f) { - PopulateTensor(cell_layer_norm_weights_, f); - } - - void SetOutputLayerNormWeights(const std::vector& f) { - PopulateTensor(output_layer_norm_weights_, f); - } - - void SetInputGateBias(const std::vector& f) { - PopulateTensor(input_gate_bias_, f); - } - - void SetForgetGateBias(const std::vector& f) { - PopulateTensor(forget_gate_bias_, f); - } - - void SetCellBias(const std::vector& f) { - PopulateTensor(cell_bias_, f); - } - - void SetOutputGateBias(const std::vector& f) { - PopulateTensor(output_gate_bias_, f); - } - - void SetProjectionWeights(const std::vector& f) { - PopulateTensor(projection_weights_, f); - } - - void SetProjectionBias(const std::vector& f) { - PopulateTensor(projection_bias_, f); - } - - void SetInput(int offset, const float* begin, const float* end) { - PopulateTensor(input_, offset, const_cast(begin), - const_cast(end)); - } - - std::vector GetOutput() { return ExtractVector(output_); } - - int num_inputs() { return n_input_; } - int num_outputs() { return n_output_; } - int num_cells() { return n_cell_; } - int num_batches() { return n_batch_; } - - protected: - int input_; - int input_to_input_weights_; - int input_to_forget_weights_; - int input_to_cell_weights_; - int input_to_output_weights_; - - int recurrent_to_input_weights_; - int recurrent_to_forget_weights_; - int recurrent_to_cell_weights_; - int recurrent_to_output_weights_; - - int cell_to_input_weights_; - int cell_to_forget_weights_; - int cell_to_output_weights_; - - int input_layer_norm_weights_; - int forget_layer_norm_weights_; - int cell_layer_norm_weights_; - int output_layer_norm_weights_; - - int input_gate_bias_; - int forget_gate_bias_; - int cell_bias_; - int output_gate_bias_; - - int projection_weights_; - int projection_bias_; - - int output_state_; - int cell_state_; - - int output_; - - int n_batch_; - int n_input_; - int n_cell_; - int n_output_; -}; - -class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel { - public: - HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, - bool use_cifg, bool use_peephole, - bool use_projection_weights, - bool use_projection_bias, float cell_clip, - float proj_clip, - const std::vector>& input_shapes) - : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, - use_peephole, use_projection_weights, - use_projection_bias, cell_clip, proj_clip, - input_shapes, TensorType_UINT8) {} - - void SetInputToInputWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(input_to_input_weights_, f); - } - - void SetInputToForgetWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); - } - - void SetInputToCellWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); - } - - void SetInputToOutputWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(input_to_output_weights_, f); - } - - void SetRecurrentToInputWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); - } - - void SetRecurrentToForgetWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); - } - - void SetRecurrentToCellWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); - } - - void SetRecurrentToOutputWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); - } - - void SetCellToInputWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); - } - - void SetCellToForgetWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); - } - - void SetCellToOutputWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); - } - - void SetInputLayerNormWeights(const std::vector& f) { - PopulateTensor(input_layer_norm_weights_, f); - } - - void SetForgetLayerNormWeights(const std::vector& f) { - PopulateTensor(forget_layer_norm_weights_, f); - } - - void SetCellLayerNormWeights(const std::vector& f) { - PopulateTensor(cell_layer_norm_weights_, f); - } - - void SetOutputLayerNormWeights(const std::vector& f) { - PopulateTensor(output_layer_norm_weights_, f); - } - - void SetProjectionWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(projection_weights_, f); - } -}; - -class BaseLayerNormLstmTest : public ::testing::Test { - protected: - // Weights of the Layer Norm LSTM model. Some are optional. - std::vector input_to_input_weights_; - std::vector input_to_cell_weights_; - std::vector input_to_forget_weights_; - std::vector input_to_output_weights_; - std::vector input_gate_bias_; - std::vector cell_gate_bias_; - std::vector forget_gate_bias_; - std::vector output_gate_bias_; - std::vector recurrent_to_input_weights_; - std::vector recurrent_to_cell_weights_; - std::vector recurrent_to_forget_weights_; - std::vector recurrent_to_output_weights_; - std::vector cell_to_input_weights_; - std::vector cell_to_forget_weights_; - std::vector cell_to_output_weights_; - std::vector input_layer_norm_weights_; - std::vector forget_layer_norm_weights_; - std::vector cell_layer_norm_weights_; - std::vector output_layer_norm_weights_; - std::vector projection_weights_; - - // Layer Norm LSTM input is stored as num_batch x num_inputs vector. - std::vector> layer_norm_lstm_input_; - - // Compares output up to tolerance to the result of the layer_norm_lstm given - // the input. - void VerifyGoldens(const std::vector>& input, - const std::vector>& output, - LayerNormLSTMOpModel* layer_norm_lstm, - float tolerance = 1e-5) { - const int num_batches = input.size(); - EXPECT_GT(num_batches, 0); - const int num_inputs = layer_norm_lstm->num_inputs(); - EXPECT_GT(num_inputs, 0); - const int input_sequence_size = input[0].size() / num_inputs; - EXPECT_GT(input_sequence_size, 0); - for (int i = 0; i < input_sequence_size; ++i) { - for (int b = 0; b < num_batches; ++b) { - const float* batch_start = input[b].data() + i * num_inputs; - const float* batch_end = batch_start + num_inputs; - - layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(), - batch_start, batch_end); - } - - layer_norm_lstm->Invoke(); - - const int num_outputs = layer_norm_lstm->num_outputs(); - std::vector expected; - for (int b = 0; b < num_batches; ++b) { - const float* golden_start_batch = output[b].data() + i * num_outputs; - const float* golden_end_batch = golden_start_batch + num_outputs; - expected.insert(expected.end(), golden_start_batch, golden_end_batch); - } - EXPECT_THAT(layer_norm_lstm->GetOutput(), - ElementsAreArray(ArrayFloatNear(expected, tolerance))); - } - } -}; - -class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest - : public BaseLayerNormLstmTest { - void SetUp() override { - input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, - 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5, - -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; - - input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, - -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, - -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; - - input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, - -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, - -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; - - input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, - -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, - -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; - - input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; - - forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; - - cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; - - output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; - - recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, - -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; - - recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, - -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; - - recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, - 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; - - recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, - -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; - - cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; - - cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; - - cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; - - input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5}; - forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3}; - cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8}; - output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5}; - - projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, - 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; - - layer_norm_lstm_input_ = { - {// Batch0: 3 (input_sequence_size) * 5 (n_input) - 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 - 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 - 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 - - {// Batch1: 3 (input_sequence_size) * 5 (n_input) - 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 - 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 - 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 - }; - } -}; - -TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, - LayerNormLstmBlackBoxTest) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 4; - const int n_output = 3; - const float ceil_clip = 0.0; - const float proj_clip = 0.0; - - LayerNormLSTMOpModel layer_norm_lstm( - n_batch, n_input, n_cell, n_output, - /*use_cifg=*/false, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, ceil_clip, proj_clip, - { - {n_batch, n_input}, // input tensor - - {n_cell, n_input}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {n_cell}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {n_cell}, // input_layer_norm_weight tensor - {n_cell}, // forget_layer_norm_weight tensor - {n_cell}, // cell_layer_norm_weight tensor - {n_cell}, // output_layer_norm_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); - layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); - layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); - layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); - - layer_norm_lstm.SetInputGateBias(input_gate_bias_); - layer_norm_lstm.SetCellBias(cell_gate_bias_); - layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); - layer_norm_lstm.SetOutputGateBias(output_gate_bias_); - - layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); - layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); - layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); - layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - - layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); - layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); - layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); - - layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); - layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); - layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); - layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); - - layer_norm_lstm.SetProjectionWeights(projection_weights_); - - // Verify the final output. - const std::vector> layer_norm_lstm_golden_output = { - { - // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0244077, 0.128027, -0.00170918, // seq 0 - 0.0137642, 0.140751, 0.0395835, // seq 1 - -0.00459231, 0.155278, 0.0837377, // seq 2 - }, - { - // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.00692428, 0.0848741, 0.063445, // seq 0 - -0.00403912, 0.139963, 0.072681, // seq 1 - 0.00752706, 0.161903, 0.0561371, // seq 2 - }}; - - VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, - &layer_norm_lstm); -} - -TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, - HybridLayerNormLstmBlackBoxTest) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 4; - const int n_output = 3; - const float ceil_clip = 0.0; - const float proj_clip = 0.0; - - HybridLayerNormLSTMOpModel layer_norm_lstm( - n_batch, n_input, n_cell, n_output, - /*use_cifg=*/false, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, ceil_clip, proj_clip, - { - {n_batch, n_input}, // input tensor - - {n_cell, n_input}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {n_cell}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {n_cell}, // input_layer_norm_weight tensor - {n_cell}, // forget_layer_norm_weight tensor - {n_cell}, // cell_layer_norm_weight tensor - {n_cell}, // output_layer_norm_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); - layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); - layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); - layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); - - layer_norm_lstm.SetInputGateBias(input_gate_bias_); - layer_norm_lstm.SetCellBias(cell_gate_bias_); - layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); - layer_norm_lstm.SetOutputGateBias(output_gate_bias_); - - layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); - layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); - layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); - layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - - layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); - layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); - layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); - - layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); - layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); - layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); - layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); - - layer_norm_lstm.SetProjectionWeights(projection_weights_); - - const std::vector> layer_norm_lstm_golden_output = { - { - // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0244576, 0.127847, -0.00181765, // seq 0 - 0.0137518, 0.140892, 0.0402234, // seq 1 - -0.0048839, 0.155096, 0.0840309, // seq 2 - }, - { - // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.00728636, 0.0843957, 0.0634786, // seq 0 - -0.00448382, 0.139278, 0.0737372, // seq 1 - 0.00734616, 0.161793, 0.0560238, // seq 2 - }}; - - VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, - &layer_norm_lstm); -} - -class CifgPeepholeProjectionNoClippingLayerNormLstmTest - : public BaseLayerNormLstmTest { - void SetUp() override { - input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, - -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, - -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; - input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, - -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, - -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; - input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, - -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, - -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; - - forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; - cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; - output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; - - recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, - -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; - recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, - 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; - recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, - -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; - - cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; - cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; - - forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3}; - cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8}; - output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5}; - projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, - 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; - - layer_norm_lstm_input_ = { - {// Batch0: 3 (input_sequence_size) * 5 (n_input) - 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 - 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 - 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 - - {// Batch1: 3 (input_sequence_size) * 5 (n_input) - 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 - 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 - 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 - }; - } -}; - -TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, - LayerNormLstmBlackBoxTest) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 4; - const int n_output = 3; - const float ceil_clip = 0.0; - const float proj_clip = 0.0; - - LayerNormLSTMOpModel layer_norm_lstm( - n_batch, n_input, n_cell, n_output, - /*use_cifg=*/true, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, ceil_clip, proj_clip, - { - {n_batch, n_input}, // input tensor - - {0, 0}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {0, 0}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {0}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {0}, // input_layer_norm_weight tensor - {n_cell}, // forget_layer_norm_weight tensor - {n_cell}, // cell_layer_norm_weight tensor - {n_cell}, // output_layer_norm_weight tensor - - {0}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); - layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); - layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); - - layer_norm_lstm.SetCellBias(cell_gate_bias_); - layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); - layer_norm_lstm.SetOutputGateBias(output_gate_bias_); - - layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); - layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); - layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - - layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); - layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); - - layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); - layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); - layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); - - layer_norm_lstm.SetProjectionWeights(projection_weights_); - - // Verify the final output. - const std::vector> layer_norm_lstm_golden_output = { - { - // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.02129706, 0.140816242, 0.0112733059, // seq 0 - 0.0132302344, 0.152308047, 0.0346313119, // seq 1 - -0.0123688057, 0.165790111, 0.0893077999, // seq 2 - }, - { - // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.0226350538, 0.0916948169, 0.0769175813, // seq 0 - -0.0269966982, 0.149707705, 0.094149217, // seq 1 - -0.0103429332, 0.173016444, 0.0720508844, // seq 2 - }}; - - VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, - &layer_norm_lstm); -} - -TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, - HybridLayerNormLstmBlackBoxTest) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 4; - const int n_output = 3; - const float ceil_clip = 0.0; - const float proj_clip = 0.0; - - HybridLayerNormLSTMOpModel layer_norm_lstm( - n_batch, n_input, n_cell, n_output, - /*use_cifg=*/true, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, ceil_clip, proj_clip, - { - {n_batch, n_input}, // input tensor - - {0, 0}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {0, 0}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {0}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {0}, // input_layer_norm_weight tensor - {n_cell}, // forget_layer_norm_weight tensor - {n_cell}, // cell_layer_norm_weight tensor - {n_cell}, // output_layer_norm_weight tensor - - {0}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }); - - layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); - layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); - layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); - - layer_norm_lstm.SetCellBias(cell_gate_bias_); - layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); - layer_norm_lstm.SetOutputGateBias(output_gate_bias_); - - layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); - layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); - layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - - layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); - layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); - - layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); - layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); - layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); - - layer_norm_lstm.SetProjectionWeights(projection_weights_); - - // Verify the final output. - const std::vector> layer_norm_lstm_golden_output = { - { - // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0212250091, 0.140474007, 0.0115012666, // seq 0 - 0.0130806509, 0.152660668, 0.0347516984, // seq 1 - -0.0124010444, 0.166042402, 0.0898982584, // seq 2 - }, - { - // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.0228835996, 0.0917588323, 0.0778886303, // seq 0 - -0.0275101066, 0.148769245, 0.0938384682, // seq 1 - -0.0103605557, 0.172605693, 0.0728750974, // seq 2 - }}; - - VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, - &layer_norm_lstm); -} - -} // namespace -} // namespace custom -} // namespace ops -} // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index 470c74d207d51688c3c48de0fc8bdecda43097a7..ea22ed56941cd4ca4e625e6a9e6be5f74f69c996 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -840,6 +840,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return full::Init(context, buffer, length); case kTfLiteLSTMBasicKernel: return basic::Init(context, buffer, length); + default: + return nullptr; } } void Free(TfLiteContext* context, void* buffer) { @@ -855,6 +857,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return full::Prepare(context, node); case kTfLiteLSTMBasicKernel: return basic::Prepare(context, node); + default: + return kTfLiteError; } } @@ -865,6 +869,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return full::Eval(context, node); case kTfLiteLSTMBasicKernel: return basic::Eval(context, node); + default: + return kTfLiteError; } } diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 244cfae4a20b93b32022bee412f241397df53c49..045f41bd91eccc24c54726d67d295b85c90185ec 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -119,7 +119,7 @@ inline void LstmStepWithAuxInput( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. + // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); const bool is_layer_norm_lstm = @@ -473,7 +473,7 @@ inline void LstmStepWithAuxInput( int8_t* quantized_cell_state_ptr, float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we - // can check the existense of only one to the get the condition. + // can check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); const bool is_layer_norm_lstm = @@ -913,7 +913,7 @@ TfLiteStatus EvalFloat( const int n_output = recurrent_to_output_weights->dims->data[1]; // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. + // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr); diff --git a/tensorflow/lite/kernels/maximum_minimum.cc b/tensorflow/lite/kernels/maximum_minimum.cc index 3bcaabf675eba4f528fe73b01610d915e7780f85..6ae836bc3fba78f0a1941fe6f9df83fbd5e474c2 100644 --- a/tensorflow/lite/kernels/maximum_minimum.cc +++ b/tensorflow/lite/kernels/maximum_minimum.cc @@ -26,7 +26,7 @@ namespace ops { namespace builtin { namespace maximum_minimum { -// This file has a reference implemenation of TFMaximum/TFMinimum. +// This file has a reference implementation of TFMaximum/TFMinimum. enum KernelType { kReference, }; @@ -108,6 +108,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: TFLiteOperation(context, node, op_context); break; + case kTfLiteInt8: + TFLiteOperation(context, node, op_context); + break; case kTfLiteInt32: TFLiteOperation(context, node, op_context); break; diff --git a/tensorflow/lite/kernels/maximum_minimum_test.cc b/tensorflow/lite/kernels/maximum_minimum_test.cc index acb74e09d3fb47c33c6c146af4d0b1b1030491be..6567c8f3611204af3bdeecbdb11a07f6f16be908 100644 --- a/tensorflow/lite/kernels/maximum_minimum_test.cc +++ b/tensorflow/lite/kernels/maximum_minimum_test.cc @@ -112,6 +112,17 @@ TEST(MaxMinOpTest, Uint8Test) { {0, 0, 1, 11, 2, 1}); } +TEST(MaxMinOpTest, Int8Test) { + std::initializer_list data1 = {1, 0, 2, 11, 2, 23}; + std::initializer_list data2 = {0, 0, 1, 12, 123, 1}; + TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT8, {3, 1, 2}}, + {TensorType_INT8, {3, 1, 2}}, {TensorType_INT8, {3, 1, 2}}, + data1, data2, {1, 0, 2, 12, 123, 23}); + TestModel(BuiltinOperator_MINIMUM, {TensorType_INT8, {3, 1, 2}}, + {TensorType_INT8, {3, 1, 2}}, {TensorType_INT8, {3, 1, 2}}, + data1, data2, {0, 0, 1, 11, 2, 1}); +} + TEST(MaximumOpTest, FloatWithBroadcastTest) { std::initializer_list data1 = {1.0, 0.0, -1.0, -2.0, -1.44, 11.0}; std::initializer_list data2 = {0.5, 2.0}; diff --git a/tensorflow/lite/kernels/mirror_pad.cc b/tensorflow/lite/kernels/mirror_pad.cc index e74e47f7a37b0f449fb2a63237e95066bb452de6..f54d925d777380b154a0041452872778ba314999 100644 --- a/tensorflow/lite/kernels/mirror_pad.cc +++ b/tensorflow/lite/kernels/mirror_pad.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" @@ -21,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -30,118 +32,37 @@ namespace builtin { namespace mirror_pad { namespace { -// Simple class that represents a mirror padded tensor - which is the output -// from the Op. -struct PaddedTensor { - // If not null that means this is a scalar value. - // Note: This is not owned by default. It will point to the value - // in the input tensor. - const void* value = nullptr; - // If this tensor is not one value, then this vector will have - // all the tensors that belongs to this tensor. - // Pointers are owned. - std::vector> values; - // Pointers to PaddedTensors that are padded on the left of the current - // tensor. - std::vector left_pad_ptrs; - // Pointers to PaddedTensors that are padded on the right of the current - // tensor. - std::vector right_pad_ptrs; - - // Returns mutable pointer to the tensor identified by 'indices'. - PaddedTensor* GetMutable(const std::vector& indices) { - auto* result = this; - for (int i = 0; i < indices.size(); ++i) { - if (indices[i] >= result->values.size()) { - return nullptr; - } - result = result->values[indices[i]].get(); - if (result == nullptr) break; - } - return result; - } -}; - -// Util method to initialize the memory of the padded tensor. -void InitializeTensorMemory(const TfLiteIntArray* const dims, int dim_index, - int dims_size, PaddedTensor* padded_tensor) { - if (dim_index >= dims_size) { - return; - } - padded_tensor->values.reserve(dims->data[dim_index]); - for (int i = 0; i < dims->data[dim_index]; ++i) { - padded_tensor->values.emplace_back(new PaddedTensor()); - InitializeTensorMemory(dims, dim_index + 1, dims_size, - padded_tensor->values.back().get()); - } -} - -// Returns pointer to the value at the specified index in 'data'. -inline const void* GetValuePointerAtIndex(const void* data, int index, - const TfLiteType data_type) { - switch (data_type) { - case kTfLiteFloat32: - return static_cast(data) + index; - case kTfLiteInt32: - return static_cast(data) + index; - case kTfLiteUInt8: - return static_cast(data) + index; - case kTfLiteInt64: - return static_cast(data) + index; - case kTfLiteBool: - return static_cast(data) + index; - case kTfLiteInt16: - return static_cast(data) + index; - case kTfLiteInt8: - return static_cast(data) + index; - // Unsupported types ? - default: - return nullptr; - } - return nullptr; -} +// Nil value for paddingMode/offset. +const int kUnsetOffset = -1; -// Util method that increment index in the N-d array. -void IncrementTensorIndex(const TfLiteIntArray* dims, - std::vector* tensor_index_ptr) { - int dimension_index = dims->size - 1; - auto& tensor_index = *tensor_index_ptr; - tensor_index[dimension_index]++; - while (dimension_index >= 0 && - tensor_index[dimension_index] == dims->data[dimension_index]) { - tensor_index[dimension_index] = 0; - dimension_index--; - if (dimension_index >= 0) tensor_index[dimension_index]++; - } -} - -// Fills the 'padded_tensor' with data from 'input_tensor'. -TfLiteStatus InitFromInputTensor(const TfLiteTensor* input_tensor, - PaddedTensor* padded_tensor) { - const auto* dims = input_tensor->dims; - const auto data_type = input_tensor->type; - const void* data = static_cast(input_tensor->data.raw_const); - // Either invalid input or unsupported type.+ - if (data == nullptr) { - return kTfLiteError; - } - // Index of current processing tensor. - std::vector tensor_index(dims->size, 0); - int flat_index = 0; - const int num_elements = NumElements(input_tensor); - while (flat_index < num_elements) { - auto* tensor = padded_tensor->GetMutable(tensor_index); - if (tensor == nullptr) { - return kTfLiteError; - } - tensor->value = GetValuePointerAtIndex(data, flat_index, data_type); - IncrementTensorIndex(dims, &tensor_index); - ++flat_index; - } +// Wrapper for data used by the op. +struct OpData { + // Holds computed value (memoized value) of an internal fill state of a + // subarray. + // State is (Dimension to fill, index in tensor as flattened array) + // The value is start and end in the output array which has the padded result. + std::vector> cache; +}; - return kTfLiteOk; -} +// Wrapper for params passed to the Eval function. +template +struct EvalData { + OpData* op_data = nullptr; + const TfLiteTensor* padding_matrix = nullptr; + const TfLiteIntArray* input_dims = nullptr; + // Holds number of elements at the nth dimension. + // value at last dimension = 1, at second to last = sizeof last dimension. + const std::vector* dimension_num_elements = nullptr; + const T* input_data = nullptr; + + int offset = kUnsetOffset; + T* output_data = nullptr; + int input_size = 0; + int output_size = 0; + int num_dims = 0; +}; +// Helper method that fills the left and right pads. template inline void GetPadding(const T* data, int offset, int64_t* left_pad, int64_t* right_pad) { @@ -149,9 +70,8 @@ inline void GetPadding(const T* data, int offset, int64_t* left_pad, *right_pad = static_cast(*(data + offset * 2 + 1)); } -inline TfLiteStatus GetPadding(const TfLiteTensor* padding_matrix, - int dimension, int64_t* left_pad, - int64_t* right_pad) { +inline void GetPadding(const TfLiteTensor* padding_matrix, int dimension, + int64_t* left_pad, int64_t* right_pad) { switch (padding_matrix->type) { case kTfLiteInt32: GetPadding(padding_matrix->data.i32, dimension, left_pad, right_pad); @@ -160,93 +80,59 @@ inline TfLiteStatus GetPadding(const TfLiteTensor* padding_matrix, GetPadding(padding_matrix->data.i64, dimension, left_pad, right_pad); break; default: - return kTfLiteError; + return; } - return kTfLiteOk; } -TfLiteStatus ValidateTensor(const TfLiteTensor* padding_matrix, int offset, - int dimension_index, PaddedTensor* padded_tensor, - TfLiteContext* context) { - if (dimension_index >= padding_matrix->dims->data[0]) { - return kTfLiteOk; - } - - int64_t left_pad = 0, right_pad = 0; - TF_LITE_ENSURE_STATUS( - GetPadding(padding_matrix, dimension_index, &left_pad, &right_pad)); - // If we are not going to include border we must have enough values - // to use. - if (left_pad + offset > padded_tensor->values.size()) { - context->ReportError( - context, "Not enough values for Mirror Pad, required %d, available %d.", - left_pad + offset, padded_tensor->values.size()); - return kTfLiteError; - } - if (right_pad + offset > padded_tensor->values.size()) { - context->ReportError( - context, "Not enough values for Mirror Pad, required %d, available %d.", - right_pad + offset, padded_tensor->values.size()); - return kTfLiteError; +template +int Eval(EvalData* eval_data, int current_dim, int flat_index, + int output_index) { + if (current_dim == eval_data->num_dims) { + // Base case if we finished evaluating. + if (output_index >= eval_data->output_size) { + return output_index; + } + eval_data->output_data[output_index] = eval_data->input_data[flat_index]; + return output_index + 1; } - if (!padded_tensor->values.empty()) { - ValidateTensor(padding_matrix, offset, dimension_index + 1, - padded_tensor->values[0].get(), context); + // Check if the value is computed already. + const int cache_index = current_dim * eval_data->input_size + flat_index; + auto& cache_entry = eval_data->op_data->cache[cache_index]; + if (cache_entry.first != -1) { + // Cache value is (start, end) interval. We can just copy the interval + // directly. + const int count = cache_entry.second - cache_entry.first; + memcpy(eval_data->output_data + output_index, + eval_data->output_data + cache_entry.first, count * sizeof(T)); + return output_index + count; } - return kTfLiteOk; -} - -// Fills 'padded_tensor' with the padding information based on -// 'padding_matrix'. -// 'dimension_index' represents which dimension the function is operating on. -TfLiteStatus PadTensor(const TfLiteTensor* padding_matrix, int offset, - int dimension_index, PaddedTensor* padded_tensor, - TfLiteContext* context) { - if (dimension_index >= padding_matrix->dims->data[0]) return kTfLiteOk; - + cache_entry.first = output_index; int64_t left_pad = 0, right_pad = 0; - TF_LITE_ENSURE_STATUS( - GetPadding(padding_matrix, dimension_index, &left_pad, &right_pad)); + const int multiplier = (*eval_data->dimension_num_elements)[current_dim]; + const TfLiteTensor* padding_matrix = eval_data->padding_matrix; + const auto offset = eval_data->offset; + auto* dims = eval_data->input_dims; + GetPadding(padding_matrix, current_dim, &left_pad, &right_pad); + // Left padding for (int i = left_pad + offset - 1; i >= offset && left_pad > 0; --i, --left_pad) { - padded_tensor->left_pad_ptrs.push_back(padded_tensor->values[i].get()); + output_index = Eval(eval_data, current_dim + 1, flat_index + i * multiplier, + output_index); } - for (int i = padded_tensor->values.size() - (1 + offset); - i >= 0 && right_pad > 0; --i, --right_pad) { - padded_tensor->right_pad_ptrs.push_back(padded_tensor->values[i].get()); + // Original values. + for (int i = 0; i < dims->data[current_dim]; ++i) { + output_index = Eval(eval_data, current_dim + 1, flat_index + i * multiplier, + output_index); } - - for (auto& tensor : padded_tensor->values) { - TF_LITE_ENSURE_STATUS(PadTensor(padding_matrix, offset, dimension_index + 1, - tensor.get(), context)); + // Right padding. + for (int i = dims->data[current_dim] - (1 + offset); i >= 0 && right_pad > 0; + --i, --right_pad) { + output_index = Eval(eval_data, current_dim + 1, flat_index + i * multiplier, + output_index); } - return kTfLiteOk; -} - -// Fills 'output_data' with data from 'padded_tensor'. -// The function does this recursively by setting left padding first then -// original data, followed by the right padding. -template -int FillOutput(const PaddedTensor* padded_tensor, T* output_data, - int index_in_output) { - if (padded_tensor == nullptr || output_data == nullptr) { - return -1; - } - if (padded_tensor->value != nullptr) { - output_data[index_in_output] = *static_cast(padded_tensor->value); - return index_in_output + 1; - } - for (const auto* tensor : padded_tensor->left_pad_ptrs) { - index_in_output = FillOutput(tensor, output_data, index_in_output); - } - for (const auto& tensor : padded_tensor->values) { - index_in_output = FillOutput(tensor.get(), output_data, index_in_output); - } - for (const auto* tensor : padded_tensor->right_pad_ptrs) { - index_in_output = FillOutput(tensor, output_data, index_in_output); - } - return index_in_output; + cache_entry.second = output_index; + return output_index; } // Returns the shape of the final output after padding. @@ -271,6 +157,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* padding_matrix = GetInput(context, node, 1); auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); if (params == nullptr) { return kTfLiteError; @@ -287,27 +174,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, output_tensor, output_size.release())); } - PaddedTensor padded_tensor; - // Initialize memory. - InitializeTensorMemory(input_tensor->dims, 0, input_dims, &padded_tensor); - // Set the values from the input_tensor. - TF_LITE_ENSURE_STATUS(InitFromInputTensor(input_tensor, &padded_tensor)); + std::vector dimension_num_elements(input_dims, 1); + for (int i = input_dims - 2; i >= 0; i--) { + dimension_num_elements[i] = + dimension_num_elements[i + 1] * input_tensor->dims->data[i + 1]; + } + const int input_size = NumElements(input_tensor); const int offset = params->mode != TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect ? 0 : 1; - // Make sure padding values are sufficient and valid to use. - TF_LITE_ENSURE_STATUS( - ValidateTensor(padding_matrix, offset, 0, &padded_tensor, context)); - // Apply padding. - TF_LITE_ENSURE_STATUS( - PadTensor(padding_matrix, offset, 0, &padded_tensor, context)); - - // Fill the output tensor from the padded tensor. TfLiteStatus status = kTfLiteOk; - -#define TF_LITE_MIRROR_PAD(type) \ - FillOutput(&padded_tensor, GetTensorData(output_tensor), 0); + int output_index = 0; + // Reset cache array. + std::fill(op_data->cache.begin(), op_data->cache.end(), + std::make_pair(-1, -1)); +#define TF_LITE_MIRROR_PAD(type) \ + EvalData eval_data; \ + eval_data.input_data = GetTensorData(input_tensor); \ + eval_data.input_dims = input_tensor->dims; \ + eval_data.input_size = input_size; \ + eval_data.dimension_num_elements = &dimension_num_elements; \ + eval_data.num_dims = input_dims; \ + eval_data.offset = offset; \ + eval_data.op_data = op_data; \ + eval_data.output_data = GetTensorData(output_tensor); \ + eval_data.output_size = NumElements(output_tensor); \ + eval_data.padding_matrix = padding_matrix; \ + Eval(&eval_data, 0, 0, output_index); switch (output_tensor->type) { case kTfLiteFloat32: { @@ -335,20 +229,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + return new OpData; } -void Free(TfLiteContext* context, void* buffer) {} +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input_tensor = GetInput(context, node, 0); const TfLiteTensor* padding_matrix = GetInput(context, node, 1); TfLiteTensor* output_tensor = GetOutput(context, node, 0); + OpData* op_data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumDimensions(padding_matrix), 2); TF_LITE_ENSURE_EQ(context, SizeOfDimension(padding_matrix, 0), NumDimensions(input_tensor)); + int num_elements = NumElements(input_tensor) * NumDimensions(input_tensor); + op_data->cache.resize(num_elements + 1); + if (!IsConstantTensor(padding_matrix)) { SetTensorToDynamic(output_tensor); return kTfLiteOk; diff --git a/tensorflow/lite/kernels/mirror_pad_test.cc b/tensorflow/lite/kernels/mirror_pad_test.cc index fd09e6e4493d3a29bffecfcd4a4d1946840a4e5e..91e48fa68aa352d5f23857fe32b3feb9e3e1b73e 100644 --- a/tensorflow/lite/kernels/mirror_pad_test.cc +++ b/tensorflow/lite/kernels/mirror_pad_test.cc @@ -185,5 +185,18 @@ TEST(MirrorPadTest, Pad_1D_Symmetric) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 3, 2})); } +TEST(MirrorPadTest, Pad_1D_Symmetric_Multiple_Invoke) { + BaseMirrorPadOpModel model( + {TensorType_INT32, {3}}, {TensorType_INT32, {1, 2}}, + {TensorType_INT32, {}}, tflite::MirrorPadMode_SYMMETRIC); + model.PopulateTensor(model.input_tensor_id(), {1, 2, 3}); + model.PopulateTensor(model.padding_matrix_tensor_id(), {0, 2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 3, 2})); + model.PopulateTensor(model.input_tensor_id(), {4, 5, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 5, 6, 6, 5})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/mul.cc b/tensorflow/lite/kernels/mul.cc index 01039a705438af2a92a68b01c2146daf69c46250..e0ff6724ea2f3ea0fd4693571d6c509f5385a5d4 100644 --- a/tensorflow/lite/kernels/mul.cc +++ b/tensorflow/lite/kernels/mul.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" @@ -87,8 +88,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { &data->output_activation_min, &data->output_activation_max); } + if (output->type == kTfLiteInt8) { + CalculateActivationRangeInt8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } - if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || + output->type == kTfLiteInt16) { double real_multiplier = input1->params.scale * input2->params.scale / output->params.scale; QuantizeMultiplierSmallerThanOneExp( @@ -151,8 +158,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, const OpData* data, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { - if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 && - output->type == kTfLiteUInt8) { + if (input1->type == input2->type && input1->type == output->type && + (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8)) { tflite::ArithmeticParams op_params; SetActivationParams(data->output_activation_min, data->output_activation_max, &op_params); @@ -163,23 +170,31 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params.output_shift = data->output_shift; bool need_broadcast = optimized_ops::ProcessBroadcastShapes( GetTensorShape(input1), GetTensorShape(input2), &op_params); -#define TF_LITE_MUL(type, opname) \ - type::opname(op_params, GetTensorShape(input1), \ - GetTensorData(input1), GetTensorShape(input2), \ - GetTensorData(input2), GetTensorShape(output), \ - GetTensorData(output)) - - if (kernel_type == kReference) { +#define TF_LITE_MUL(type, opname, dtype) \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + if (input1->type == kTfLiteInt8) { if (need_broadcast) { - TF_LITE_MUL(reference_ops, BroadcastMul4DSlow); + TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t); } else { - TF_LITE_MUL(reference_ops, Mul); + TF_LITE_MUL(reference_integer_ops, Mul, int8_t); } } else { - if (need_broadcast) { - TF_LITE_MUL(optimized_ops, BroadcastMulFivefold); + // type == kTfLiteUInt8 + if (kernel_type == kReference) { + if (need_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t); + } else { + TF_LITE_MUL(reference_ops, Mul, uint8_t); + } } else { - TF_LITE_MUL(optimized_ops, Mul); + if (need_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, uint8_t); + } else { + TF_LITE_MUL(optimized_ops, Mul, uint8_t); + } } } #undef TF_LITE_MUL @@ -198,8 +213,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, } #undef TF_LITE_MUL } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && - output->type == kTfLiteUInt8) { -#define TF_LITE_MUL(type, opname) \ + (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8)) { +#define TF_LITE_MUL(type, opname, output_dtype) \ tflite::ArithmeticParams op_params; \ SetActivationParams(data->output_activation_min, \ data->output_activation_max, &op_params); \ @@ -207,11 +222,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, type::opname(op_params, GetTensorShape(input1), \ GetTensorData(input1), GetTensorShape(input2), \ GetTensorData(input2), GetTensorShape(output), \ - GetTensorData(output)) - if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops, Mul); + GetTensorData(output)) + if (output->type == kTfLiteInt8) { + TF_LITE_MUL(reference_integer_ops, Mul, int8_t); } else { - TF_LITE_MUL(optimized_ops, Mul); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul, uint8_t); + } else { + TF_LITE_MUL(optimized_ops, Mul, uint8_t); + } } #undef TF_LITE_MUL } else { @@ -233,14 +252,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { EvalMul(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || + output->type == kTfLiteInt16) { TF_LITE_ENSURE_OK( context, EvalQuantized(context, node, params, data, input1, input2, output)); } else { context->ReportError(context, - "Mul only supports FLOAT32, INT32 and quantized UINT8 " - "and INT16 now, got %d.", + "Mul only supports FLOAT32, INT32 and quantized UINT8," + " INT8 and INT16 now, got %d.", output->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/mul_test.cc b/tensorflow/lite/kernels/mul_test.cc index 200cc26dadc3527813a7dabd3b9ca4811d4c8856..96f5a8a0e07e730394510f432b3313724e6c9172 100644 --- a/tensorflow/lite/kernels/mul_test.cc +++ b/tensorflow/lite/kernels/mul_test.cc @@ -73,9 +73,10 @@ class QuantizedMulOpModel : public BaseMulOpModel { public: using BaseMulOpModel::BaseMulOpModel; + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); } std::vector GetDequantizedOutputInt16() { @@ -191,19 +192,28 @@ TEST(IntegerMulOpTest, WithBroadcast) { } } -TEST(QuantizedMulOpTest, NoActivation) { - QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}, +template +void NoActivation() { + QuantizedMulOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {}, -1.0, 1.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); - m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, kQuantizedTolerance))); } +TEST(QuantizedMulOpTest, NoActivationUInt8) { + NoActivation(); +} + +TEST(QuantizedMulOpTest, NoActivationInt8) { + NoActivation(); +} + TEST(QuantizedMulOpTest, NoActivationInt16) { const float kMin = -1.f; const float kMax = 32767.f / 32768.f; @@ -219,23 +229,32 @@ TEST(QuantizedMulOpTest, NoActivationInt16) { kQuantizedToleranceInt16))); } -TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) { +template +void NoActivationInt16With8BitOutput() { const float kMinInt16 = -1.f; const float kMaxInt16 = 32767.f / 32768.f; const float kMinUint8 = -1.f; const float kMaxUint8 = 127.f / 128.f; QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, - {TensorType_UINT8, {}, kMinUint8, kMaxUint8}, + {tensor_type, {}, kMinUint8, kMaxUint8}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, kQuantizedTolerance))); } +TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) { + NoActivationInt16With8BitOutput(); +} + +TEST(QuantizedMulOpTest, NoActivationInt16Withint8Output) { + NoActivationInt16With8BitOutput(); +} + // for quantized Mul, the error shouldn't exceed 2*step float GetTolerance(int min, int max) { float kQuantizedStep = (max - min) / 255.0; @@ -243,25 +262,35 @@ float GetTolerance(int min, int max) { return kQuantizedTolerance; } -TEST(QuantizedMulOpTest, WithBroadcast) { +template +void WithBroadcast() { float kQuantizedTolerance = GetTolerance(-3.0, 3.0); std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar - {TensorType_UINT8, {}, -3.0, 3.0}, + QuantizedMulOpModel m({tensor_type, test_shapes[i], -3.0, 3.0}, + {tensor_type, {}, -3.0, 3.0}, // always a scalar + {tensor_type, {}, -3.0, 3.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); - m.QuantizeAndPopulate(m.input2(), {0.1}); + m.QuantizeAndPopulate(m.input1(), + {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance))) << "With shape number " << i; } } +TEST(QuantizedMulOpTest, WithBroadcastUInt8) { + WithBroadcast(); +} + +TEST(QuantizedMulOpTest, WithBroadcastInt8) { + WithBroadcast(); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/pack.cc b/tensorflow/lite/kernels/pack.cc index d15a5a08af38672cbdaef76ff8a37c42d6e6f226..e26abaaff1e5c9e460621048eb15d0549b81fb36 100644 --- a/tensorflow/lite/kernels/pack.cc +++ b/tensorflow/lite/kernels/pack.cc @@ -39,8 +39,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(renjieliu): Support negative axis. TF_LITE_ENSURE(context, data->axis >= 0); if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 && - input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16 && - input0->type != kTfLiteInt64) { + input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 && + input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) { context->ReportError(context, "Type '%s' is not supported by pack.", TfLiteTypeGetName(input0->type)); return kTfLiteError; @@ -106,6 +106,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PackImpl(context, node, output, data->values_count, data->axis); break; } + case kTfLiteInt8: { + PackImpl(context, node, output, data->values_count, data->axis); + break; + } case kTfLiteInt32: { PackImpl(context, node, output, data->values_count, data->axis); break; diff --git a/tensorflow/lite/kernels/pack_test.cc b/tensorflow/lite/kernels/pack_test.cc index 530cc2e50f0fe640cc5b120b8bbb1bade7e996fc..f44111567fc34f17912af7db352b47e57f8704f3 100644 --- a/tensorflow/lite/kernels/pack_test.cc +++ b/tensorflow/lite/kernels/pack_test.cc @@ -191,6 +191,37 @@ TEST(PackOpTest, Uint8MultilDimensions) { ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } +// int8 +TEST(PackOpTest, Int8ThreeInputs) { + PackOpModel model({TensorType_INT8, {2}}, 0, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(PackOpTest, Int8ThreeInputsDifferentAxis) { + PackOpModel model({TensorType_INT8, {2}}, 1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(PackOpTest, Int8MultilDimensions) { + PackOpModel model({TensorType_INT8, {2, 3}}, 1, 2); + model.SetInput(0, {1, 2, 3, 4, 5, 6}); + model.SetInput(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc index 8e6ed6e741f782f070714164a7af7b4f98a1558f..b60b3dd9c871bf864492505dd9fa4aabf496364c 100644 --- a/tensorflow/lite/kernels/pad.cc +++ b/tensorflow/lite/kernels/pad.cc @@ -214,6 +214,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } } break; + case kTfLiteInt8: { + int8_t pad_value; + if (op_context.constant_values == nullptr) { + // Quantized Pad requires that 0 is represented in the quantized + // range. + TF_LITE_ENSURE(context, op_context.output->params.zero_point >= + std::numeric_limits::min()); + TF_LITE_ENSURE(context, op_context.output->params.zero_point <= + std::numeric_limits::max()); + pad_value = static_cast(op_context.output->params.zero_point); + } else { + // Quantized Pad requires that 'constant_values' is represented in the + // same quantized range as the input and output tensors. + TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, + op_context.constant_values->params.zero_point); + TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, + op_context.constant_values->params.scale); + pad_value = *GetTensorData(op_context.constant_values); + } + if (op_context.resizing_category == ResizingCategory::kImageStyle) { + TF_LITE_PAD(reference_ops, PadImageStyle, int8_t, pad_value); + } else { + TF_LITE_PAD(reference_ops, Pad, int8_t, pad_value); + } + } break; case kTfLiteInt32: { int32_t pad_value = op_context.constant_values == nullptr diff --git a/tensorflow/lite/kernels/pad_test.cc b/tensorflow/lite/kernels/pad_test.cc index 3caa4065dcbadd699ee9e61b8e97a42281d32309..ca246e9c3464664f492d85e8f41497da583a3989 100644 --- a/tensorflow/lite/kernels/pad_test.cc +++ b/tensorflow/lite/kernels/pad_test.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -24,31 +25,37 @@ namespace { using ::testing::ElementsAreArray; using ::testing::Matcher; -template +template class PadOpModel : public SingleOpModel { public: - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } + template void SetQuantizedInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } + template void SetQuantizedPadValue(float data) { - QuantizeAndPopulate(constant_values_, {data}); + QuantizeAndPopulate(constant_values_, {data}); } void SetPaddings(std::initializer_list paddings) { PopulateTensor(paddings_, paddings); } - std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutput() { + return ExtractVector(output_); + } std::vector GetOutputShape() { return GetTensorShape(output_); } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize( + ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } protected: @@ -59,18 +66,18 @@ class PadOpModel : public SingleOpModel { }; // Tests case where paddings is a const tensor. Type T is the dtype. -template -class PadV2OpConstModel : public PadOpModel { +template +class PadV2OpConstModel : public PadOpModel { public: PadV2OpConstModel(const TensorData& input, std::initializer_list paddings_shape, - std::initializer_list paddings, T constant_values, + std::initializer_list paddings, T1 constant_values, const TensorData& output) { this->input_ = this->AddInput(input); this->paddings_ = this->AddConstInput(TensorType_INT32, paddings, paddings_shape); this->constant_values_ = - this->AddConstInput(GetTensorType(), {constant_values}, {1}); + this->AddConstInput(GetTensorType(), {constant_values}, {1}); this->output_ = this->AddOutput(output); @@ -109,7 +116,7 @@ class PadOpConstModel : public PadOpModel { std::initializer_list paddings_shape, std::initializer_list paddings, const TensorData& output) { - input_ = AddInput(input); + this->input_ = AddInput(input); paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); constant_values_ = AddNullInput(); output_ = AddOutput(output); @@ -121,16 +128,17 @@ class PadOpConstModel : public PadOpModel { }; // Test case where paddings is a non-const tensor. -template -class PadV2OpDynamicModel : public PadOpModel { +template +class PadV2OpDynamicModel : public PadOpModel { public: PadV2OpDynamicModel(const TensorData& input, std::initializer_list paddings_shape, - T constant_values, const TensorData& output) { + RegularInputOuput constant_values, + const TensorData& output) { this->input_ = this->AddInput(input); this->paddings_ = this->AddInput(TensorType_INT32); - this->constant_values_ = - this->AddConstInput(GetTensorType(), {constant_values}, {1}); + this->constant_values_ = this->AddConstInput( + GetTensorType(), {constant_values}, {1}); this->output_ = this->AddOutput(output); this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, @@ -164,14 +172,14 @@ class PadOpDynamicModel : public PadOpModel { PadOpDynamicModel(const TensorData& input, std::initializer_list paddings_shape, const TensorData& output) { - input_ = AddInput(input); - paddings_ = AddInput(TensorType_INT32); - constant_values_ = AddNullInput(); - output_ = AddOutput(output); + this->input_ = this->AddInput(input); + this->paddings_ = this->AddInput(TensorType_INT32); + this->constant_values_ = this->AddNullInput(); + this->output_ = this->AddOutput(output); - SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, - CreatePadOptions(builder_).Union()); - BuildInterpreter({input.shape, paddings_shape}); + this->SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(this->builder_).Union()); + this->BuildInterpreter({input.shape, paddings_shape}); } }; @@ -299,61 +307,83 @@ TEST(PadOpTest, AdvancedDynamicTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } -class QuantizedPadOpTest : public ::testing::Test { - protected: - std::vector> DequantizedArrayNear( - const std::vector& values, const float min, const float max) { - const float quantization_tolerance = (max - min) / 255.0; - return ArrayFloatNear(values, quantization_tolerance); - } -}; +std::vector> DequantizedArrayNear( + const std::vector& values, const float min, const float max) { + const float quantization_tolerance = (max - min) / 255.0; + return ArrayFloatNear(values, quantization_tolerance); +} + +class QuantizedPadOpTest : public ::testing::Test {}; #ifdef GTEST_HAS_DEATH_TEST -TEST_F(QuantizedPadOpTest, ZeroNotInQuantizationRange) { +template +void ZeroNotInQuantizationRange() { // The test_util and actual quantization code currently ensure that the range // must include zero, but if that ever changes, this test will catch it. - EXPECT_DEATH(PadOpConstModel m({TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, - {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, - {TensorType_UINT8, {}, 1.0, 2.0}), - ".*Check failed: f_min <= 0.*"); + EXPECT_DEATH( + PadOpConstModel m({tensor_dtype, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, {tensor_dtype, {}, 1.0, 2.0}), + ".*Check failed: f_min <= 0.*"); +} + +TEST_F(QuantizedPadOpTest, UInt8ZeroNotInQuantizationRange) { + ZeroNotInQuantizationRange(); +} +TEST_F(QuantizedPadOpTest, Int8ZeroNotInQuantizationRange) { + ZeroNotInQuantizationRange(); } #endif -TEST_F(QuantizedPadOpTest, SimpleConstTest) { +template +void SimpleConstTest() { // Padding is represented as four 2-D lists representing above padding and // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). - PadOpConstModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, - {0, 0, 1, 1, 1, 1, 0, 0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + PadOpConstModel m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, -1.0, 1.0))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST_F(QuantizedPadOpTest, SimpleDynamicTest) { - PadOpDynamicModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); +TEST_F(QuantizedPadOpTest, UInt8SimpleConstTest) { + SimpleConstTest(); +} +TEST_F(QuantizedPadOpTest, Int8SimpleConstTest) { + SimpleConstTest(); +} + +template +void SimpleDynamicTest() { + PadOpDynamicModel m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, + {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, -1.0, 1.0))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST_F(QuantizedPadOpTest, AdvancedConstTest) { - PadOpConstModel m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, - {0, 0, 0, 2, 1, 3, 0, 0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); +TEST_F(QuantizedPadOpTest, UInt8SimpleDynamicTest) { + SimpleDynamicTest(); +} +TEST_F(QuantizedPadOpTest, Int8SimpleDynamicTest) { + SimpleDynamicTest(); +} + +template +void AdvancedConstTest() { + PadOpConstModel m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, + {0, 0, 0, 2, 1, 3, 0, 0}, {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -361,13 +391,21 @@ TEST_F(QuantizedPadOpTest, AdvancedConstTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } -TEST_F(QuantizedPadOpTest, AdvancedDynamicTest) { - PadOpDynamicModel m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); +TEST_F(QuantizedPadOpTest, UInt8AdvancedConstTest) { + AdvancedConstTest(); +} +TEST_F(QuantizedPadOpTest, Int8AdvancedConstTest) { + AdvancedConstTest(); +} + +template +void AdvancedDynamicTest() { + PadOpDynamicModel m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, + {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -375,31 +413,38 @@ TEST_F(QuantizedPadOpTest, AdvancedDynamicTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } +TEST_F(QuantizedPadOpTest, UInt8AdvancedDynamicTest) { + AdvancedDynamicTest(); +} +TEST_F(QuantizedPadOpTest, Int8AdvancedDynamicTest) { + AdvancedDynamicTest(); +} + #ifdef GTEST_HAS_DEATH_TEST TEST(PadV2OpTest, TooManyDimensions) { - EXPECT_DEATH(PadV2OpConstModel( - {TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, - {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0, - {TensorType_FLOAT32}), + typedef PadV2OpConstModel f; + EXPECT_DEATH(f({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, + {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0, + {TensorType_FLOAT32}), "dims <= 4"); } TEST(PadV2OpTest, UnequalDimensions) { - EXPECT_DEATH( - PadV2OpConstModel({TensorType_FLOAT32, {1, 1, 2, 1}}, {3, 2}, - {1, 1, 2, 2, 3, 3}, 0.0, {TensorType_FLOAT32}), - "3 != 4"); + typedef PadV2OpConstModel f; + EXPECT_DEATH(f({TensorType_FLOAT32, {1, 1, 2, 1}}, {3, 2}, {1, 1, 2, 2, 3, 3}, + 0.0, {TensorType_FLOAT32}), + "3 != 4"); } TEST(PadV2OpTest, InvalidPadValue) { - EXPECT_DEATH(PadV2OpConstModel({TensorType_FLOAT32, {1, 1, 2, 1}}, - {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}, 0.0, - {TensorType_FLOAT32}), + typedef PadV2OpConstModel f; + EXPECT_DEATH(f({TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2}, + {0, 0, 1, -1, 2, -1, 0, 0}, 0.0, {TensorType_FLOAT32}), "Pad value has to be greater than equal to 0."); } #endif -TEST(PadV2OpTest, SimpleConstTest) { +TEST(PadV2OpTest, SimpleConstTestUint8) { // Padding is represented as four 2-D lists representing above padding and // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). PadV2OpConstModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, @@ -412,7 +457,32 @@ TEST(PadV2OpTest, SimpleConstTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST(PadV2OpTest, SimpleConstFloat32ValuedTest) { +TEST(PadV2OpTest, SimpleConstTestInt8) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, 0.0, + {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, SimpleConstFloat32ValuedTestUint8) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4, + 5, 5, 5, 5, 5})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, SimpleConstFloat32ValuedTestInt8) { // Padding is represented as four 2-D lists representing above padding and // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). PadV2OpConstModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, @@ -502,58 +572,80 @@ class QuantizedPadV2OpTest : public ::testing::Test { }; #ifdef GTEST_HAS_DEATH_TEST -TEST_F(QuantizedPadV2OpTest, ZeroNotInQuantizationRange) { +template +void ZeroNotInQuantizationRangeV2() { // The test_util and actual quantization code currently ensure that the range // must include zero, but if that ever changes, this test will catch it. - EXPECT_DEATH( - PadV2OpConstModel m({TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, - {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, 0, - {TensorType_UINT8, {}, 1.0, 2.0}), - ".*Check failed: f_min <= 0.*"); + typedef PadV2OpConstModel f; + EXPECT_DEATH(f({tensor_dtype, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, 0, {tensor_dtype, {}, 1.0, 2.0}), + ".*Check failed: f_min <= 0.*"); +} + +TEST_F(QuantizedPadV2OpTest, UInt8ZeroNotInQuantizationRange) { + ZeroNotInQuantizationRangeV2(); +} +TEST_F(QuantizedPadV2OpTest, Int8ZeroNotInQuantizationRange) { + ZeroNotInQuantizationRangeV2(); } #endif -TEST_F(QuantizedPadV2OpTest, SimpleConstTest) { +template +void SimpleConstTestV2() { // Padding is represented as four 2-D lists representing above padding and // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). - PadV2OpConstModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, - {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); - m.SetQuantizedPadValue(0); + PadV2OpConstModel m( + {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, + {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.template SetQuantizedPadValue(0); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, -1.0, 1.0))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST_F(QuantizedPadV2OpTest, SimpleDynamicTest) { - PadV2OpDynamicModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); - m.SetQuantizedPadValue(0); +TEST_F(QuantizedPadV2OpTest, UInt8SimpleConstTest) { + SimpleConstTestV2(); +} +TEST_F(QuantizedPadV2OpTest, Int8SimpleConstTest) { + SimpleConstTestV2(); +} + +template +void SimpleDynamicTestV2() { + PadV2OpDynamicModel m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, + {4, 2}, {tensor_dtype, {1}, -1.0, 1.0}, + {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.template SetQuantizedPadValue(0); m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, -1.0, 1.0))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST_F(QuantizedPadV2OpTest, AdvancedConstTest) { - PadV2OpConstModel m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, - {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}, - {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); - m.SetQuantizedPadValue(0); +TEST_F(QuantizedPadV2OpTest, UInt8SimpleDynamicTest) { + SimpleDynamicTestV2(); +} +TEST_F(QuantizedPadV2OpTest, Int8SimpleDynamicTest) { + SimpleDynamicTestV2(); +} + +template +void AdvancedConstTestV2() { + PadV2OpConstModel m( + {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}, + {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.template SetQuantizedPadValue(0); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -561,15 +653,23 @@ TEST_F(QuantizedPadV2OpTest, AdvancedConstTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } -TEST_F(QuantizedPadV2OpTest, AdvancedDynamicTest) { - PadV2OpDynamicModel m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, - {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); - m.SetQuantizedPadValue(0); +TEST_F(QuantizedPadV2OpTest, UInt8AdvancedConstTest) { + AdvancedConstTestV2(); +} +TEST_F(QuantizedPadV2OpTest, Int8AdvancedConstTest) { + AdvancedConstTestV2(); +} + +template +void AdvancedDynamicTestV2() { + PadV2OpDynamicModel m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, + {4, 2}, {tensor_dtype, {1}, -1.0, 1.0}, + {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.template SetQuantizedPadValue(0); m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -577,17 +677,24 @@ TEST_F(QuantizedPadV2OpTest, AdvancedDynamicTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } -TEST_F(QuantizedPadV2OpTest, SimpleConstValuedTest) { +TEST_F(QuantizedPadV2OpTest, UInt8AdvancedDynamicTest) { + AdvancedDynamicTestV2(); +} +TEST_F(QuantizedPadV2OpTest, Int8AdvancedDynamicTest) { + AdvancedDynamicTestV2(); +} + +template +void SimpleConstValuedTest() { // Padding is represented as four 2-D lists representing above padding and // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). - PadV2OpConstModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, - {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); - m.SetQuantizedPadValue(-0.5); + PadV2OpConstModel m( + {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, + {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.template SetQuantizedPadValue(-0.5); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9, 0.7, -0.5, -0.5, -0.5, -0.5, -0.5}, @@ -595,15 +702,23 @@ TEST_F(QuantizedPadV2OpTest, SimpleConstValuedTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST_F(QuantizedPadV2OpTest, SimpleDynamicValuedTest) { - PadV2OpDynamicModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); - m.SetQuantizedPadValue(-0.5); +TEST_F(QuantizedPadV2OpTest, UInt8SimpleConstValuedTest) { + SimpleConstValuedTest(); +} +TEST_F(QuantizedPadV2OpTest, Int8SimpleConstValuedTest) { + SimpleConstValuedTest(); +} + +template +void SimpleDynamicValuedTest() { + PadV2OpDynamicModel m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, + {4, 2}, {tensor_dtype, {1}, -1.0, 1.0}, + {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.template SetQuantizedPadValue(-0.5); m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9, 0.7, -0.5, -0.5, -0.5, -0.5, -0.5}, @@ -611,15 +726,22 @@ TEST_F(QuantizedPadV2OpTest, SimpleDynamicValuedTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } -TEST_F(QuantizedPadV2OpTest, AdvancedConstValuedTest) { - PadV2OpConstModel m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, - {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}, - {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); - m.SetQuantizedPadValue(-0.5); +TEST_F(QuantizedPadV2OpTest, UInt8SimpleDynamicValuedTest) { + SimpleDynamicValuedTest(); +} +TEST_F(QuantizedPadV2OpTest, Int8SimpleDynamicValuedTest) { + SimpleDynamicValuedTest(); +} + +template +void AdvancedConstValuedTest() { + PadV2OpConstModel m( + {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}, + {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.template SetQuantizedPadValue(-0.5); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {-0.5, -0.8, 0.2, 0.9, -0.5, -0.5, -0.5, -0.5, 0.7, 0.1, -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, @@ -628,15 +750,23 @@ TEST_F(QuantizedPadV2OpTest, AdvancedConstValuedTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } -TEST_F(QuantizedPadV2OpTest, AdvancedDynamicValuedTest) { - PadV2OpDynamicModel m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, - {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); - m.SetQuantizedPadValue(-0.5); +TEST_F(QuantizedPadV2OpTest, UInt8AdvancedConstValuedTest) { + AdvancedConstValuedTest(); +} +TEST_F(QuantizedPadV2OpTest, Int8AdvancedConstValuedTest) { + AdvancedConstValuedTest(); +} + +template +void AdvancedDynamicValuedTest() { + PadV2OpDynamicModel m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, + {4, 2}, {tensor_dtype, {1}, -1.0, 1.0}, + {tensor_dtype, {}, -1.0, 1.0}); + m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.template SetQuantizedPadValue(-0.5); m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {-0.5, -0.8, 0.2, 0.9, -0.5, -0.5, -0.5, -0.5, 0.7, 0.1, -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, @@ -645,6 +775,13 @@ TEST_F(QuantizedPadV2OpTest, AdvancedDynamicValuedTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } +TEST_F(QuantizedPadV2OpTest, UInt8AdvancedDynamicValuedTest) { + AdvancedDynamicValuedTest(); +} +TEST_F(QuantizedPadV2OpTest, Int8AdvancedDynamicValuedTest) { + AdvancedDynamicValuedTest(); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/pooling.cc b/tensorflow/lite/kernels/pooling.cc index e6155fcb8c67ed3b5e676c2530ec7966d6cec56f..bdf736dcfb278ad93f43c25b9ae1c0b4038b695f 100644 --- a/tensorflow/lite/kernels/pooling.cc +++ b/tensorflow/lite/kernels/pooling.cc @@ -226,9 +226,9 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, } template -void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLitePoolParams* params, OpData* data, - const TfLiteTensor* input, TfLiteTensor* output) { +void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { int32_t activation_min; int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, @@ -254,6 +254,31 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, #undef TF_LITE_MAX_POOL } +template +void MaxEvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeInt8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::MaxPool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) + TF_LITE_MAX_POOL(reference_integer_ops); +#undef TF_LITE_MAX_POOL +} + template void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, @@ -321,7 +346,12 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { MaxEvalFloat(context, node, params, data, input, output); break; case kTfLiteUInt8: - MaxEvalQuantized(context, node, params, data, input, output); + MaxEvalQuantizedUInt8(context, node, params, data, input, + output); + break; + case kTfLiteInt8: + MaxEvalQuantizedInt8(context, node, params, data, input, + output); break; default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/lite/kernels/pooling_test.cc b/tensorflow/lite/kernels/pooling_test.cc index e1b79340115ad18e50ecdb6944904bf2ab7c9e84..4627d7a5f0c2803635b9df85dd9275cc7851e8fb 100644 --- a/tensorflow/lite/kernels/pooling_test.cc +++ b/tensorflow/lite/kernels/pooling_test.cc @@ -97,6 +97,24 @@ class SymmetricQuantizedPoolingOpModel : public BasePoolingOpModel { } }; +// Replicate each entry in a vector n times along depth (innermost dimension). +// The values are incremented by delta, creating ramps offset by each input +// value. This is used to create simple and predicatable variation. +std::vector ReplicateDepthRamp(const std::vector& image_plane, + int n, float delta) { + const int size = image_plane.size(); + std::vector ramped_data(n * size); + // The input is treated as a 1-D even if logically it is multi-dimensional. + for (int input_index = 0; input_index < size; ++input_index) { + for (int depth = 0; depth < n; ++depth) { + ramped_data[n * input_index + depth] = + image_plane[input_index] + depth * delta; + } + } + + return ramped_data; +} + TEST(FloatPoolingOpTest, AveragePool) { FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, @@ -147,6 +165,31 @@ TEST(QuantizedPoolingOpTest, AveragePoolImageSize16) { EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({16}))); } +TEST(QuantizedPoolingOpTest, AveragePoolLargeDepth) { + // Test with a larger depth that is not a multiple of the tranche size, or of + // any register-oriented multiples such as 8 and 16. + constexpr int depth = 1999; // Prime number. + QuantizedPoolingOpModel m( + BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, depth}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + + std::vector input_image_plane({ + 0.f, 6.f, 2.f, 4.f, // + 3.f, 2.f, 10.f, 7.f, // + }); + std::vector output_image_plane({2.75f, 5.75f}); + + m.SetInput(ReplicateDepthRamp(input_image_plane, depth, 1.f / 512.f)); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + ReplicateDepthRamp(output_image_plane, depth, 1.f / 512.f), + 1. / 32.f))); +} + // Test quantized AveragePool with int8 input and output. The input is the same // as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is // identical to uint8 test and quantized output is identical to uint8 test with @@ -204,7 +247,7 @@ TEST(FloatPoolingOpTest, MaxPool) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); } -TEST(QuantizedPoolingOpTest, MaxPool) { +TEST(QuantizedUInt8PoolingOpTest, MaxPool) { // Choose the input ranges carefully so that the dequantized output matches // the results of the float model above. QuantizedPoolingOpModel m( @@ -223,6 +266,50 @@ TEST(QuantizedPoolingOpTest, MaxPool) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160})); } +TEST(QuantizedPoolingOpTest, MaxPoolLargeDepth) { + // Test with a larger depth that is not a multiple of the tranche size, or of + // any register-oriented multiples such as 8 and 16. + constexpr int depth = 1999; // Prime number. + QuantizedPoolingOpModel m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, depth}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + + std::vector input_image_plane({ + 0.f, 6.f, 2.f, 4.f, // + 3.f, 2.f, 10.f, 7.f, // + }); + std::vector output_image_plane({6.f, 10.f}); + + m.SetInput(ReplicateDepthRamp(input_image_plane, depth, 1.f / 512.f)); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + ReplicateDepthRamp(output_image_plane, depth, 1.f / 512.f), + 1. / 32.f))); +} + +TEST(QuantizedInt8PoolingOpTest, MaxPool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + SymmetricQuantizedPoolingOpModel m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_INT8, {}, 0, 15.9375}); + m.SetInput({ + 0, -6, 2, 4, // + 3, 2, -10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({3, 7}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-80, -16})); +} + TEST(FloatPoolingOpTest, L2Pool) { FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, diff --git a/tensorflow/lite/kernels/rank.cc b/tensorflow/lite/kernels/rank.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cef1f53a024b833034deb497909beac4b4753e6 --- /dev/null +++ b/tensorflow/lite/kernels/rank.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace rank { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + output->type = kTfLiteInt32; + + // Rank produces a 0-D int32 Tensor representing the rank of input. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(0); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0); + + if (output->type == kTfLiteInt32) { + int32_t* output_data = GetTensorData(output); + *output_data = NumDimensions(input); + } else { + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace rank + +TfLiteRegistration* Register_RANK() { + static TfLiteRegistration r = {nullptr, nullptr, rank::Prepare, rank::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/rank_test.cc b/tensorflow/lite/kernels/rank_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c31fc5866931708eb8155c2dc88026b623039ed --- /dev/null +++ b/tensorflow/lite/kernels/rank_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class RankOpModel : public SingleOpModel { + public: + RankOpModel(std::initializer_list input_shape, TensorType input_type) { + TensorType output_type = TensorType_INT32; + input_ = AddInput(input_type); + output_ = AddOutput(output_type); + SetBuiltinOp(BuiltinOperator_RANK, BuiltinOptions_RankOptions, + CreateRankOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(RankOpTest, InputTypeFloat) { + RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); + EXPECT_TRUE(model.GetOutputShape().empty()); +} + +TEST(RankOpTest, InputTypeInt) { + RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5})); + EXPECT_TRUE(model.GetOutputShape().empty()); +} + +TEST(RankOpTest, ScalarTensor) { + RankOpModel model({}, TensorType_FLOAT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_TRUE(model.GetOutputShape().empty()); +} + +TEST(RankOpTest, EmptyTensor) { + RankOpModel model({1, 0}, TensorType_FLOAT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2})); + EXPECT_TRUE(model.GetOutputShape().empty()); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc index 336e827ca4c76abf3a08492249dfc0ce9cd81439..a0f1126048ea43458eae9ce327d31f6c1c9ead17 100644 --- a/tensorflow/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -17,7 +17,10 @@ limitations under the License. #include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/gemm_support.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -35,6 +38,13 @@ enum KernelType { kReference, }; +struct OpData { + int32_t multiplier; + int shift; + // The index of the temporary tensor where the quantized inputs are cached. + int scratch_tensor_index; +}; + struct OpContext { OpContext(TfLiteContext* context, TfLiteNode* node) { params = reinterpret_cast(node->builtin_data); @@ -49,15 +59,17 @@ struct OpContext { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { + gemm_support::IncrementUsageCounter(context); // Creates two temp tensors to store index and axis for internal // implementation only. - auto* scratch_tensor_index = new int; - context->AddTensors(context, 3, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, 3, &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); } // Resizes the temp tensor that stores resolved axis. @@ -148,10 +160,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, OpContext* op_context) { // Creates a temp index to iterate through input data. - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(3); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0); scratch_tensor->type = kTfLiteInt32; scratch_tensor->allocation_type = kTfLiteArenaRw; @@ -161,11 +173,11 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, context->ResizeTensor(context, scratch_tensor, index_size)); // Creates a temp tensor to store resolved axis given input data. - node->temporaries->data[1] = *scratch_tensor_index + 1; + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); resolved_axis->type = kTfLiteInt32; // Creates a temp tensor to store temp sums when calculating mean. - node->temporaries->data[2] = *scratch_tensor_index + 2; + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); switch (op_context->input->type) { case kTfLiteFloat32: @@ -180,6 +192,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, case kTfLiteUInt8: temp_sum->type = kTfLiteInt32; break; + case kTfLiteInt8: + temp_sum->type = kTfLiteInt32; + break; case kTfLiteBool: temp_sum->type = kTfLiteBool; break; @@ -219,9 +234,18 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); + OpData* data = reinterpret_cast(node->user_data); // reduce_mean requires a buffer to store intermediate sum result. OpContext op_context(context, node); + if (op_context.input->type == kTfLiteInt8) { + const double real_multiplier = + static_cast(op_context.input->params.scale) / + static_cast(op_context.output->params.scale); + int exponent; + QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent); + data->shift = exponent; + } TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(temp_sum); @@ -245,6 +269,8 @@ void ResolveAxis(const int* axis_data, int axis_count, template TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); + OpData* data = reinterpret_cast(node->user_data); + int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); @@ -257,6 +283,51 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum)); } + // Defer to specialized implementation for 4D Mean across axes 1 & 2. + if (op_context.input->type == kTfLiteFloat32 || + op_context.input->type == kTfLiteUInt8) { + tflite::MeanParams op_params; + op_params.axis_count = num_axis; + ResolveAxis(GetTensorData(op_context.axis), num_axis, &op_params); + const TfLiteTensor* input = op_context.input; + if (op_context.params->keep_dims && NumDimensions(input) == 4 && + op_params.axis_count == 2 && + ((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1))) { + if (op_context.input->type == kTfLiteUInt8) { + gemmlowp::GemmContext* gemm_context = + gemm_support::GetFromContext(context); + optimized_ops::Mean( + op_params, GetTensorShape(input), GetTensorData(input), + op_context.input->params.zero_point, op_context.input->params.scale, + GetTensorShape(op_context.output), + GetTensorData(op_context.output), + op_context.output->params.zero_point, + op_context.output->params.scale, gemm_context); + } else { + reference_ops::Mean(op_params, GetTensorShape(input), + GetTensorData(input), + GetTensorShape(op_context.output), + GetTensorData(op_context.output)); + } + return kTfLiteOk; + } + } + + if (op_context.input->type == kTfLiteInt8) { + tflite::MeanParams op_params; + op_params.axis_count = num_axis; + ResolveAxis(GetTensorData(op_context.axis), num_axis, &op_params); + const TfLiteTensor* input = op_context.input; + reference_integer_ops::Mean( + op_params, data->multiplier, data->shift, GetTensorShape(input), + GetTensorData(input), op_context.input->params.zero_point, + GetTensorShape(op_context.output), + GetTensorData(op_context.output), + op_context.output->params.zero_point); + return kTfLiteOk; + } + #define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \ kernel_type::Mean<>( \ GetTensorData(op_context.input), \ @@ -436,6 +507,9 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: return EvalType(context, node, &op_context, reduce_type); break; + case kTfLiteInt8: + return EvalType(context, node, &op_context, reduce_type); + break; case kTfLiteBool: return EvalType(context, node, &op_context, reduce_type); break; diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc index c1526bddb719e74a6396dc4aeac4b5827220a65a..dd852b90aef13a307649f8843ec1b3056860e030 100644 --- a/tensorflow/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -38,9 +38,10 @@ class BaseOpModel : public SingleOpModel { return ExtractVector(output_); } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } std::vector GetOutputShape() { return GetTensorShape(output_); } @@ -256,6 +257,66 @@ TEST(ConstFloatMeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } +// Uses a set of reduction conditions that trigger the specialized 4D version +// of Mean. +TEST(ConstFloatMeanOpTest, KeepDims4DMean) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {2, 2, 3, 2}}, + {TensorType_FLOAT32, {3}}, {2}, {1, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({6, 7, 18, 19}))); +} + +TEST(ConstFloatMeanOpTest, KeepDims4DMeanUInt8) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, + 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 3}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 3})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.25098, 0.25098, 0.25098}, + kQuantizedTolerance))); +} + +TEST(ConstFloatMeanOpTest, KeepDims4DMeanLargeDepthUInt8) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, + 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, + 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_UINT8, {1, 2, 2, 9}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425}, + kQuantizedTolerance))); +} + +TEST(ConstFloatMeanOpTest, KeepDims4DMeanQuantized) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, + 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_UINT8, {1, 2, 3, 2}, 0.0, 1.0}, + {TensorType_UINT8, {3}, -5.0, 5.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.235294, 0.313726}, kQuantizedTolerance))); +} + TEST(ConstFloatMeanOpTest, Scalar) { std::vector data = {3.27}; MeanOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, @@ -318,8 +379,9 @@ TEST(ConstUint8MeanOpTest, NotKeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {0.4, 0.4}, kQuantizedTolerance))); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.4, 0.4}, kQuantizedTolerance))); } TEST(ConstUint8MeanOpTest, KeepDims) { @@ -331,10 +393,44 @@ TEST(ConstUint8MeanOpTest, KeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance))); } +TEST(ConstInt8MeanOpTest, QuantizedSameScale) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, + 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, + 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0}, + {TensorType_INT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425}, + kQuantizedTolerance))); +} + +TEST(ConstInt8MeanOpTest, QuantizedDifferentScale) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, + 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, + 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0}, + {TensorType_INT8, {2}, -4.0, 4.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425}, + kQuantizedTolerance))); +} + TEST(DynamicUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-5.0, 2.0); std::vector data = {1.3, -4.8, -3.6, 0.24}; @@ -347,7 +443,7 @@ TEST(DynamicUint8MeanOpTest, NotKeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance))); } @@ -363,7 +459,7 @@ TEST(DynamicUint8MeanOpTest, KeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); } @@ -377,7 +473,7 @@ TEST(DynamicUint8MeanOpTest, QuantizedScalar) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), IsEmpty()); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({0.643}, kQuantizedTolerance))); } @@ -390,7 +486,7 @@ TEST(ConstUint8MeanOpTest, QuantizedKeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance))); } @@ -483,7 +579,7 @@ TEST(ConstUint8SumOpTest, NotKeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance))); } @@ -496,8 +592,9 @@ TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {1.2, 1.2}, kQuantizedTolerance))); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({1.2, 1.2}, kQuantizedTolerance))); } TEST(ConstUint8SumOpTest, KeepDims) { @@ -508,7 +605,7 @@ TEST(ConstUint8SumOpTest, KeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177}, kQuantizedTolerance))); } @@ -524,7 +621,7 @@ TEST(DynamicUint8SumOpTest, NotKeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance))); } @@ -541,7 +638,7 @@ TEST(DynamicUint8SumOpTest, KeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); } @@ -698,7 +795,20 @@ TEST(ConstUint8MaxOpTest, NotKeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.501961, 0.603922}, kQuantizedTolerance))); +} + +TEST(ConstInt8MaxOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MaxOpConstModel m({TensorType_INT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_INT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({0.501961, 0.603922}, kQuantizedTolerance))); } @@ -711,7 +821,20 @@ TEST(ConstUint8MaxOpTest, KeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.4, 0.4, 0.603922}, kQuantizedTolerance))); +} + +TEST(ConstInt8MaxOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MaxOpConstModel m({TensorType_INT8, {3, 2}, -1.0, 1.0}, + {TensorType_INT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({0.4, 0.4, 0.603922}, kQuantizedTolerance))); } @@ -727,7 +850,23 @@ TEST(DynamicUint8MaxOpTest, NotKeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.2902, 0.247059}, kQuantizedTolerance))); +} + +TEST(DynamicInt8MaxOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::vector data = {1.3, -4.8, -3.6, 0.24}; + MaxOpDynamicModel m({TensorType_INT8, {2, 2}, -5.0, 2.0}, + {TensorType_INT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::vector axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({1.2902, 0.247059}, kQuantizedTolerance))); } @@ -743,7 +882,23 @@ TEST(DynamicUint8MaxOpTest, KeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({11.1294, 0.862745}, kQuantizedTolerance))); +} + +TEST(DynamicInt8MaxOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14, -0.14, 7.423, 0.879}; + MaxOpDynamicModel m({TensorType_INT8, {2, 2}, -10.0, 12.0}, + {TensorType_INT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({11.1294, 0.862745}, kQuantizedTolerance))); } @@ -758,7 +913,21 @@ TEST(DynamicUint8MaxOpTest, Scalar) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), IsEmpty()); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); +} + +TEST(DynamicInt8MaxOpTest, Scalar) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14}; + MaxOpDynamicModel m({TensorType_INT8, {}, -10.0, 12.0}, + {TensorType_INT8, {}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); } @@ -840,7 +1009,20 @@ TEST(ConstUint8MinOpTest, NotKeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.294117, 0.2}, kQuantizedTolerance))); +} + +TEST(ConstInt8MinOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MinOpConstModel m({TensorType_INT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_INT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({0.294117, 0.2}, kQuantizedTolerance))); } @@ -853,7 +1035,20 @@ TEST(ConstUint8MinOpTest, KeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.2, 0.3, 0.5}, kQuantizedTolerance))); +} + +TEST(ConstInt8MinOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MinOpConstModel m({TensorType_INT8, {3, 2}, -1.0, 1.0}, + {TensorType_INT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT( + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({0.2, 0.3, 0.5}, kQuantizedTolerance))); } @@ -869,7 +1064,23 @@ TEST(DynamicUint8MinOpTest, NotKeepDims) { m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); EXPECT_THAT( - m.GetDequantizedOutput(), + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-4.807843, -3.6}, kQuantizedTolerance))); +} + +TEST(DynamicInt8MinOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::vector data = {1.3, -4.8, -3.6, 0.24}; + MinOpDynamicModel m({TensorType_INT8, {2, 2}, -5.0, 2.0}, + {TensorType_INT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::vector axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT( + m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({-4.807843, -3.6}, kQuantizedTolerance))); } @@ -884,7 +1095,23 @@ TEST(DynamicUint8MinOpTest, KeepDims) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({7.427451, -0.164706}, kQuantizedTolerance))); +} + +TEST(DynamicInt8MinOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14, -0.14, 7.423, 0.879}; + MinOpDynamicModel m({TensorType_INT8, {2, 2}, -10.0, 12.0}, + {TensorType_INT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray( ArrayFloatNear({7.427451, -0.164706}, kQuantizedTolerance))); } @@ -899,7 +1126,21 @@ TEST(DynamicUint8MinOpTest, Scalar) { m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), IsEmpty()); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); +} + +TEST(DynamicInt8MinOpTest, Scalar) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14}; + MinOpDynamicModel m({TensorType_INT8, {}, -10.0, 12.0}, + {TensorType_INT8, {}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); } diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index df2b15fe83ad7e86ac87e64cf97953f3f77f2b59..b11cb9938c8d2410b6093d92ed6693f346d9c1d3 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -22,10 +22,10 @@ namespace ops { namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); -TfLiteRegistration* Register_LAYER_NORM_LSTM(); TfLiteRegistration* Register_MFCC(); TfLiteRegistration* Register_DETECTION_POSTPROCESS(); -TfLiteRegistration* Register_RELU_1(); +TfLiteRegistration* Register_IF(); +TfLiteRegistration* Register_WHILE(); } // namespace custom @@ -94,6 +94,7 @@ TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_CEIL(); TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_SUM(); @@ -104,6 +105,7 @@ TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); +TfLiteRegistration* Register_COS(); TfLiteRegistration* Register_TRANSPOSE_CONV(); TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); @@ -112,6 +114,7 @@ TfLiteRegistration* Register_NOT_EQUAL(); TfLiteRegistration* Register_SQRT(); TfLiteRegistration* Register_RSQRT(); TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_RANK(); TfLiteRegistration* Register_POW(); TfLiteRegistration* Register_FAKE_QUANT(); TfLiteRegistration* Register_PACK(); @@ -130,6 +133,12 @@ TfLiteRegistration* Register_SQUARED_DIFFERENCE(); TfLiteRegistration* Register_FILL(); TfLiteRegistration* Register_MIRROR_PAD(); TfLiteRegistration* Register_UNIQUE(); +TfLiteRegistration* Register_REVERSE_V2(); +TfLiteRegistration* Register_ADD_N(); +TfLiteRegistration* Register_GATHER_ND(); +TfLiteRegistration* Register_WHERE(); +TfLiteRegistration* Register_ELU(); +TfLiteRegistration* Register_REVERSE_SEQUENCE(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -162,18 +171,26 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_ABS, Register_ABS()); AddBuiltin(BuiltinOperator_RELU, Register_RELU()); AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); - AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); - AddBuiltin(BuiltinOperator_TANH, Register_TANH()); - AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); - AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D()); - AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D()); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), /* min_version */ 1, /* max_version */ 2); @@ -195,16 +212,28 @@ BuiltinOpResolver::BuiltinOpResolver() { Register_EMBEDDING_LOOKUP_SPARSE()); AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), /* min_version */ 1, - /* max_version */ 3); + /* max_version */ 4); AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); - AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); - AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); - AddBuiltin(BuiltinOperator_ADD, Register_ADD()); - AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND()); - AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND()); + AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_ADD, Register_ADD(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_MUL, Register_MUL()); - AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); + AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, @@ -215,62 +244,115 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, Register_UNIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version */ 1, /* max_version */ 2); - AddBuiltin(BuiltinOperator_PAD, Register_PAD()); - AddBuiltin(BuiltinOperator_PADV2, Register_PADV2()); + AddBuiltin(BuiltinOperator_PAD, Register_PAD(), /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_PADV2, Register_PADV2(), /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); - AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); + AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, - Register_RESIZE_NEAREST_NEIGHBOR()); + Register_RESIZE_NEAREST_NEIGHBOR(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); - AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); - AddBuiltin(BuiltinOperator_GATHER, Register_GATHER()); - AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_MEAN, Register_MEAN()); AddBuiltin(BuiltinOperator_DIV, Register_DIV()); - AddBuiltin(BuiltinOperator_SUB, Register_SUB()); - AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT()); + AddBuiltin(BuiltinOperator_SUB, Register_SUB(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), /* min_version */ 1, + /* max_version */ 3); AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); - AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); + AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); - AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_LOG, Register_LOG()); - AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); + AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), /* min_version */ 1, /* max_version */ 2); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); - AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); - AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); - AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); - AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN()); - AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); - AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); - AddBuiltin(BuiltinOperator_LESS, Register_LESS()); - AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL()); + AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_GREATER, Register_GREATER(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_LESS, Register_LESS(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); + AddBuiltin(BuiltinOperator_CEIL, Register_CEIL()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); - AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); - AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); + AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); + AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_SUM, Register_SUM()); AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); - AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX()); - AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN()); + AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); - AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); - AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); + AddBuiltin(BuiltinOperator_RANK, Register_RANK()); AddBuiltin(BuiltinOperator_POW, Register_POW()); AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); - AddBuiltin(BuiltinOperator_PACK, Register_PACK()); + AddBuiltin(BuiltinOperator_PACK, Register_PACK(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); @@ -286,16 +368,24 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_FILL, Register_FILL()); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE()); + AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2()); + AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); + AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND()); + AddBuiltin(BuiltinOperator_WHERE, Register_WHERE()); + AddBuiltin(BuiltinOperator_ELU, Register_ELU()); + AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. // AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); // AddCustom("AudioSpectrogram", // tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); - AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); - AddCustom("Relu1", tflite::ops::custom::Register_RELU_1()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS()); + + // WARNING: Control flow ops are experimental and subject to change. + AddCustom("Experimental_If", tflite::ops::custom::Register_IF()); + AddCustom("Experimental_While", tflite::ops::custom::Register_WHILE()); } } // namespace builtin diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 6840ea39bf243f476f7935ed85a53aacb044e498..faa864b0e236e4a61453c6fcecafd2ca09f41ce1 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -22,10 +22,8 @@ namespace ops { namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); -TfLiteRegistration* Register_LAYER_NORM_LSTM(); TfLiteRegistration* Register_MFCC(); TfLiteRegistration* Register_DETECTION_POSTPROCESS(); -TfLiteRegistration* Register_RELU_1(); } // namespace custom @@ -286,8 +284,6 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); - AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); - AddCustom("Relu1", tflite::ops::custom::Register_RELU_1()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } diff --git a/tensorflow/lite/kernels/relu1_test.cc b/tensorflow/lite/kernels/relu1_test.cc deleted file mode 100644 index f52d10b0b7f32af3444c702835f0674d7181bb7a..0000000000000000000000000000000000000000 --- a/tensorflow/lite/kernels/relu1_test.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/kernels/test_util.h" - -namespace tflite { -namespace ops { -namespace custom { - -TfLiteRegistration* Register_RELU_1(); - -namespace { - -using ::testing::ElementsAreArray; - -class BaseActivationsOpModel : public SingleOpModel { - public: - explicit BaseActivationsOpModel(const TensorData& input) { - input_ = AddInput(input); - output_ = AddOutput({input.type, {}}); - flexbuffers::Builder fbb; - fbb.Map([&]() {}); - fbb.Finish(); - SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1); - BuildInterpreter({GetShape(input_)}); - } - - protected: - int input_; - int output_; -}; - -class FloatActivationsOpModel : public BaseActivationsOpModel { - public: - using BaseActivationsOpModel::BaseActivationsOpModel; - - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); - } - std::vector GetOutput() { return ExtractVector(output_); } -}; - -TEST(FloatActivationsOpTest, Relu1) { - FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); - m.SetInput({ - 0.0, -0.6, 0.2, -0.4, // - 0.3, -2.0, 1.1, -0.1, // - }); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({ - 0.0, 0.0, 0.2, 0.0, // - 0.3, 0.0, 1.0, 0.0, // - })); -} - -} // namespace -} // namespace custom -} // namespace ops -} // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/kernels/reshape_test.cc b/tensorflow/lite/kernels/reshape_test.cc index e2210aeaf09395fec8fa9462096f61c4fbf19804..eb05eb6d4e1fbb51a07645ecf4788188a8ad9f7c 100644 --- a/tensorflow/lite/kernels/reshape_test.cc +++ b/tensorflow/lite/kernels/reshape_test.cc @@ -33,7 +33,7 @@ enum ShapeSpecificationType { // Const node, which is guaranteed not to change once inference starts. The // shape is also hardcoded as in kAsReshapeOption. kAsConstantTensor, - // The output shape is specifed as an input tensor that can change based on + // The output shape is specified as an input tensor that can change based on // external input. That is, the shape is not know before the inference // starts. The shape is also hardcoded as in kAsReshapeOption. kAsTensor, @@ -137,21 +137,11 @@ TEST_P(ReshapeOpTest, MismatchedDimensions) { #endif TEST_P(ReshapeOpTest, TooManyDimensions) { - if (GetParam() == kAsReshapeOption) { #ifdef GTEST_HAS_DEATH_TEST EXPECT_DEATH(ReshapeOpModel({1, 1, 2, 1, 1, 1, 1, 1, 1}, {9}, {1, 1, 1, 1, 1, 1, 1, 1, 2}, GetParam()), "Found too many dimensions"); #endif - } else { - ReshapeOpModel m({1, 1, 2, 1, 1, 1, 1, 1, 1}, {9}, - {1, 1, 1, 1, 1, 1, 1, 1, 2}, GetParam()); - m.SetInput({3, 4}); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4})); - EXPECT_THAT(m.GetOutputShape(), - ElementsAreArray({1, 1, 1, 1, 1, 1, 1, 1, 2})); - } } #ifdef GTEST_HAS_DEATH_TEST diff --git a/tensorflow/lite/kernels/resize_bilinear.cc b/tensorflow/lite/kernels/resize_bilinear.cc index d42cb188669587a957dd085f9ecb123f44b59437..7383d03438c65a710efbfe30f3d3c0ce261f0ca8 100644 --- a/tensorflow/lite/kernels/resize_bilinear.cc +++ b/tensorflow/lite/kernels/resize_bilinear.cc @@ -109,6 +109,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t); } + } else if (output->type == kTfLiteInt8) { + TF_LITE_RESIZE_BILINEAR(reference_ops, int8_t); #undef TF_LITE_RESIZE_BILINEAR } else { context->ReportError(context, "Output type is %d, requires float.", diff --git a/tensorflow/lite/kernels/resize_bilinear_test.cc b/tensorflow/lite/kernels/resize_bilinear_test.cc index d3f4837a287accd93c23e17fa3a361efd4120101..b7b7fcd41578967e9b96797e6a3c1ed23f29d75a 100644 --- a/tensorflow/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/lite/kernels/resize_bilinear_test.cc @@ -64,7 +64,7 @@ class ResizeBilinearOpModel : public SingleOpModel { }; TEST(ResizeBilinearOpTest, HorizontalResize) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); @@ -78,8 +78,8 @@ TEST(ResizeBilinearOpTest, HorizontalResize) { ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } -TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { - ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); +TEST(ResizeBilinearOpTest, HorizontalResizeUInt8) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); @@ -93,8 +93,23 @@ TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } +TEST(ResizeBilinearOpTest, HorizontalResizeInt8) { + ResizeBilinearOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + TEST(ResizeBilinearOpTest, VerticalResize) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); @@ -108,8 +123,8 @@ TEST(ResizeBilinearOpTest, VerticalResize) { ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } -TEST(ResizeBilinearOpTest, VerticalResize8Bit) { - ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); +TEST(ResizeBilinearOpTest, VerticalResizeUInt8) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); @@ -123,8 +138,23 @@ TEST(ResizeBilinearOpTest, VerticalResize8Bit) { ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } +TEST(ResizeBilinearOpTest, VerticalResizeInt8) { + ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + TEST(ResizeBilinearOpTest, TwoDimensionalResize) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12 // @@ -150,8 +180,8 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) { }))); } -TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { - ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); +TEST(ResizeBilinearOpTest, TwoDimensionalResizeUInt8) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12 // @@ -177,8 +207,35 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { }))); } +TEST(ResizeBilinearOpTest, TwoDimensionalResizeInt8) { + ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12, // @@ -215,7 +272,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // @@ -241,8 +298,8 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { }))); } -TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { - ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) { + ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12, // @@ -278,8 +335,45 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { }))); } -TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { - ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) { + ResizeBilinearOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {}); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 12, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 9, 12, 13, // + 12, 14, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 12, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 9, 12, 13, // + 12, 14, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {}); m.SetInput({ 3, 4, 6, 10, // 10, 12, 14, 16, // @@ -304,6 +398,33 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { 10, 12, 12, 14, 14, 16, // }))); } + +TEST(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) { + ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {}); + m.SetInput({ + 3, 4, 6, 10, // + 10, 12, 14, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 13, // + 10, 12, 12, 14, 14, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 10, 12, 14, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 13, // + 10, 12, 12, 14, 14, 16, // + }))); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/resize_nearest_neighbor.cc b/tensorflow/lite/kernels/resize_nearest_neighbor.cc index a48d8004f8b6cead177286328082310237af515a..3030a4f28e22396cbc51e55ff04562fa76a0264e 100644 --- a/tensorflow/lite/kernels/resize_nearest_neighbor.cc +++ b/tensorflow/lite/kernels/resize_nearest_neighbor.cc @@ -106,8 +106,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(size), GetTensorData(size), GetTensorShape(output), GetTensorData(output)); } + } else if (output->type == kTfLiteInt8) { + reference_ops::ResizeNearestNeighbor( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(size), GetTensorData(size), + GetTensorShape(output), GetTensorData(output)); } else { - context->ReportError(context, "Output type is %d, requires float or uint8.", + context->ReportError(context, + "Output type is %d, requires float, uint8 or int8.", output->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc index 03e2effd84c4adb13db1bb3ada4f5cfe1c0b12c9..4d4cec9101c3a2e0561e5b6b1313a5da96276ca8 100644 --- a/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc @@ -64,7 +64,7 @@ class ResizeNearestNeighborOpModel : public SingleOpModel { }; TEST(ResizeNearestNeighborOpTest, HorizontalResize) { - ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); + ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); @@ -79,8 +79,8 @@ TEST(ResizeNearestNeighborOpTest, HorizontalResize) { ElementsAreArray(ArrayFloatNear({3, 3, 6}))); } -TEST(ResizeNearestNeighborOpTest, HorizontalResize8Bit) { - ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); +TEST(ResizeNearestNeighborOpTest, HorizontalResizeUInt8) { + ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); @@ -95,8 +95,23 @@ TEST(ResizeNearestNeighborOpTest, HorizontalResize8Bit) { ElementsAreArray(ArrayFloatNear({3, 3, 6}))); } +TEST(ResizeNearestNeighborOpTest, HorizontalResizeInt8) { + ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {}); + m.SetInput({-3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-3, -3, 6}))); + + ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({-3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-3, -3, 6}))); +} + TEST(ResizeNearestNeighborOpTest, VerticalResize) { - ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); + ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); @@ -111,8 +126,8 @@ TEST(ResizeNearestNeighborOpTest, VerticalResize) { ElementsAreArray(ArrayFloatNear({3, 3, 9}))); } -TEST(ResizeNearestNeighborOpTest, VerticalResize8Bit) { - ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); +TEST(ResizeNearestNeighborOpTest, VerticalResizeUInt8) { + ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); @@ -127,8 +142,23 @@ TEST(ResizeNearestNeighborOpTest, VerticalResize8Bit) { ElementsAreArray(ArrayFloatNear({3, 3, 9}))); } +TEST(ResizeNearestNeighborOpTest, VerticalResizeInt8) { + ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {}); + m.SetInput({3, -9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 3, -9}))); + + ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, -9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 3, -9}))); +} + TEST(ResizeNearestNeighborOpTest, TwoDimensionalResize) { - ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); + ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12 // @@ -155,8 +185,8 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResize) { }))); } -TEST(ResizeNearestNeighborOpTest, TwoDimensionalResize8Bit) { - ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); +TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeUInt8) { + ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12 // @@ -183,8 +213,35 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResize8Bit) { }))); } +TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) { + ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {}); + m.SetInput({ + 3, -6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 3, -6, // + 3, 3, -6, // + 9, 9, 12, // + }))); + + ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, -6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 3, -6, // + 3, 3, -6, // + 9, 9, 12, // + }))); +} + TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); + ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12, // @@ -222,7 +279,7 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) { } TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResize) { - ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); + ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // @@ -249,8 +306,8 @@ TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResize) { }))); } -TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { - ResizeNearestNeighborOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); +TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) { + ResizeNearestNeighborOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12, // @@ -287,8 +344,45 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { }))); } -TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResize8Bit) { - ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); +TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) { + ResizeNearestNeighborOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {}); + m.SetInput({ + 3, 6, // + 9, -12, // + -4, 10, // + 12, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 3, 6, // + 3, 3, 6, // + 9, 9, -12, // + -4, -4, 10, // + -4, -4, 10, // + 12, 12, 16, // + }))); + + ResizeNearestNeighborOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, -12, // + -4, 10, // + 12, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 3, 6, // + 3, 3, 6, // + 9, 9, -12, // + -4, -4, 10, // + -4, -4, 10, // + 12, 12, 16, // + }))); +} + +TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) { + ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {}); m.SetInput({ 3, 4, 6, 10, // 10, 12, 14, 16, // @@ -315,6 +409,33 @@ TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResize8Bit) { }))); } +TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) { + ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {}); + m.SetInput({ + 3, 4, -6, 10, // + 10, 12, -14, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 3, 4, -6, 10, // + 3, 4, 3, 4, -6, 10, // + 10, 12, 10, 12, -14, 16, // + }))); + + ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, -6, 10, // + 10, 12, -14, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 3, 4, -6, 10, // + 3, 4, 3, 4, -6, 10, // + 10, 12, 10, 12, -14, 16, // + }))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/reverse.cc b/tensorflow/lite/kernels/reverse.cc new file mode 100644 index 0000000000000000000000000000000000000000..855aee8df1c0969bba9ec7d32bee78e04aeccbca --- /dev/null +++ b/tensorflow/lite/kernels/reverse.cc @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace reverse { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kAxisTensor = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* axis = GetInput(context, node, kAxisTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1); + TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis)); + + if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && + input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 && + input->type != kTfLiteInt64) { + context->ReportError(context, "Type '%s' is not supported by reverse.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + + if (axis->type != kTfLiteInt32) { + context->ReportError(context, "Axis Type '%s' is not supported by reverse.", + TfLiteTypeGetName(axis->type)); + return kTfLiteError; + } + + // TODO(renjieliu): support multi-axis case. + if (NumElements(axis) > 1) { + context->ReportError(context, "Current does not support more than 1 axis."); + } + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_EQ(context, output->type, input->type); + + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor); + int axis = GetTensorData(axis_tensor)[0]; + + TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input)); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (output->type) { + case kTfLiteFloat32: { + reference_ops::Reverse( + axis, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + break; + } + case kTfLiteUInt8: { + reference_ops::Reverse( + axis, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + break; + } + case kTfLiteInt16: { + reference_ops::Reverse( + axis, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + break; + } + case kTfLiteInt32: { + reference_ops::Reverse( + axis, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + break; + } + case kTfLiteInt64: { + reference_ops::Reverse( + axis, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + break; + } + default: { + context->ReportError(context, "Type '%s' is not supported by reverse.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } + + return kTfLiteOk; +} + +} // namespace +} // namespace reverse + +TfLiteRegistration* Register_REVERSE_V2() { + static TfLiteRegistration r = {nullptr, nullptr, reverse::Prepare, + reverse::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/reverse_sequence.cc b/tensorflow/lite/kernels/reverse_sequence.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b698bdb16c9ccc5d00bc0c0c93311d8fd3b271d --- /dev/null +++ b/tensorflow/lite/kernels/reverse_sequence.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace reverse_sequence { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kSeqLengthsTensor = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* seq_lengths = GetInput(context, node, kSeqLengthsTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1); + + if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && + input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 && + input->type != kTfLiteInt64) { + context->ReportError(context, + "Type '%s' is not supported by reverse_sequence.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + + if (seq_lengths->type != kTfLiteInt32 && seq_lengths->type != kTfLiteInt64) { + context->ReportError( + context, "Seq_lengths type '%s' is not supported by reverse_sequence.", + TfLiteTypeGetName(seq_lengths->type)); + return kTfLiteError; + } + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_EQ(context, output->type, input->type); + + return context->ResizeTensor(context, output, output_shape); +} + +template +TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* seq_lengths_tensor = + GetInput(context, node, kSeqLengthsTensor); + const TS* seq_lengths = GetTensorData(seq_lengths_tensor); + + auto* params = + reinterpret_cast(node->builtin_data); + int seq_dim = params->seq_dim; + int batch_dim = params->batch_dim; + + TF_LITE_ENSURE(context, seq_dim >= 0); + TF_LITE_ENSURE(context, batch_dim >= 0); + TF_LITE_ENSURE(context, seq_dim != batch_dim); + TF_LITE_ENSURE(context, seq_dim < NumDimensions(input)); + TF_LITE_ENSURE(context, batch_dim < NumDimensions(input)); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(seq_lengths_tensor, 0), + SizeOfDimension(input, batch_dim)); + for (int i = 0; i < NumDimensions(seq_lengths_tensor); ++i) { + TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim)); + } + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + reference_ops::ReverseSequence( + seq_lengths, seq_dim, batch_dim, GetTensorShape(input), + GetTensorData(input), GetTensorShape(output), + GetTensorData(output)); + + return kTfLiteOk; +} + +template +TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* seq_lengths_tensor = + GetInput(context, node, kSeqLengthsTensor); + switch (seq_lengths_tensor->type) { + case kTfLiteInt32: { + return ReverseSequenceImpl(context, node); + } + case kTfLiteInt64: { + return ReverseSequenceImpl(context, node); + } + default: { + context->ReportError( + context, + "Seq_lengths type '%s' is not supported by reverse_sequence.", + TfLiteTypeGetName(seq_lengths_tensor->type)); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (output->type) { + case kTfLiteFloat32: { + return ReverseSequenceHelper(context, node); + } + case kTfLiteUInt8: { + return ReverseSequenceHelper(context, node); + } + case kTfLiteInt16: { + return ReverseSequenceHelper(context, node); + } + case kTfLiteInt32: { + return ReverseSequenceHelper(context, node); + } + case kTfLiteInt64: { + return ReverseSequenceHelper(context, node); + } + default: { + context->ReportError(context, + "Type '%s' is not supported by reverse_sequence.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + } + return kTfLiteOk; +} // namespace + +} // namespace +} // namespace reverse_sequence + +TfLiteRegistration* Register_REVERSE_SEQUENCE() { + static TfLiteRegistration r = {nullptr, nullptr, reverse_sequence::Prepare, + reverse_sequence::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/reverse_sequence_test.cc b/tensorflow/lite/kernels/reverse_sequence_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e81f1380290de3f926ab8fae98e23c072935ca32 --- /dev/null +++ b/tensorflow/lite/kernels/reverse_sequence_test.cc @@ -0,0 +1,211 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class ReverseSequenceOpModel : public SingleOpModel { + public: + ReverseSequenceOpModel(const TensorData& input, const TensorData& seq_lengths, + int seq_dim, int batch_dim) { + input_ = AddInput(input); + seq_lengths_ = AddInput(seq_lengths); + + output_ = AddOutput({input.type, {}}); + + SetBuiltinOp( + BuiltinOperator_REVERSE_SEQUENCE, BuiltinOptions_ReverseSequenceOptions, + CreateReverseSequenceOptions(builder_, seq_dim, batch_dim).Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + int seq_lengths() { return seq_lengths_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int seq_lengths_; + int output_; +}; + +// float32 tests +TEST(ReverseSequenceOpTest, FloatSeqDimIsGreater) { + ReverseSequenceOpModel model({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_INT32, {4}}, 1, 0); + model.PopulateTensor(model.input(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 2, 3, 3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 9, 10, 7, 8, 11, 12, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +TEST(ReverseSequenceOpTest, FloatBatchDimIsGreater) { + ReverseSequenceOpModel model({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_INT32, {2}}, 0, 2); + model.PopulateTensor(model.input(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({13, 20, 15, 22, 17, 24, 7, 14, 9, 16, 11, 18, 1, + 8, 3, 10, 5, 12, 19, 2, 21, 4, 23, 6})); +} + +// int32 tests +TEST(ReverseSequenceOpTest, Int32SeqDimIsGreater) { + ReverseSequenceOpModel model({TensorType_INT32, {4, 3, 2}}, + {TensorType_INT32, {4}}, 1, 0); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 2, 3, 3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 9, 10, 7, 8, 11, 12, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +TEST(ReverseSequenceOpTest, Int32BatchDimIsGreater) { + ReverseSequenceOpModel model({TensorType_INT32, {4, 3, 2}}, + {TensorType_INT32, {2}}, 0, 2); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({13, 20, 15, 22, 17, 24, 7, 14, 9, 16, 11, 18, 1, + 8, 3, 10, 5, 12, 19, 2, 21, 4, 23, 6})); +} + +// int64 tests +TEST(ReverseSequenceOpTest, Int64SeqDimIsGreater) { + ReverseSequenceOpModel model({TensorType_INT64, {4, 3, 2}}, + {TensorType_INT32, {4}}, 1, 0); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 2, 3, 3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 9, 10, 7, 8, 11, 12, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +TEST(ReverseSequenceOpTest, Int64BatchDimIsGreater) { + ReverseSequenceOpModel model({TensorType_INT64, {4, 3, 2}}, + {TensorType_INT32, {2}}, 0, 2); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({13, 20, 15, 22, 17, 24, 7, 14, 9, 16, 11, 18, 1, + 8, 3, 10, 5, 12, 19, 2, 21, 4, 23, 6})); +} + +// uint8 tests +TEST(ReverseSequenceOpTest, Uint8SeqDimIsGreater) { + ReverseSequenceOpModel model({TensorType_UINT8, {4, 3, 2}}, + {TensorType_INT32, {4}}, 1, 0); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 2, 3, 3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 9, 10, 7, 8, 11, 12, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +TEST(ReverseSequenceOpTest, Uint8BatchDimIsGreater) { + ReverseSequenceOpModel model({TensorType_UINT8, {4, 3, 2}}, + {TensorType_INT32, {2}}, 0, 2); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({13, 20, 15, 22, 17, 24, 7, 14, 9, 16, 11, 18, 1, + 8, 3, 10, 5, 12, 19, 2, 21, 4, 23, 6})); +} + +// int16 tests +TEST(ReverseSequenceOpTest, Int16SeqDimIsGreater) { + ReverseSequenceOpModel model({TensorType_INT16, {4, 3, 2}}, + {TensorType_INT32, {4}}, 1, 0); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 2, 3, 3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 9, 10, 7, 8, 11, 12, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +TEST(ReverseSequenceOpTest, Int16BatchDimIsGreater) { + ReverseSequenceOpModel model({TensorType_INT16, {4, 3, 2}}, + {TensorType_INT32, {2}}, 0, 2); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.seq_lengths(), {3, 4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({13, 20, 15, 22, 17, 24, 7, 14, 9, 16, 11, 18, 1, + 8, 3, 10, 5, 12, 19, 2, 21, 4, 23, 6})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/reverse_test.cc b/tensorflow/lite/kernels/reverse_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bc0c24b64c197d5c9a60ff74bdd53c5ae0352b9 --- /dev/null +++ b/tensorflow/lite/kernels/reverse_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class ReverseOpModel : public SingleOpModel { + public: + ReverseOpModel(const TensorData& input, const TensorData& axis) { + input_ = AddInput(input); + axis_ = AddInput(axis); + + output_ = AddOutput({input.type, {}}); + + SetBuiltinOp(BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options, + CreateReverseV2Options(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + int axis() { return axis_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int axis_; + int output_; +}; + +// float32 tests. +TEST(ReverseOpTest, FloatOneDimension) { + ReverseOpModel model({TensorType_FLOAT32, {4}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), {1, 2, 3, 4}); + model.PopulateTensor(model.axis(), {0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1})); +} + +TEST(ReverseOpTest, FloatMultiDimensions) { + ReverseOpModel model({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.axis(), {1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +// int32 tests +TEST(ReverseOpTest, Int32OneDimension) { + ReverseOpModel model({TensorType_INT32, {4}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), {1, 2, 3, 4}); + model.PopulateTensor(model.axis(), {0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1})); +} + +TEST(ReverseOpTest, Int32MultiDimensions) { + ReverseOpModel model({TensorType_INT32, {4, 3, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.axis(), {1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +// int64 tests +TEST(ReverseOpTest, Int64OneDimension) { + ReverseOpModel model({TensorType_INT64, {4}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), {1, 2, 3, 4}); + model.PopulateTensor(model.axis(), {0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1})); +} + +TEST(ReverseOpTest, Int64MultiDimensions) { + ReverseOpModel model({TensorType_INT64, {4, 3, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.axis(), {1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +// uint8 tests +TEST(ReverseOpTest, Uint8OneDimension) { + ReverseOpModel model({TensorType_UINT8, {4}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), {1, 2, 3, 4}); + model.PopulateTensor(model.axis(), {0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1})); +} + +TEST(ReverseOpTest, Uint8MultiDimensions) { + ReverseOpModel model({TensorType_UINT8, {4, 3, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.axis(), {1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +// int16 tests +TEST(ReverseOpTest, Int16OneDimension) { + ReverseOpModel model({TensorType_INT16, {4}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), {1, 2, 3, 4}); + model.PopulateTensor(model.axis(), {0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1})); +} + +TEST(ReverseOpTest, Int16MultiDimensions) { + ReverseOpModel model({TensorType_INT16, {4, 3, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor( + model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + model.PopulateTensor(model.axis(), {1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8, + 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/select.cc b/tensorflow/lite/kernels/select.cc index 4687ab44171fab73ff1b4ef93592b25680f3a59f..d1c63d887db00143fb6b154306313411643cf2b8 100644 --- a/tensorflow/lite/kernels/select.cc +++ b/tensorflow/lite/kernels/select.cc @@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: \ TF_LITE_SELECT(uint8_t, op); \ break; \ + case kTfLiteInt8: \ + TF_LITE_SELECT(int8_t, op); \ + break; \ case kTfLiteInt16: \ TF_LITE_SELECT(int16_t, op); \ break; \ diff --git a/tensorflow/lite/kernels/select_test.cc b/tensorflow/lite/kernels/select_test.cc index 5111300e479a92ad9cbf00628750dc61effc50d3..d7cadeb51eb3ee0645eaccc1bbcea59bd279e0d7 100644 --- a/tensorflow/lite/kernels/select_test.cc +++ b/tensorflow/lite/kernels/select_test.cc @@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } +TEST(SelectOpTest, SelectInt8) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_INT8); + + model.PopulateTensor(model.input1(), {false, true, false, false}); + model.PopulateTensor(model.input2(), {1, -2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, -8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, -2, 7, -8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + TEST(SelectOpTest, SelectInt16) { SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT16); diff --git a/tensorflow/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc index 5fca7a3ea71aa41c6e466b7814921e2e1ac6293d..8472572d7e2a4ad1ab60b144f9d06fbfc7bc7387 100644 --- a/tensorflow/lite/kernels/slice.cc +++ b/tensorflow/lite/kernels/slice.cc @@ -117,7 +117,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64); TF_LITE_ENSURE(context, size->type == kTfLiteInt32 || size->type == kTfLiteInt64); - TF_LITE_ENSURE(context, NumDimensions(begin) == NumDimensions(size) == 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim, "Slice op only supports 1D-4D input arrays."); @@ -204,6 +205,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_SLICE(int64_t, kernel_type); break; + case kTfLiteInt8: + TF_LITE_SLICE(int8_t, kernel_type); + break; case kTfLiteUInt8: TF_LITE_SLICE(uint8_t, kernel_type); break; diff --git a/tensorflow/lite/kernels/slice_test.cc b/tensorflow/lite/kernels/slice_test.cc index 563329ddb164d3aa5f13c8ee0d6482d79b84ed32..102218ba23c105014ee6d501d2941f8b4755a44e 100644 --- a/tensorflow/lite/kernels/slice_test.cc +++ b/tensorflow/lite/kernels/slice_test.cc @@ -163,6 +163,28 @@ TEST(SliceOpTest, SizeMinus1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); } +TEST(SliceOpTest, SliceUint8) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_UINT8); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({2, 1, -1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); +} + +TEST(SliceOpTest, SliceInt8) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_INT8); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({2, 1, -1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/space_to_batch_nd.cc b/tensorflow/lite/kernels/space_to_batch_nd.cc index 1c61b2ef30379e808085f3b0d16a5b1157bea314..2fb7198cd67e8b9d13873d25a2eaa04fd2ff2ae0 100644 --- a/tensorflow/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/lite/kernels/space_to_batch_nd.cc @@ -141,6 +141,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_context.output->params.zero_point); } break; + case kTfLiteInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t, + op_context.output->params.zero_point); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t, + op_context.output->params.zero_point); + } + break; case kTfLiteInt32: if (kernel_type == kReference) { TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0); diff --git a/tensorflow/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/lite/kernels/space_to_batch_nd_test.cc index c5d6e9a53062d97801b518f15305e2052f861e7c..52a77984d935d04a79807707729754abd21d3be6 100644 --- a/tensorflow/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/lite/kernels/space_to_batch_nd_test.cc @@ -31,8 +31,9 @@ class SpaceToBatchNDOpModel : public SingleOpModel { PopulateTensor(input_, data); } + template void SetQuantizedInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } void SetBlockShape(std::initializer_list data) { @@ -46,9 +47,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } protected: @@ -233,29 +235,62 @@ TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) { } #endif -TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) { +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTestUint8) { SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, {3, 2}, {1, 0, 2, 0}, {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.SetQuantizedInput( + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, -1.0, 1.0))); } -TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) { +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTestInt8) { + SpaceToBatchNDOpConstModel m({TensorType_INT8, {1, 5, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 0, 2, 0}, + {TensorType_INT8, {}, -1.0, 1.0}); + m.SetQuantizedInput( + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTestUint8) { SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.SetQuantizedInput( + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTestInt8) { + SpaceToBatchNDOpDynamicModel m({TensorType_INT8, {1, 5, 2, 1}, -1.0, 1.0}, + {TensorType_INT8, {}, -1.0, 1.0}); + m.SetQuantizedInput( + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 0, 2, 0}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, @@ -266,10 +301,10 @@ TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) { SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, {3, 2}, {1, 1, 2, 4}, {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( { 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, @@ -282,12 +317,12 @@ TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) { TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}); - m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( { 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, diff --git a/tensorflow/lite/kernels/space_to_depth.cc b/tensorflow/lite/kernels/space_to_depth.cc index 79e28bf47d98b64572d9e7404f8d69788cd30e08..cf6b0bd4d3d4b61b87a14d1090a7e89d9b77a0f2 100644 --- a/tensorflow/lite/kernels/space_to_depth.cc +++ b/tensorflow/lite/kernels/space_to_depth.cc @@ -50,7 +50,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto data_type = output->type; TF_LITE_ENSURE(context, data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || - data_type == kTfLiteInt32 || data_type == kTfLiteInt64); + data_type == kTfLiteInt8 || data_type == kTfLiteInt32 || + data_type == kTfLiteInt64); TF_LITE_ENSURE_EQ(context, input->type, output->type); const int block_size = params->block_size; @@ -100,6 +101,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); } break; + case kTfLiteInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t); + } + break; case kTfLiteInt32: if (kernel_type == kReference) { TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); diff --git a/tensorflow/lite/kernels/space_to_depth_test.cc b/tensorflow/lite/kernels/space_to_depth_test.cc index 3fa8d86348ef899b9bd42c19f5b1510b4c4e33d3..58665fc9d83007d7bed638418cba058e4ff189c5 100644 --- a/tensorflow/lite/kernels/space_to_depth_test.cc +++ b/tensorflow/lite/kernels/space_to_depth_test.cc @@ -74,6 +74,14 @@ TEST(SpaceToDepthOpModel, Uint8) { EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } +TEST(SpaceToDepthOpModel, int8) { + SpaceToDepthOpModel m({TensorType_INT8, {1, 2, 2, 1}}, 2); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + TEST(SpaceToDepthOpModel, Int32) { SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); diff --git a/tensorflow/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/lite/kernels/sparse_output_fully_connected.cc deleted file mode 100644 index 248969535c66a31be9e15ea366cf461e93206161..0000000000000000000000000000000000000000 --- a/tensorflow/lite/kernels/sparse_output_fully_connected.cc +++ /dev/null @@ -1,252 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// SparseOutputFullyConnected is a fully connected layer that uses a single -// row in the weights and bias via a lookup. -#include "tensorflow/lite/context.h" -#include "tensorflow/lite/kernels/internal/tensor.h" -#include "tensorflow/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/lite/kernels/kernel_util.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace sparse_output_fully_connected { - -// Input tensors of size {n_batch, n_input} -constexpr int kInputTensor = 0; -// Auxiliary input tensor of size { 1 } -constexpr int kInputLookupTensor = 1; - -// Weights tensor of size { n_embeddings , n_input } -constexpr int kWeightsTensor = 2; -// Bias tensor of size { n_embeddings } -constexpr int kBiasTensor = 3; - -// Output tensor. -constexpr int kOutputTensor = 0; - -// Temporary tensors. -enum TemporaryTensor { - kInputQuantized = 0, - kScalingFactors = 1, - kNumTemporaryTensors = 2 -}; - -// Struct to hold op data. -struct OpData { - int scratch_tensor_index; -}; - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* data = new OpData; - context->AddTensors(context, /*tensors_to_add=*/kNumTemporaryTensors, - &data->scratch_tensor_index); - return data; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - OpData* op_data = reinterpret_cast(node->user_data); - - TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); - TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); - const int n_batch = SizeOfDimension(input, 0); - const int n_input = SizeOfDimension(input, 1); - - const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor); - TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); - // Only support single lookup. - TF_LITE_ENSURE_EQ(context, SizeOfDimension(lookup, 0), 1); - - const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 2); - TF_LITE_ENSURE_EQ(context, SizeOfDimension(weights, 1), n_input); - - const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); - TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(weights, 0)); - - const bool is_hybrid_op = - ((weights->type == kTfLiteUInt8 || weights->type == kTfLiteInt8) && - input->type == kTfLiteFloat32); - - // Resize output. - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(1); - output_size_array->data[0] = 1; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, output, output_size_array)); - - if (is_hybrid_op) { - TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); - - // Allocate temporary tensors to store quantized values of input. - node->temporaries->data[kInputQuantized] = op_data->scratch_tensor_index; - TfLiteTensor* input_quantized = - GetTemporary(context, node, /*index=*/kInputQuantized); - input_quantized->type = weights->type; - input_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { - TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, - input_quantized_size)); - } - - // Tell interpreter to allocate temporary tensors to store scaling factors. - node->temporaries->data[kScalingFactors] = - op_data->scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, /*index=*/kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; - int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); - } - } - return kTfLiteOk; -} - -TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* lookup, - const TfLiteTensor* weights, const TfLiteTensor* bias, - TfLiteTensor* output) { - const int n_batch = SizeOfDimension(input, 0); - const int n_input = SizeOfDimension(input, 1); - - const float* input_ptr_batch = input->data.f; - - // Initialize pointer to right row according to lookup value. - int32 lookup_index = lookup->data.i32[0]; - const float* weights_ptr = weights->data.f + lookup_index * n_input; - - // Initialize output to bias. - if (bias) { - float* bias_ptr = bias->data.f + lookup_index; - tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, n_batch * 1); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - weights_ptr, /*m_rows=*/1, n_input, input_ptr_batch, n_batch, - output->data.f, /*result_stride=*/1); - - return kTfLiteOk; -} - -TfLiteStatus EvalHybrid(const TfLiteTensor* input, const TfLiteTensor* lookup, - const TfLiteTensor* weights, const TfLiteTensor* bias, - TfLiteTensor* scaling_factors, - TfLiteTensor* input_quantized, TfLiteTensor* output) { - const int n_batch = SizeOfDimension(input, 0); - const int n_input = SizeOfDimension(input, 1); - - const float* input_ptr_batch = input->data.f; - - // Initialize the pointer to storage for scaling factors. - float* scaling_factors_ptr = scaling_factors->data.f; - int32 lookup_index = lookup->data.i32[0]; - - // Initialize the pointer to storage for quantized values and a pointer to - // the row according to lookup value. - int8_t *quantized_input_ptr_batch, *weights_ptr; - if (weights->type == kTfLiteUInt8) { - quantized_input_ptr_batch = - reinterpret_cast(input_quantized->data.uint8); - weights_ptr = - reinterpret_cast(weights->data.uint8) + lookup_index * n_input; - } else { - quantized_input_ptr_batch = input_quantized->data.int8; - weights_ptr = weights->data.int8 + lookup_index * n_input; - } - - // Initialize output to bias. - if (bias) { - float* bias_ptr = bias->data.f + lookup_index; - tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, n_batch * 1); - } - - if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { - // Quantize input from float to int8. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, - &unused_min, &unused_max, &scaling_factors_ptr[b]); - scaling_factors_ptr[b] *= weights->params.scale; - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - weights_ptr, /*m_rows=*/1, n_input, quantized_input_ptr_batch, - scaling_factors_ptr, n_batch, output->data.f, /*result_stride=*/1); - } - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor); - const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - switch (weights->type) { - case kTfLiteFloat32: { - return EvalFloat(input, lookup, weights, bias, output); - } - case kTfLiteUInt8: - case kTfLiteInt8: { - TfLiteTensor* input_quantized = - GetTemporary(context, node, /*index=*/kInputQuantized); - TfLiteTensor* scaling_factors = - GetTemporary(context, node, /*index=*/kScalingFactors); - return EvalHybrid(input, lookup, weights, bias, scaling_factors, - input_quantized, output); - } - default: - context->ReportError(context, "Type %d is not currently supported.", - weights->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace sparse_output_fully_connected - -TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED() { - static TfLiteRegistration r = {sparse_output_fully_connected::Init, - sparse_output_fully_connected::Free, - sparse_output_fully_connected::Prepare, - sparse_output_fully_connected::Eval}; - return &r; -} - -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/lite/kernels/sparse_output_fully_connected_test.cc deleted file mode 100644 index 7d5fec192ce6b103c41f47ed60eb1283f72da45f..0000000000000000000000000000000000000000 --- a/tensorflow/lite/kernels/sparse_output_fully_connected_test.cc +++ /dev/null @@ -1,192 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Unit test for TFLite sparse output fully connected op. -#include -#include -#include - -#include -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/lite/kernels/internal/types.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/kernels/test_util.h" - -namespace tflite { - -namespace ops { -namespace custom { - -TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED(); - -namespace { - -using ::testing::ElementsAreArray; - -class BaseSparseOutputFullyConnectedOpModel : public SingleOpModel { - public: - BaseSparseOutputFullyConnectedOpModel(const TensorData& input, - const TensorData& weights, - const TensorData& output = { - TensorType_FLOAT32}) { - input_ = AddInput(input); - lookup_ = AddInput({TensorType_INT32, {1}}); - weights_ = AddInput(weights); - int bias_size = GetShape(weights_)[0]; - bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); - output_ = AddOutput(output); - - // Create empty (required) options map. - flexbuffers::Builder fbb; - fbb.Map([&]() {}); - fbb.Finish(); - - SetCustomOp("SPARSE_OUTPUT_FULLY_CONNECTED", fbb.GetBuffer(), - Register_SPARSE_OUTPUT_FULLY_CONNECTED); - BuildInterpreter({GetShape(input_), GetShape(lookup_), GetShape(weights_), - GetShape(bias_)}); - } - - void SetInput(const std::vector& data) { - PopulateTensor(input_, data); - } - - void SetLookup(const std::vector& f) { PopulateTensor(lookup_, f); } - - void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } - - std::vector GetOutput() { return ExtractVector(output_); } - - protected: - int input_; - int lookup_; - int weights_; - int bias_; - int output_; -}; - -class FloatSparseOutputFullyConnectedOpModel - : public BaseSparseOutputFullyConnectedOpModel { - public: - using BaseSparseOutputFullyConnectedOpModel:: - BaseSparseOutputFullyConnectedOpModel; - - void SetWeights(const std::vector& f) { PopulateTensor(weights_, f); } -}; - -class HybridSparseOutputFullyConnectedOpModel - : public BaseSparseOutputFullyConnectedOpModel { - public: - using BaseSparseOutputFullyConnectedOpModel:: - BaseSparseOutputFullyConnectedOpModel; - - void SetWeights(const std::vector& f) { - SymmetricQuantizeAndPopulate(weights_, f); - } - - void SetSignedWeights(const std::vector& f) { - SignedSymmetricQuantizeAndPopulate(weights_, f); - } -}; - -TEST(SparseOutputFullyConnectedOpTest, SimpleTestFloat) { - FloatSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}}, - {TensorType_FLOAT32, {3, 5}}, - {TensorType_FLOAT32, {}}); - - m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0}); - - m.SetLookup({2}); - - m.SetWeights({ - -1.0, 0.0, 1.0, 2.0, 3.0, // - 0.0, 1.0, 2.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 4.0, 5.0, // - }); - - m.SetBias({1.0, 2.0, 3.0}); - - m.Invoke(); - - EXPECT_THAT(m.GetOutput(), ElementsAreArray({28})); -} - -TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybridUint8) { - HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}}, - {TensorType_UINT8, {3, 5}}, - {TensorType_FLOAT32, {}}); - - m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0}); - - m.SetLookup({2}); - - m.SetWeights({ - -1.0, 0.0, 1.0, 2.0, 3.0, // - 0.0, 1.0, 2.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 4.0, 5.0, // - }); - - m.SetBias({1.0, 2.0, 3.0}); - - m.Invoke(); - - // We get 28.0552 instead of 28. - // - // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3. - // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5. - // - // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0 - // gives us the expected result. - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553))); -} - -TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybridInt8) { - HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}}, - {TensorType_INT8, {3, 5}}, - {TensorType_FLOAT32, {}}); - - m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0}); - - m.SetLookup({2}); - - m.SetSignedWeights({ - -1.0, 0.0, 1.0, 2.0, 3.0, // - 0.0, 1.0, 2.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 4.0, 5.0, // - }); - - m.SetBias({1.0, 2.0, 3.0}); - - m.Invoke(); - - // We get 28.0552 instead of 28. - // - // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3. - // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5. - // - // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0 - // gives us the expected result. - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553))); -} - -} // namespace -} // namespace custom -} // namespace ops -} // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/kernels/split.cc b/tensorflow/lite/kernels/split.cc index 7902ed2a46d297cca6f076bf1bb48580f3c4bf40..c0f701f55dd096279e1e9f1e54817490cb0c230b 100644 --- a/tensorflow/lite/kernels/split.cc +++ b/tensorflow/lite/kernels/split.cc @@ -76,9 +76,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); auto input_type = op_context.input->type; - TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || - input_type == kTfLiteUInt8 || - input_type == kTfLiteInt16); + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || + input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || + input_type == kTfLiteInt32); for (int i = 0; i < NumOutputs(node); ++i) { GetOutput(context, node, i)->type = input_type; } @@ -137,15 +138,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(uint8_t); break; } + case kTfLiteInt8: { + TF_LITE_SPLIT(int8_t); + break; + } case kTfLiteInt16: { TF_LITE_SPLIT(int16_t); break; } + case kTfLiteInt32: { + TF_LITE_SPLIT(int32_t); + break; + } default: - context->ReportError( - context, - "Only float32, uint8 and int16 are currently supported, got %d.", - op_context.input->type); + context->ReportError(context, + "Only float32, uint8, int8, int16 and int32 are " + "currently supported, got %d.", + op_context.input->type); return kTfLiteError; } #undef TF_LITE_SPLIT diff --git a/tensorflow/lite/kernels/split_test.cc b/tensorflow/lite/kernels/split_test.cc index f3d9ea3bf4158dd51b5102b942125b7561024c19..fa313d4b18f803dc5060425d8162af25129dd5d9 100644 --- a/tensorflow/lite/kernels/split_test.cc +++ b/tensorflow/lite/kernels/split_test.cc @@ -47,13 +47,15 @@ class SplitOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } - std::vector GetOutput(int i) { - return ExtractVector(outputs_[i]); + template + std::vector GetOutput(int i) { + return ExtractVector(outputs_[i]); } std::vector GetOutputShape(int i) { return GetTensorShape(outputs_[i]); } @@ -63,33 +65,34 @@ class SplitOpModel : public SingleOpModel { std::vector outputs_; }; -using TensorValues = std::initializer_list; - +template void Check(int axis, int num_splits, std::initializer_list input_shape, std::initializer_list output_shape, - const TensorValues& input_data, - const std::vector& output_data) { + const std::initializer_list& input_data, + const std::vector>& output_data, + const TensorType& type = TensorType_FLOAT32) { auto debug = [&](int i) { std::stringstream ss; ss << "for output tensor " << i << " axis=" << axis << " and num_splits=" << num_splits; return ss.str(); }; - SplitOpModel m({TensorType_FLOAT32, input_shape}, num_splits); + SplitOpModel m({type, input_shape}, num_splits); m.SetInput(input_data); m.SetAxis(axis); m.Invoke(); for (int i = 0; i < num_splits; ++i) { - EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])) << debug(i); + EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])) + << debug(i); EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape)) << debug(i); } - SplitOpModel const_m({TensorType_FLOAT32, input_shape}, num_splits, axis); + SplitOpModel const_m({type, input_shape}, num_splits, axis); const_m.SetInput(input_data); const_m.Invoke(); for (int i = 0; i < num_splits; ++i) { - EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])) + EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])) << debug(i); EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape)) << debug(i); @@ -97,44 +100,106 @@ void Check(int axis, int num_splits, std::initializer_list input_shape, } TEST(SplitOpTest, FourDimensional) { - Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - { - {1, 2, 3, 4, 5, 6, 7, 8}, - {9, 10, 11, 12, 13, 14, 15, 16}, - }); - Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - { - {1, 2, 3, 4, 9, 10, 11, 12}, - {5, 6, 7, 8, 13, 14, 15, 16}, - }); - Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - { - {1, 2, 5, 6, 9, 10, 13, 14}, - {3, 4, 7, 8, 11, 12, 15, 16}, - }); - Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - { - {1, 3, 5, 7, 9, 11, 13, 15}, - {2, 4, 6, 8, 10, 12, 14, 16}, - }); + Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); + Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }); + Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }); + Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }); +} + +TEST(SplitOpTest, FourDimensionalInt8) { + Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }, + TensorType_INT8); + Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }, + TensorType_INT8); + Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }, + TensorType_INT8); + Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }, + TensorType_INT8); +} + +TEST(SplitOpTest, FourDimensionalInt32) { + Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }, + TensorType_INT32); + Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }, + TensorType_INT32); + Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }, + TensorType_INT32); + Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }, + TensorType_INT32); } TEST(SplitOpTest, OneDimensional) { - Check(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8}, - {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); + Check(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); } TEST(SplitOpTest, NegativeAxis) { - Check(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - { - {1, 2, 3, 4, 5, 6, 7, 8}, - {9, 10, 11, 12, 13, 14, 15, 16}, - }); + Check(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); } } // namespace diff --git a/tensorflow/lite/kernels/split_v.cc b/tensorflow/lite/kernels/split_v.cc index 060e3c5f79c808cd3c8d4b21efd7f2595a68b8e8..c95396c621b988930208ea62cca03ea9aa67a1f9 100644 --- a/tensorflow/lite/kernels/split_v.cc +++ b/tensorflow/lite/kernels/split_v.cc @@ -183,10 +183,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } default: - context->ReportError( - context, - "Only float32, uint8 and int16 are currently supported, got %d.", - op_context.input->type); + context->ReportError(context, "Type %s currently not supported.", + TfLiteTypeGetName(op_context.input->type)); return kTfLiteError; } #undef TF_LITE_SPLIT_V diff --git a/tensorflow/lite/kernels/split_v_test.cc b/tensorflow/lite/kernels/split_v_test.cc index 2d1d36d6851c12d1b05374cda5ef32255e162875..27fed63f0eea452104a26dd3d0527ae98ce81dc7 100644 --- a/tensorflow/lite/kernels/split_v_test.cc +++ b/tensorflow/lite/kernels/split_v_test.cc @@ -50,16 +50,18 @@ class SplitVOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } void SetSizeSplits(std::initializer_list data) { PopulateTensor(size_splits_, data); } void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } - std::vector GetOutput(int i) { - return ExtractVector(outputs_[i]); + template + std::vector GetOutput(int i) { + return ExtractVector(outputs_[i]); } std::vector GetOutputShape(int i) { return GetTensorShape(outputs_[i]); } @@ -70,99 +72,132 @@ class SplitVOpModel : public SingleOpModel { std::vector outputs_; }; -// TODO(ruic): Add tests to test quantized values. b/119638735 -using TensorValues = std::initializer_list; - +template void Check(int axis, std::initializer_list input_shape, std::initializer_list size_splits_shape, std::vector> output_shapes, - const TensorValues& input_data, + const std::initializer_list& input_data, const std::initializer_list& size_splits_data, - const std::vector& output_data) { + const std::vector>& output_data) { int num_splits = size_splits_data.size(); - SplitVOpModel m({TensorType_FLOAT32, input_shape}, - {TensorType_INT32, size_splits_shape}, num_splits, - kAxisIsATensor); - m.SetInput(input_data); + SplitVOpModel m({T1, input_shape}, {TensorType_INT32, size_splits_shape}, + num_splits, kAxisIsATensor); + m.SetInput(input_data); m.SetSizeSplits(size_splits_data); m.SetAxis(axis); m.Invoke(); for (int i = 0; i < num_splits; ++i) { - EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])); + EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])); EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shapes[i])); } - SplitVOpModel const_m({TensorType_FLOAT32, input_shape}, + SplitVOpModel const_m({T1, input_shape}, {TensorType_INT32, size_splits_shape}, num_splits, axis); - const_m.SetInput(input_data); + const_m.SetInput(input_data); const_m.SetSizeSplits(size_splits_data); const_m.Invoke(); for (int i = 0; i < num_splits; ++i) { - EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])); + EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])); EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shapes[i])); } } TEST(SplitVOpTest, TwoDimensional) { // Input shape: {4, 3} - // size_splits: {1, 1, 3} + // size_splits: {1, 1, 2} // axis: 0 // We should have 3 outpus with shapes respectively: - // output 0 : {1, 3} // output 1 : {1, 3} - // output 1 : {2, 3} - Check(/*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); } TEST(SplitVOpTest, FourDimensional) { - Check(/*axis=*/0, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 2, 3, 4, 5, 6, 7, 8}, - {9, 10, 11, 12, 13, 14, 15, 16}, - }); - Check(/*axis=*/1, {2, 2, 2, 2}, {2}, {{2, 1, 2, 2}, {2, 1, 2, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, -1}, - { - {1, 2, 3, 4, 9, 10, 11, 12}, - {5, 6, 7, 8, 13, 14, 15, 16}, - }); - Check(/*axis=*/2, {2, 2, 2, 2}, {2}, {{2, 2, 1, 2}, {2, 2, 1, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 2, 5, 6, 9, 10, 13, 14}, - {3, 4, 7, 8, 11, 12, 15, 16}, - }); - Check(/*axis=*/3, {2, 2, 2, 2}, {2}, {{2, 2, 2, 1}, {2, 2, 2, 1}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 3, 5, 7, 9, 11, 13, 15}, - {2, 4, 6, 8, 10, 12, 14, 16}, - }); + Check( + /*axis=*/0, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); + Check( + /*axis=*/1, {2, 2, 2, 2}, {2}, {{2, 1, 2, 2}, {2, 1, 2, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, -1}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }); + Check( + /*axis=*/2, {2, 2, 2, 2}, {2}, {{2, 2, 1, 2}, {2, 2, 1, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }); + Check( + /*axis=*/3, {2, 2, 2, 2}, {2}, {{2, 2, 2, 1}, {2, 2, 2, 1}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }); } TEST(SplitVOpTest, OneDimensional) { - Check(/*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}}, - {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 1, 1}, - {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); + Check( + /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}}, + {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 1, 1}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); } TEST(SplitVOpTest, OneDimensional2) { - Check(/*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {2}, {0}}, - {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 2, -1}, - {{1}, {2}, {3}, {4}, {5}, {6}, {7, 8}, {}}); + Check( + /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {2}, {0}}, + {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 2, -1}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7, 8}, {}}); } TEST(SplitVOpTest, NegativeAxis) { - Check(/*axis=*/-4, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 2, 3, 4, 5, 6, 7, 8}, - {9, 10, 11, 12, 13, 14, 15, 16}, - }); + Check( + /*axis=*/-4, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); +} + +TEST(SplitVOpTest, TwoDimensionalUint8) { + // Input shape: {4, 3} + // size_splits: {1, 1, 2} + // axis: 0 + // We should have 3 outpus with shapes respectively: + // output 1 : {1, 3} + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); +} + +TEST(SplitVOpTest, TwoDimensionalInt16) { + // Input shape: {4, 3} + // size_splits: {1, 1, 2} + // axis: 0 + // We should have 3 outpus with shapes respectively: + // output 1 : {1, 3} + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); } } // namespace diff --git a/tensorflow/lite/kernels/squared_difference.cc b/tensorflow/lite/kernels/squared_difference.cc index 59b53a6287dbbc863a61875be82090c1b9c6d442..3661cf9f98c5d0133090ae926f8d76e54f428eba 100644 --- a/tensorflow/lite/kernels/squared_difference.cc +++ b/tensorflow/lite/kernels/squared_difference.cc @@ -105,10 +105,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } else if (output->type == kTfLiteInt32) { EvalSquaredDifference(context, node, data, input1, input2, output); } else { - context->ReportError(context, - "SquaredDifference only supports FLOAT32, INT32 and " - "quantized UINT8 now, got %d.", - output->type); + context->ReportError( + context, + "SquaredDifference only supports FLOAT32 and INT32 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index c797a98e9f1bda8595e6822638949bab48cb2eab..8c25ffa3a1a669684d9fb1b552893de3a450264f 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -234,6 +234,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_STRIDED_SLICE(reference_ops, uint8_t); } break; + case kTfLiteInt8: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int8_t); + } + break; default: context->ReportError(context, "Type %d is currently not supported " diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 34875bf0497a000da02f3d0212b042399046a492..cac9e1672f871268d6d37b3488d00a0c1399aaa7 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -577,6 +577,18 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, + 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({1, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/sub.cc b/tensorflow/lite/kernels/sub.cc index 06a3b3499a005f19bfd1461dfe861835f8331b96..8bd6052307cc0e032a566e437923cac2f16be69e 100644 --- a/tensorflow/lite/kernels/sub.cc +++ b/tensorflow/lite/kernels/sub.cc @@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -39,6 +41,23 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // These fields are used in both the general 8-bit -> 8bit quantized path, + // and the special 16-bit -> 16bit quantized path + int input1_shift; + int input2_shift; + int32 output_activation_min; + int32 output_activation_max; + + // These fields are used only in the general 8-bit -> 8bit quantized path + int32 input1_multiplier; + int32 input2_multiplier; + int32 output_multiplier; + int output_shift; + int left_shift; + int32 input1_offset; + int32 input2_offset; + int32 output_offset; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -51,8 +70,126 @@ void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast(buffer); } +TfLiteStatus Prepare8BitSubOp(TfLiteContext* context, + const TfLiteTensor* input_1, + const TfLiteTensor* input_2, TfLiteTensor* output, + TfLiteSubParams* params, OpData* op_params, + int op_sign) { + TF_LITE_ENSURE(context, + output->type == kTfLiteUInt8 || output->type == kTfLiteInt8); + const auto& input1_quantization_params = input_1->params; + const auto& input2_quantization_params = input_2->params; + const auto& output_quantization_params = output->params; + int32_t integer_type_min = 0; + int32_t integer_type_max = 0; + if (output->type == kTfLiteUInt8) { + integer_type_min = std::numeric_limits::min(); + integer_type_max = std::numeric_limits::max(); + } else { + // output->type == kTfLiteInt8 + integer_type_min = std::numeric_limits::min(); + integer_type_max = std::numeric_limits::max(); + } + + TF_LITE_ENSURE(context, + input1_quantization_params.zero_point >= integer_type_min); + TF_LITE_ENSURE(context, + input1_quantization_params.zero_point <= integer_type_max); + TF_LITE_ENSURE(context, + input2_quantization_params.zero_point >= integer_type_min); + TF_LITE_ENSURE(context, + input2_quantization_params.zero_point <= integer_type_max); + TF_LITE_ENSURE(context, + output_quantization_params.zero_point >= integer_type_min); + TF_LITE_ENSURE(context, + output_quantization_params.zero_point <= integer_type_max); + + op_params->input1_offset = -input1_quantization_params.zero_point; + op_params->input2_offset = -input2_quantization_params.zero_point; + op_params->output_offset = output_quantization_params.zero_point; + op_params->left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1_quantization_params.scale, + input2_quantization_params.scale); + const double real_input1_multiplier = + input1_quantization_params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2_quantization_params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / + ((1 << op_params->left_shift) * output_quantization_params.scale); + + tflite::QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &op_params->input1_multiplier, + &op_params->input1_shift); + tflite::QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &op_params->input2_multiplier, + &op_params->input2_shift); + op_params->input2_multiplier *= op_sign; + tflite::QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &op_params->output_multiplier, + &op_params->output_shift); + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &op_params->output_activation_min, + &op_params->output_activation_max); + } else { + CalculateActivationRangeInt8(params->activation, output, + &op_params->output_activation_min, + &op_params->output_activation_max); + } + return kTfLiteOk; +} + +TfLiteStatus PrepareInt16SubOp(TfLiteContext* context, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output, + TfLiteSubParams* params, OpData* data) { + // 16bit -> 16bit special quantized path, supporting only a rather + // narrow case of quantization parameters: zero_points must all be 0 + // ("symmetric quantization") and scales must be power-of-two (which + // we abbreviate as "POT" below). The intended use case for this path + // is in LSTM cells, where, due to the constraints of implementing + // some of the math in these LSTM cells in fixed-point arithmetic, + // we need to have such symmetric, power-of-two quantization + // (Fixed-point formats are inherently symmetric, power-of-two). + TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input1_scale_log2_rounded; + bool input1_scale_is_pot = + CheckedLog2(input1->params.scale, &input1_scale_log2_rounded); + TF_LITE_ENSURE(context, input1_scale_is_pot); + + int input2_scale_log2_rounded; + bool input2_scale_is_pot = + CheckedLog2(input2->params.scale, &input2_scale_log2_rounded); + TF_LITE_ENSURE(context, input2_scale_is_pot); + + int output_scale_log2_rounded; + bool output_scale_is_pot = + CheckedLog2(output->params.scale, &output_scale_log2_rounded); + TF_LITE_ENSURE(context, output_scale_is_pot); + + data->input1_shift = input1_scale_log2_rounded - output_scale_log2_rounded; + data->input2_shift = input2_scale_log2_rounded - output_scale_log2_rounded; + + // Shifting of one input is supported. The graph quantization should ensure + // that the other input matches the output. + TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0); + TF_LITE_ENSURE(context, data->input1_shift <= 0); + TF_LITE_ENSURE(context, data->input2_shift <= 0); + + CalculateActivationRangeQuantized(context, params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + return kTfLiteOk; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); + auto* params = reinterpret_cast(node->builtin_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -74,6 +211,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { + TF_LITE_ENSURE_OK(context, Prepare8BitSubOp(context, input1, input2, output, + params, data, -1)); + } else if (output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK(context, PrepareInt16SubOp(context, input1, input2, + output, params, data)); + } + return context->ResizeTensor(context, output, output_size); } @@ -129,60 +274,67 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, const OpData* data, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - const int left_shift = 20; - const double twice_max_input_scale = - 2 * std::max(input1->params.scale, input2->params.scale); - const double real_input1_multiplier = - input1->params.scale / twice_max_input_scale; - const double real_input2_multiplier = - input2->params.scale / twice_max_input_scale; - const double real_output_multiplier = - twice_max_input_scale / ((1 << left_shift) * output->params.scale); + tflite::ArithmeticParams op_params; + op_params.left_shift = data->left_shift; + op_params.input1_offset = data->input1_offset; + op_params.input1_multiplier = data->input1_multiplier; + op_params.input1_shift = data->input1_shift; + op_params.input2_offset = data->input2_offset; + op_params.input2_multiplier = data->input2_multiplier; + op_params.input2_shift = data->input2_shift; + op_params.output_offset = data->output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + SetActivationParams(data->output_activation_min, data->output_activation_max, + &op_params); - int32 input1_multiplier; - int input1_shift; - QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, - &input1_multiplier, &input1_shift); - int32 input2_multiplier; - int input2_shift; - QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, - &input2_multiplier, &input2_shift); - int32 output_multiplier; - int output_shift; - QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, - &output_multiplier, &output_shift); - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_SUB(type, opname) \ - tflite::ArithmeticParams op_params; \ - op_params.left_shift = left_shift; \ - op_params.input1_offset = input1_offset; \ - op_params.input1_multiplier = input1_multiplier; \ - op_params.input1_shift = input1_shift; \ - op_params.input2_offset = input2_offset; \ - op_params.input2_multiplier = input2_multiplier; \ - op_params.input2_shift = input2_shift; \ - op_params.output_offset = output_offset; \ - op_params.output_multiplier = output_multiplier; \ - op_params.output_shift = output_shift; \ - SetActivationParams(output_activation_min, output_activation_max, \ - &op_params); \ - type::opname(op_params, GetTensorShape(input1), \ - GetTensorData(input1), GetTensorShape(input2), \ - GetTensorData(input2), GetTensorShape(output), \ - GetTensorData(output)) - // The quantized version of Sub doesn't support activations, so we - // always use BroadcastSub. - if (kernel_type == kReference) { - TF_LITE_SUB(reference_ops, BroadcastSub4DSlow); + const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( + GetTensorShape(input1), GetTensorShape(input2), &op_params); + +#define TF_LITE_SUB(type, opname, data_type) \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + // NOTE: We are using the add kernels. This is possible as the second values + // multiplier is negated before being passed down. + if (output->type == kTfLiteInt8) { + if (need_broadcast) { + TF_LITE_SUB(reference_integer_ops, BroadcastAdd4DSlow, int8_t); + } else { + TF_LITE_SUB(reference_integer_ops, Add, int8_t); + } + } else if (output->type == kTfLiteUInt8) { + if (kernel_type == kReference) { + if (need_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastAdd4DSlow, uint8_t); + } else { + TF_LITE_SUB(reference_ops, Add, uint8_t); + } + } else { + if (op_params.broadcast_category == + BroadcastableOpCategory::kGenericBroadcast) { + TF_LITE_SUB(optimized_ops, BroadcastAdd4DSlow, uint8_t); + } else if (need_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastAddFivefold, uint8_t); + } else { + TF_LITE_SUB(optimized_ops, Add, uint8_t); + } + } } else { - TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow); + if (kernel_type == kReference) { + if (need_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int16_t); + } else { + TF_LITE_SUB(reference_ops, Sub16, int16_t); + } + } else { + if (need_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int16_t); + } else { + TF_LITE_SUB(optimized_ops, Sub16, int16_t); + } + } } #undef TF_LITE_SUB } @@ -198,7 +350,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { EvalSub(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || + output->type == kTfLiteInt16) { EvalQuantized(context, node, params, data, input1, input2, output); } else { diff --git a/tensorflow/lite/kernels/sub_test.cc b/tensorflow/lite/kernels/sub_test.cc index 41503300ab599fbfcfee425c41033dd3bc10d2ea..3c19678b20f21894461f5ef79b1df6c45e1cac5a 100644 --- a/tensorflow/lite/kernels/sub_test.cc +++ b/tensorflow/lite/kernels/sub_test.cc @@ -63,17 +63,27 @@ class QuantizedSubOpModel : public BaseSubOpModel { public: using BaseSubOpModel::BaseSubOpModel; + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } }; -// for quantized Sub, the error shouldn't exceed 2*step +// for quantized Sub, the error shouldn't exceed step float GetTolerance(int min, int max) { float kQuantizedStep = (max - min) / 255.0; - float kQuantizedTolerance = 2.0 * kQuantizedStep; - return kQuantizedTolerance; + return kQuantizedStep; +} + +float GetToleranceInt16(float min, float max) { + float kQuantizedStep = (max - min) / std::numeric_limits::max(); + return kQuantizedStep; } TEST(FloatSubOpModel, NoActivation) { @@ -183,7 +193,8 @@ TEST(IntegerSubOpModel, WithBroadcast) { } } -TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) { +template +void QuantizedTestsNoActivation() { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { {0.1, 0.2, 0.3, 0.4}, {-0.2, 0.2, 0.4, 0.7}, {-0.01, 0.2, 0.7, 0.3}}; @@ -193,20 +204,30 @@ TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) { {-0.8, -0.2, -0.1, 0.9}, {-0.61, -0.2, 0.88, -0.2}}; for (int i = 0; i < inputs1.size(); ++i) { - QuantizedSubOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}, + QuantizedSubOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {}, -1.0, 1.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), inputs1[i]); - m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - results[i], kQuantizedTolerance))) + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) << "With test number " << i; } } -TEST(QuantizedSubOpModel, QuantizedTestsActivationRELU_N1_TO_1) { +TEST(QuantizedSubOpModel, QuantizedTestsNoActivationUInt8) { + QuantizedTestsNoActivation(); +} + +TEST(QuantizedSubOpModel, QuantizedTestsNoActivationInt8) { + QuantizedTestsNoActivation(); +} + +template +void QuantizedTestsActivationRELU_N1_TO_1() { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, {-0.8, 0.2, 0.7, 0.5}}; @@ -215,57 +236,185 @@ TEST(QuantizedSubOpModel, QuantizedTestsActivationRELU_N1_TO_1) { std::vector> results = {{-1.0, -0.2, 0.0, 1.0}, {-1.0, -0.2, 1.0, 0.2}}; for (int i = 0; i < inputs1.size(); ++i) { - QuantizedSubOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}, + QuantizedSubOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {1, 2, 2, 1}, -1.0, 1.0}, + {tensor_type, {}, -1.0, 1.0}, ActivationFunctionType_RELU_N1_TO_1); - m.QuantizeAndPopulate(m.input1(), inputs1[i]); - m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - results[i], kQuantizedTolerance))) + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) << "With test number " << i; } } +TEST(QuantizedSubOpModel, QuantizedTestsActivationRELUN1TO1UInt8) { + QuantizedTestsActivationRELU_N1_TO_1(); +} + +TEST(QuantizedSubOpModel, QuantizedTestsActivationRELUN1TO1Int8) { + QuantizedTestsActivationRELU_N1_TO_1(); +} -TEST(QuantizedSubOpModel, QuantizedVariousInputShapes) { +template +void QuantizedVariousInputShapes() { float kQuantizedTolerance = GetTolerance(-3.0, 3.0); std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedSubOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, {}, -3.0, 3.0}, + QuantizedSubOpModel m({tensor_type, test_shapes[i], -3.0, 3.0}, + {tensor_type, test_shapes[i], -3.0, 3.0}, + {tensor_type, {}, -3.0, 3.0}, ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); - m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); + m.QuantizeAndPopulate(m.input1(), + {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), + {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( {-2.1, -0.1, 0.4, 0.3, 0.0, 1.9}, kQuantizedTolerance))) << "With shape number " << i; } } -TEST(QuantizedSubOpModel, QuantizedWithBroadcast) { +TEST(QuantizedSubOpModel, QuantizedVariousInputShapesUInt8) { + QuantizedVariousInputShapes(); +} + +TEST(QuantizedSubOpModel, QuantizedVariousInputShapesInt8) { + QuantizedVariousInputShapes(); +} + +template +void QuantizedWithBroadcast() { float kQuantizedTolerance = GetTolerance(-3.0, 3.0); std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedSubOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, {}, -3.0, 3.0}, - {TensorType_UINT8, {}, -3.0, 3.0}, - ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); - m.QuantizeAndPopulate(m.input2(), {0.7}); + QuantizedSubOpModel m( + {tensor_type, test_shapes[i], -3.0, 3.0}, {tensor_type, {}, -3.0, 3.0}, + {tensor_type, {}, -3.0, 3.0}, ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), + {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.7}); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( {-2.7, -0.5, 0.0, 0.1, 0.4, 1.3}, kQuantizedTolerance))) << "With shape number " << i; } } +TEST(QuantizedSubOpModel, QuantizedWithBroadcastUInt8) { + QuantizedWithBroadcast(); +} + +TEST(QuantizedSubOpModel, QuantizedWithBroadcastInt8) { + QuantizedWithBroadcast(); +} + +TEST(QuantizedSubOpModel, QuantizedTestsNoActivationInt16) { + const float kMin = -1.f; + const float kMax = + static_cast(std::numeric_limits::max() - 1) / + std::numeric_limits::max(); + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> inputs1 = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.3, 0.8}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, 0.8, 0.5}}; + std::vector> results = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, -1.0, 0.3}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedSubOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedSubOpModel, QuantizedTestsReluActivationInt16) { + const float kMin = -2.f; + const float kMax = 2.0 * (std::numeric_limits::max() - 1) / + std::numeric_limits::max(); + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, + {-0.8, 0.2, 0.7, 0.5}}; + std::vector> inputs2 = {{0.6, 0.4, 0.9, -0.8}, + {0.6, 0.4, -0.8, 0.3}}; + std::vector> results = {{-1.0, -0.2, 0.0, 1.0}, + {-1.0, -0.2, 1.0, 0.2}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedSubOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_RELU_N1_TO_1); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedSubOpModel, QuantizedTestsNoActivationBroadcastInt16) { + const float kMin = -1.f; + const float kMax = + static_cast(std::numeric_limits::max() - 1) / + std::numeric_limits::max(); + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedSubOpModel m({TensorType_INT16, test_shapes[i], kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), + {-0.9, -0.7, -0.3, 0.0, 0.3, 0.5}); + m.QuantizeAndPopulate(m.input2(), {0.2}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear( + {-1.0, -0.9, -0.5, -0.2, 0.1, 0.3}, kQuantizedTolerance))) + << "With shape number " << i; + } +} + +TEST(QuantizedSubOpModel, QuantizedTestsReluActivationBroadcastInt16) { + const float kMin = -2.f; + const float kMax = 2.0 * (std::numeric_limits::max() - 1) / + std::numeric_limits::max(); + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedSubOpModel m({TensorType_INT16, test_shapes[i], kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_RELU_N1_TO_1); + m.QuantizeAndPopulate(m.input1(), + {-0.9, -0.7, -0.3, 0.0, 0.3, 0.5}); + m.QuantizeAndPopulate(m.input2(), {0.2}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear( + {-1.0, -0.9, -0.5, -0.2, 0.1, 0.3}, kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..e712be1b516ed0dca5097e66c2d1f20e63a78038 --- /dev/null +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -0,0 +1,409 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/subgraph_test_util.h" + +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { + +namespace ops { +namespace builtin { +// ADD and MUL are used to test simple branch. +TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_MUL(); +// ADD and MUL are used to test dynamic sized subgraphs. +TfLiteRegistration* Register_PAD(); +TfLiteRegistration* Register_LESS_EQUAL(); +} // namespace builtin +namespace custom { +TfLiteRegistration* Register_IF(); +TfLiteRegistration* Register_WHILE(); +} // namespace custom +} // namespace ops + +namespace subgraph_test_util { + +namespace { + +void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) { + ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0, + nullptr, {}, false), + kTfLiteOk); +} + +} // namespace + +SubgraphBuilder::~SubgraphBuilder() { + for (auto buffer : buffers_) { + free(buffer); + } +} + +void SubgraphBuilder::BuildAddSubgraph(Subgraph* subgraph) { + const int kInput1 = 0; + const int kInput2 = 1; + const int kOutput = 2; + const int kTensorCount = 3; + // kInput1(0) --> +---+ + // |ADD| --> kOutput(2) + // kInput2(1) --> +---+ + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk); + + SetupTensor(subgraph, kInput1, kTfLiteInt32); + SetupTensor(subgraph, kInput2, kTfLiteInt32); + SetupTensor(subgraph, kOutput, kTfLiteInt32); + + TfLiteAddParams* params = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + params->activation = kTfLiteActNone; + int node_index; + subgraph->AddNodeWithParameters( + {kInput1, kInput2}, {kOutput}, nullptr, 0, params, + ::tflite::ops::builtin::Register_ADD(), &node_index); +} + +// Build a subgraph with an mul op. Helper function for testing. +void SubgraphBuilder::BuildMulSubgraph(Subgraph* subgraph) { + const int kInput1 = 0; + const int kInput2 = 1; + const int kOutput = 2; + const int kTensorCount = 3; + // kInput1(0) --> +---+ + // |MUL| --> kOutput(2) + // kInput2(1) --> +---+ + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk); + + SetupTensor(subgraph, kInput1, kTfLiteInt32); + SetupTensor(subgraph, kInput2, kTfLiteInt32); + SetupTensor(subgraph, kOutput, kTfLiteInt32); + + TfLiteMulParams* params = + reinterpret_cast(malloc(sizeof(TfLiteMulParams))); + params->activation = kTfLiteActNone; + int node_index; + subgraph->AddNodeWithParameters( + {kInput1, kInput2}, {kOutput}, nullptr, 0, params, + ::tflite::ops::builtin::Register_MUL(), &node_index); +} + +// Build a subgraph with a pad op. Helper function for testing. +void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) { + const int kInput1 = 0; + const int kInput2 = 1; + const int kOutput = 2; + const int kTensorCount = 3; + // kInput1(0) --> +---+ + // |PAD| --> kOutput(2) + // kInput2(1) --> +---+ + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk); + + SetupTensor(subgraph, kInput1, kTfLiteInt32); + SetupTensor(subgraph, kInput2, kTfLiteInt32); + SetupTensor(subgraph, kOutput, kTfLiteInt32); + + TfLitePadParams* params = + reinterpret_cast(malloc(sizeof(TfLitePadParams))); + int node_index; + subgraph->AddNodeWithParameters( + {kInput1, kInput2}, {kOutput}, nullptr, 0, params, + ::tflite::ops::builtin::Register_PAD(), &node_index); +} + +void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) { + const int kCondInput = 0; + const int kInput1 = 1; + const int kInput2 = 2; + const int kOutput = 3; + const int kTensorCount = 4; + + // kCondInput(0) --> +----+ + // kInput1(1) ----> | IF | --> kOutput(3) + // kInput2(2) ----> +----+ + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kCondInput, kInput1, kInput2}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk); + + SetupTensor(subgraph, kCondInput, kTfLiteBool); + SetupTensor(subgraph, kInput1, kTfLiteInt32); + SetupTensor(subgraph, kInput2, kTfLiteInt32); + SetupTensor(subgraph, kOutput, kTfLiteInt32); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("then_subgraph_index", 1); + fbb.Int("else_subgraph_index", 2); + }); + fbb.Finish(); + const auto& buffer = fbb.GetBuffer(); + + int node_index; + subgraph->AddNodeWithParameters( + {kCondInput, kInput1, kInput2}, {kOutput}, + reinterpret_cast(buffer.data()), buffer.size(), nullptr, + ::tflite::ops::custom::Register_IF(), &node_index); +} + +void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) { + const int kInput1 = 0; + const int kInput2 = 1; + const int kOutput = 2; + const int kConstRhs = 3; + const int kTensorCount = 4; + + // kInput1(0) ----> +------------+ + // | LESS_EQUAL | --> kOutput(2) + // kConstRhs(3) --> +------------+ + // + // kInput2(1) --> (unused) + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk); + + SetupTensor(subgraph, kInput1, kTfLiteInt32); + SetupTensor(subgraph, kInput2, kTfLiteInt32); + SetupTensor(subgraph, kOutput, kTfLiteBool); + + CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs}); + int node_index; + subgraph->AddNodeWithParameters( + {kInput1, kConstRhs}, {kOutput}, nullptr, 0, nullptr, + ::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index); +} + +void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) { + const int kInputCounter = 0; + const int kInputValue = 1; + const int kOutputCounter = 2; + const int kOutputValue = 3; + const int kConstStep = 4; + const int kTensorCount = 5; + + // kInputCounter(0) --> +-----+ + // | ADD | --> kOutputCounter(2) + // kConstStep(4) -----> +-----+ | + // | + // v + // +-----+ + // | ADD | --> kOutputValue(3) + // kInputValue(1) ----------------------+-----+ + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue}), kTfLiteOk); + + SetupTensor(subgraph, kInputCounter, kTfLiteInt32); + SetupTensor(subgraph, kInputValue, kTfLiteInt32); + SetupTensor(subgraph, kOutputCounter, kTfLiteInt32); + SetupTensor(subgraph, kOutputValue, kTfLiteInt32); + CreateConstantInt32Tensor(subgraph, kConstStep, {1}, {1}); + + int node_index; + TfLiteAddParams* params = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + params->activation = kTfLiteActNone; + subgraph->AddNodeWithParameters({0, 4}, {2}, nullptr, 0, params, + ::tflite::ops::builtin::Register_ADD(), + &node_index); + params = reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + params->activation = kTfLiteActNone; + subgraph->AddNodeWithParameters({2, 1}, {3}, nullptr, 0, params, + ::tflite::ops::builtin::Register_ADD(), + &node_index); +} + +void SubgraphBuilder::BuildPadLoopBodySubgraph(Subgraph* subgraph, + const std::vector padding) { + const int kInputCounter = 0; + const int kInputValue = 1; + const int kOutputCounter = 2; + const int kOutputValue = 3; + const int kConstStep = 4; + const int kConstPadding = 5; + const int kTensorCount = 6; + + // kInputCounter(0) --> +-----+ + // | ADD | --> kOutputCounter(2) + // kConstStep(4) -----> +-----+ + // + // kInputValue(1) ----> +-----+ + // | PAD | --> kOutputValue(3) + // kConstPadding(5) --> +-----+ + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue}), kTfLiteOk); + + SetupTensor(subgraph, kInputCounter, kTfLiteInt32); + SetupTensor(subgraph, kInputValue, kTfLiteInt32); + SetupTensor(subgraph, kOutputCounter, kTfLiteInt32); + SetupTensor(subgraph, kOutputValue, kTfLiteInt32); + + CreateConstantInt32Tensor(subgraph, kConstStep, {1}, {1}); + ASSERT_EQ(padding.size() % 2, 0); + int padding_dims = padding.size(); + CreateConstantInt32Tensor(subgraph, kConstPadding, {1, padding_dims}, + padding); + + int node_index; + TfLiteAddParams* add_params = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + add_params->activation = kTfLiteActNone; + subgraph->AddNodeWithParameters( + {kInputCounter, kConstStep}, {kOutputCounter}, nullptr, 0, add_params, + ::tflite::ops::builtin::Register_ADD(), &node_index); + TfLitePadParams* pad_params = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + subgraph->AddNodeWithParameters( + {kInputValue, kConstPadding}, {kOutputValue}, nullptr, 0, pad_params, + ::tflite::ops::builtin::Register_PAD(), &node_index); +} + +void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) { + const int kInput1 = 0; + const int kInput2 = 1; + const int kOutput1 = 2; + const int kOutput2 = 3; + const int kTensorCount = 4; + + // kInput1(0) --> +-------+ --> kOutput1(2) + // | WHILE | + // kInput2(1) --> +-------+ --> kOutput2(3) + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput1, kOutput2}), kTfLiteOk); + + SetupTensor(subgraph, kInput1, kTfLiteInt32); + SetupTensor(subgraph, kInput2, kTfLiteInt32); + SetupTensor(subgraph, kOutput1, kTfLiteInt32); + SetupTensor(subgraph, kOutput2, kTfLiteInt32); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("cond_subgraph_index", 1); + fbb.Int("body_subgraph_index", 2); + }); + fbb.Finish(); + const auto& buffer = fbb.GetBuffer(); + + int node_index; + subgraph->AddNodeWithParameters( + {0, 1}, {2, 3}, reinterpret_cast(buffer.data()), + buffer.size(), nullptr, ::tflite::ops::custom::Register_WHILE(), + &node_index); +} + +void SubgraphBuilder::CreateConstantInt32Tensor(Subgraph* subgraph, + int tensor_index, + const std::vector& shape, + const std::vector& data) { + ASSERT_GT(shape.size(), 0); + int num_elements = 1; + for (int dim : shape) { + num_elements *= dim; + } + ASSERT_EQ(data.size(), num_elements); + size_t size_in_bytes = sizeof(int32_t) * num_elements; + // Maybe aligned. + int32_t* buffer = reinterpret_cast(malloc(size_in_bytes)); + for (int i = 0; i < num_elements; ++i) { + buffer[i] = data[i]; + } + buffers_.push_back(buffer); + ASSERT_EQ(subgraph->SetTensorParametersReadOnly( + tensor_index, kTfLiteInt32, "", shape, {}, + reinterpret_cast(buffer), size_in_bytes), + kTfLiteOk); +} + +void FillIntTensor(TfLiteTensor* tensor, const std::vector& data) { + int count = NumElements(tensor); + ASSERT_EQ(count, data.size()); + for (int i = 0; i < count; ++i) { + tensor->data.i32[i] = data[i]; + } +} + +void CheckIntTensor(const TfLiteTensor* tensor, const std::vector& shape, + const std::vector& data) { + ASSERT_EQ(tensor->dims->size, shape.size()); + for (int i = 0; i < tensor->dims->size; ++i) { + ASSERT_EQ(tensor->dims->data[i], shape[i]); + } + ASSERT_EQ(tensor->type, kTfLiteInt32); + int count = NumElements(tensor); + ASSERT_EQ(count, data.size()); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(tensor->data.i32[i], data[i]); + } +} + +void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector& shape, + const std::vector& data) { + ASSERT_EQ(tensor->dims->size, shape.size()); + for (int i = 0; i < tensor->dims->size; ++i) { + ASSERT_EQ(tensor->dims->data[i], shape[i]); + } + ASSERT_EQ(tensor->type, kTfLiteBool); + int count = NumElements(tensor); + ASSERT_EQ(count, data.size()); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(tensor->data.b[i], data[i]); + } +} + +} // namespace subgraph_test_util +} // namespace tflite diff --git a/tensorflow/lite/kernels/subgraph_test_util.h b/tensorflow/lite/kernels/subgraph_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..972f1381af2804252461bf81dfbce3563be41c3b --- /dev/null +++ b/tensorflow/lite/kernels/subgraph_test_util.h @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This module provides helper functions for testing the interaction between +// control flow ops and subgraphs. +// For convenience, we mostly only use `kTfLiteInt32` in this module. + +#ifndef TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_ +#define TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_ + +#include +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/interpreter.h" + +namespace tflite { +namespace subgraph_test_util { + +// TODO(ycling): This file should be renamed as +// `control_flow_test_util` to avoid confusion. I'll do it immediately +// in a separated change. +class SubgraphBuilder { + public: + ~SubgraphBuilder(); + + // Build a subgraph with a single Add op. + // 2 inputs. 1 output. + void BuildAddSubgraph(Subgraph* subgraph); + + // Build a subgraph with a single Mul op. + // 2 inputs. 1 output. + void BuildMulSubgraph(Subgraph* subgraph); + + // Build a subgraph with a single Pad op. + // 2 inputs. 1 output. + void BuildPadSubgraph(Subgraph* subgraph); + + // Build a subgraph with a single If op. + // 3 inputs: + // The 1st input is condition with boolean type. + // The 2nd and 3rd inputs are feed input the branch subgraphs. + // 1 output. + void BuildIfSubgraph(Subgraph* subgraph); + + // Build a subgraph with a single Less op. + // The subgraph is used as the condition subgraph for testing `While` op. + // 2 inputs: + // The 1st input is a counter with `kTfLiteInt32` type. + // The 2nd input is ignored in this subgraph. + // 1 output with `kTfLiteBool` type. + // Equivalent to (input < rhs). + void BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs); + + // An accumulate loop body subgraph. Used to produce triangle number + // seqeuence. 2 inputs and 2 outpus + // Equivalent to (counter, value) -> (counter + 1, counter + 1 + value) + void BuildAccumulateLoopBodySubgraph(Subgraph* subgraph); + + // A pad loop body subgraph. When used in a loop it will repeatively enlarge + // the + // tensor. + // 2 inputs and 2 outputs. + // Equivalent to (counter, value) -> (counter + 1, tf.pad(value, padding)) + // Note the padding is created as a constant tensor. + void BuildPadLoopBodySubgraph(Subgraph* subgraph, + const std::vector padding); + + // Build a subgraph with a single While op. + // 2 inputs, 2 outputs. + void BuildWhileSubgraph(Subgraph* subgraph); + + private: + void CreateConstantInt32Tensor(Subgraph* subgraph, int tensor_index, + const std::vector& shape, + const std::vector& data); + std::vector buffers_; +}; + +class ControlFlowOpTest : public ::testing::Test { + public: + ControlFlowOpTest() + : interpreter_(new Interpreter), builder_(new SubgraphBuilder) {} + + ~ControlFlowOpTest() override { + interpreter_.reset(); + builder_.reset(); + } + + protected: + std::unique_ptr interpreter_; + std::unique_ptr builder_; +}; + +// Fill a `TfLiteTensor` with a 32-bits integer vector. +// Preconditions: +// * The tensor must have `kTfLiteInt32` type. +// * The tensor must be allocated. +// * The element count of the tensor must be equal to the length or +// the vector. +void FillIntTensor(TfLiteTensor* tensor, const std::vector& data); + +// Check if the shape and int32 data of a tensor is as expected. +void CheckIntTensor(const TfLiteTensor* tensor, const std::vector& shape, + const std::vector& data); +// Check if the shape and bool data of a tensor is as expected. +void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector& shape, + const std::vector& data); + +} // namespace subgraph_test_util +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_ diff --git a/tensorflow/lite/kernels/subgraph_test_util_test.cc b/tensorflow/lite/kernels/subgraph_test_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..04e5118b543c1723e1de1875ffd9315991a4dd69 --- /dev/null +++ b/tensorflow/lite/kernels/subgraph_test_util_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/subgraph_test_util.h" +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" + +namespace tflite { + +namespace subgraph_test_util { + +namespace { + +class SubgraphBuilderTest : public ::testing::Test { + public: + SubgraphBuilderTest() + : interpreter_(new Interpreter), builder_(new SubgraphBuilder) {} + + ~SubgraphBuilderTest() override { + interpreter_.reset(); + builder_.reset(); + } + + protected: + void TestAccumelateLoopBody(int input1, int input2, int output1, + int output2) { + interpreter_.reset(new Interpreter); + builder_->BuildAccumulateLoopBodySubgraph( + &interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {input1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {input2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output_tensor1 = + interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output_tensor1, {1}, {output1}); + TfLiteTensor* output_tensor2 = + interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output_tensor2, {1}, {output2}); + } + + std::unique_ptr interpreter_; + std::unique_ptr builder_; +}; + +TEST_F(SubgraphBuilderTest, TestBuildAddSubgraph) { + builder_->BuildAddSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output, {1, 2}, {6, 9}); +} + +TEST_F(SubgraphBuilderTest, TestBuildMulSubgraph) { + builder_->BuildMulSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output, {1, 2}, {5, 14}); +} + +TEST_F(SubgraphBuilderTest, TestBuildPadSubgraph) { + builder_->BuildPadSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output, {5}, {0, 5, 7, 0, 0}); +} + +TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) { + builder_->BuildLessEqualCondSubgraph(&interpreter_->primary_subgraph(), 3); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {10, 10}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + // Test [1, 2, 3, 4, 5] <= 3 == [true, true, true, false, false] + // (with broadcasting). + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), + {1, 2, 3, 4, 5}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + CheckBoolTensor(output, {5}, {true, true, true, false, false}); +} + +TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) { + TestAccumelateLoopBody(1, 1, 2, 3); + TestAccumelateLoopBody(2, 3, 3, 6); + TestAccumelateLoopBody(3, 6, 4, 10); +} + +TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) { + builder_->BuildPadLoopBodySubgraph(&interpreter_->primary_subgraph(), {1, 2}); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), + {0, 5, 7, 0, 0}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {2}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {8}, {0, 0, 5, 7, 0, 0, 0, 0}); +} + +} // namespace +} // namespace subgraph_test_util +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index 19d7e37409cba2f4b018082d13a2d3e130a3c5c4..dd8d9ed21830f852c9a13f09626278f4ca4938e2 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -47,7 +47,12 @@ std::vector>> ArrayComplex64Near( } int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { - int id = AddTensor(t, {}, is_variable); + int id = 0; + if (t.per_channel_quantization) { + id = AddTensorPerChannelQuant(t); + } else { + id = AddTensor(t, {}, is_variable); + } inputs_.push_back(id); return id; } @@ -119,7 +124,7 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, CHECK(interpreter_ != nullptr); - for (int i = 0; i < input_shapes.size(); ++i) { + for (size_t i = 0; i < input_shapes.size(); ++i) { const int input_idx = interpreter_->inputs()[i]; if (input_idx == kOptionalTensor) continue; const auto& shape = input_shapes[i]; diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 4a442f9fa7554fbc5c149e1dd20f82c162d392d4..08c027f9d9d4fb469abecd460a09cdd543214215 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -21,13 +21,14 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/tools/optimize/quantization_utils.h" namespace tflite { @@ -82,14 +83,36 @@ inline std::vector Dequantize(const std::vector& data, float scale, // A helper struct to construct test tensors. This is particularly useful for // quantized tensor which must have their scale and zero_point defined before // the actual data is known. This mimics what happens in practice: quantization -// parameters are calculated during training. +// parameters are calculated during training or post training.. struct TensorData { + TensorData(TensorType type = TensorType_FLOAT32, std::vector shape = {}, + float min = 0.0f, float max = 0.0f, float scale = 0.0f, + int32_t zero_point = 0, bool per_channel_quantization = false, + std::vector per_channel_quantization_scales = {}, + std::vector per_channel_quantization_offsets = {}, + int32_t channel_index = 0) + : type(type), + shape(shape), + min(min), + max(max), + scale(scale), + zero_point(zero_point), + per_channel_quantization(per_channel_quantization), + per_channel_quantization_scales( + std::move(per_channel_quantization_scales)), + per_channel_quantization_offsets( + std::move(per_channel_quantization_offsets)), + channel_index(channel_index) {} TensorType type; std::vector shape; float min; float max; float scale; int32_t zero_point; + bool per_channel_quantization; + std::vector per_channel_quantization_scales; + std::vector per_channel_quantization_offsets; + int32_t channel_index; }; class SingleOpResolver : public OpResolver { @@ -172,6 +195,46 @@ class SingleOpModel { PopulateTensor(index, /*offset=*/0, q.data(), q.data() + q.size()); } + // Quantize and populate data for filter with per channel quantization. + void PerChannelSymmetricQuantizeAndPopulate( + int index, const std::vector& input_data) { + TfLiteTensor* t = interpreter_->tensor(index); + auto* params = + reinterpret_cast(t->quantization.params); + const int channel_index = params->quantized_dimension; + + std::vector shape(t->dims->size); + for (int i = 0; i < shape.size(); ++i) { + shape[i] = t->dims->data[i]; + } + const int32_t num_inputs = input_data.size(); + const int32_t num_channel = shape[channel_index]; + std::vector quantized_output(num_inputs); + std::vector scales_inv(num_channel); + for (int i = 0; i < num_channel; ++i) { + scales_inv[i] = 1.0f / params->scale->data[i]; + } + optimize::utils::SymmetricPerChannelQuantizeValues( + input_data.data(), scales_inv, shape, channel_index, &quantized_output); + + PopulateTensor(index, /*offset=*/0, quantized_output.data(), + quantized_output.data() + quantized_output.size()); + } + + // Quantize and populate data for bias with per channel quantization. + void PerChannelQuantizeBias(int index, const std::vector& input_data) { + const int32_t num_inputs = input_data.size(); + std::vector quantized_output(num_inputs); + TfLiteTensor* t = interpreter_->tensor(index); + auto* params = + reinterpret_cast(t->quantization.params); + for (int i = 0; i < num_inputs; ++i) { + quantized_output[i] = input_data[i] * params->scale->data[i]; + } + PopulateTensor(index, /*offset=*/0, quantized_output.data(), + quantized_output.data() + quantized_output.size()); + } + const std::vector& GetShape(int id) { return tensor_data_.at(id).shape; } float GetScale(int id) { return tensor_data_.at(id).scale; } @@ -263,6 +326,10 @@ class SingleOpModel { return result; } + void SetNumThreads(int num_threads) { + interpreter_->SetNumThreads(num_threads); + } + void SetResolver(std::unique_ptr resolver) { resolver_ = std::move(resolver); } @@ -292,6 +359,24 @@ class SingleOpModel { return {scale, zero_point}; } + int AddTensorPerChannelQuant(TensorData t) { + const int id = tensors_.size(); + flatbuffers::Offset q_params = 0; + q_params = CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + /*scale=*/ + builder_.CreateVector(t.per_channel_quantization_scales), + /*zero point=*/ + builder_.CreateVector(t.per_channel_quantization_offsets), + QuantizationDetails_NONE, 0, t.channel_index); + tensors_.push_back( + CreateTensor(builder_, builder_.CreateVector(t.shape), t.type, + /*buffer=*/0, + /*name=*/0, q_params, /*is_variable=*/false)); + tensor_data_[id] = t; + return id; + } + template int AddTensor(TensorData t, std::initializer_list data, bool is_variable = false) { @@ -367,6 +452,17 @@ class SingleOpModel { // Update quantization params. t->params.scale = scaling_factor; t->params.zero_point = 0; + // Populate the new quantization params. + TfLiteQuantizationFree(&t->quantization); + t->quantization.type = kTfLiteAffineQuantization; + auto* affine_quantization = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + affine_quantization->quantized_dimension = 0; + affine_quantization->scale = TfLiteFloatArrayCreate(1); + affine_quantization->zero_point = TfLiteIntArrayCreate(1); + affine_quantization->scale->data[0] = scaling_factor; + affine_quantization->zero_point->data[0] = 0; + t->quantization.params = affine_quantization; return q; } diff --git a/tensorflow/lite/kernels/topk_v2.cc b/tensorflow/lite/kernels/topk_v2.cc index 444b01e7b2e055ab4e26a2ea1dce28642dc430b7..64973d7b860fc3089850cc3648ee4fb9da11047e 100644 --- a/tensorflow/lite/kernels/topk_v2.cc +++ b/tensorflow/lite/kernels/topk_v2.cc @@ -207,6 +207,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32, output_values->data.uint8); break; + case kTfLiteInt8: + TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32, + output_values->data.int8); + break; case kTfLiteInt32: TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32, output_values->data.i32); diff --git a/tensorflow/lite/kernels/topk_v2_test.cc b/tensorflow/lite/kernels/topk_v2_test.cc index 108b8123666aaddcc8ba8438bac82c91ce98d50d..0097ae2f9aece116c963a4b460c2e3ff0fc127c4 100644 --- a/tensorflow/lite/kernels/topk_v2_test.cc +++ b/tensorflow/lite/kernels/topk_v2_test.cc @@ -46,6 +46,10 @@ class TopKV2OpModel : public SingleOpModel { PopulateTensor(input_, data); } + void SetInputInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInputInt32(std::initializer_list data) { PopulateTensor(input_, data); } @@ -66,6 +70,10 @@ class TopKV2OpModel : public SingleOpModel { return ExtractVector(output_values_); } + std::vector GetValuesInt8() { + return ExtractVector(output_values_); + } + std::vector GetValuesInt32() { return ExtractVector(output_values_); } @@ -128,6 +136,14 @@ TEST(TopKV2OpTest, TypeUint8) { EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); } +TEST(TopKV2OpTest, TypeInt8) { + TopKV2OpModel m({2, 3}, TensorType_INT8, 2); + m.SetInputInt8({1, 2, 3, -126, 125, -24}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 1, 2})); + EXPECT_THAT(m.GetValuesInt8(), ElementsAreArray({3, 2, 125, -24})); +} + // Check that int32_t works. TEST(TopKV2OpTest, TypeInt32) { TopKV2OpModel m({2, 3}, TensorType_INT32, 2); diff --git a/tensorflow/lite/kernels/transpose.cc b/tensorflow/lite/kernels/transpose.cc index 7a6d320674ad1c8302f8bf3a9d1d5153223deed3..0ef4972d1a856f84d3511657ec9d9f2f3cc36182 100644 --- a/tensorflow/lite/kernels/transpose.cc +++ b/tensorflow/lite/kernels/transpose.cc @@ -117,6 +117,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_TRANSPOSE(reference_ops, uint8_t); } break; + case kTfLiteInt8: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, int8_t); + } + break; case kTfLiteInt32: if (kernel_type == kReference) { TF_LITE_TRANSPOSE(reference_ops, int32_t); diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index 44d1336b99fe03535451c7dbacfe77be58fd6fad..05d3451d0056649b8cc958974b105774a1bdcfc3 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -57,7 +57,7 @@ class TransposeConvOpModel : public SingleOpModel { resolver_ = absl::make_unique( BuiltinOperator_TRANSPOSE_CONV, registration); BuildInterpreter( - {GetShape(output_shape_), GetShape(input_), GetShape(filter_)}); + {GetShape(output_shape_), GetShape(filter_), GetShape(input_)}); } void SetOutputShape(std::initializer_list i) { @@ -97,8 +97,8 @@ class TransposeConvOpTest : public SingleOpTest { // [1, 1, 1, 1 ], // "SAME") TEST_P(TransposeConvOpTest, SimpleTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 4, 4, 1}}, - {TensorType_FLOAT32, {1, 3, 3, 1}}, + TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 4, 4, 1}}, {TensorType_FLOAT32, {}}, Padding_SAME, 1, 1); m.SetOutputShape({1, 4, 4, 1}); m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -125,8 +125,8 @@ TEST_P(TransposeConvOpTest, SimpleTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) TEST_P(TransposeConvOpTest, TwoFiltersTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 4, 4, 2}}, - {TensorType_FLOAT32, {1, 3, 3, 2}}, + TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 2}}, + {TensorType_FLOAT32, {1, 4, 4, 2}}, {TensorType_FLOAT32, {}}, Padding_SAME, 1, 1); m.SetOutputShape({1, 4, 4, 1}); m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); @@ -153,8 +153,8 @@ TEST_P(TransposeConvOpTest, TwoFiltersTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) TEST_P(TransposeConvOpTest, PaddingValidTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 4, 4, 2}}, - {TensorType_FLOAT32, {1, 3, 3, 2}}, + TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 2}}, + {TensorType_FLOAT32, {1, 4, 4, 2}}, {TensorType_FLOAT32, {}}, Padding_VALID, 1, 1); m.SetOutputShape({1, 6, 6, 1}); m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); @@ -182,8 +182,8 @@ TEST_P(TransposeConvOpTest, PaddingValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST_P(TransposeConvOpTest, StrideValidTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {1, 3, 3, 1}}, + TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, Padding_VALID, 2, 2); m.SetOutputShape({1, 5, 5, 1}); m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -208,8 +208,8 @@ TEST_P(TransposeConvOpTest, StrideValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST_P(TransposeConvOpTest, MultiChannelTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {2, 3, 3, 1}}, + TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, Padding_VALID, 2, 2); m.SetOutputShape({1, 5, 5, 2}); m.SetFilter({1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}); @@ -238,8 +238,8 @@ TEST_P(TransposeConvOpTest, MultiChannelTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1]) TEST_P(TransposeConvOpTest, AccuracyTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 1, 2, 1}}, - {TensorType_FLOAT32, {1, 3, 3, 1}}, + TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 1, 2, 1}}, {TensorType_FLOAT32, {}}, Padding_SAME, 3, 3); m.SetOutputShape({1, 3, 4, 1}); m.SetFilter({9, 5, 6, 9, 8, 5, 3, 1, 4}); diff --git a/tensorflow/lite/kernels/transpose_test.cc b/tensorflow/lite/kernels/transpose_test.cc index 93df2c81db8c17de7a36d155c7d26b826c859c99..71644159209cc289329f65d1cac929585f2f4200 100644 --- a/tensorflow/lite/kernels/transpose_test.cc +++ b/tensorflow/lite/kernels/transpose_test.cc @@ -25,16 +25,17 @@ namespace { using ::testing::ElementsAreArray; +template void RunTestPermutation(const std::vector& shape, const std::vector& perms, - std::vector* input_transposed) { + std::vector* input_transposed) { // Count elements and allocate output. int count = 1; for (auto factor : shape) count *= factor; input_transposed->resize(count); // Create the dummy data - std::vector input(count); + std::vector input(count); for (int i = 0; i < input.size(); i++) { input[i] = i; } @@ -64,8 +65,8 @@ void RunTestPermutation(const std::vector& shape, params.perm[i] = perms[i]; } - reference_ops::Transpose(params, input_shape, input.data(), - output_shape, input_transposed->data()); + reference_ops::Transpose(params, input_shape, input.data(), output_shape, + input_transposed->data()); } TEST(TransposeTest, TestRefOps1D) { @@ -125,6 +126,28 @@ TEST(TransposeTest, TestRefOps4D) { ASSERT_EQ(out, ref); } +TEST(TransposeTest, TestRefOps4DInt8) { + std::vector out; + // Basic 4d. + RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out); + ASSERT_EQ( + out, + std::vector( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119})); + RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out); + // Basic identity. + std::vector ref(out.size()); + for (int k = 0; k < ref.size(); k++) ref[k] = k; + ASSERT_EQ(out, ref); +} + class TransposeOpModel : public SingleOpModel { public: void SetInput(std::initializer_list data) { diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index e2fc73ba29b5c96ad83536fb8752c11d70191d4d..8c2d0d57c7b794c74d5b48f8d902a69efba89645 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -246,7 +246,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, } // Resize the output and state tensors based on the sizes of the input tensors. -// Allocate a temprory scratch tensor. Also check that the sizes of the input +// Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int* scratch_tensor_index = reinterpret_cast(node->user_data); diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc index 1caffe14f90b8ce9d13d8c781e87bf918c02b9f4..99ad4bb4e817ed435043fb17469381192db843ff 100644 --- a/tensorflow/lite/kernels/unpack.cc +++ b/tensorflow/lite/kernels/unpack.cc @@ -52,9 +52,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE(context, NumDimensions(input) <= 4); TF_LITE_ENSURE(context, NumDimensions(input) > 1); - TF_LITE_ENSURE(context, NumDimensions(input) > data->axis); - // TODO(renjieliu): Support negative axis. - TF_LITE_ENSURE(context, data->axis >= 0); + int axis = data->axis; + if (axis < 0) { + axis += NumDimensions(input); + } + TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) { context->ReportError(context, "Currently pack only supports int32 and float32."); @@ -67,12 +69,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1); int o = 0; for (int index = 0; index < NumDimensions(input); ++index) { - if (index != data->axis) { + if (index != axis) { output_shape->data[o++] = input_shape->data[index]; } } - TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]); + TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]); for (int i = 0; i < data->num; ++i) { TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape); TfLiteTensor* output = GetOutput(context, node, i); diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc index 9b60cce549804a59e343f3e26f978679a1624c00..76f7dff93e395414f0e5a69fe4cef151a7517315 100644 --- a/tensorflow/lite/kernels/unpack_test.cc +++ b/tensorflow/lite/kernels/unpack_test.cc @@ -28,14 +28,16 @@ template class UnpackOpModel : public SingleOpModel { public: UnpackOpModel(const TensorData& input, int axis) { - CHECK_LE(axis, input.shape.size()); + if (axis < 0) { + axis += input.shape.size(); + } const int num_outputs = input.shape[axis]; input_ = AddInput(input); for (int i = 0; i < num_outputs; ++i) { outputs_.push_back(AddOutput(input.type)); } SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions, - CreatePackOptions(builder_, num_outputs, axis).Union()); + CreateUnpackOptions(builder_, num_outputs, axis).Union()); BuildInterpreter({GetShape(input_)}); } @@ -104,6 +106,44 @@ TEST(UnpackOpTest, FloatThreeOutputsAxisOne) { EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6)); } +TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisOne) { + UnpackOpModel model({TensorType_FLOAT32, {3, 2}}, -1); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(3)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + + // Check outputs values. + const std::vector>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6)); +} + +TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisTwo) { + UnpackOpModel model({TensorType_FLOAT32, {3, 2}}, -2); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 3); + EXPECT_THAT(output_shapes[0], ElementsAre(2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + + // Check outputs values. + const std::vector>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 3); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2)); + EXPECT_THAT(output_datas[1], ElementsAre(3, 4)); + EXPECT_THAT(output_datas[2], ElementsAre(5, 6)); +} + TEST(UnpackOpTest, FloatOneOutput) { UnpackOpModel model({TensorType_FLOAT32, {1, 6}}, 0); model.SetInput({1, 2, 3, 4, 5, 6}); diff --git a/tensorflow/lite/kernels/where.cc b/tensorflow/lite/kernels/where.cc new file mode 100644 index 0000000000000000000000000000000000000000..96ee36f08bc0144ce0e4a66d3d7350a791d26d86 --- /dev/null +++ b/tensorflow/lite/kernels/where.cc @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace where { + +constexpr int kInputConditionTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const TfLiteTensor* cond_tensor, + TfLiteTensor* output_tensor) { + // Output tensor should have shape: + // (num_true, cond_rank), where num_true denotes the number of true values + // in condition. + const RuntimeShape& cond_shape = GetTensorShape(cond_tensor); + const int size = cond_shape.FlatSize(); + const int cond_rank = cond_shape.DimensionsCount(); + const bool* cond_data = GetTensorData(cond_tensor); + + int true_count = 0; + for (int i = 0; i < size; ++i) { + if (cond_data[i]) { + true_count++; + } + } + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2); + output_dims->data[0] = true_count; + output_dims->data[1] = cond_rank; + return context->ResizeTensor(context, output_tensor, output_dims); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* cond_tensor = + GetInput(context, node, kInputConditionTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (cond_tensor->type != kTfLiteBool) { + context->ReportError(context, + "Condition tensor must be of type bool, but saw '%s'.", + TfLiteTypeGetName(cond_tensor->type)); + return kTfLiteError; + } + + // As output will be a 2D tensor of indices, we use int32 as data type. + output->type = kTfLiteInt32; + + // Exit early if cond is a non-const tensor. Set output tensor to dynamic so + // output size can be determined in Eval. + if (!IsConstantTensor(cond_tensor)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, cond_tensor, output); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* cond_tensor = + GetInput(context, node, kInputConditionTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, cond_tensor, output)); + } + + reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor), + GetTensorData(cond_tensor), + GetTensorData(output)); + return kTfLiteOk; +} +} // namespace where + +TfLiteRegistration* Register_WHERE() { + static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr, + where::Prepare, where::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/where_test.cc b/tensorflow/lite/kernels/where_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..89bd7c43646f80f8b0adb4ef4026f1d9bc7b43c4 --- /dev/null +++ b/tensorflow/lite/kernels/where_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseWhereOpModel : public SingleOpModel { + public: + BaseWhereOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions, + CreateWhereOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + protected: + int input_; + int output_; +}; + +class IntegerWhereOpModel : public BaseWhereOpModel { + public: + using BaseWhereOpModel::BaseWhereOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(WhereOpTest, SelectFromVectorNoResult) { + IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {false, false, false}); + m.Invoke(); + EXPECT_THAT(m.GetOutput().size(), 0); +} + +TEST(WhereOpTest, SelectFromVector) { + IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2})); +} + +TEST(WhereOpTest, SelectFromMatrixNoResult) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {false, false, false, // + false, false, false, // + false, false, false}); + m.Invoke(); + EXPECT_EQ(m.GetOutput().size(), 0); +} + +TEST(WhereOpTest, SelectFromMatrix1) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 1}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, // + 2, 0})); +} + +TEST(WhereOpTest, SelectFromMatrix2) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, true, false, // + true, false, false, // + true, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, // + 0, 1, // + 1, 0, // + 2, 0, // + 2, 2})); +} + +TEST(WhereOpTest, SelectFromMatrix3) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 5}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, false, true, true, // + false, true, true, false, false, // + true, false, true, false, false}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, // + 0, 3, // + 0, 4, // + 1, 1, // + 1, 2, // + 2, 0, // + 2, 2})); +} + +TEST(WhereOpTest, SelectFromRank3TensorNoResult) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {false, false, false, false, // + false, false, false, false}); + m.Invoke(); + EXPECT_EQ(m.GetOutput().size(), 0); +} + +TEST(WhereOpTest, SelectFromRank3Tensor1) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 1, 3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, true, // + false, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, // + 0, 0, 2, // + 1, 0, 2})); +} + +TEST(WhereOpTest, SelectFromRank3Tensor2) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, true, false, true, // + false, false, true, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, // + 0, 0, 1, // + 0, 1, 1, // + 1, 1, 0, // + 1, 1, 1})); +} + +TEST(WhereOpTest, SelectFromRank3Tensor3) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, true, false, true, false, false, // + false, false, true, false, true, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, // + 0, 0, 1, // + 0, 1, 1, // + 1, 1, 0, // + 1, 2, 0, // + 1, 2, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3f00d3fe13c40fb9caa2a168c2b56d9ef1a73c0 --- /dev/null +++ b/tensorflow/lite/kernels/while.cc @@ -0,0 +1,312 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace while_kernel { + +namespace { + +// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` +// to `dst_tensor_indices` in `dst_subgraph`. +template +TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context, + Subgraph* src_subgraph, + const SrcVector& src_tensor_indices, + Subgraph* dst_subgraph, + const DstVector& dst_tensor_indices) { + TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), + dst_tensor_indices.size()); + for (int i = 0; i < src_tensor_indices.size(); ++i) { + const TfLiteTensor* src_tensor = + src_subgraph->tensor(src_tensor_indices[i]); + std::vector dims(src_tensor->dims->data, + src_tensor->dims->data + src_tensor->dims->size); + dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + dst_tensor->type = src_tensor->type; + } + return kTfLiteOk; +} + +// Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph` +// to `dst_tensor_indices` in `dst_subgraph`. +template +TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph, + const SrcVector& src_tensor_indices, + Subgraph* dst_subgraph, + const DstVector& dst_tensor_indices) { + TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), + dst_tensor_indices.size()); + for (int i = 0; i < src_tensor_indices.size(); ++i) { + const TfLiteTensor* src_tensor = + src_subgraph->tensor(src_tensor_indices[i]); + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes); + memcpy(dst_tensor->data.raw, src_tensor->data.raw, src_tensor->bytes); + } + return kTfLiteOk; +} + +TfLiteStatus CheckCondOutput(TfLiteContext* context, + const TfLiteTensor* cond_output) { + // The condition output must be a single boolean value. + TF_LITE_ENSURE_EQ(context, cond_output->type, kTfLiteBool); + if (cond_output->dims->size == 0) { + // It's okay if it's a 0D scalar. + return kTfLiteOk; + } + // Otherwise it must be 1D with shape [1]. + TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1); + return kTfLiteOk; +} + +} // namespace + +struct OpData { + int cond_subgraph_index; + int body_subgraph_index; + bool cond_has_dynamic_output_tensors; + bool body_has_dynamic_output_tensors; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32(); + op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32(); + op_data->cond_has_dynamic_output_tensors = false; + op_data->body_has_dynamic_output_tensors = false; + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast(node->user_data); + int num_inputs = node->inputs->size; + // The number of outputs should be the same as number of inputs. + TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs); + + // Check subgraph indices and get subgraphs. + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); + TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size()); + TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size()); + + Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); + Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); + + // Check input & output count of the condition subgraph. + TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs); + TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1); + + // Check input & output count of the body subgraph. + TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs); + TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs); + + // Prepare and check the condition subgraph. + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType(context, this_subgraph, + TfLiteIntArrayView(node->inputs), + cond_subgraph, cond_subgraph->inputs())); + TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); + TfLiteTensor* cond_output = + cond_subgraph->tensor(cond_subgraph->outputs()[0]); + // TODO(ycling): Handle the case the cond subgraph has dynamic tensor outputs. + // This should rarely happens. In most cases the output is static with shape + // [1]. However theoretically intermediate tensors in the cond subgraph + // can be dynamic. + if (IsDynamicTensor(cond_output)) { + op_data->cond_has_dynamic_output_tensors = true; + } else { + TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output)); + } + + // Prepare and check the body subgraph. + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType(context, this_subgraph, + TfLiteIntArrayView(node->inputs), + body_subgraph, body_subgraph->inputs())); + TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); + if (body_subgraph->HasDynamicTensors()) { + op_data->body_has_dynamic_output_tensors = true; + } else { + for (int i = 0; i < num_inputs; ++i) { + TfLiteTensor* body_input = + body_subgraph->tensor(body_subgraph->inputs()[i]); + TfLiteTensor* body_output = + body_subgraph->tensor(body_subgraph->outputs()[i]); + TF_LITE_ENSURE_EQ(context, body_input->type, body_output->type); + + // TODO(ycling): Support dynamic sized body subgraph. + TF_LITE_ENSURE(context, !IsDynamicTensor(body_output)); + if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) { + // If the output shape of the body subgraph is static w.r.t. a fixed + // input size, but it's different from input size, it's still considered + // dynamic. For example: If a subgraph keeps padding its input with a + // fixed padding, the output shape is static w.r.t the input shape and + // padding, but running it in a loop will keep bloating the tensor. + op_data->body_has_dynamic_output_tensors = true; + break; + } + } + } + for (int i = 0; i < num_inputs; ++i) { + TfLiteTensor* output = GetOutput(context, node, i); + if (op_data->body_has_dynamic_output_tensors) { + SetTensorToDynamic(output); + } else { + TfLiteTensor* body_output = + body_subgraph->tensor(body_subgraph->outputs()[i]); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* op_data = reinterpret_cast(node->user_data); + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); + Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get(); + Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get(); + + // The follow graph illustrates the current implementation. + // + // This Subgraph Cond Subgraph Body Subgraph + // +-----------+ (1) +------------+ (3) +------------+ + // | WHILE |-------->| SUBGRAPH |-------->| SUBGRAPH | + // | INPUT | /| INPUT |<----- | INPUT | + // +-----------+ / +------------+ \ +------------+ + // / | \ | + // (6) / | (2) (5) \ | (4) + // / v \ v + // +-----------+ / +------------+ +------------+ + // | WHILE |<-- | SUBGRAPH | | SUBGRAPH | + // | OUTPUT | | OUTPUT | | OUTPUT | + // +-----------+ +------------+ +------------+ + // + // (1) Copy the inputs of WHILE op to the inputs of condition subgraph. + // (2) Invoke condition subgraph. + // Jump to step 5 if result is false. + // (3) Copy the inputs of condition subgraph to the inputs of body subgraph. + // (4) Invoke body subgraph. + // (5) Copy the outputs of body subgraph to the inputs condition subgraph. + // Jump back to step 2! + // (6) Copy the inputs of condition subgraph to the outputs of WHILE op. + // + // If the body subgraph has dynamic sized outputs, it's required to resize the + // tensor before copying in step 1, 3, 4 and 6. + // + // Note the flow is carefully designed to handle the dynamic sized output + // case. The loop invariant is: The newest value is in the inputs of condition + // subgraph. This is always true before step 2. + // + // This is the best we can do without sharing tensor buffer across subgraph + // boundary. Currently we copy the input / output between the subgraphs. This + // isn't optimized yet and a lot of redundant copies are made. + // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. + TF_LITE_ENSURE_OK( + context, + CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs), + cond_subgraph, cond_subgraph->inputs())); + + while (true) { + TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke()); + int cond_subgraph_output_index = cond_subgraph->outputs()[0]; + cond_subgraph->EnsureTensorDataIsReadable(cond_subgraph_output_index); + TfLiteTensor* cond_output = + cond_subgraph->tensor(cond_subgraph_output_index); + if (op_data->cond_has_dynamic_output_tensors) { + TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output)); + } + + if (!cond_output->data.b[0]) { + break; + } + if (op_data->body_has_dynamic_output_tensors) { + TF_LITE_ENSURE_OK(context, + CopyTensorsShapeAndType( + context, cond_subgraph, cond_subgraph->inputs(), + body_subgraph, body_subgraph->inputs())); + TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); + } + + TF_LITE_ENSURE_OK( + context, + CopyTensorsData(context, cond_subgraph, cond_subgraph->inputs(), + body_subgraph, body_subgraph->inputs())); + + TF_LITE_ENSURE_OK(context, body_subgraph->Invoke()); + + for (int tensor_index : body_subgraph->outputs()) { + body_subgraph->EnsureTensorDataIsReadable(tensor_index); + } + + if (op_data->body_has_dynamic_output_tensors) { + TF_LITE_ENSURE_OK(context, + CopyTensorsShapeAndType( + context, body_subgraph, body_subgraph->outputs(), + cond_subgraph, cond_subgraph->inputs())); + TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); + } + + TF_LITE_ENSURE_OK( + context, + CopyTensorsData(context, body_subgraph, body_subgraph->outputs(), + cond_subgraph, cond_subgraph->inputs())); + } + + // Note that copying from body's output will fail if body is never invoked. + // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. + if (op_data->body_has_dynamic_output_tensors) { + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType(context, cond_subgraph, + cond_subgraph->inputs(), this_subgraph, + TfLiteIntArrayView(node->outputs))); + } + + TF_LITE_ENSURE_OK( + context, + CopyTensorsData(context, cond_subgraph, cond_subgraph->inputs(), + this_subgraph, TfLiteIntArrayView(node->outputs))); + return kTfLiteOk; +} + +} // namespace while_kernel + +TfLiteRegistration* Register_WHILE() { + static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free, + while_kernel::Prepare, while_kernel::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9946b4a3280116d7cb176f54b94b73bb956a5f71 --- /dev/null +++ b/tensorflow/lite/kernels/while_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/subgraph_test_util.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { + +using subgraph_test_util::CheckIntTensor; +using subgraph_test_util::ControlFlowOpTest; +using subgraph_test_util::FillIntTensor; + +namespace { + +class WhileTest : public ControlFlowOpTest {}; + +// The test builds a model that produces the i-th number of +// triangular number sequence. +// +// TODO(ycling): Consider to improve this test case by adding a +// concat into the body subgraph. +TEST_F(WhileTest, TestTriangularNumberSequence) { + const std::vector expected = {1, 3, 6, 10, 15, 21, 28}; + for (int i = 0; i < expected.size(); ++i) { + interpreter_.reset(new Interpreter); + interpreter_->AddSubgraphs(2); + builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), i); + builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2)); + builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {i + 1}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {1}, {expected[i]}); + } +} + +// This requires dynamic sized subgraphs and it's not supported right now. +// TODO(ycling): Support dynamic sized subgraphs. +TEST_F(WhileTest, TestPadLoop) { + interpreter_.reset(new Interpreter); + interpreter_->AddSubgraphs(2); + builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), 3); + builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(2), {1, 2}); + builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2}); + // This is not supported yet. The test ensures thatit doesn't crash and raises + // an error properly. + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {4}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0}); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/minimal_logging.cc b/tensorflow/lite/minimal_logging.cc new file mode 100644 index 0000000000000000000000000000000000000000..8768ef6e312ec80c9b3653983421e07c662f8e5e --- /dev/null +++ b/tensorflow/lite/minimal_logging.cc @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/minimal_logging.h" + +#include + +namespace tflite { +namespace logging_internal { + +void MinimalLogger::Log(LogSeverity severity, const char* format, ...) { + va_list args; + va_start(args, format); + VLog(severity, format, args); + va_end(args); +} + +const char* MinimalLogger::GetSeverityName(LogSeverity severity) { + switch (severity) { + case TFLITE_LOG_INFO: + return "INFO"; + case TFLITE_LOG_WARNING: + return "WARNING"; + case TFLITE_LOG_ERROR: + return "ERROR"; + default: + return ""; + } +} + +} // namespace logging_internal +} // namespace tflite diff --git a/tensorflow/lite/minimal_logging.h b/tensorflow/lite/minimal_logging.h new file mode 100644 index 0000000000000000000000000000000000000000..7682ed8edc401762613a9cae582fdf1e16e61f51 --- /dev/null +++ b/tensorflow/lite/minimal_logging.h @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MINIMAL_LOGGING_H_ +#define TENSORFLOW_LITE_MINIMAL_LOGGING_H_ + +#include + +namespace tflite { + +enum LogSeverity { + TFLITE_LOG_INFO = 0, + TFLITE_LOG_WARNING = 1, + TFLITE_LOG_ERROR = 2, +}; + +namespace logging_internal { + +// Helper class for simple platform-specific console logging. Note that we +// explicitly avoid the convenience of ostream-style logging to minimize binary +// size impact. +class MinimalLogger { + public: + // Logging hook that takes variadic args. + static void Log(LogSeverity severity, const char* format, ...); + + // Logging hook that takes a formatted va_list. + static void VLog(LogSeverity severity, const char* format, va_list args); + + private: + static const char* GetSeverityName(LogSeverity severity); +}; + +} // namespace logging_internal +} // namespace tflite + +// Convenience macro for basic internal logging in production builds. +// Note: This should never be used for debug-type logs, as it will *not* be +// stripped in release optimized builds. In general, prefer the error reporting +// APIs for developer-facing errors, and only use this for diagnostic output +// that should always be logged in user builds. +#define TFLITE_LOG_PROD(severity, format, ...) \ + tflite::logging_internal::MinimalLogger::Log(severity, format, ##__VA_ARGS__); + +#endif // TENSORFLOW_LITE_MINIMAL_LOGGING_H_ diff --git a/tensorflow/lite/minimal_logging_android.cc b/tensorflow/lite/minimal_logging_android.cc new file mode 100644 index 0000000000000000000000000000000000000000..f87e6fa18e1dac7e2b4e093f6d4e91a1e652ba3d --- /dev/null +++ b/tensorflow/lite/minimal_logging_android.cc @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/minimal_logging.h" + +#include +#include + +namespace tflite { +namespace logging_internal { +namespace { + +int GetPlatformSeverity(LogSeverity severity) { + switch (severity) { + case TFLITE_LOG_INFO: + return ANDROID_LOG_INFO; + case TFLITE_LOG_WARNING: + return ANDROID_LOG_WARN; + case TFLITE_LOG_ERROR: + return ANDROID_LOG_ERROR; + default: + return ANDROID_LOG_DEBUG; + } +} + +} // namespace + +void MinimalLogger::VLog(LogSeverity severity, const char* format, + va_list args) { + // First log to Android's explicit log(cat) API. + va_list args_for_android_log; + va_copy(args_for_android_log, args); + __android_log_vprint(GetPlatformSeverity(severity), "tflite", format, args); + va_end(args_for_android_log); + + // Also print to stderr for standard console applications. + fprintf(stderr, "%s: ", GetSeverityName(severity)); + vfprintf(stderr, format, args); + fputc('\n', stderr); +} + +} // namespace logging_internal +} // namespace tflite diff --git a/tensorflow/lite/minimal_logging_default.cc b/tensorflow/lite/minimal_logging_default.cc new file mode 100644 index 0000000000000000000000000000000000000000..9fa13e47e63a01b5c15ada3b09498fdb755f6376 --- /dev/null +++ b/tensorflow/lite/minimal_logging_default.cc @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/minimal_logging.h" + +#include + +namespace tflite { +namespace logging_internal { + +void MinimalLogger::VLog(LogSeverity severity, const char* format, + va_list args) { + fprintf(stderr, "%s: ", GetSeverityName(severity)); + vfprintf(stderr, format, args); + fputc('\n', stderr); +} + +} // namespace logging_internal +} // namespace tflite diff --git a/tensorflow/lite/minimal_logging_ios.cc b/tensorflow/lite/minimal_logging_ios.cc new file mode 100644 index 0000000000000000000000000000000000000000..a774682a5b42f71d0cc77c49bbcf9a4ec6ef21b7 --- /dev/null +++ b/tensorflow/lite/minimal_logging_ios.cc @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/minimal_logging.h" + +#include +#include + +namespace tflite { +namespace logging_internal { +namespace { + +int GetPlatformSeverity(LogSeverity severity) { + switch (severity) { + case TFLITE_LOG_INFO: + return LOG_INFO; + case TFLITE_LOG_WARNING: + return LOG_WARNING; + case TFLITE_LOG_ERROR: + return LOG_ERR; + default: + return LOG_DEBUG; + } +} + +} // namespace + +void MinimalLogger::VLog(LogSeverity severity, const char* format, + va_list args) { + // TODO(b/123704468): Use os_log when available. + vsyslog(GetPlatformSeverity(severity), format, args); +} + +} // namespace logging_internal +} // namespace tflite diff --git a/tensorflow/lite/minimal_logging_test.cc b/tensorflow/lite/minimal_logging_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e59425a2b264b72d44477c6484fc0ffea014a750 --- /dev/null +++ b/tensorflow/lite/minimal_logging_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/minimal_logging.h" + +#include + +#include + +namespace tflite { + +TEST(MinimalLogging, Basic) { + testing::internal::CaptureStderr(); + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Foo"); + EXPECT_EQ("INFO: Foo\n", testing::internal::GetCapturedStderr()); +} + +TEST(MinimalLogging, BasicFormatted) { + testing::internal::CaptureStderr(); + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Foo %s %s", "Bar", "Baz"); + EXPECT_EQ("INFO: Foo Bar Baz\n", testing::internal::GetCapturedStderr()); +} + +TEST(MinimalLogging, Warn) { + testing::internal::CaptureStderr(); + TFLITE_LOG_PROD(TFLITE_LOG_WARNING, "One", ""); + EXPECT_EQ("WARNING: One\n", testing::internal::GetCapturedStderr()); +} + +TEST(MinimalLogging, Error) { + testing::internal::CaptureStderr(); + TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Two"); + EXPECT_EQ("ERROR: Two\n", testing::internal::GetCapturedStderr()); +} + +TEST(MinimalLogging, UnknownSeverity) { + testing::internal::CaptureStderr(); + TFLITE_LOG_PROD(static_cast(-1), "Three"); + EXPECT_EQ(": Three\n", + testing::internal::GetCapturedStderr()); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc index 663ee38280ed4d65d9dafb8353dd4746c6da6292..e333138fb78e81316e29a7c37b1fba2df7b1408a 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/model.h" @@ -245,11 +246,11 @@ class MallocDataAllocator : public BuiltinDataAllocator { TfLiteStatus InterpreterBuilder::ParseNodes( const flatbuffers::Vector>* operators, - Interpreter* interpreter) { + Subgraph* subgraph) { TfLiteStatus status = kTfLiteOk; // Reduce the number of redundant allocations - interpreter->ReserveNodes(operators->Length()); + subgraph->ReserveNodes(operators->Length()); for (int i = 0; i < operators->Length(); ++i) { const auto* op = operators->Get(i); @@ -279,7 +280,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( } if (op->custom_options()) { - interpreter->AddNodeWithParameters( + subgraph->AddNodeWithParameters( FlatBufferIntArrayToVector(op->inputs()), FlatBufferIntArrayToVector(op->outputs()), reinterpret_cast(op->custom_options()->data()), @@ -289,24 +290,73 @@ TfLiteStatus InterpreterBuilder::ParseNodes( MallocDataAllocator malloc_allocator; TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, &malloc_allocator, &builtin_data)); - interpreter->AddNodeWithParameters( - FlatBufferIntArrayToVector(op->inputs()), - FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data, - registration); + subgraph->AddNodeWithParameters(FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + nullptr, 0, builtin_data, registration); } } return status; } +TfLiteStatus InterpreterBuilder::ParseQuantization( + const QuantizationParameters* src_quantization, + TfLiteQuantization* quantization) { + quantization->type = kTfLiteNoQuantization; + if (!src_quantization || !src_quantization->scale() || + src_quantization->scale()->size() == 0) { + return kTfLiteOk; + } + if (!src_quantization->zero_point()) { + error_reporter_->Report( + "Quantization parameters has non-null scale but null zero_point."); + return kTfLiteError; + } + + // Ensure that the number of scales matches the number of zero_points. + if (src_quantization->scale()->size() != + src_quantization->zero_point()->size()) { + error_reporter_->Report( + "QuantizationParam has %d zero_point values and %d scale values. Must " + "have same number.", + src_quantization->zero_point()->size(), + src_quantization->scale()->size()); + return kTfLiteError; + } + + // Affine-quantization. + quantization->type = kTfLiteAffineQuantization; + auto* affine_quantization = reinterpret_cast( + malloc(sizeof(TfLiteAffineQuantization))); + const size_t num_scales = src_quantization->scale()->size(); + affine_quantization->scale = TfLiteFloatArrayCreate(num_scales); + affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales); + for (size_t i = 0; i < num_scales; ++i) { + affine_quantization->scale->data[i] = src_quantization->scale()->Get(i); + affine_quantization->zero_point->data[i] = + src_quantization->zero_point()->Get(i); + } + if (src_quantization->quantized_dimension() < 0 || + src_quantization->quantized_dimension() >= num_scales) { + error_reporter_->Report( + "quantized_dimension must be in range [0, %d). Was %d.", num_scales, + src_quantization->quantized_dimension()); + return kTfLiteError; + } + affine_quantization->quantized_dimension = + src_quantization->quantized_dimension(); + quantization->params = reinterpret_cast(affine_quantization); + return kTfLiteOk; +} + TfLiteStatus InterpreterBuilder::ParseTensors( const flatbuffers::Vector>* buffers, const flatbuffers::Vector>* tensors, - Interpreter* interpreter) { + Subgraph* subgraph) { TfLiteStatus status = kTfLiteOk; // A little helper to get the names of inputs and outputs. Note that they - // must outlive the interpreter. + // must outlive the subgraph. auto get_name = [](const tflite::Tensor* t) -> const char* { auto name = t->name(); if (name) return name->c_str(); @@ -317,36 +367,11 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const auto* tensor = tensors->Get(i); std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); - TfLiteQuantizationParams quantization; - quantization.scale = 0; - quantization.zero_point = 0; - auto* q_params = tensor->quantization(); - if (q_params) { - // Note that the schema could hold per-channel quantization parameters - // but we really only support one value for the whole tensor. - // TODO(aselle): This breaks as well if these are nullptr's. - // TODO(aselle): This assumes non per-channel quantization. - - if (q_params->scale()) { - if (q_params->scale()->size() != 1) { - error_reporter_->Report( - "QuantizationParam has %d scale values (only 1 is supported).", - q_params->scale()->size()); - return kTfLiteError; - } - quantization.scale = q_params->scale()->Get(0); - } - - if (q_params->zero_point()) { - if (q_params->zero_point()->size() != 1) { - error_reporter_->Report( - "QuantizationParam has %d zero_point values" - " (only 1 is supported).", - q_params->zero_point()->size()); - return kTfLiteError; - } - quantization.zero_point = q_params->zero_point()->Get(0); - } + const auto* src_quantization = tensor->quantization(); + TfLiteQuantization quantization; + if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) { + status = kTfLiteError; + continue; } TfLiteType type; @@ -392,7 +417,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } - if (interpreter->SetTensorParametersReadOnly( + if (subgraph->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", @@ -400,9 +425,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor), - dims, quantization, - is_variable) != kTfLiteOk) { + if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor), + dims, quantization, + is_variable) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; @@ -484,42 +509,56 @@ TfLiteStatus InterpreterBuilder::operator()( // Construct interpreter with correct number of tensors and operators. auto* subgraphs = model_->subgraphs(); auto* buffers = model_->buffers(); - if (subgraphs->size() != 1) { - error_reporter_->Report("Only 1 subgraph is currently supported.\n"); - return cleanup_and_error(); - } - const tflite::SubGraph* subgraph = (*subgraphs)[0]; - auto operators = subgraph->operators(); - auto tensors = subgraph->tensors(); - if (!operators || !tensors || !buffers) { - error_reporter_->Report( - "Did not get operators, tensors, or buffers in input flat buffer.\n"); + + if (subgraphs->size() == 0) { + error_reporter_->Report("No subgraph in the model.\n"); return cleanup_and_error(); } + interpreter->reset(new Interpreter(error_reporter_)); - if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { - return cleanup_and_error(); + (*interpreter)->SetNumThreads(num_threads); + if (subgraphs->Length() > 1) { + (*interpreter)->AddSubgraphs(subgraphs->Length() - 1); } - // Set num threads - (**interpreter).SetNumThreads(num_threads); - // Parse inputs/outputs - (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); - (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); - - // Finally setup nodes and tensors - if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) - return cleanup_and_error(); - if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) - return cleanup_and_error(); - std::vector variables; - for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { - auto* tensor = (*interpreter)->tensor(i); - if (tensor->is_variable) { - variables.push_back(i); + for (int subgraph_index = 0; subgraph_index < subgraphs->Length(); + ++subgraph_index) { + const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index]; + tflite::Subgraph* modified_subgraph = + (*interpreter)->subgraph(subgraph_index); + auto operators = subgraph->operators(); + auto tensors = subgraph->tensors(); + if (!operators || !tensors || !buffers) { + error_reporter_->Report( + "Did not get operators, tensors, or buffers in subgraph %d.\n", + subgraph_index); + return cleanup_and_error(); + } + if (modified_subgraph->AddTensors(tensors->Length()) != kTfLiteOk) { + return cleanup_and_error(); + } + // Set num threads + // Parse inputs/outputs + modified_subgraph->SetInputs( + FlatBufferIntArrayToVector(subgraph->inputs())); + modified_subgraph->SetOutputs( + FlatBufferIntArrayToVector(subgraph->outputs())); + + // Finally setup nodes and tensors + if (ParseNodes(operators, modified_subgraph) != kTfLiteOk) + return cleanup_and_error(); + if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk) + return cleanup_and_error(); + + std::vector variables; + for (int i = 0; i < modified_subgraph->tensors_size(); ++i) { + auto* tensor = modified_subgraph->tensor(i); + if (tensor->is_variable) { + variables.push_back(i); + } } + modified_subgraph->SetVariables(std::move(variables)); } - (**interpreter).SetVariables(std::move(variables)); if (ApplyDelegates(interpreter->get()) != kTfLiteOk) return cleanup_and_error(); diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index 069cefabf91ceceaa6da79fdc8ebbdb31cf9a6d3..bae4229cbab672397392349437e3c1e7e871c5d9 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -35,6 +35,7 @@ limitations under the License. #define TENSORFLOW_LITE_MODEL_H_ #include +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/interpreter.h" @@ -197,12 +198,14 @@ class InterpreterBuilder { TfLiteStatus BuildLocalIndexToRegistrationMapping(); TfLiteStatus ParseNodes( const flatbuffers::Vector>* operators, - Interpreter* interpreter); + Subgraph* subgraph); TfLiteStatus ParseTensors( const flatbuffers::Vector>* buffers, const flatbuffers::Vector>* tensors, - Interpreter* interpreter); + Subgraph* subgraph); TfLiteStatus ApplyDelegates(Interpreter* interpreter); + TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, + TfLiteQuantization* quantization); const ::tflite::Model* model_; const OpResolver& op_resolver_; diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index e677ea94a71b979a01fd4b56e331d592cef76cd5..d58dbf4d45fefcf787b3349426d794146ac84e52 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -87,20 +87,21 @@ TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { // Make sure currently unsupported # of subgraphs are checked // TODO(aselle): Replace this test when multiple subgraphs are supported. -TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { - auto m1 = FlatBufferModel::BuildFromFile( +TEST(BasicFlatBufferModel, TestZeroSubgraphs) { + auto m = FlatBufferModel::BuildFromFile( "tensorflow/lite/testdata/0_subgraphs.bin"); - ASSERT_TRUE(m1); - std::unique_ptr interpreter1; - ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), - kTfLiteOk); + ASSERT_TRUE(m); + std::unique_ptr interpreter; + ASSERT_NE(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk); +} - auto m2 = FlatBufferModel::BuildFromFile( +TEST(BasicFlatBufferModel, TestMultipleSubgraphs) { + auto m = FlatBufferModel::BuildFromFile( "tensorflow/lite/testdata/2_subgraphs.bin"); - ASSERT_TRUE(m2); - std::unique_ptr interpreter2; - ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), - kTfLiteOk); + ASSERT_TRUE(m); + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk); + EXPECT_EQ(interpreter->subgraphs_size(), 2); } // Test what happens if we cannot bind any of the ops. @@ -115,7 +116,7 @@ TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { ASSERT_EQ(interpreter, nullptr); } -// Make sure model is read to interpreter propelrly +// Make sure model is read to interpreter properly TEST(BasicFlatBufferModel, TestModelInInterpreter) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/lite/testdata/test_model.bin"); diff --git a/tensorflow/lite/models/smartreply/BUILD b/tensorflow/lite/models/smartreply/BUILD index 9b2f0da64c9c47d649216d64e13d99c6a7541aad..100e1d36b2bbf4679cacad3ef0e285d92bc5d467 100644 --- a/tensorflow/lite/models/smartreply/BUILD +++ b/tensorflow/lite/models/smartreply/BUILD @@ -50,6 +50,7 @@ cc_library( ], ) +# TODO(b/118895218): Make this test compatible with oss. tf_cc_test( name = "predictor_test", srcs = ["predictor_test.cc"], @@ -57,6 +58,7 @@ tf_cc_test( "//tensorflow/lite/models:testdata/smartreply_samples.tsv", "@tflite_smartreply//:smartreply.tflite", ], + tags = ["no_oss"], deps = [ ":predictor_lib", "//tensorflow/core:test", @@ -67,10 +69,11 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "extract_feature_op_test", size = "small", srcs = ["ops/extract_feature_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/lite:framework", @@ -81,10 +84,11 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "normalize_op_test", size = "small", srcs = ["ops/normalize_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/lite:framework", @@ -95,10 +99,11 @@ tf_cc_test( ], ) -tf_cc_test( +cc_test( name = "predict_op_test", size = "small", srcs = ["ops/predict_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/lite:framework", diff --git a/tensorflow/lite/models/speech_test.cc b/tensorflow/lite/models/speech_test.cc index f3509d1ecedfca4eea58343abdef00188bca5cff..a3713c55312cb7cb6526b7e82606cb949e5c2af4 100644 --- a/tensorflow/lite/models/speech_test.cc +++ b/tensorflow/lite/models/speech_test.cc @@ -139,7 +139,7 @@ TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, DISABLED_AsrAmTest) { +TEST_P(SpeechTest, AsrAmTest) { std::stringstream os; ASSERT_TRUE( ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv", @@ -152,6 +152,19 @@ TEST_P(SpeechTest, DISABLED_AsrAmTest) { << test_driver.GetErrorMessage(); } +TEST_P(SpeechTest, AsrAmQuantizedTest) { + std::stringstream os; + ASSERT_TRUE(ConvertCsvData( + "speech_asr_am_model_int8.tflite", "speech_asr_am_model_in.csv", + "speech_asr_am_model_int8_out.csv", /*input_tensor=*/"0", + /*output_tensor=*/"104", + /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99", + /*sequence_size=*/320, &os)); + testing::TfLiteDriver test_driver(/*use_nnapi=*/false); + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) + << test_driver.GetErrorMessage(); +} + // The original version of speech_asr_lm_model_test.cc ran a few sequences // through the interpreter and stored the sum of all the output, which was them // compared for correctness. In this test we are comparing all the intermediate diff --git a/tensorflow/lite/models/testdata/g3doc/README.md b/tensorflow/lite/models/testdata/g3doc/README.md index 2a4f1c143a21722945e8e396b81bd23e3312e87e..afe5f16b383b26efd7aab866c3215a8d2a203f4c 100644 --- a/tensorflow/lite/models/testdata/g3doc/README.md +++ b/tensorflow/lite/models/testdata/g3doc/README.md @@ -3,6 +3,42 @@ Sample test data has been provided for speech related models in Tensorflow Lite to help users working with speech models to verify and test their models. +### Models and Inputs and Outputs: + +[ASR AM model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_asr_am_model.tflite) + +[ASR AM quantized model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_asr_am_model_int8.tflite) + +[ASR AM test inputs](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_asr_am_model_in.csv) + +[ASR AM test outputs](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_asr_am_model_out.csv) + +[ASR AM int8 test outputs](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_asr_am_model_int8_out.csv) + +The models below are not maintained. + +[Speech hotword model (Svdf +rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite) + +[Speech hotword model (Svdf +rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite) + +[Speaker-id +model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite) + +[TTS +model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite) + +### Test Bench + +[Model tests](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_test.cc) + +Download the ASR AM test models and inputs and output files to the +models/testdata directory to run the tests. + + +## Speech Model Architectures + For the hotword, speaker-id and automatic speech recognition sample models, the architecture assumes that the models receive their input from a speech pre-processing module. The speech pre-processing module receives the audio @@ -87,57 +123,3 @@ The model consists of a convolutional layer, followed by a fully-connected layer, two LSTM layers, and two additional fully-connected layers. The corresponding parameters as shown in the figure. ![endpointer_model](endpointer.svg "Endpointer model") - - -## Speech models test input/output generation - -As mentioned above the input to models are generated from a pre-processing -module (output of a log-mel filterbank, or linguistic features), and the outputs -are generated by running the equivalent TensorFlow model by feeding them the -same input. - -## Link to the open source code - -### Models: - -[Speech hotword model (Svdf -rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite) - -[Speech hotword model (Svdf -rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite) - -[Speaker-id -model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite) - -[TTS -model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite) - -[ASR AM -model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite) - -### Test benches - -[Speech hotword model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_hotword_model_test.cc) - -[Speaker-id model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_speakerid_model_test.cc) - -[TTS model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_tts_model_test.cc) - -[ASR AM model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_asr_am_model_test.cc) - -[ASR LM model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_asr_lm_model_test.cc) - -[Endpointer model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_endpointer_model_test.cc) - -## Android Support -The models have been tested on Android phones, using the following tests: - -[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/android/BUILD?rcl=172930882&l=25) - -[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/android/BUILD?rcl=172930882&l=36) diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD index bd3a8a69af66a1cd2c043e76779db7d3d26cc17e..662754d59bfcd23e9164754c473cb3129a055852 100644 --- a/tensorflow/lite/nnapi/BUILD +++ b/tensorflow/lite/nnapi/BUILD @@ -12,3 +12,40 @@ cc_library( ], linkopts = ["-ldl"], ) + +cc_library( + name = "nnapi_implementation", + srcs = select({ + "//tensorflow:ios": [ + "nnapi_implementation_disabled.cc", + ], + "//tensorflow:windows": [ + "nnapi_implementation_disabled.cc", + ], + "//conditions:default": [ + "nnapi_implementation.cc", + ], + }), + hdrs = [ + "nnapi_implementation.h", + ], + linkopts = ["-ldl"] + select({ + "//tensorflow:android": [], + "//tensorflow:macos": [], + "//tensorflow:ios": [], + "//tensorflow:windows": [], + "//conditions:default": ["-lrt"], + }), + deps = [ + "//tensorflow/lite/nnapi:nnapi_lib", + ], +) + +cc_test( + name = "nnapi_implementation_test", + srcs = ["nnapi_implementation_test.cc"], + deps = [ + "//tensorflow/lite/nnapi:nnapi_implementation", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/nnapi/NeuralNetworksShim.h b/tensorflow/lite/nnapi/NeuralNetworksShim.h index 2ce6e50de6ef4c2b530ef6239f5dde94e68988cb..c48528fa2da5e7992beb9b029d2d112a8e48ba4c 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/lite/nnapi/NeuralNetworksShim.h @@ -22,6 +22,11 @@ limitations under the License. #include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" +// This interface is now deprecated. You should use instead +// nnapi_implementation. + +// TODO(b/123017568): Update all current usages of this file. + // helpers #define NNAPI_LOG(format, ...) fprintf(stderr, format "\n", __VA_ARGS__); @@ -54,7 +59,8 @@ inline int ASharedMemory_create(const char* name, size_t size) { handle != nullptr ? reinterpret_cast( dlsym(handle, "ASharedMemory_create")) : nullptr; - return fn(name, size); + int fd = fn != nullptr ? fn(name, size) : -1; + return fd; } inline void* getLibraryHandle() { @@ -250,6 +256,32 @@ inline int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, EXECUTE_FUNCTION_RETURN(model, index, buffer, length); } +/** + * Sets an operand's per channel quantization parameters. + * + * Sets parameters required by a tensor of type + * {@link ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL}. + * This function must be called for every tensor of type + * {@link ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL} before + * calling {@link ANeuralNetworksModel_finish}. + * + * Available since API level 29. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param channelQuant The per channel quantization parameters for the operand. + * No memory in this struct needs to outlive the call to + * this function. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksSymmPerChannelQuantParams* channelQuant) { + LOAD_FUNCTION(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams); + EXECUTE_FUNCTION_RETURN(model, index, channelQuant); +} + /** * Sets an operand to a value stored in a memory object. * @@ -681,6 +713,445 @@ inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) { EXECUTE_FUNCTION(event); } +/** + * Get the number of available devices. + * + * @param numDevices Used to return the number of devices. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ +inline int ANeuralNetworks_getDeviceCount(uint32_t* numDevices) { + LOAD_FUNCTION(ANeuralNetworks_getDeviceCount); + EXECUTE_FUNCTION_RETURN(numDevices); +} + +/** + * Get the representation of the specified device. + * + * @param devIndex The index of the specified device. Must be less than the + * number of available devices. + * @param device The representation of the specified device. + * The same representation will always be returned for the + * specified device. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + +inline int ANeuralNetworks_getDevice(uint32_t devIndex, + ANeuralNetworksDevice** device) { + LOAD_FUNCTION(ANeuralNetworks_getDevice); + EXECUTE_FUNCTION_RETURN(devIndex, device); +} + +/** + * Get the name of the specified device. + * + * @param device The representation of the specified device. + * @param name The returned name of the specified device. The name will be in + * UTF-8 and will be null-terminated. It will be recognizable as a + * known device name rather than a cryptic string. For devices + * with API level 29 and above, the format of the name is + * {VENDOR}-{DEVICE}, e.g. “google-ipu”. For devices with feature + * level 28 or lower, the name will always be “unknown-device”. + * The name will remain valid for the duration of the application. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ +inline int ANeuralNetworksDevice_getName(const ANeuralNetworksDevice* device, + const char** name) { + LOAD_FUNCTION(ANeuralNetworksDevice_getName); + EXECUTE_FUNCTION_RETURN(device, name); +} + +/** + * Get the version of the driver implementation of the specified device. + * + * It’s the responsibility of the driver implementor to insure that this version + * string uniquely distinguishes this implementation from all previous + * implementations. + * + * This version string must not be confused with the feature level which is + * solely defined by {@link ANeuralNetworksDevice_getFeatureLevel}. There is no + * implicit ordering of the versions. For example, it is not possible to filter + * all drivers older than a certain version. + * + * Application developers may use this version string to avoid or prefer + * specific driver implementations. For example, an application may want to do + * so because: + * - A specific version of the driver does not provide the required + * performance, perhaps because of a performance regression. + * - A specific version of the driver has a bug or returns results that + * don’t match the minimum precision requirement for the application. + * + * @param device The representation of the specified device. + * @param version The returned version string of the driver for the specified + * device. The string will be in UTF-8 and will be + * null-terminated. For devices with feature level 28 or lower, + * "UNKNOWN" will be returned. The version string will remain + * valid for the duration of the application. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ +inline int ANeuralNetworksDevice_getVersion(const ANeuralNetworksDevice* device, + const char** version) { + LOAD_FUNCTION(ANeuralNetworksDevice_getVersion); + EXECUTE_FUNCTION_RETURN(device, version); +} + +/** + * Get the supported NNAPI version of the specified device. + * + * Each device has a supported feature level, which is the most advanced feature + * this driver implements. For example, if the driver implements the features + * introduced in Android P, but does not implement the features introduced after + * Android P, the value would be 28. Developers could decide whether or not the + * specified device should be used for a Model that has certain feature + * requirements. + * + * @param device The representation of the specified device. + * @param featureLevel The API level of the most advanced feature this driver + * implements. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ +inline int ANeuralNetworksDevice_getFeatureLevel( + const ANeuralNetworksDevice* device, int64_t* featureLevel) { + LOAD_FUNCTION(ANeuralNetworksDevice_getFeatureLevel); + EXECUTE_FUNCTION_RETURN(device, featureLevel); +} + +/** + * Get the supported operations for a specified set of devices. If multiple + * devices are selected, the supported operation list is a union of supported + * operations of all selected devices. + * + * @param model The model to be queried. + * @param devices The set of devices. Must not contain duplicates. + * @param numDevices The number of devices in the set. + * @param supportedOps The boolean array to be filled. True means supported. The + * size of the boolean array must be at least as large as + * the number of operations in the model. The order of + * elements in the supportedOps array matches the order in + * which the corresponding operations were added to the + * model. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ +inline int ANeuralNetworksModel_getSupportedOperationsForDevices( + const ANeuralNetworksModel* model, + const ANeuralNetworksDevice* const* devices, uint32_t numDevices, + bool* supportedOps) { + LOAD_FUNCTION(ANeuralNetworksModel_getSupportedOperationsForDevices); + EXECUTE_FUNCTION_RETURN(model, devices, numDevices, supportedOps); +} + +/** + * Create a {@link ANeuralNetworksCompilation} to compile the given model for a + * specified set of devices. If more than one device is specified, the + * compilation will distribute the workload automatically across the devices. + * The model must be fully supported by the specified set of devices. This means + * that ANeuralNetworksModel_getSupportedOperationsForDevices() must have + * returned true for every operation for that model/devices pair. + * + * @param model The {@link ANeuralNetworksModel} to be compiled. + * @param devices The set of devices. Must not contain duplicates. + * @param numDevices The number of devices in the set. + * @param compilation The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the model is invalid. + * + * Available since API level 29. + */ +inline int ANeuralNetworksCompilation_createForDevices( + ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, + uint32_t numDevices, ANeuralNetworksCompilation** compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_createForDevices); + EXECUTE_FUNCTION_RETURN(model, devices, numDevices, compilation); +} + +/** + * Sets the compilation caching signature and the cache directory. + * + * Provides optional caching information to the runtime for faster repeated + * compilation. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be modified. + * @param cacheDir The cache directory to store and retrieve caching data. It is + * recommended to use the code_cache provided by the Android + * runtime. If not using the code_cache, the user should choose + * a directory local to the application, and is responsible to + * manage and clean the cache entries. + * @param token The token provided by the user to specify a model, must be of + * length ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN. The user should + * ensure that the token is unique to a model within the + * application. The NNAPI runtime will not detected token + * collisions. If there is a collision, the compilation outcome may + * be incorrect without notifying with error. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ +inline int ANeuralNetworksCompilation_setCaching( + ANeuralNetworksCompilation* compilation, const char* cacheDir, + const uint8_t* token) { + LOAD_FUNCTION(ANeuralNetworksCompilation_setCaching); + EXECUTE_FUNCTION_RETURN(compilation, cacheDir, token); +} + +/** + * Schedule synchronous evaluation of the execution. + * + *

Schedules synchronous evaluation of the execution. Returns once the + * execution has completed and the outputs are ready to be consumed. + *

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * See {@link ANeuralNetworksExecution_startCompute} for asynchronous execution. + * Synchronous execution incurs lower overhead than asynchronous execution. + * + * Available since API level 29. + * + * @param execution The execution to be scheduled and executed. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + * ANEURALNETWORKS_UNMAPPABLE if the execution input or output memory + * cannot be properly mapped. + */ +inline int ANeuralNetworksExecution_compute( + ANeuralNetworksExecution* execution) { + LOAD_FUNCTION(ANeuralNetworksExecution_compute); + EXECUTE_FUNCTION_RETURN(execution); +} + +/** + * Get the dimensional information of the specified output operand of the model + * of the + * {@link ANeuralNetworksExecution}. + * + * On asynchronous execution initiated by {@link + * ANeuralNetworksExecution_startCompute}, + * {@link ANeuralNetworksEvent_wait} must be called prior to this function to + * recuperate the resources used by the execution. + * + * @param execution The execution to be queried. + * @param index The index of the output argument we are querying. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param rank The rank of the output operand. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, + * ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE if the target output is provided an + * insufficient buffer at execution time, ANEURALNETWORKS_BAD_DATA if the index + * is invalid. + * + * Available since API level 29. + */ +inline int ANeuralNetworksExecution_getOutputOperandRank( + ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank) { + LOAD_FUNCTION(ANeuralNetworksExecution_getOutputOperandRank); + EXECUTE_FUNCTION_RETURN(execution, index, rank); +} + +/** + * Get the dimensional information of the specified output operand of the model + * of the + * {@link ANeuralNetworksExecution}. The target output operand cannot be a + * scalar. + * + * On asynchronous execution initiated by + * {@link ANeuralNetworksExecution_startCompute}, + * {@link ANeuralNetworksEvent_wait} must be called prior to this function to + * recuperate the resources used by the execution. + * + * @param execution The execution to be queried. + * @param index The index of the output argument we are querying. It is an index + * into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with + * {@link ANeuralNetworksModel_addOperand}. + * @param dimensions The dimension array to be filled. The size of the array + * must be exactly as large as the rank of the output operand + * to be queried in the model. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, + * ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE if the target output is provided an + * insufficient buffer at execution time, ANEURALNETWORKS_BAD_DATA if the index + * is invalid or if the target is a scalar. + * + * Available since API level 29. + */ +inline int ANeuralNetworksExecution_getOutputOperandDimensions( + ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions) { + LOAD_FUNCTION(ANeuralNetworksExecution_getOutputOperandDimensions); + EXECUTE_FUNCTION_RETURN(execution, index, dimensions); +} + +/** + * Create a {@link ANeuralNetworksBurst} to apply the given compilation. + * This only creates the burst object. Computation is only performed once + * {@link ANeuralNetworksExecution_burstCompute} is invoked with a valid + * {@link ANeuralNetworksExecution} and {@link ANeuralNetworksBurst}. + * + *

The provided compilation must outlive the burst object.

+ * + * Available since API level 29. + * + * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated. + * @param burst The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the compilation is invalid. + */ +inline int ANeuralNetworksBurst_create(ANeuralNetworksCompilation* compilation, + ANeuralNetworksBurst** burst) { + LOAD_FUNCTION(ANeuralNetworksBurst_create); + EXECUTE_FUNCTION_RETURN(compilation, burst); +} + +/** + * Destroys the burst object. + * + * Available since API level 29. + * + * @param burst The burst object to be destroyed. Passing NULL is acceptable and + * results in no operation. + */ +inline void ANeuralNetworksBurst_free(ANeuralNetworksBurst* burst) { + LOAD_FUNCTION(ANeuralNetworksBurst_free); + EXECUTE_FUNCTION(burst); +} + +/** + * Schedule synchronous evaluation of the execution on a burst object. + * + *

Schedules synchronous evaluation of the execution. Returns once the + * execution has completed and the outputs are ready to be consumed.

+ * + *

There must be at most one {@link ANeuralNetworksExecution} processing at + * any given time for any given burst object. Any + * {@link ANeuralNetworksExecution} launched before the previous has finished + * will result in ANEURALNETWORKS_BAD_STATE.

+ * + * Available since API level 29. + * + * @param burst The burst object to execute on. + * @param execution The execution to be scheduled and executed. The execution + * must be created from the same {@link + * ANeuralNetworksCompilation} as the burst object. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + */ +inline int ANeuralNetworksExecution_burstCompute( + ANeuralNetworksExecution* execution, ANeuralNetworksBurst* burst) { + LOAD_FUNCTION(ANeuralNetworksExecution_burstCompute); + EXECUTE_FUNCTION_RETURN(execution, burst); +} + +/** + * Creates a shared memory object from an AHardwareBuffer handle. + * + * If the shared memory is backed by an AHardwareBuffer of + * AHARDWAREBUFFER_FORMAT_BLOB format, it can be used the same way as shared + * memory created from a file handle. See + * {@link ANeuralNetworksMemory} for a description on how to use this shared + * memory. + * + * If the shared memory is backed by an AHardwareBuffer of a format other than + * AHARDWAREBUFFER_FORMAT_BLOB, it can only be used for Model inputs and + * outputs. When calling {@link ANeuralNetworksExecution_setInputFromMemory} or + * {@link ANeuralNetworksExecution_setOutputFromMemory} with the shared memory, + * both offset and length must be set to zero and the entire memory region will + * be associated with the specified input or output operand. There is no + * guarantee that an arbitrary AHardwareBuffer_Format and + * AHardwareBuffer_UsageFlags combination can be used by arbitrary devices. The + * execution will fail if selected set of devices cannot consume the buffer. + * + * Calling {@link ANeuralNetworksModel_setOperandValueFromMemory} with shared + * memory backed by an AHardwareBuffer of a format other than + * AHARDWAREBUFFER_FORMAT_BLOB is disallowed. + * + * TODO(miaowang): add documentation about intended usage with introspection + * API. + * + * Available since API level 29. + * + * @param ahwb The AHardwareBuffer handle. + * @param memory The memory object to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if the request completed normally. + * + * @see AHardwareBuffer + */ +inline int ANeuralNetworksMemory_createFromAHardwareBuffer( + const AHardwareBuffer* ahwb, ANeuralNetworksMemory** memory) { + LOAD_FUNCTION(ANeuralNetworksMemory_createFromAHardwareBuffer); + EXECUTE_FUNCTION_RETURN(ahwb, memory); +} + +/** + * Specifies whether duration of the {@link ANeuralNetworksExecution} is to be + * measured. By default, duration is not measured. + * + * The {@link ANeuralNetworksExecution} must have been created with + * {@link ANeuralNetworksCompilation_createForDevices} with numDevices = 1. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * Available since API level 29. + * + * @param execution The execution to be modified. + * @param measure 'true' if duration is to be measured, 'false' if not. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksExecution_setMeasureTiming( + ANeuralNetworksExecution* execution, bool measure) { + LOAD_FUNCTION(ANeuralNetworksExecution_setMeasureTiming); + EXECUTE_FUNCTION_RETURN(execution, measure); +} + +/** + * Get the time spent in the specified {@link ANeuralNetworksExecution}, in + * nanoseconds. The execution must have completed. + * + * @param execution The execution to be queried. + * @param durationCode The measurement to be queried, specified by {@link + * DurationCode}. + * @param duration The returned duration. If no measurement was requested by + * {@link ANeuralNetworksExecution_setMeasureTiming}, or for + * some other reason the duration is not available, UINT64_MAX will be returned. + * A particular device need not support any given measurement. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksExecution_getDuration( + const ANeuralNetworksExecution* execution, int32_t durationCode, + uint64_t* duration) { + LOAD_FUNCTION(ANeuralNetworksExecution_getDuration); + EXECUTE_FUNCTION_RETURN(execution, durationCode, duration); +} + /**/ #endif // TENSORFLOW_LITE_NNAPI_NEURALNETWORKSSHIM_H_ diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h index de8b84a8234340cda3c1ae5942c1863a09e3c228..109c6b001438a7963d13acaaa6da455b02cdac27 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h +++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +typedef struct AHardwareBuffer AHardwareBuffer; + // NN api types based on NNAPI header file // https://developer.android.com/ndk/reference/group/neural-networks @@ -37,6 +39,7 @@ enum { ANEURALNETWORKS_TENSOR_FLOAT32 = 3, ANEURALNETWORKS_TENSOR_INT32 = 4, ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, + ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13, }; /** @@ -115,8 +118,10 @@ enum { ANEURALNETWORKS_UNEXPECTED_NULL = 3, ANEURALNETWORKS_BAD_DATA = 4, ANEURALNETWORKS_OP_FAILED = 5, - ANEURALNETWORKS_UNMAPPABLE = 5, ANEURALNETWORKS_BAD_STATE = 6, + ANEURALNETWORKS_UNMAPPABLE = 7, + ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE = 8, + ANEURALNETWORKS_UNAVAILABLE_DEVICE = 9, }; /** @@ -127,6 +132,25 @@ enum { ANEURALNETWORKS_PADDING_VALID = 2, }; +/** + * Device types. + * + * The type of NNAPI device. + */ +enum { + /** The device type cannot be provided. */ + ANEURALNETWORKS_DEVICE_UNKNOWN = 0, + /** The device does not fall into any category below. */ + ANEURALNETWORKS_DEVICE_OTHER = 1, + /** The device runs NNAPI models on single or multi-core CPU. */ + ANEURALNETWORKS_DEVICE_CPU = 2, + /** The device can run NNAPI models and also accelerate graphics APIs such + * as OpenGL ES and Vulkan. */ + ANEURALNETWORKS_DEVICE_GPU = 3, + /** Dedicated accelerator for Machine Learning workloads. */ + ANEURALNETWORKS_DEVICE_ACCELERATOR = 4, +}; + /** * ANeuralNetworksMemory is an opaque type that represents memory. * @@ -239,6 +263,53 @@ typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation; */ typedef struct ANeuralNetworksExecution ANeuralNetworksExecution; +/** + * Parameters for ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL operand. + */ +typedef struct ANeuralNetworksSymmPerChannelQuantParams { + /* The index of the channel dimension. */ + uint32_t channelDim; + /** The size of the scale array. Should be equal to dimension[channelDim] of + * the Operand. */ + uint32_t scaleCount; + /** The array of scaling values for each channel. Each value must be greater + * than zero. */ + const float* scales; +} ANeuralNetworksSymmPerChannelQuantParams; + +/** + * ANeuralNetworksBurst is an opaque type that can be used to reduce the latency + * of a rapid sequence of executions. It will likely cause overhead if only used + * for a single execution. + * + * ANeuralNetworksBurst serves as a context object for any number of inferences + * using {@link ANeuralNetworksExecution} objects. An ANeuralNetworksBurst + * object and the {@link ANeuralNetworksExecution} objects used with it must all + * have been created from the same {@link ANeuralNetworksCompilation} object. + * + * This object is also used as a hint to drivers, providing insight to the + * lifetime of a rapid sequence of executions. For example, a driver may choose + * to increase the clock frequency of its accelerator for the lifetime of a + * burst object. + * + *

To use:

    + *
  • Create a new burst object by calling the + * {@link ANeuralNetworksBurst_create} function.
  • + *
  • For each execution:
    • + *
    • Create {@link ANeuralNetworksExecution} and configure its + * properties (see {@link ANeuralNetworksExecution} for + * details).
    • Apply the model synchronously with + * {@link ANeuralNetworksExecution_burstCompute}, reusing the same + * {@link ANeuralNetworksBurst} with the new + * {@link ANeuralNetworksExecution}.
    • + *
    • Use and free the {@link ANeuralNetworksExecution}.
    + *
  • Destroy the burst with + * {@link ANeuralNetworksBurst_free}.

+ * + * Available since API level 29. + */ +typedef struct ANeuralNetworksBurst ANeuralNetworksBurst; + /** * ANeuralNetworksOperandType describes the type of an operand. * This structure is used to describe both scalars and tensors. @@ -266,6 +337,16 @@ typedef struct ANeuralNetworksEvent ANeuralNetworksEvent; typedef int32_t ANeuralNetworksOperationType; +/** + * ANeuralNetworksDevice is an opaque type that represents a device. + * + * This type is used to query basic properties and supported operations of the + * corresponding device, and control which device(s) a model is to be run on. + * + * Available since API level 29. + */ +typedef struct ANeuralNetworksDevice ANeuralNetworksDevice; + // nn api function types typedef int (*ANeuralNetworksMemory_createFromFd_fn)( @@ -299,6 +380,10 @@ typedef int (*ANeuralNetworksModel_setOperandValue_fn)( ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length); +typedef int (*ANeuralNetworksModel_setOperandSymmPerChannelQuantParams_fn)( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksSymmPerChannelQuantParams* channelQuant); + typedef int (*ANeuralNetworksModel_setOperandValueFromMemory_fn)( ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length); @@ -349,4 +434,70 @@ typedef void (*ANeuralNetworksEvent_free_fn)(ANeuralNetworksEvent* event); typedef int (*ASharedMemory_create_fn)(const char* name, size_t size); +typedef int (*ANeuralNetworks_getDeviceCount_fn)(uint32_t* numDevices); + +typedef int (*ANeuralNetworks_getDevice_fn)(uint32_t devIndex, + ANeuralNetworksDevice** device); + +typedef int (*ANeuralNetworksDevice_getName_fn)( + const ANeuralNetworksDevice* device, const char** name); + +typedef int (*ANeuralNetworksDevice_getType_fn)( + const ANeuralNetworksDevice* device, int32_t* type); + +typedef int (*ANeuralNetworksDevice_getVersion_fn)( + const ANeuralNetworksDevice* device, const char** version); + +typedef int (*ANeuralNetworksDevice_getFeatureLevel_fn)( + const ANeuralNetworksDevice* device, int64_t* featureLevel); + +typedef int (*ANeuralNetworksModel_getSupportedOperationsForDevices_fn)( + const ANeuralNetworksModel* model, + const ANeuralNetworksDevice* const* devices, uint32_t numDevices, + bool* supportedOps); + +typedef int (*ANeuralNetworksCompilation_createForDevices_fn)( + ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, + uint32_t numDevices, ANeuralNetworksCompilation** compilation); + +typedef int (*ANeuralNetworksCompilation_setCaching_fn)( + ANeuralNetworksCompilation* compilation, const char* cacheDir, + const uint8_t* token); + +typedef int (*ANeuralNetworksExecution_compute_fn)( + ANeuralNetworksExecution* execution); + +typedef int (*ANeuralNetworksExecution_getOutputOperandRank_fn)( + ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank); + +typedef int (*ANeuralNetworksExecution_getOutputOperandDimensions_fn)( + ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions); + +typedef int (*ANeuralNetworksBurst_create_fn)( + ANeuralNetworksCompilation* compilation, ANeuralNetworksBurst** burst); + +typedef void (*ANeuralNetworksBurst_free_fn)(ANeuralNetworksBurst* burst); + +typedef int (*ANeuralNetworksExecution_burstCompute_fn)( + ANeuralNetworksExecution* execution, ANeuralNetworksBurst* burst); + +typedef int (*ANeuralNetworksMemory_createFromAHardwareBuffer_fn)( + const AHardwareBuffer* ahwb, ANeuralNetworksMemory** memory); + +typedef int (*ANeuralNetworksExecution_setMeasureTiming_fn)( + ANeuralNetworksExecution* execution, bool measure); + +typedef enum { + // Execution time on hardware (not driver, which runs on host processor). + ANEURALNETWORKS_DURATION_ON_HARDWARE = 0, + // Execution time in driver (including time on hardware). Excludes overhead + // such as that of the runtime itself and the IPC needed for the runtime to + // communicate with the driver. + ANEURALNETWORKS_DURATION_IN_DRIVER = 1, +} DurationCode; + +typedef int (*ANeuralNetworksExecution_getDuration_fn)( + const ANeuralNetworksExecution* execution, int32_t durationCode, + uint64_t* duration); + #endif // TENSORFLOW_LITE_NNAPI_NEURALNETWORKSTYPES_H_ diff --git a/tensorflow/lite/nnapi/nnapi_implementation.cc b/tensorflow/lite/nnapi/nnapi_implementation.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b3de3ca5946f0f96e8da06ada26b4e0c46d0dce --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_implementation.cc @@ -0,0 +1,203 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/nnapi/nnapi_implementation.h" + +#include +#include +#include +#include +#include + +#include + +#ifdef __ANDROID__ +#include +#endif // __ANDROID__ + +#define NNAPI_LOG(format, ...) fprintf(stderr, format "\n", __VA_ARGS__); + +namespace { + +#ifdef __ANDROID__ +int32_t GetAndroidSdkVersion() { + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + int32_t result = 0; + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher than expected; + return 0xffff; + } + result = result * 10 + digit; + } + // TODO(levp): remove once SDK gets updated to 29th level + // Upgrade SDK version for pre-release Q to be able to test functionality + // available from SDK level 29. + if (result == 28) { + char versionCodename[PROP_VALUE_MAX]; + const char* versionCodenameProp = "ro.build.version.codename"; + length = __system_property_get(versionCodenameProp, versionCodename); + if (length != 0) { + if (versionCodename[0] == 'Q') { + return 29; + } + } + } + return result; + } + return 0; +} +#endif // __ANDROID__ + +void* LoadFunction(void* handle, const char* name, bool optional) { + if (handle == nullptr) { + return nullptr; + } + void* fn = dlsym(handle, name); + if (fn == nullptr && !optional) { + NNAPI_LOG("nnapi error: unable to open function %s", name); + } + return fn; +} + +#ifndef __ANDROID__ +// Add /dev/shm implementation of shared memory for non-Android platforms +int ASharedMemory_create(const char* name, size_t size) { + int fd = shm_open(name, O_RDWR | O_CREAT, 0644); + if (fd < 0) { + return fd; + } + int result = ftruncate(fd, size); + if (result < 0) { + close(fd); + return -1; + } + return fd; +} +#endif // __ANDROID__ + +#define LOAD_FUNCTION(handle, name) \ + nnapi.name = reinterpret_cast( \ + LoadFunction(handle, #name, /*optional*/ false)); + +#define LOAD_FUNCTION_OPTIONAL(handle, name) \ + nnapi.name = reinterpret_cast( \ + LoadFunction(handle, #name, /*optional*/ true)); + +const NnApi LoadNnApi() { + NnApi nnapi = {}; + nnapi.android_sdk_version = 0; + +#ifdef __ANDROID__ + void* libandroid = nullptr; + nnapi.android_sdk_version = GetAndroidSdkVersion(); + if (nnapi.android_sdk_version < 27) { + NNAPI_LOG("nnapi error: requires android sdk version to be at least %d", + 27); + nnapi.nnapi_exists = false; + return nnapi; + } + libandroid = dlopen("libandroid.so", RTLD_LAZY | RTLD_LOCAL); + if (libandroid == nullptr) { + NNAPI_LOG("nnapi error: unable to open library %s", "libandroid.so"); + } +#endif // __ANDROID__ + + void* libneuralnetworks = nullptr; + // TODO(b/123243014): change RTLD_LOCAL? Assumes there can be multiple + // instances of nn api RT + libneuralnetworks = dlopen("libneuralnetworks.so", RTLD_LAZY | RTLD_LOCAL); + if (libneuralnetworks == nullptr) { + NNAPI_LOG("nnapi error: unable to open library %s", "libneuralnetworks.so"); + } + + nnapi.nnapi_exists = libneuralnetworks != nullptr; + + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksMemory_createFromFd); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksMemory_free); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksModel_create); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksModel_free); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksModel_finish); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksModel_addOperand); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksModel_setOperandValue); + LOAD_FUNCTION_OPTIONAL( + libneuralnetworks, + ANeuralNetworksModel_setOperandSymmPerChannelQuantParams); + LOAD_FUNCTION(libneuralnetworks, + ANeuralNetworksModel_setOperandValueFromMemory); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksModel_addOperation); + LOAD_FUNCTION(libneuralnetworks, + ANeuralNetworksModel_identifyInputsAndOutputs); + LOAD_FUNCTION(libneuralnetworks, + ANeuralNetworksModel_relaxComputationFloat32toFloat16); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksCompilation_create); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksCompilation_free); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksCompilation_setPreference); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksCompilation_finish); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksExecution_create); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksExecution_free); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksExecution_setInput); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksExecution_setInputFromMemory); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksExecution_setOutput); + LOAD_FUNCTION(libneuralnetworks, + ANeuralNetworksExecution_setOutputFromMemory); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksExecution_startCompute); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksEvent_wait); + LOAD_FUNCTION(libneuralnetworks, ANeuralNetworksEvent_free); +#ifdef __ANDROID__ + LOAD_FUNCTION(libandroid, ASharedMemory_create); +#else + nnapi.ASharedMemory_create = ASharedMemory_create; +#endif // __ANDROID__ + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworks_getDeviceCount); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworks_getDevice); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksDevice_getName); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksDevice_getVersion); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksDevice_getFeatureLevel); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksDevice_getType); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksModel_getSupportedOperationsForDevices); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksCompilation_createForDevices); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksCompilation_setCaching); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksExecution_compute); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_getOutputOperandRank); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_getOutputOperandDimensions); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksBurst_create); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksBurst_free); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_burstCompute); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksMemory_createFromAHardwareBuffer); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_setMeasureTiming); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_getDuration); + return nnapi; +} + +} // namespace + +const NnApi* NnApiImplementation() { + static const NnApi nnapi = LoadNnApi(); + return &nnapi; +} diff --git a/tensorflow/lite/nnapi/nnapi_implementation.h b/tensorflow/lite/nnapi/nnapi_implementation.h new file mode 100644 index 0000000000000000000000000000000000000000..b42c189d523dd9f5a5d014dc318d3e8b46936fde --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_implementation.h @@ -0,0 +1,1017 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_NNAPI_NNAPI_IMPLEMENTATION_H_ +#define TENSORFLOW_LITE_NNAPI_NNAPI_IMPLEMENTATION_H_ + +#include +#include +#include + +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" + +struct NnApi { + bool nnapi_exists; + int32_t android_sdk_version; + + /** + * Creates a shared memory object from a file descriptor. + * + * The shared memory is backed by a file descriptor via mmap. + * See {@link ANeuralNetworksMemory} for a description on how to use + * this shared memory. + * + * @param size The requested size in bytes. + * Must not be larger than the file size. + * @param prot The desired memory protection for the mapping. + * It is either PROT_NONE or the bitwise OR of one or + * more of the following flags: PROT_READ, PROT_WRITE. + * @param fd The requested file descriptor. + * The file descriptor has to be mmap-able. The file + * descriptor will be duplicated. + * @param offset The offset to the beginning of the file of the area to map. + * The offset has to be aligned to a page size. + * @param memory The memory object to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if the request completed normally. + */ + int (*ANeuralNetworksMemory_createFromFd)(size_t size, int protect, int fd, + size_t offset, + ANeuralNetworksMemory** memory); + + /** + * Delete a memory object. + * + * Destroys the object used by the run time to keep track of the memory. + * This will free the underlying actual memory if no other code has open + * handles to this memory. + * + * @param memory The memory object to be freed. + */ + void (*ANeuralNetworksMemory_free)(ANeuralNetworksMemory* memory); + + /** + * Create an empty {@link ANeuralNetworksModel}. + * + *

This only creates the object. Computation is performed once + * {@link ANeuralNetworksExecution_startCompute} is invoked. + * + * The model should be constructed with calls to + * {@link ANeuralNetworksModel_addOperation} and + * {@link ANeuralNetworksModel_addOperand} + * + *

{@link ANeuralNetworksModel_finish} should be called once the model + * has been fully constructed.

+ * + *

{@link ANeuralNetworksModel_free} should be called once the model + * is no longer needed.

+ * + * @param model The {@link ANeuralNetworksModel} to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_create)(ANeuralNetworksModel** model); + + /** + * Destroy a model. + * + * The model need not have been finished by a call to + * {@link ANeuralNetworksModel_finish}. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be destroyed. Passing NULL is acceptable and + * results in no operation. + */ + void (*ANeuralNetworksModel_free)(ANeuralNetworksModel* model); + + /** + * Indicate that we have finished modifying a model. Required before + * calling {@link ANeuralNetworksCompilation_compile}. + * + * An application is responsible to make sure that no other thread uses + * the model at the same time. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be finished. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_finish)(ANeuralNetworksModel* model); + + /** + * Add an operand to a model. + * + * The order in which the operands are added is important. The first one added + * to a model will have the index value 0, the second 1, etc. These indexes + * are used as operand identifiers in + * {@link ANeuralNetworksModel_addOperation}, + * {@link ANeuralNetworksExecution_setInput}, + * {@link ANeuralNetworksExecution_setInputFromMemory}, + * {@link ANeuralNetworksExecution_setOutput}, + * {@link ANeuralNetworksExecution_setOutputFromMemory} and + * {@link ANeuralNetworksExecution_setOperandValue}. + * + * To build a model that can accommodate inputs of various sizes, as you may + * want to do for a CNN, set the size of the dimensions that will vary at run + * time to 0. If you do so, provide the full dimensions when calling + * {@link ANeuralNetworksExecution_setInput} or {@link + * ANeuralNetworksExecution_setInputFromMemory}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param type The {@link ANeuralNetworksOperandType} that describes the shape + * of the operand. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_addOperand)( + ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type); + + /** + * Sets an operand to a constant value. + * + * For scalar values, the content of buffer is copied into the model. + * + * For tensor values, a pointer to the buffer is stored within the model. + * The application is responsible for not changing the content of this region + * until all executions using this model have completed. As the data may + * be copied during processing, modifying the data after this call yields + * undefined results. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param buffer A pointer to the data to use. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_setOperandValue)(ANeuralNetworksModel* model, + int32_t index, const void* buffer, + size_t length); + + /** + * Sets an operand's per channel quantization parameters. + * + * Sets parameters required by a tensor of type + * {@link ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL}. + * This function must be called for every tensor of type + * {@link ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL} before + * calling {@link ANeuralNetworksModel_finish}. + * + * Available since API level 29. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param channelQuant The per channel quantization parameters for the + * operand. No memory in this struct needs to outlive the + * call to this function. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_setOperandSymmPerChannelQuantParams)( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksSymmPerChannelQuantParams* channelQuant); + + /** + * Sets an operand to a value stored in a memory object. + * + * The content of the memory is not copied. A reference to that memory is + * stored inside the model. The application is responsible for not changing + * the content of the memory region until all executions using this model have + * completed. + * As the data may be copied during processing, modifying the data after this + * call yields undefined results. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param buffer A pointer to the data to use. + * @param memory The memory containing the data. + * @param offset This specifies the location of the data within the memory. + * The offset is in bytes from the start of memory. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_setOperandValueFromMemory)( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksMemory* memory, size_t offset, size_t length); + + /** + * Add an operation to a model. + * + * @param model The model to be modified. + * @param type The type of the operation. + * @param inputCount The number of entries in the inputs array. + * @param inputs An array of indexes identifying each operand. + * @param outputCount The number of entries in the outputs array. + * @param outputs An array of indexes identifying each operand. + * + * The operands specified by inputs and outputs must have been + * previously added by calls to {@link ANeuralNetworksModel_addOperand}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksModel_addOperation)(ANeuralNetworksModel* model, + ANeuralNetworksOperationType type, + uint32_t inputCount, + const uint32_t* inputs, + uint32_t outputCount, + const uint32_t* outputs); + + /** + * Specifies which operands will be the model's inputs and outputs. + * + * An operand cannot be used for both input and output. Doing so will + * return an error. + * + * @param model The model to be modified. + * @param inputCount The number of entries in the inputs array. + * @param inputs An array of indexes identifying the input operands. + * @param outputCount The number of entries in the outputs array. + * @param outputs An array of indexes identifying the output operands. + * + * The operands specified by inputs and outputs must have been + * previously added by calls to {@link ANeuralNetworksModel_addOperand}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + */ + int (*ANeuralNetworksModel_identifyInputsAndOutputs)( + ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, + uint32_t outputCount, const uint32_t* outputs); + + /** + * Specifies whether {@link ANEURALNETWORKS_TENSOR_FLOAT32} is allowed to be + * calculated with range and/or precision as low as that of the + * IEEE 754 16-bit floating-point format. By default, + * {@link ANEURALNETWORKS_TENSOR_FLOAT32} must be calculated using at least + * the range and precision of the IEEE 754 32-bit floating-point format. + * + * @param model The model to be modified. + * @param allow 'true' indicates {@link ANEURALNETWORKS_TENSOR_FLOAT32} may be + * calculated with range and/or precision as low as that of the + * IEEE 754 16-bit floating point format. 'false' indicates + * {@link ANEURALNETWORKS_TENSOR_FLOAT32} must be calculated + * using at least the range and precision of the IEEE 754 32-bit + * floating point format. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * Available since API level 28. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + */ + int (*ANeuralNetworksModel_relaxComputationFloat32toFloat16)( + ANeuralNetworksModel* model, bool allow); + + /** + * Create a {@link ANeuralNetworksCompilation} to compile the given model. + * This only creates the object. Compilation is only performed once + * {@link ANeuralNetworksCompilation_start} is invoked. + * + *

The provided model must outlive the compilation.

+ * + * The model must already have been finished by a call to + * {@link ANeuralNetworksModel_finish}. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param model The {@link ANeuralNetworksModel} to be compiled. + * @param compilation The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the model is invalid. + */ + int (*ANeuralNetworksCompilation_create)( + ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation); + + /** + * Destroy a compilation. + * + *

If called on a compilation for which + * {@link ANeuralNetworksCompilation_start} has been called, the + * function will return immediately but will mark the compilation to be + * deleted once the compilation completes. The + * {@link ANeuralNetworksCompilation_wait} will return ERROR_DELETED. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be destroyed. Passing NULL is + * acceptable and results in no operation. + */ + void (*ANeuralNetworksCompilation_free)( + ANeuralNetworksCompilation* compilation); + + /** + * Sets the execution preference. + * + *

Provides guidance to the runtime when trade-offs are possible.

+ * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be modified. + * @param preference Either {@link PREFER_LOW_POWER}, + * {@link PREFER_SINGLE_FAST_ANSWER}, or + * {@link PREFER_SUSTAINED_SPEED}. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksCompilation_setPreference)( + ANeuralNetworksCompilation* compilation, int32_t preference); + + /** + * Waits until the compilation completes. + * + * More than one thread can wait on a compilation. When the compilation + * completes, all threads will be released. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @return ANEURALNETWORKS_NO_ERROR if the compilation completed normally. + */ + int (*ANeuralNetworksCompilation_finish)( + ANeuralNetworksCompilation* compilation); + + /** + * Create a {@link ANeuralNetworksExecution} to apply the given compilation. + * This only creates the object. Computation is only performed once + * {@link ANeuralNetworksExecution_startCompute} is invoked. + * + *

The provided compilation must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated. + * @param execution The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the compilation is invalid. + */ + int (*ANeuralNetworksExecution_create)( + ANeuralNetworksCompilation* compilation, + ANeuralNetworksExecution** execution); + + /** + * Destroy an execution. + * + *

If called on an execution for which + * {@link ANeuralNetworksExecution_startCompute} has been called, the + * function will return immediately but will mark the execution to be deleted + * once the computation completes. The {link ANeuralNetworksExecution_wait} + * will return ANEURALNETWORKS_ERROR_DELETED. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be destroyed. Passing NULL is acceptable + * and results in no operation. + */ + void (*ANeuralNetworksExecution_free)(ANeuralNetworksExecution* execution); + + /** + * Associate a user buffer with an input of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided buffer must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be modified. + * @param index The index of the input argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is + * not the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This should be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other properties of the type must be the same as + * specified in the model. If the type is the same as specified + * when the model was built, NULL can be passed. + * @param buffer The buffer containing the data. + * @param length The length in bytes of the buffer. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the input. + */ + int (*ANeuralNetworksExecution_setInput)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const void* buffer, + size_t length); + + /** + * Associate part of a memory object with an input of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided memory must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be modified. + * @param index The index of the input argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is + * not the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param memory The memory containing the data. + * @param offset This specifies the location of the data within the memory. + * The offset is in bytes from the start of memory. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the input. + */ + int (*ANeuralNetworksExecution_setInputFromMemory)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, + const ANeuralNetworksMemory* memory, size_t offset, size_t length); + + /** + * Associate a user buffer with an output of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided buffer must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be modified. + * @param index The index of the output argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is + * not the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param buffer The buffer where the data is to be written. + * @param length The length in bytes of the buffer. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the output. + */ + int (*ANeuralNetworksExecution_setOutput)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, void* buffer, size_t length); + + /** + * Associate part of a memory object with an output of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided memory must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be modified. + * @param index The index of the output argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is + * not the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param memory The memory where the data is to be stored. + * @param offset This specifies the location of the data within the memory. + * The offset is in bytes from the start of memory. + * @param length The length in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the output. + */ + int (*ANeuralNetworksExecution_setOutputFromMemory)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, + const ANeuralNetworksMemory* memory, size_t offset, size_t length); + + /** + * Schedule evaluation of the execution. + * + *

Schedules evaluation of the execution. Once the model has been + * applied and the outputs are ready to be consumed, the execution will be + * signaled. Use {@link ANeuralNetworksExecution_wait} to wait for that + * signal. + *

+ * + * Multiple executions can be scheduled and evaluated concurrently, and + * compilations can be performed concurrently with executions. The runtime + * makes no guarantee on the ordering of the completion of compilations and + * executions. If it's important to the application, the application should + * enforce the ordering by using {@link ANeuralNetworksCompilation_wait} and + * {@link ANeuralNetworksExecution_wait}. + * + * ANeuralNetworksExecution_wait must be called to recuperate the resources + * used by the execution. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be scheduled and executed. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksExecution_startCompute)( + ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event); + + /** + * Waits until the execution completes. + * + * More than one thread can wait on an event. When the execution completes, + * all threads will be released. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + */ + int (*ANeuralNetworksEvent_wait)(ANeuralNetworksEvent* event); + + /** + * Destroys the event. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + */ + void (*ANeuralNetworksEvent_free)(ANeuralNetworksEvent* event); + + // ASharedMemory_create was added in Android 8.0, so safe to use with NNAPI + // which was added in 8.1. + int (*ASharedMemory_create)(const char* name, size_t size); + + /** + * Get the number of available devices. + * + * @param numDevices Used to return the number of devices. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworks_getDeviceCount)(uint32_t* numDevices); + + /** + * Get the representation of the specified device. + * + * @param devIndex The index of the specified device. Must be less than the + * number of available devices. + * @param device The representation of the specified device. + * The same representation will always be returned for the + * specified device. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + + int (*ANeuralNetworks_getDevice)(uint32_t devIndex, + ANeuralNetworksDevice** device); + + /** + * Get the name of the specified device. + * + * @param device The representation of the specified device. + * @param name The returned name of the specified device. The name will be + * in UTF-8 and will be null-terminated. It will be recognizable + * as a known device name rather than a cryptic string. For + * devices with API level 29 and above, the format of the name is + * {VENDOR}-{DEVICE}, e.g. “google-ipu”. For devices with feature + * level 28 or lower, the name will always be “unknown-device”. + * The name will remain valid for the duration of the application. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworksDevice_getName)(const ANeuralNetworksDevice* device, + const char** name); + + /** + * Get the version of the driver implementation of the specified device. + * + * It’s the responsibility of the driver implementor to insure that this + * version string uniquely distinguishes this implementation from all previous + * implementations. + * + * This version string must not be confused with the feature level which is + * solely defined by {@link ANeuralNetworksDevice_getFeatureLevel}. There is + * no implicit ordering of the versions. For example, it is not possible to + * filter all drivers older than a certain version. + * + * Application developers may use this version string to avoid or prefer + * specific driver implementations. For example, an application may want to do + * so because: + * - A specific version of the driver does not provide the required + * performance, perhaps because of a performance regression. + * - A specific version of the driver has a bug or returns results that + * don’t match the minimum precision requirement for the application. + * + * @param device The representation of the specified device. + * @param version The returned version string of the driver for the specified + * device. The string will be in UTF-8 and will be + * null-terminated. For devices with feature level 28 or lower, + * "UNKNOWN" will be returned. The version string will remain + * valid for the duration of the application. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworksDevice_getVersion)(const ANeuralNetworksDevice* device, + const char** version); + + /** + * Get the supported NNAPI version of the specified device. + * + * Each device has a supported feature level, which is the most advanced + * feature this driver implements. For example, if the driver implements the + * features introduced in Android P, but does not implement the features + * introduced after Android P, the value would be 28. Developers could decide + * whether or not the specified device should be used for a Model that has + * certain feature requirements. + * + * @param device The representation of the specified device. + * @param featureLevel The API level of the most advanced feature this driver + * implements. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworksDevice_getFeatureLevel)( + const ANeuralNetworksDevice* device, int64_t* featureLevel); + + /** + * Get the type of a given device. + * + * The device type can be used to help application developers to distribute + * Machine Learning workloads and other workloads such as graphical rendering. + * E.g., for an app which renders AR scenes based on real time object + * detection results, the developer could choose an ACCELERATOR type device + * for ML workloads, and reserve GPU for graphical rendering. + * + * @param device The representation of the specified device. + * @param type The returned {@link DeviceTypeCode} of the specified device. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworksDevice_getType)(const ANeuralNetworksDevice* device, + int32_t* type); + + /** + * Get the supported operations for a specified set of devices. If multiple + * devices are selected, the supported operation list is a union of supported + * operations of all selected devices. + * + * @param model The model to be queried. + * @param devices The set of devices. Must not contain duplicates. + * @param numDevices The number of devices in the set. + * @param supportedOps The boolean array to be filled. True means supported. + * The size of the boolean array must be at least as large + * as the number of operations in the model. The order of + * elements in the supportedOps array matches the order in + * which the corresponding operations were added to the + * model. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworksModel_getSupportedOperationsForDevices)( + const ANeuralNetworksModel* model, + const ANeuralNetworksDevice* const* devices, uint32_t numDevices, + bool* supportedOps); + + /** + * Create a {@link ANeuralNetworksCompilation} to compile the given model for + * a specified set of devices. If more than one device is specified, the + * compilation will distribute the workload automatically across the devices. + * The model must be fully supported by the specified set of devices. This + * means that ANeuralNetworksModel_getSupportedOperationsForDevices() must + * have returned true for every operation for that model/devices pair. + * + * @param model The {@link ANeuralNetworksModel} to be compiled. + * @param devices The set of devices. Must not contain duplicates. + * @param numDevices The number of devices in the set. + * @param compilation The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the model is invalid. + * + * Available since API level 29. + */ + int (*ANeuralNetworksCompilation_createForDevices)( + ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, + uint32_t numDevices, ANeuralNetworksCompilation** compilation); + + /** + * Sets the compilation caching signature and the cache directory. + * + * Provides optional caching information to the runtime for faster repeated + * compilation. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be modified. + * @param cacheDir The cache directory to store and retrieve caching data. It + * is recommended to use the code_cache provided by the + * Android runtime. If not using the code_cache, the user + * should choose a directory local to the application, and is + * responsible to manage and clean the cache entries. + * @param token The token provided by the user to specify a model, must be of + * length ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN. The user + * should ensure that the token is unique to a model within the + * application. The NNAPI runtime will not detected token + * collisions. If there is a collision, the compilation outcome + * may be incorrect without notifying with error. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 29. + */ + int (*ANeuralNetworksCompilation_setCaching)( + ANeuralNetworksCompilation* compilation, const char* cacheDir, + const uint8_t* token); + + /** + * Schedule synchronous evaluation of the execution. + * + *

Schedules synchronous evaluation of the execution. Returns once the + * execution has completed and the outputs are ready to be consumed. + *

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * See {@link ANeuralNetworksExecution_startCompute} for asynchronous + * execution. Synchronous execution incurs lower overhead than asynchronous + * execution. + * + * Available since API level 29. + * + * @param execution The execution to be scheduled and executed. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + * ANEURALNETWORKS_UNMAPPABLE if the execution input or output memory + * cannot be properly mapped. + */ + int (*ANeuralNetworksExecution_compute)(ANeuralNetworksExecution* execution); + + /** + * Get the dimensional information of the specified output operand of the + * model of the + * {@link ANeuralNetworksExecution}. + * + * On asynchronous execution initiated by {@link + * ANeuralNetworksExecution_startCompute}, + * {@link ANeuralNetworksEvent_wait} must be called prior to this function to + * recuperate the resources used by the execution. + * + * @param execution The execution to be queried. + * @param index The index of the output argument we are querying. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is + * not the index associated with + * {@link ANeuralNetworksModel_addOperand}. + * @param rank The rank of the output operand. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, + * ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE if the target output is + * provided an insufficient buffer at execution time, + * ANEURALNETWORKS_BAD_DATA if the index is invalid. + * + * Available since API level 29. + */ + int (*ANeuralNetworksExecution_getOutputOperandRank)( + ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank); + + /** + * Get the dimensional information of the specified output operand of the + * model of the + * {@link ANeuralNetworksExecution}. The target output operand cannot be a + * scalar. + * + * On asynchronous execution initiated by {@link + * ANeuralNetworksExecution_startCompute}, + * {@link ANeuralNetworksEvent_wait} must be called prior to this function to + * recuperate the resources used by the execution. + * + * @param execution The execution to be queried. + * @param index The index of the output argument we are querying. It is an + * index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is + * not the index associated with + * {@link ANeuralNetworksModel_addOperand}. + * @param dimensions The dimension array to be filled. The size of the array + * must be exactly as large as the rank of the output + * operand to be queried in the model. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, + * ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE if the target output is + * provided an insufficient buffer at execution time, + * ANEURALNETWORKS_BAD_DATA if the index is invalid or if the target + * is a scalar. + * + * Available since API level 29. + */ + int (*ANeuralNetworksExecution_getOutputOperandDimensions)( + ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions); + + /** + * Create a {@link ANeuralNetworksBurst} to apply the given compilation. + * This only creates the burst object. Computation is only performed once + * {@link ANeuralNetworksExecution_burstCompute} is invoked with a valid + * {@link ANeuralNetworksExecution} and {@link ANeuralNetworksBurst}. + * + *

The provided compilation must outlive the burst object.

+ * + * Available since API level 29. + * + * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated. + * @param burst The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the compilation is invalid. + */ + int (*ANeuralNetworksBurst_create)(ANeuralNetworksCompilation* compilation, + ANeuralNetworksBurst** burst); + + /** + * Destroys the burst object. + * + * Available since API level 29. + * + * @param burst The burst object to be destroyed. Passing NULL is acceptable + * and results in no operation. + */ + void (*ANeuralNetworksBurst_free)(ANeuralNetworksBurst* burst); + + /** + * Schedule synchronous evaluation of the execution on a burst object. + * + *

Schedules synchronous evaluation of the execution. Returns once the + * execution has completed and the outputs are ready to be consumed.

+ * + *

There must be at most one {@link ANeuralNetworksExecution} processing at + * any given time for any given burst object. Any + * {@link ANeuralNetworksExecution} launched before the previous has finished + * will result in ANEURALNETWORKS_BAD_STATE.

+ * + * Available since API level 29. + * + * @param burst The burst object to execute on. + * @param execution The execution to be scheduled and executed. The execution + * must be created from the same {@link + * ANeuralNetworksCompilation} as the burst object. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + */ + int (*ANeuralNetworksExecution_burstCompute)( + ANeuralNetworksExecution* execution, ANeuralNetworksBurst* burst); + + /** + * Creates a shared memory object from an AHardwareBuffer handle. + * + * If the shared memory is backed by an AHardwareBuffer of + * AHARDWAREBUFFER_FORMAT_BLOB format, it can be used the same way as + * shared memory created from a file handle. See + * {@link ANeuralNetworksMemory} for a description on how to use this + * shared memory. + * + * If the shared memory is backed by an AHardwareBuffer of a format other + * than AHARDWAREBUFFER_FORMAT_BLOB, it can only be used for Model inputs + * and outputs. When calling + * {@link ANeuralNetworksExecution_setInputFromMemory} or + * {@link ANeuralNetworksExecution_setOutputFromMemory} with the shared + * memory, both offset and length must be set to zero and the entire + * memory region will be associated with the specified input or output + * operand. There is no guarantee that an arbitrary AHardwareBuffer_Format + * and AHardwareBuffer_UsageFlags combination can be used by arbitrary + * devices. The execution will fail if selected set of devices cannot + * consume the buffer. + * + * Calling {@link ANeuralNetworksModel_setOperandValueFromMemory} with + * shared memory backed by an AHardwareBuffer of a format other than + * AHARDWAREBUFFER_FORMAT_BLOB is disallowed. + * + * TODO(miaowang): add documentation about intended usage with + * introspection API. + * + * Available since API level 29. + * + * @param ahwb The AHardwareBuffer handle. + * @param memory The memory object to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if the request completed normally. + * + * @see AHardwareBuffer + */ + int (*ANeuralNetworksMemory_createFromAHardwareBuffer)( + const AHardwareBuffer* ahwb, ANeuralNetworksMemory** memory); + + /** + * Specifies whether duration of the {@link ANeuralNetworksExecution} is to be + * measured. By default, duration is not measured. + * + * The {@link ANeuralNetworksExecution} must have been created with + * {@link ANeuralNetworksCompilation_createForDevices} with numDevices = 1. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * Available since API level 29. + * + * @param execution The execution to be modified. + * @param measure 'true' if duration is to be measured, 'false' if not. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksExecution_setMeasureTiming)( + ANeuralNetworksExecution* execution, bool measure); + + /** + * Get the time spent in the specified {@link ANeuralNetworksExecution}, in + * nanoseconds. The execution must have completed. + * + * @param execution The execution to be queried. + * @param durationCode The measurement to be queried, specified by {@link + * DurationCode}. + * @param duration The returned duration. If no measurement was requested by + * {@link ANeuralNetworksExecution_setMeasureTiming}, or for + * some other reason the duration is not available, UINT64_MAX will be + * returned. A particular device need not support any given measurement. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksExecution_getDuration)( + const ANeuralNetworksExecution* execution, int32_t durationCode, + uint64_t* duration); + + /**/ +}; + +/** + * Load the NNAPI implementation from the shared libraries. + * The NnApi structure is filled with all the pointers. If one function doesn't + * exist, a null pointer is stored. + */ +const NnApi* NnApiImplementation(); + +#endif // TENSORFLOW_LITE_NNAPI_NNAPI_IMPLEMENTATION_H_ diff --git a/tensorflow/lite/nnapi/nnapi_implementation_disabled.cc b/tensorflow/lite/nnapi/nnapi_implementation_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bc78e53da64b209d53bfcfc97e194e7430f016c --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_implementation_disabled.cc @@ -0,0 +1,20 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/nnapi/nnapi_implementation.h" + +const NnApi* NnApiImplementation() { + static const NnApi nnapi = {}; + return &nnapi; +} diff --git a/tensorflow/lite/nnapi/nnapi_implementation_test.cc b/tensorflow/lite/nnapi/nnapi_implementation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f30b95ec37e3c878d3bdbc1acc96026dfeef9e1 --- /dev/null +++ b/tensorflow/lite/nnapi/nnapi_implementation_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/nnapi/nnapi_implementation.h" +#include + +namespace { + +TEST(NnapiLibTest, NnApiImplementation) { + const NnApi* nnapi = NnApiImplementation(); + EXPECT_NE(nnapi, nullptr); +#ifdef __ANDROID__ + EXPECT_GT(nnapi->android_sdk_version, 0); + if (nnapi.android_sdk_version < 27) { + EXPECT_FALSE(nnapi->nnapi_exists); + EXPECT_EQ(nnapi->ANeuralNetworksMemory_createFromFd, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksMemory_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_finish, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_addOperand, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_setOperandValue, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_setOperandValueFromMemory, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_addOperation, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_identifyInputsAndOutputs, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_relaxComputationFloat32toFloat16, + nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_setPreference, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_finish, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setInput, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setInputFromMemory, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setOutput, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setOutputFromMemory, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_startCompute, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksEvent_wait, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksEvent_free, nullptr); + EXPECT_EQ(nnapi->ASharedMemory_create, nullptr); + } else { + EXPECT_TRUE(nnapi->nnapi_exists); + EXPECT_NE(nnapi->ANeuralNetworksMemory_createFromFd, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksMemory_free, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_create, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_free, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_finish, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_addOperand, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_setOperandValue, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_setOperandValueFromMemory, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_addOperation, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksModel_identifyInputsAndOutputs, nullptr); + if (nnapi->android_sdk_version >= 28) { + // relaxComputationFloat32toFloat16 only available with Android 9.0 (P). + EXPECT_NE(nnapi->ANeuralNetworksModel_relaxComputationFloat32toFloat16, + nullptr); + } else { + EXPECT_EQ(nnapi->ANeuralNetworksModel_relaxComputationFloat32toFloat16, + nullptr); + } + EXPECT_NE(nnapi->ANeuralNetworksCompilation_create, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksCompilation_free, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksCompilation_setPreference, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksCompilation_finish, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_create, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_free, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_setInput, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_setInputFromMemory, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_setOutput, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_setOutputFromMemory, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksExecution_startCompute, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksEvent_wait, nullptr); + EXPECT_NE(nnapi->ANeuralNetworksEvent_free, nullptr); + EXPECT_NE(nnapi->ASharedMemory_create, nullptr); + // TODO(b/123423795): Test Q-specific APIs after release. + } +#else + EXPECT_FALSE(nnapi->nnapi_exists); + EXPECT_EQ(nnapi->android_sdk_version, 0); + EXPECT_EQ(nnapi->ANeuralNetworksMemory_createFromFd, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksMemory_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_finish, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_addOperand, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_setOperandValue, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_setOperandSymmPerChannelQuantParams, + nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_setOperandValueFromMemory, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_addOperation, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_identifyInputsAndOutputs, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_relaxComputationFloat32toFloat16, + nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_setPreference, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_finish, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setInput, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setInputFromMemory, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setOutput, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setOutputFromMemory, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_startCompute, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksEvent_wait, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksEvent_free, nullptr); + EXPECT_NE(nnapi->ASharedMemory_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworks_getDeviceCount, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworks_getDevice, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksDevice_getName, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksDevice_getVersion, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksDevice_getFeatureLevel, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksModel_getSupportedOperationsForDevices, + nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_createForDevices, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksCompilation_setCaching, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_compute, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_getOutputOperandRank, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_getOutputOperandDimensions, + nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksBurst_create, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksBurst_free, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_burstCompute, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksMemory_createFromAHardwareBuffer, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_setMeasureTiming, nullptr); + EXPECT_EQ(nnapi->ANeuralNetworksExecution_getDuration, nullptr); +#endif +} + +} // namespace diff --git a/tensorflow/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc index dc8e81cde758f6d187046d865d42141200f753bc..443651b9910fc0c4b5388409ce450f9638de4898 100644 --- a/tensorflow/lite/nnapi_delegate.cc +++ b/tensorflow/lite/nnapi_delegate.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/model.h" -#include "tensorflow/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" #ifdef __ANDROID__ #include @@ -84,56 +84,27 @@ void logError(const char* format, ...) { static const int64_t kOperandIdNotSet = -1; static const int64_t kOperandNotNeeded = -2; -namespace { - -int32_t GetAndroidSdkVersion() { -#ifdef __ANDROID__ - const char* sdkProp = "ro.build.version.sdk"; - char sdkVersion[PROP_VALUE_MAX]; - int length = __system_property_get(sdkProp, sdkVersion); - if (length != 0) { - for (int i = 0; i < length; ++i) { - int digit = sdkVersion[i] - '0'; - if (digit < 0 || digit > 9) { - // Non-numeric SDK version, assume it's higher then expected; - return 0xFFFF; - } - } - return atoi(sdkVersion); - } - FATAL("No %s prop", sdkProp); -#endif // __ANDROID__ - return 0; -} - -int32_t GetAndroidSdkVersionCached() { - static int32_t androidSdkVersion = GetAndroidSdkVersion(); - return androidSdkVersion; -} - -} // namespace - NNAPIAllocation::NNAPIAllocation(const char* filename, ErrorReporter* error_reporter) : MMAPAllocation(filename, error_reporter) { if (mmapped_buffer_ != MAP_FAILED) - CHECK_NN(ANeuralNetworksMemory_createFromFd(buffer_size_bytes_, PROT_READ, - mmap_fd_, 0, &handle_)); + CHECK_NN(NnApiImplementation()->ANeuralNetworksMemory_createFromFd( + buffer_size_bytes_, PROT_READ, mmap_fd_, 0, &handle_)); } NNAPIAllocation::~NNAPIAllocation() { if (handle_) { - ANeuralNetworksMemory_free(handle_); + NnApiImplementation()->ANeuralNetworksMemory_free(handle_); } } NNAPIDelegate::~NNAPIDelegate() { if (nn_compiled_model_) { - ANeuralNetworksCompilation_free(nn_compiled_model_); + NnApiImplementation()->ANeuralNetworksCompilation_free(nn_compiled_model_); nn_compiled_model_ = nullptr; } if (nn_model_) { - ANeuralNetworksModel_free(nn_model_); + NnApiImplementation()->ANeuralNetworksModel_free(nn_model_); nn_model_ = nullptr; // TODO(aselle): Is this thread-safe and callable multiple times? } @@ -145,6 +116,7 @@ TfLiteStatus addTensorOperands(tflite::Subgraph* subgraph, ANeuralNetworksModel* nn_model, uint32_t* no_of_operands_added, std::vector* nnapi_ids) { + const NnApi* nnapi = NnApiImplementation(); uint32_t next_id = 0; for (size_t i = 0; i < subgraph->tensors_size(); i++) { // Skip temporaries and RNN back-edges. @@ -198,24 +170,24 @@ TfLiteStatus addTensorOperands(tflite::Subgraph* subgraph, nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; RETURN_ERROR_IF_NN_FAILED( - ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)); // TODO(aselle): Based on Michael's suggestion, limiting this to read // only memory if (tensor->allocation_type == kTfLiteMmapRo) { if (const NNAPIAllocation* alloc = dynamic_cast( static_cast(tensor->allocation))) { RETURN_ERROR_IF_NN_FAILED( - ANeuralNetworksModel_setOperandValueFromMemory( + nnapi->ANeuralNetworksModel_setOperandValueFromMemory( nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw), tensor->bytes)); } else { - RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue( + RETURN_ERROR_IF_NN_FAILED(nnapi->ANeuralNetworksModel_setOperandValue( nn_model, next_id, tensor->data.raw, tensor->bytes)); } } else if (tensor->bytes == 0) { // These size 0 tensors are optional tensors reserved. - RETURN_ERROR_IF_NN_FAILED( - ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0)); + RETURN_ERROR_IF_NN_FAILED(nnapi->ANeuralNetworksModel_setOperandValue( + nn_model, next_id, nullptr, 0)); } ++next_id; @@ -244,6 +216,7 @@ TfLiteStatus AddOpsAndParams( uint32_t next_id, std::vector* model_state_inputs, std::vector* model_state_outputs, const std::vector& tensor_id_to_nnapi_id) { + const NnApi* nnapi = NnApiImplementation(); for (size_t i = 0; i < subgraph->nodes_size(); i++) { const auto* node_and_registration = subgraph->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; @@ -258,21 +231,21 @@ TfLiteStatus AddOpsAndParams( MapAndAddTensorIds(node.outputs->data, node.outputs->size, &augmented_outputs, tensor_id_to_nnapi_id); - auto add_scalar_int32 = [&nn_model, &augmented_inputs, + auto add_scalar_int32 = [nnapi, &nn_model, &augmented_inputs, &next_id](int value) { ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value, - sizeof(int32_t))) + CHECK_NN(nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(nnapi->ANeuralNetworksModel_setOperandValue( + nn_model, next_id, &value, sizeof(int32_t))) augmented_inputs.push_back(next_id++); }; - auto add_scalar_float32 = [&nn_model, &augmented_inputs, + auto add_scalar_float32 = [nnapi, &nn_model, &augmented_inputs, &next_id](float value) { ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_FLOAT32}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value, - sizeof(float))) + CHECK_NN(nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(nnapi->ANeuralNetworksModel_setOperandValue( + nn_model, next_id, &value, sizeof(float))) augmented_inputs.push_back(next_id++); }; @@ -281,8 +254,8 @@ TfLiteStatus AddOpsAndParams( .type = ANEURALNETWORKS_TENSOR_INT32, .dimensionCount = 1, .dimensions = &num_values}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue( + CHECK_NN(nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(nnapi->ANeuralNetworksModel_setOperandValue( nn_model, next_id, values, sizeof(int32_t) * num_values)); augmented_inputs.push_back(next_id++); }; @@ -291,15 +264,16 @@ TfLiteStatus AddOpsAndParams( // For each state_out tensor, a corresponding state_in operand needs to be // created for NNAPI. auto duplicate_state_tensor_float32 = - [subgraph, &nn_model, &next_id, &augmented_inputs, &model_state_inputs, - &model_state_outputs](int tensor_id) { + [nnapi, subgraph, &nn_model, &next_id, &augmented_inputs, + &model_state_inputs, &model_state_outputs](int tensor_id) { const TfLiteTensor* tensor = subgraph->tensor(tensor_id); ANeuralNetworksOperandType operand_type{ ANEURALNETWORKS_TENSOR_FLOAT32, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), tensor->params.scale, tensor->params.zero_point}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + CHECK_NN( + nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)); augmented_inputs.push_back(next_id); model_state_inputs->push_back(next_id); model_state_outputs->push_back(tensor_id); @@ -388,7 +362,7 @@ TfLiteStatus AddOpsAndParams( }; // LSTM in NNAPI requires scratch tensor as an output operand. - auto add_lstm_scratch_tensor_float32 = [subgraph, &node, &nn_model, + auto add_lstm_scratch_tensor_float32 = [nnapi, subgraph, &node, &nn_model, &next_id, &augmented_outputs]() { if (node.temporaries->size == 0) return; int scratch_buffer_index = node.temporaries->data[0]; @@ -398,7 +372,7 @@ TfLiteStatus AddOpsAndParams( static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), tensor->params.scale, tensor->params.zero_point}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + CHECK_NN(nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)); augmented_outputs.insert(augmented_outputs.begin(), next_id++); }; @@ -427,15 +401,16 @@ TfLiteStatus AddOpsAndParams( }; // Handle optional input tensors. - auto add_optional_tensors = [&nn_model, &augmented_inputs, + auto add_optional_tensors = [nnapi, &nn_model, &augmented_inputs, &next_id](int nn_type) { for (size_t idx = 0; idx < augmented_inputs.size(); idx++) { if (augmented_inputs[idx] == kOptionalTensor) { const std::vector dim = {0, 0}; ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, - nullptr, 0)) + CHECK_NN( + nnapi->ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(nnapi->ANeuralNetworksModel_setOperandValue( + nn_model, next_id, nullptr, 0)) augmented_inputs[idx] = next_id++; } } @@ -635,6 +610,7 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_SPLIT: case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: + case tflite::BuiltinOperator_COS: case tflite::BuiltinOperator_LOG_SOFTMAX: case tflite::BuiltinOperator_DEQUANTIZE: case tflite::BuiltinOperator_DELEGATE: @@ -687,6 +663,14 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_ABS: case tflite::BuiltinOperator_SPLIT_V: case tflite::BuiltinOperator_UNIQUE: + case tflite::BuiltinOperator_CEIL: + case tflite::BuiltinOperator_REVERSE_V2: + case tflite::BuiltinOperator_ADD_N: + case tflite::BuiltinOperator_GATHER_ND: + case tflite::BuiltinOperator_WHERE: + case tflite::BuiltinOperator_RANK: + case tflite::BuiltinOperator_ELU: + case tflite::BuiltinOperator_REVERSE_SEQUENCE: logError("Op code %d is currently not delegated to NNAPI", builtin); return kTfLiteError; break; @@ -696,13 +680,13 @@ TfLiteStatus AddOpsAndParams( break; } - if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) { + if (nnapi_version == 11 && nnapi->android_sdk_version < 28) { logError("Op %d needs NNAPI1.1", builtin); return kTfLiteError; } // Add the operation. - RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation( + RETURN_ERROR_IF_NN_FAILED(nnapi->ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast(augmented_inputs.size()), augmented_inputs.data(), static_cast(augmented_outputs.size()), @@ -714,9 +698,10 @@ TfLiteStatus AddOpsAndParams( TfLiteStatus NNAPIDelegate::BuildGraph(Subgraph* subgraph) { if (nn_model_ && nn_compiled_model_) return model_status_; + const NnApi* nnapi = NnApiImplementation(); // TODO(aselle): This is not correct. need to handle resize invalidation. if (!nn_model_) { - CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); + CHECK_NN(nnapi->ANeuralNetworksModel_create(&nn_model_)); // Find which tensors should be added to NNAPI. TFLite has temporaries // and RNN back-edges which are are not valid for NNAPI. We look through all @@ -763,21 +748,22 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Subgraph* subgraph) { model_states_outputs_.size(), &augmented_outputs, tensor_id_to_nnapi_id); - CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( + CHECK_NN(nnapi->ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_, static_cast(augmented_inputs.size()), reinterpret_cast(augmented_inputs.data()), static_cast(augmented_outputs.size()), reinterpret_cast(augmented_outputs.data()))); - if (GetAndroidSdkVersionCached() >= 28) { - CHECK_NN(ANeuralNetworksModel_relaxComputationFloat32toFloat16( + if (nnapi->android_sdk_version >= 28) { + CHECK_NN(nnapi->ANeuralNetworksModel_relaxComputationFloat32toFloat16( nn_model_, subgraph->GetAllowFp16PrecisionForFp32())); } - CHECK_NN(ANeuralNetworksModel_finish(nn_model_)); + CHECK_NN(nnapi->ANeuralNetworksModel_finish(nn_model_)); } if (!nn_compiled_model_) { - CHECK_NN(ANeuralNetworksCompilation_create(nn_model_, &nn_compiled_model_)); - CHECK_NN(ANeuralNetworksCompilation_finish(nn_compiled_model_)); + CHECK_NN(nnapi->ANeuralNetworksCompilation_create(nn_model_, + &nn_compiled_model_)); + CHECK_NN(nnapi->ANeuralNetworksCompilation_finish(nn_compiled_model_)); } return kTfLiteOk; } @@ -793,8 +779,10 @@ TfLiteStatus NNAPIDelegate::Invoke(Subgraph* subgraph) { return model_status_; } + const NnApi* nnapi = NnApiImplementation(); ANeuralNetworksExecution* execution = nullptr; - CHECK_NN(ANeuralNetworksExecution_create(nn_compiled_model_, &execution)); + CHECK_NN( + nnapi->ANeuralNetworksExecution_create(nn_compiled_model_, &execution)); // Currently perform deep copy of input buffer for (size_t i = 0; i < subgraph->inputs().size(); i++) { @@ -802,7 +790,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Subgraph* subgraph) { // TODO(aselle): Is this what we want or do we want input instead? // TODO(aselle): This should be called setInputValue maybe to be cons. TfLiteTensor* tensor = subgraph->tensor(input); - CHECK_NN(ANeuralNetworksExecution_setInput( + CHECK_NN(nnapi->ANeuralNetworksExecution_setInput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } @@ -810,7 +798,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Subgraph* subgraph) { for (size_t i = 0; i < subgraph->outputs().size(); i++) { int output = subgraph->outputs()[i]; TfLiteTensor* tensor = subgraph->tensor(output); - CHECK_NN(ANeuralNetworksExecution_setOutput( + CHECK_NN(nnapi->ANeuralNetworksExecution_setOutput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } @@ -822,21 +810,21 @@ TfLiteStatus NNAPIDelegate::Invoke(Subgraph* subgraph) { // Here we are using a deep copy for state_in tensors so that we are not // reading and writing into the same buffer during a invocation. // TODO(miaowang): using double shared buffer to minimize the copies. - CHECK_NN(ANeuralNetworksExecution_setInput( + CHECK_NN(nnapi->ANeuralNetworksExecution_setInput( execution, i + subgraph->inputs().size(), nullptr, tensor->data.raw, tensor->bytes)); // Tell NNAPI where to output the state_out. - CHECK_NN(ANeuralNetworksExecution_setOutput( + CHECK_NN(nnapi->ANeuralNetworksExecution_setOutput( execution, i + subgraph->outputs().size(), nullptr, tensor->data.raw, tensor->bytes)); } // Currently use blocking compute. ANeuralNetworksEvent* event = nullptr; - CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event)); - CHECK_NN(ANeuralNetworksEvent_wait(event)); - ANeuralNetworksEvent_free(event); - ANeuralNetworksExecution_free(execution); + CHECK_NN(nnapi->ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(nnapi->ANeuralNetworksEvent_wait(event)); + nnapi->ANeuralNetworksEvent_free(event); + nnapi->ANeuralNetworksExecution_free(execution); #if 0 printf("From the NN API:\n"); @@ -854,6 +842,8 @@ TfLiteStatus NNAPIDelegate::Invoke(Subgraph* subgraph) { return kTfLiteOk; } -bool NNAPIDelegate::IsSupported() { return NNAPIExists(); } +bool NNAPIDelegate::IsSupported() { + return NnApiImplementation()->nnapi_exists; +} } // namespace tflite diff --git a/tensorflow/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD index 64723ba3856b75a614ded2a134a6a61254b38657..bbc252045baad0316333bf9bc19dd78b8bd58590 100644 --- a/tensorflow/lite/profiling/BUILD +++ b/tensorflow/lite/profiling/BUILD @@ -74,7 +74,6 @@ tf_cc_test( ":profile_summarizer", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", - "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:test_util", "//tensorflow/lite/testing:util", diff --git a/tensorflow/lite/profiling/profile_buffer.h b/tensorflow/lite/profiling/profile_buffer.h index 9aa9e411314b2f389fda1bedaa290a87021ee254..2202df2dbe3613aab1f700b88398cae5b9aa01a0 100644 --- a/tensorflow/lite/profiling/profile_buffer.h +++ b/tensorflow/lite/profiling/profile_buffer.h @@ -128,7 +128,7 @@ class ProfileBuffer { // Returns the profile event at the given index. If the index is invalid a // nullptr is returned. The return event may get overwritten if more events // are added to buffer. - const struct ProfileEvent* const At(int index) const { + const struct ProfileEvent* const At(size_t index) const { size_t size = Size(); if (index >= size) { return nullptr; diff --git a/tensorflow/lite/profiling/profile_buffer_test.cc b/tensorflow/lite/profiling/profile_buffer_test.cc index 6642a15884fdf57cb385e186fd75620183098375..92973302e73c22ad24707983df87001ec81db667 100644 --- a/tensorflow/lite/profiling/profile_buffer_test.cc +++ b/tensorflow/lite/profiling/profile_buffer_test.cc @@ -27,7 +27,7 @@ namespace { std::vector GetProfileEvents(const ProfileBuffer& buffer) { std::vector events; - for (auto i = 0; i < buffer.Size(); i++) { + for (size_t i = 0; i < buffer.Size(); i++) { events.push_back(buffer.At(i)); } return events; @@ -69,7 +69,7 @@ TEST(ProfileBufferTest, OverFlow) { EXPECT_EQ(expected_size, buffer.Size()); } EXPECT_EQ(max_size, buffer.Size()); - for (int j = 0; j < buffer.Size(); ++j) { + for (size_t j = 0; j < buffer.Size(); ++j) { auto event = buffer.At(j); EXPECT_EQ(eventNames[j % 4], event->tag); EXPECT_EQ(ProfileEvent::EventType::DEFAULT, event->event_type); diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 4949d7c92eb6fcf9fcb2c6244a9c888d3a0559d6..a31f6cec707718d0a9c9ba5a96c7625f09cd724e 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -4,12 +4,6 @@ package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "interpreter_test_data", - srcs = glob(["**/testdata/*"]), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "interpreter", srcs = [ @@ -19,7 +13,6 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", - "//tensorflow/python:util", "//third_party/py/numpy", ], ) @@ -27,9 +20,11 @@ py_library( py_test( name = "interpreter_test", srcs = ["interpreter_test.py"], - data = [":interpreter_test_data"], + data = ["//tensorflow/lite/python/testdata:interpreter_test_data"], srcs_version = "PY2AND3", - tags = ["no_oss"], + tags = [ + "no_windows", + ], deps = [ ":interpreter", "//tensorflow/python:client_testlib", @@ -44,6 +39,22 @@ py_binary( srcs = ["tflite_convert.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = [":tflite_convert_main_lib"], +) + +py_library( + name = "tflite_convert_main_lib", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [":tflite_convert_lib"], +) + +py_library( + name = "tflite_convert_lib", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":lite", ], @@ -60,6 +71,8 @@ py_library( ":interpreter", ":lite_constants", ":op_hint", + "//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops", + "//tensorflow/lite/python/optimize:calibrator", "//tensorflow/python:graph_util", "//tensorflow/python:tf_optimizer", "//tensorflow/python/keras", @@ -75,6 +88,36 @@ py_test( shard_count = 4, srcs_version = "PY2AND3", tags = [ + "no_windows", + ], + deps = [ + ":lite", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( + name = "lite_v2_test", + srcs = ["lite_v2_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_windows", + ], + deps = [ + ":lite", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( + name = "lite_flex_test", + srcs = ["lite_flex_test.py"], + srcs_version = "PY2AND3", + tags = [ + # TODO(b/111881877): Enable in oss after resolving op registry issues. "no_oss", "no_windows", ], @@ -145,7 +188,6 @@ py_library( srcs = ["convert_saved_model.py"], srcs_version = "PY2AND3", visibility = [ - "//tensorflow/contrib/lite:__subpackages__", "//tensorflow/lite:__subpackages__", ], deps = [ diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 9c603998717019ac8624868b16d720e300a30efd..c3f15816e256a8da491dda4b702d68f12e2dc59c 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -214,7 +214,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): def tensor_name(x): - return x.name.split(":")[0] + """Returns name of the input tensor.""" + parts = x.name.split(":") + if len(parts) > 2: + raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( + len(parts) - 1)) + + # To be consistent with the tensor naming scheme in tensorflow, we need + # drop the ':0' suffix for the first tensor. + if len(parts) > 1 and parts[1] != "0": + return x.name + return parts[0] # Don't expose these for now. @@ -244,7 +254,7 @@ def build_toco_convert_protos(input_tensors, Args: input_tensors: List of input tensors. Type and shape are computed using - `foo.get_shape()` and `foo.dtype`. + `foo.shape` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). inference_type: Target data type of real-number arrays in the output file. Must be `{tf.float32, tf.uint8}`. (default tf.float32) @@ -347,7 +357,7 @@ def build_toco_convert_protos(input_tensors, "inference_input_type is QUANTIZED_UINT8.") input_array.mean_value, input_array.std_value = quantized_input_stats[idx] if input_shapes is None: - shape = input_tensor.get_shape() + shape = input_tensor.shape else: shape = input_shapes[idx] input_array.shape.dims.extend(map(int, shape)) @@ -423,7 +433,7 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args, Args: input_data: Input data (i.e. often `sess.graph_def`), input_tensors: List of input tensors. Type and shape are computed using - `foo.get_shape()` and `foo.dtype`. + `foo.shape` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). *args: See `build_toco_convert_protos`, **kwargs: See `build_toco_convert_protos`. @@ -443,7 +453,7 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args, return data -@_tf_export("lite.toco_convert") +@_tf_export(v1=["lite.toco_convert"]) @deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): """Convert a model using TOCO. @@ -456,7 +466,7 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): Args: input_data: Input data (i.e. often `sess.graph_def`), input_tensors: List of input tensors. Type and shape are computed using - `foo.get_shape()` and `foo.dtype`. + `foo.shape` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). *args: See `build_toco_convert_protos`, **kwargs: See `build_toco_convert_protos`. diff --git a/tensorflow/lite/python/convert_saved_model.py b/tensorflow/lite/python/convert_saved_model.py index f8d986b746911c68e0589b587ce0beceafc0c534..b085a106f993b0bff63b3ce52ac45f19bbe7c4f2 100644 --- a/tensorflow/lite/python/convert_saved_model.py +++ b/tensorflow/lite/python/convert_saved_model.py @@ -215,8 +215,7 @@ def set_tensor_shapes(tensors, shapes): tensor.set_shape(shape) except ValueError as error: message = ("The shape of tensor '{0}' cannot be changed from {1} to " - "{2}. {3}".format(name, tensor.get_shape(), shape, - str(error))) + "{2}. {3}".format(name, tensor.shape, shape, str(error))) raise ValueError(message) diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py index e270abaa5afa0f2b3bb255e896c706794277c26e..12d8d494c1f1845e5a5f3bd11307cc9c4c1a761e 100644 --- a/tensorflow/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -55,6 +55,17 @@ class ConvertTest(test_util.TensorFlowTestCase): # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): # result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) + def testTensorName(self): + in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) + # out_tensors should have names: "split:0", "split:1", "split:2", "split:3". + out_tensors = array_ops.split( + value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0) + expect_names = ["split", "split:1", "split:2", "split:3"] + + for i in range(len(expect_names)): + got_name = convert.tensor_name(out_tensors[i]) + self.assertEqual(got_name, expect_names[i]) + def testQuantization(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) @@ -323,6 +334,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): self.assertEqual(self._get_input_index(a), 0) self.assertEqual(self._get_sort_index(a), 0) self.assertEqual(self._get_input_index(b), 1) + self.assertEqual(self._get_sort_index(b), 0) self.assertEqual(self._get_input_index(c), 0) self.assertEqual(self._get_sort_index(c), 1) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index a1325f0b1ff8bec11f0ad90846154401b1bb0134..9b9516f6d0bdedb30e9ddcb419639920fe6e000f 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -216,7 +216,8 @@ class Interpreter(object): def get_tensor(self, tensor_index): """Gets the value of the input tensor (get a copy). - If you wish to avoid the copy, use `tensor()`. + If you wish to avoid the copy, use `tensor()`. This function cannot be used + to read intermediate results. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -233,7 +234,8 @@ class Interpreter(object): This allows reading and writing to this tensors w/o copies. This more closely mirrors the C++ Interpreter class interface's tensor() member, hence the name. Be careful to not hold these output references through calls - to `allocate_tensors()` and `invoke()`. + to `allocate_tensors()` and `invoke()`. This function cannot be used to read + intermediate results. Usage: diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 7ec56a21c9ffa82e1893d3846d92564539ac34ae..b21779226f62ead3fd4bde5aacdfc393a4d5bff9 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -91,6 +91,41 @@ class InterpreterTest(test_util.TensorFlowTestCase): output_data = interpreter.get_tensor(output_details[0]['index']) self.assertTrue((expected_output == output_data).all()) + def testString(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/gather_string.tflite')) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.string_, input_details[0]['dtype']) + self.assertTrue(([10] == input_details[0]['shape']).all()) + self.assertEqual((0.0, 0), input_details[0]['quantization']) + self.assertEqual('indices', input_details[1]['name']) + self.assertEqual(np.int64, input_details[1]['dtype']) + self.assertTrue(([3] == input_details[1]['shape']).all()) + self.assertEqual((0.0, 0), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.string_, output_details[0]['dtype']) + self.assertTrue(([3] == output_details[0]['shape']).all()) + self.assertEqual((0.0, 0), output_details[0]['quantization']) + + test_input = np.array([1, 2, 3], dtype=np.int64) + interpreter.set_tensor(input_details[1]['index'], test_input) + + test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']) + expected_output = np.array([b'b', b'c', b'd']) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) + class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 767a9fc476398dd8fb60128f73f8ae7c518d9a21..6ec7ce497a51b9b7b66d680ea9a81ef47df51718 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -6,12 +6,26 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +cc_library( + name = "numpy", + srcs = ["numpy.cc"], + hdrs = ["numpy.h"], + deps = [ + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + ], +) + cc_library( name = "interpreter_wrapper_lib", srcs = ["interpreter_wrapper.cc"], hdrs = ["interpreter_wrapper.h"], deps = [ + ":numpy", + ":python_error_reporter", + ":python_utils", "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", "//tensorflow/lite/kernels:builtin_ops", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", @@ -19,6 +33,29 @@ cc_library( ], ) +cc_library( + name = "python_error_reporter", + srcs = ["python_error_reporter.cc"], + hdrs = ["python_error_reporter.h"], + deps = [ + "//tensorflow/lite/core/api", + "//third_party/python_runtime:headers", + ], +) + +cc_library( + name = "python_utils", + srcs = ["python_utils.cc"], + hdrs = ["python_utils.h"], + deps = [ + ":numpy", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + ], +) + tf_py_wrap_cc( name = "tensorflow_wrap_interpreter_wrapper", srcs = [ diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index d14af439ec0ab600ea260da17ef0041cca25d629..6023587d3b191d8c486dac78b889510ff1c22805 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -21,22 +21,10 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" - -// Disallow Numpy 1.7 deprecated symbols. -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include - -#include "numpy/arrayobject.h" -#include "numpy/ufuncobject.h" - -#if PY_MAJOR_VERSION >= 3 -#define PY_TO_CPPSTRING PyBytes_AsStringAndSize -#define CPP_TO_PYSTRING PyBytes_FromStringAndSize -#else -#define PY_TO_CPPSTRING PyString_AsStringAndSize -#define CPP_TO_PYSTRING PyString_FromStringAndSize -#endif +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" +#include "tensorflow/lite/string_util.h" #define TFLITE_PY_CHECK(x) \ if ((x) != kTfLiteOk) { \ @@ -60,43 +48,9 @@ limitations under the License. namespace tflite { namespace interpreter_wrapper { -class PythonErrorReporter : public tflite::ErrorReporter { - public: - PythonErrorReporter() {} - - // Report an error message - int Report(const char* format, va_list args) override { - char buf[1024]; - int formatted = vsnprintf(buf, sizeof(buf), format, args); - buffer_ << buf; - return formatted; - } - - // Set's a Python runtime exception with the last error. - PyObject* exception() { - std::string last_message = message(); - PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); - return nullptr; - } - - // Gets the last error message and clears the buffer. - std::string message() { - std::string value = buffer_.str(); - buffer_.clear(); - return value; - } - - private: - std::stringstream buffer_; -}; - namespace { -// Calls PyArray's initialization to initialize all the API pointers. Note that -// this usage implies only this translation unit can use the pointers. See -// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend -// this further. -void ImportNumpy() { import_array1(); } +using python_utils::PyDecrefDeleter; std::unique_ptr CreateInterpreter( const tflite::FlatBufferModel* model, @@ -105,7 +59,7 @@ std::unique_ptr CreateInterpreter( return nullptr; } - ImportNumpy(); + ::tflite::python::ImportNumpy(); std::unique_ptr interpreter; if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { @@ -114,65 +68,6 @@ std::unique_ptr CreateInterpreter( return interpreter; } -int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { - switch (tf_lite_type) { - case kTfLiteFloat32: - return NPY_FLOAT32; - case kTfLiteInt32: - return NPY_INT32; - case kTfLiteInt16: - return NPY_INT16; - case kTfLiteUInt8: - return NPY_UINT8; - case kTfLiteInt8: - return NPY_INT8; - case kTfLiteInt64: - return NPY_INT64; - case kTfLiteString: - return NPY_OBJECT; - case kTfLiteBool: - return NPY_BOOL; - case kTfLiteComplex64: - return NPY_COMPLEX64; - case kTfLiteNoType: - return NPY_NOTYPE; - // Avoid default so compiler errors created when new types are made. - } - return NPY_NOTYPE; -} - -TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { - int pyarray_type = PyArray_TYPE(array); - switch (pyarray_type) { - case NPY_FLOAT32: - return kTfLiteFloat32; - case NPY_INT32: - return kTfLiteInt32; - case NPY_INT16: - return kTfLiteInt16; - case NPY_UINT8: - return kTfLiteUInt8; - case NPY_INT8: - return kTfLiteInt8; - case NPY_INT64: - return kTfLiteInt64; - case NPY_BOOL: - return kTfLiteBool; - case NPY_OBJECT: - case NPY_STRING: - case NPY_UNICODE: - return kTfLiteString; - case NPY_COMPLEX64: - return kTfLiteComplex64; - // Avoid default so compiler errors created when new types are made. - } - return kTfLiteNoType; -} - -struct PyDecrefDeleter { - void operator()(PyObject* p) const { Py_DECREF(p); } -}; - PyObject* PyArrayFromIntVector(const int* data, npy_intp size) { void* pydata = malloc(size * sizeof(int)); memcpy(pydata, data, size * sizeof(int)); @@ -307,7 +202,7 @@ PyObject* InterpreterWrapper::TensorType(int i) const { return nullptr; } - int code = TfLiteTypeToPyArrayType(tensor->type); + int code = python_utils::TfLiteTypeToPyArrayType(tensor->type); if (code == -1) { PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); return nullptr; @@ -350,38 +245,53 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { } PyArrayObject* array = reinterpret_cast(array_safe.get()); - const TfLiteTensor* tensor = interpreter_->tensor(i); + TfLiteTensor* tensor = interpreter_->tensor(i); - if (TfLiteTypeFromPyArray(array) != tensor->type) { + if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) { PyErr_Format(PyExc_ValueError, "Cannot set tensor:" " Got tensor of type %d" " but expected type %d for input %d ", - TfLiteTypeFromPyArray(array), tensor->type, i); + python_utils::TfLiteTypeFromPyArray(array), tensor->type, i); return nullptr; } if (PyArray_NDIM(array) != tensor->dims->size) { - PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + PyErr_Format(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch." + " Got %d" + " but expected %d for input %d.", + PyArray_NDIM(array), tensor->dims->size, i); return nullptr; } for (int j = 0; j < PyArray_NDIM(array); j++) { if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { - PyErr_SetString(PyExc_ValueError, - "Cannot set tensor: Dimension mismatch"); + PyErr_Format(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch." + " Got %ld" + " but expected %d for dimension %d of input %d.", + PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i); return nullptr; } } - size_t size = PyArray_NBYTES(array); - if (size != tensor->bytes) { - PyErr_Format(PyExc_ValueError, - "numpy array had %zu bytes but expected %zu bytes.", size, - tensor->bytes); - return nullptr; + if (tensor->type != kTfLiteString) { + size_t size = PyArray_NBYTES(array); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } + memcpy(tensor->data.raw, PyArray_DATA(array), size); + } else { + DynamicBuffer dynamic_buffer; + if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) { + return nullptr; + } + dynamic_buffer.WriteToTensor(tensor, nullptr); } - memcpy(tensor->data.raw, PyArray_DATA(array), size); Py_RETURN_NONE; } @@ -400,7 +310,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, return nullptr; } - *type_num = TfLiteTypeToPyArrayType((*tensor)->type); + *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type); if (*type_num == -1) { PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); return nullptr; @@ -428,19 +338,51 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { std::vector dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); - // Make a buffer copy but we must tell Numpy It owns that data or else - // it will leak. - void* data = malloc(tensor->bytes); - if (!data) { - PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); - return nullptr; + if (tensor->type != kTfLiteString) { + // Make a buffer copy but we must tell Numpy It owns that data or else + // it will leak. + void* data = malloc(tensor->bytes); + if (!data) { + PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); + return nullptr; + } + memcpy(data, tensor->data.raw, tensor->bytes); + PyObject* np_array = + PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); + PyArray_ENABLEFLAGS(reinterpret_cast(np_array), + NPY_ARRAY_OWNDATA); + return PyArray_Return(reinterpret_cast(np_array)); + } else { + // Create a C-order array so the data is contiguous in memory. + const int32_t kCOrder = 0; + PyObject* py_object = + PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder); + + if (py_object == nullptr) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray."); + return nullptr; + } + + PyArrayObject* py_array = reinterpret_cast(py_object); + PyObject** data = reinterpret_cast(PyArray_DATA(py_array)); + auto num_strings = GetStringCount(tensor->data.raw); + for (int j = 0; j < num_strings; ++j) { + auto ref = GetString(tensor->data.raw, j); + + PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len); + if (bytes == nullptr) { + Py_DECREF(py_object); + PyErr_Format(PyExc_ValueError, + "Could not create PyBytes from string %d of input %d.", j, + i); + return nullptr; + } + // PyArray_EMPTY produces an array full of Py_None, which we must decref. + Py_DECREF(data[j]); + data[j] = bytes; + } + return py_object; } - memcpy(data, tensor->data.raw, tensor->bytes); - PyObject* np_array = - PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); - PyArray_ENABLEFLAGS(reinterpret_cast(np_array), - NPY_ARRAY_OWNDATA); - return PyArray_Return(reinterpret_cast(np_array)); } PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { @@ -477,7 +419,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( char * buf = nullptr; Py_ssize_t length; std::unique_ptr error_reporter(new PythonErrorReporter); - if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { + + if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) { return nullptr; } std::unique_ptr model = diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i index f52ef1eeca7db397d84d249b74445a3276bc65fb..ef4b28f04723ab8d7f4f395a028bb565b4ca9cf3 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" %} diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff5403d2a60a66886681db73c4aa69bf43369170 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -0,0 +1,25 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define TFLITE_IMPORT_NUMPY // See numpy.h for explanation. +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" + +namespace tflite { +namespace python { + +void ImportNumpy() { import_array1(); } + +} // namespace python +} // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.h b/tensorflow/lite/python/interpreter_wrapper/numpy.h new file mode 100644 index 0000000000000000000000000000000000000000..a3b013fcb27ad1837dfb83efbcec2ae800850058 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_ + +#ifdef PyArray_Type +#error "Numpy cannot be included before numpy.h." +#endif + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +// To handle PyArray_* calles, numpy defines a static lookup table called +// PyArray_API, or PY_ARRAY_UNIQUE_SYMBOL, if defined. This causes the +// PyArray_* pointers to be different for different translation units, unless +// we take care of selectivel defined NO_IMPORT_ARRAY. +// +// Virtually every usage will define NO_IMPORT_ARRAY, and will have access to +// the lookup table via: +// extern void **PyArray_API; +// In numpy.cc we will define TFLITE_IMPORT_NUMPY, effectively disabling that +// and instead using: +// void **PyArray_API; +// which is initialized when ImportNumpy() is called. +// +// If we don't define PY_ARRAY_UNIQUE_SYMBOL then PyArray_API is a static +// variable, which causes strange crashes when the pointers are used across +// translation unit boundaries. +// +// For mone info see https://sourceforge.net/p/numpy/mailman/message/5700519 +// See also tensorflow/python/lib/core/numpy.h for a similar approach. +#define PY_ARRAY_UNIQUE_SYMBOL _tensorflow_numpy_api +#ifndef TFLITE_IMPORT_NUMPY +#define NO_IMPORT_ARRAY +#endif + +#include + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" + +namespace tflite { +namespace python { + +void ImportNumpy(); + +} // namespace python +} // namespace tflite + +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_ diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..803a4c29345a44bcdba41d851884fa86d6e87d3e --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" + +namespace tflite { +namespace interpreter_wrapper { + +// Report an error message +int PythonErrorReporter::Report(const char* format, va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; +} + +// Set's a Python runtime exception with the last error. +PyObject* PythonErrorReporter::exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; +} + +// Gets the last error message and clears the buffer. +std::string PythonErrorReporter::message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; +} +} // namespace interpreter_wrapper +} // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..7d4e308834a21b795644f0c1f89607a3b75ad7ce --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ + +#include + +#include +#include + +#include "tensorflow/lite/core/api/error_reporter.h" + +namespace tflite { +namespace interpreter_wrapper { + +class PythonErrorReporter : public tflite::ErrorReporter { + public: + PythonErrorReporter() {} + + // Report an error message + int Report(const char* format, va_list args) override; + + // Sets a Python runtime exception with the last error and + // clears the error message buffer. + PyObject* exception(); + + // Gets the last error message and clears the buffer. + std::string message(); + + private: + std::stringstream buffer_; +}; + +} // namespace interpreter_wrapper +} // namespace tflite +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3d713630f6d39dd21b3c01cc4c75d4408243827 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc @@ -0,0 +1,180 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" + +#include + +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" + +namespace tflite { +namespace python_utils { + +struct PyObjectDereferencer { + void operator()(PyObject* py_object) const { Py_DECREF(py_object); } +}; + +using UniquePyObjectRef = std::unique_ptr; + +int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { + switch (tf_lite_type) { + case kTfLiteFloat32: + return NPY_FLOAT32; + case kTfLiteInt32: + return NPY_INT32; + case kTfLiteInt16: + return NPY_INT16; + case kTfLiteUInt8: + return NPY_UINT8; + case kTfLiteInt8: + return NPY_INT8; + case kTfLiteInt64: + return NPY_INT64; + case kTfLiteString: + return NPY_STRING; + case kTfLiteBool: + return NPY_BOOL; + case kTfLiteComplex64: + return NPY_COMPLEX64; + case kTfLiteNoType: + return NPY_NOTYPE; + // Avoid default so compiler errors created when new types are made. + } + return NPY_NOTYPE; +} + +TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { + int pyarray_type = PyArray_TYPE(array); + switch (pyarray_type) { + case NPY_FLOAT32: + return kTfLiteFloat32; + case NPY_INT32: + return kTfLiteInt32; + case NPY_INT16: + return kTfLiteInt16; + case NPY_UINT8: + return kTfLiteUInt8; + case NPY_INT8: + return kTfLiteInt8; + case NPY_INT64: + return kTfLiteInt64; + case NPY_BOOL: + return kTfLiteBool; + case NPY_OBJECT: + case NPY_STRING: + case NPY_UNICODE: + return kTfLiteString; + case NPY_COMPLEX64: + return kTfLiteComplex64; + // Avoid default so compiler errors created when new types are made. + } + return kTfLiteNoType; +} + +#if PY_VERSION_HEX >= 0x03030000 +bool FillStringBufferFromPyUnicode(PyObject* value, + DynamicBuffer* dynamic_buffer) { + Py_ssize_t len = -1; + const char* buf = PyUnicode_AsUTF8AndSize(value, &len); + if (buf == NULL) { + PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed."); + return false; + } + dynamic_buffer->AddString(buf, len); + return true; +} +#else +bool FillStringBufferFromPyUnicode(PyObject* value, + DynamicBuffer* dynamic_buffer) { + UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value)); + if (!utemp) { + PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed."); + return false; + } + char* buf = nullptr; + Py_ssize_t len = -1; + if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) { + PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed."); + return false; + } + dynamic_buffer->AddString(buf, len); + return true; +} +#endif + +bool FillStringBufferFromPyString(PyObject* value, + DynamicBuffer* dynamic_buffer) { + if (PyUnicode_Check(value)) { + return FillStringBufferFromPyUnicode(value, dynamic_buffer); + } + + char* buf = nullptr; + Py_ssize_t len = -1; + if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) { + PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed."); + return false; + } + dynamic_buffer->AddString(buf, len); + return true; +} + +bool FillStringBufferWithPyArray(PyObject* value, + DynamicBuffer* dynamic_buffer) { + PyArrayObject* array = reinterpret_cast(value); + switch (PyArray_TYPE(array)) { + case NPY_OBJECT: + case NPY_STRING: + case NPY_UNICODE: { + UniquePyObjectRef iter(PyArray_IterNew(value)); + while (PyArray_ITER_NOTDONE(iter.get())) { + UniquePyObjectRef item(PyArray_GETITEM( + array, reinterpret_cast(PyArray_ITER_DATA(iter.get())))); + + if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) { + return false; + } + + PyArray_ITER_NEXT(iter.get()); + } + return true; + } + default: + break; + } + + PyErr_Format(PyExc_ValueError, + "Cannot use numpy array of type %d for string tensor.", + PyArray_TYPE(array)); + return false; +} + +int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) { +#if PY_MAJOR_VERSION >= 3 + return PyBytes_AsStringAndSize(obj, data, length); +#else + return PyString_AsStringAndSize(obj, data, length); +#endif +} + +PyObject* ConvertToPyString(const char* data, size_t length) { +#if PY_MAJOR_VERSION >= 3 + return PyBytes_FromStringAndSize(data, length); +#else + return PyString_FromStringAndSize(data, length); +#endif +} + +} // namespace python_utils +} // namespace tflite diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.h b/tensorflow/lite/python/interpreter_wrapper/python_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f4677378cbc177b42c1b802b40beeba86ed605c4 --- /dev/null +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ + +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace python_utils { + +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type); + +TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array); + +bool FillStringBufferWithPyArray(PyObject* value, + DynamicBuffer* dynamic_buffer); + +int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length); +PyObject* ConvertToPyString(const char* data, size_t length); + +} // namespace python_utils +} // namespace tflite +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 3b0aa02b7c1c5215908c86b35525566669a0cd30..fc8ae7790fcc87713b171362516144376919ba6c 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -12,31 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow Lite tooling helper functionality. +"""TensorFlow Lite tooling helper functionality.""" -EXPERIMENTAL: APIs here are unstable and likely to change without notice. - -@@TocoConverter -@@TFLiteConverter -@@toco_convert -@@toco_convert_protos -@@Interpreter -@@OpHint -@@convert_op_hints_to_stubs -@@build_toco_convert_protos - -@@TFLITE -@@GRAPHVIZ_DOT - -""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings +import enum from six import PY3 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError +from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import +from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import +from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import @@ -52,12 +42,17 @@ from tensorflow.lite.python.convert_saved_model import set_tensor_shapes as _set from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.lite.python.optimize import calibrator as _calibrator from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2 from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session +from tensorflow.python.eager import def_function as _def_function +from tensorflow.python.eager import function as _function +from tensorflow.python.framework import convert_to_constants as _convert_to_constants +from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import graph_util as _tf_graph_util from tensorflow.python.framework import ops as _ops from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError @@ -71,18 +66,20 @@ from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export -def _run_graph_optimizations(graph_def, input_arrays, output_arrays): +def _run_graph_optimizations(graph_def, input_arrays, output_arrays, + graph=None): """Apply standard TensorFlow optimizations to the graph_def. Args: graph_def: Frozen GraphDef to be optimized. input_arrays: List of arrays that are considered inputs of the graph. output_arrays: List of arrays that are considered outputs of the graph. + graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) Returns: A new, optimized GraphDef. """ - meta_graph = _export_meta_graph(graph_def=graph_def) + meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) # We need to add a collection called 'train_op' so that grappler # knows what the outputs are. @@ -100,7 +97,220 @@ def _run_graph_optimizations(graph_def, input_arrays, output_arrays): return _tf_optimizer.OptimizeGraph(config, meta_graph) -@_tf_export("lite.TFLiteConverter") +@_tf_export("lite.Optimize") +class Optimize(enum.Enum): + """Enum defining the optimizations to apply when generating tflite graphs. + + Some optimizations may come at the cost of accuracy. + """ + + # Optimize for size. + # + # Optimizations that reduce the size of the model. + # The model size will be reduced. Optimizations can include quantizing the + # weights of the floating point model. + OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" + + # Optimize for latency. + # + # Optimizations that reduce the latency of the model. + # The model latency will be reduced. Optimizations can include quantizing the + # weights of the floating point model. + OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" + + def __str__(self): + return self.value + + +@_tf_export("lite.RepresentativeDataset") +class RepresentativeDataset(object): + """Representative dataset to evaluate optimizations. + + A representative dataset that can be used to evaluate optimizations by the + converter. E.g. converter can use these examples to estimate (min, max) ranges + by calibrating the model on inputs. This can allow converter to quantize a + converted floating point model. + """ + + def __init__(self, input_gen, output_gen=None): + """Creates a representative dataset. + + Args: + input_gen: an input generator that can be used to generate input samples + for the model. This must be a callable object that returns an object + that supports the `iter()` protocol (e.g. a generator function). The + elements generated must have same type and shape as inputs to the model. + output_gen: (optional) an output generator that can be used to generate + output samples for the model. This must be a callable object that + returns an object that supports the `iter()` protocol (e.g. a generator + function). The elements generated must have same type and shape as + outputs to the model. (default None) + """ + self.input_gen = input_gen + self.output_gen = output_gen + + +@_tf_export("lite.TargetSpec") +class TargetSpec(object): + """Specification of target device. + + Details about target device. Converter optimizes the generated model for + specific device. + + Attributes: + supported_ops: Experimental flag, subject to change. Set of OpsSet options + supported by the device. (default set([OpsSet.TFLITE_BUILTINS])) + """ + + def __init__(self, supported_ops=None): + if supported_ops is None: + supported_ops = set([OpsSet.TFLITE_BUILTINS]) + self.supported_ops = supported_ops + + +@_tf_export("lite.TFLiteConverter", v1=[]) +class TFLiteConverterV2(object): + """Converts a TensorFlow model into TensorFlow Lite model. + + Attributes: + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. (default + False) + target_spec: Experimental flag, subject to change. Specification of target + device. + optimizations: Experimental flag, subject to change, A list of optimizations + to apply when converting the model. The converter applies the + optimizations by giving priority to the optimizations specified earlier in + the list. E.g. `[optimize.OPTIMIZE_FOR_SIZE, + optimize.OPTIMIZE_FOR_LATENCY]` requires the converter to do both size and + latency optimizations giving priority to size optimizations over latency + optimizations. + representative_dataset: A representative dataset that can be used to + generate input and output samples for the model. The converter can use the + dataset to evaluate different optimizations. + + Example usage: + + ```python + # Converting a GraphDef from a ConcreteFunction. + converter = lite.TFLiteConverter.from_concrete_function(func) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + ``` + """ + + def __init__(self, func): + """Constructor for TFLiteConverter. + + Args: + func: TensorFlow ConcreteFunction. + """ + self._func = func + self.allow_custom_ops = False + self.target_spec = TargetSpec() + self.representative_dataset = None + self.optimizations = [] + + @classmethod + def from_concrete_function(cls, func): + """Creates a TFLiteConverter class from a ConcreteFunction. + + Args: + func: TensorFlow ConcreteFunction. + + Returns: + TFLiteConverter class. + """ + if not isinstance(func, _function.ConcreteFunction): + message = "This function takes in a ConcreteFunction." + if isinstance(func, _def_function.Function): + message += (" To get the ConcreteFunction from a Function," + " call from_concrete_function.") + raise ValueError(message) + return cls(func) + + def convert(self): + """Converts a TensorFlow GraphDef based on instance variables. + + Returns: + The converted data in serialized format. + + Raises: + ValueError: + Input shape is not specified. + None value for dimension in input_tensor. + """ + frozen_func = _convert_to_constants.convert_variables_to_constants_v2( + self._func) + input_tensors = [ + tensor for tensor in frozen_func.inputs + if tensor.dtype != _dtypes.resource + ] + output_tensors = frozen_func.outputs + + # Run a Grappler pass. + graph_def = _run_graph_optimizations(frozen_func.graph.as_graph_def(), + input_tensors, output_tensors, + frozen_func.graph) + + # Checks dimensions in input tensor. + for tensor in input_tensors: + # Note that shape_list might be empty for scalar shapes. + shape_list = tensor.shape.as_list() + if None in shape_list[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format(_tensor_name(tensor), shape_list)) + elif shape_list and shape_list[0] is None: + # Set the batch size to 1 if undefined. + shape = tensor.shape.as_list() + shape[0] = 1 + tensor.set_shape(shape) + + if self.representative_dataset: + if not isinstance(self.representative_dataset, RepresentativeDataset): + raise TypeError("`representative_dataset` must be an instance of " + "`RepresentativeDataset`") + if self.representative_dataset.input_gen is None: + raise ValueError( + "Provide an input generator for `representative_dataset`") + + # TODO(shashishekhar): For now use optimizations order is ignored. + # Both size and latency optimizations decide whether to apply post + # training optimizations. + post_training_optimize = bool( + len( + set(self.optimizations) + & set([Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE]))) + # Do weights only quantization if there is no dataset for calibration. + weights_only_quantize_flag = ( + post_training_optimize and (self.representative_dataset is None)) + + converter_kwargs = { + "input_format": constants.TENSORFLOW_GRAPHDEF, + "allow_custom_ops": self.allow_custom_ops, + "post_training_quantize": weights_only_quantize_flag, + "target_ops": self.target_spec.supported_ops, + } + + # Converts model. + result = _toco_convert_impl( + input_data=graph_def, + input_tensors=input_tensors, + output_tensors=output_tensors, + **converter_kwargs) + + if self.representative_dataset and post_training_optimize: + calibrate_quantize = _calibrator.Calibrator(result) + result = calibrate_quantize.calibrate_and_quantize( + self.representative_dataset.input_gen) + + return result + + +@_tf_export(v1=["lite.TFLiteConverter"]) class TFLiteConverter(object): """Convert a TensorFlow model into `output_format` using TOCO. @@ -141,10 +351,11 @@ class TFLiteConverter(object): created for any op that is unknown. The developer will need to provide these to the TensorFlow Lite runtime with a custom resolver. (default False) - post_training_quantize: Boolean indicating whether to quantize the weights - of the converted float model. Model size will be reduced and there will be - latency improvements (at the cost of accuracy). - (default False) + post_training_quantize: deprecated, please specify + `[optimize.OPTIMIZE_FOR_SIZE]` for `optimizations` instead. Boolean + indicating whether to quantize the weights of the converted float model. + Model size will be reduced and there will be latency improvements + (at the cost of accuracy). (default False) dump_graphviz_dir: Full filepath of folder to dump the graphs at various stages of processing GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order to keep the requirements of the @@ -154,6 +365,16 @@ class TFLiteConverter(object): target_ops: Experimental flag, subject to change. Set of OpsSet options indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS])) + optimizations: Experimental flag, subject to change, A list of + optimizations to apply when converting the model. The converter applies + the optimizations by giving priority to the optimizations specified + earlier in the list. E.g. + `[optimize.OPTIMIZE_FOR_SIZE, optimize.OPTIMIZE_FOR_LATENCY]` requires + the converter to do both size and latency optimizations giving priority + to size optimizations over latency optimizations. + representative_dataset: A representative dataset that can be used to + generate input and output samples for the model. The converter can use + the dataset to evaluate different optimizations. Example usage: @@ -190,7 +411,7 @@ class TFLiteConverter(object): Args: graph_def: Frozen TensorFlow GraphDef. input_tensors: List of input tensors. Type and shape are computed using - `foo.get_shape()` and `foo.dtype`. + `foo.shape` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). input_arrays_with_shape: Tuple of strings representing input tensor names and list of integers representing input shapes @@ -216,10 +437,12 @@ class TFLiteConverter(object): self.reorder_across_fake_quant = False self.change_concat_input_ranges = False self.allow_custom_ops = False - self.post_training_quantize = False + self._post_training_quantize = False self.dump_graphviz_dir = None self.dump_graphviz_video = False self.target_ops = set([OpsSet.TFLITE_BUILTINS]) + self.representative_dataset = None + self.optimizations = [] # Attributes are used by models that cannot be loaded into TensorFlow. if not self._has_valid_tensors(): @@ -237,7 +460,7 @@ class TFLiteConverter(object): Args: sess: TensorFlow Session. input_tensors: List of input tensors. Type and shape are computed using - `foo.get_shape()` and `foo.dtype`. + `foo.shape` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). Returns: @@ -419,6 +642,27 @@ class TFLiteConverter(object): graph_def = _freeze_graph(sess, output_tensors) return cls(graph_def, input_tensors, output_tensors) + def __setattr__(self, name, value): + if name == "post_training_quantize": + warnings.warn("Property %s is deprecated, " + "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]" + " instead." % name) + if value: + # Use OPTIMIZE_FOR_SIZE for post training for now. + self.optimizations = [Optimize.OPTIMIZE_FOR_SIZE] + else: + self.optimizations = [] + return + object.__setattr__(self, name, value) + + def __getattribute__(self, name): + if name == "post_training_quantize": + warnings.warn("Property %s is deprecated, " + "please use optimizations=[Optimize.OPTIMIZE_FOR_SIZE]" + " instead." % name) + return Optimize.OPTIMIZE_FOR_SIZE in set(self.optimizations) + return object.__getattribute__(self, name) + def convert(self): """Converts a TensorFlow GraphDef based on instance variables. @@ -434,7 +678,7 @@ class TFLiteConverter(object): # Checks dimensions in input tensor. if self._has_valid_tensors(): for tensor in self._input_tensors: - shape = tensor.get_shape() + shape = tensor.shape if not shape: raise ValueError("Provide an input shape for input array " "'{0}'.".format(_tensor_name(tensor))) @@ -463,6 +707,24 @@ class TFLiteConverter(object): "tensors '{0}'.".format(",".join(invalid_stats))) else: quantized_stats = None + if self.representative_dataset: + if not isinstance(self.representative_dataset, RepresentativeDataset): + raise TypeError( + "representative_dataset must be an instance of " + "RepresentativeDataset") + if self.representative_dataset.input_gen is None: + raise ValueError( + "Provide an input generator for representative_dataset") + + # TODO(shashishekhar): For now use optimizations order is ignored. + # Both size and latency optimizations decide whether to apply post + # training optimizations. + post_training_optimize = bool( + len(set(self.optimizations) & set([Optimize.OPTIMIZE_FOR_LATENCY, + Optimize.OPTIMIZE_FOR_SIZE]))) + # Do weights only quantization if there is no dataset for calibration. + weights_only_quantize_flag = ( + post_training_optimize and (self.representative_dataset is None)) converter_kwargs = { "inference_type": self.inference_type, @@ -475,7 +737,7 @@ class TFLiteConverter(object): "reorder_across_fake_quant": self.reorder_across_fake_quant, "change_concat_input_ranges": self.change_concat_input_ranges, "allow_custom_ops": self.allow_custom_ops, - "post_training_quantize": self.post_training_quantize, + "post_training_quantize": weights_only_quantize_flag, "target_ops": self.target_ops, "dump_graphviz_dir": self.dump_graphviz_dir, "dump_graphviz_video": self.dump_graphviz_video @@ -504,6 +766,12 @@ class TFLiteConverter(object): input_arrays_with_shape=self._input_arrays_with_shape, output_arrays=self._output_arrays, **converter_kwargs) + + if self.representative_dataset and post_training_optimize: + calibrate_quantize = _calibrator.Calibrator(result) + result = calibrate_quantize.calibrate_and_quantize( + self.representative_dataset.input_gen) + return result def get_input_arrays(self): @@ -540,12 +808,12 @@ class TFLiteConverter(object): "use input_shapes parameter.") for tensor in self._input_tensors: - shape = tensor.get_shape().as_list() + shape = tensor.shape.as_list() shape[0] = batch_size tensor.set_shape(shape) -@_tf_export("lite.TocoConverter") +@_tf_export(v1=["lite.TocoConverter"]) class TocoConverter(object): """Convert a TensorFlow model into `output_format` using TOCO. diff --git a/tensorflow/lite/python/lite_flex_test.py b/tensorflow/lite/python/lite_flex_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ae629413e782d011fafdb3b7e294cd884a301c --- /dev/null +++ b/tensorflow/lite/python/lite_flex_test.py @@ -0,0 +1,58 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lite.py functionality related to select TF op usage.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.lite.python import lite +from tensorflow.lite.python.interpreter import Interpreter +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +@test_util.run_v1_only('b/120545219') +class FromSessionTest(test_util.TensorFlowTestCase): + + def testFlexMode(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensures the model contains TensorFlow ops. + # TODO(nupurgarg): Check values once there is a Python delegate interface. + interpreter = Interpreter(model_content=tflite_model) + with self.assertRaises(RuntimeError) as error: + interpreter.allocate_tensors() + self.assertIn( + 'Regular TensorFlow ops are not supported by this interpreter. Make ' + 'sure you invoke the Flex delegate before inference.', + str(error.exception)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 83fd56bf1d2617b7132d0eb2314c80460e968c18..14d08ec70a6f05fd8a971dce3caf8740375e7f3a 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer from tensorflow.python.platform import gfile @@ -131,13 +132,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): input_details = interpreter.get_input_details() self.assertEqual(1, len(input_details)) self.assertEqual('Placeholder', input_details[0]['name']) - self.assertEqual(np.object_, input_details[0]['dtype']) + self.assertEqual(np.string_, input_details[0]['dtype']) self.assertTrue(([4] == input_details[0]['shape']).all()) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('Reshape', output_details[0]['name']) - self.assertEqual(np.object_, output_details[0]['dtype']) + self.assertEqual(np.string_, output_details[0]['dtype']) self.assertTrue(([2, 2] == output_details[0]['shape']).all()) # TODO(b/122659643): Test setting/getting string data via the python # interpreter API after support has been added. @@ -481,6 +482,29 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testPostTrainingQuantizeDeprecatedAttribute(self): + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + sess = session.Session() + + quantized_converter = lite.TFLiteConverter.from_session( + sess, [in_tensor_1], [out_tensor]) + self.assertFalse(quantized_converter.post_training_quantize) + + quantized_converter.post_training_quantize = True + self.assertTrue(quantized_converter.post_training_quantize) + self.assertEqual(quantized_converter.optimizations, + [lite.Optimize.OPTIMIZE_FOR_SIZE]) + + quantized_tflite = quantized_converter.convert() + self.assertTrue(quantized_tflite) + def testPostTrainingQuantize(self): np.random.seed(0) # We need the tensor to have more than 1024 elements for quantize_weights @@ -504,35 +528,58 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Convert quantized weights model. quantized_converter = lite.TFLiteConverter.from_session( sess, [in_tensor_1], [out_tensor]) - quantized_converter.post_training_quantize = True + quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE] quantized_tflite = quantized_converter.convert() self.assertTrue(quantized_tflite) # Ensure that the quantized weights tflite model is smaller. self.assertTrue(len(quantized_tflite) < len(float_tflite)) - def testFlexMode(self): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32) - out_tensor = in_tensor + in_tensor + def testPostTrainingCalibrateAndQuantize(self): + np.random.seed(0) + # Create a mobilenet like model. + output_channel = 16 + depth_multiplier = 1 + inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3)) + conv = nn_ops.conv2d( + inp, + filter=array_ops.zeros([3, 3, 3, output_channel]), + strides=[1, 1, 1, 1], + padding='SAME') + dconv = nn_ops.depthwise_conv2d_native( + conv, + filter=array_ops.zeros( + [16, 16, output_channel, output_channel * depth_multiplier]), + strides=[1, 1, 1, 1], + padding='SAME') + pool = nn_ops.pool( + dconv, window_shape=[2, 2], pooling_type='AVG', padding='SAME') + max_pool = nn_ops.pool( + pool, window_shape=[2, 2], pooling_type='MAX', padding='SAME') + output = nn_ops.softmax(max_pool) + + def calibration_gen(): + for _ in range(10): + yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] + sess = session.Session() - # Convert model and ensure model is not None. - converter = lite.TFLiteConverter.from_session(sess, [in_tensor], - [out_tensor]) - converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS]) - tflite_model = converter.convert() - self.assertTrue(tflite_model) + # Convert float model. + float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) - # Ensures the model contains TensorFlow ops. - # TODO(nupurgarg): Check values once there is a Python delegate interface. - interpreter = Interpreter(model_content=tflite_model) - with self.assertRaises(RuntimeError) as error: - interpreter.allocate_tensors() - self.assertIn( - 'Regular TensorFlow ops are not supported by this interpreter. Make ' - 'sure you invoke the Flex delegate before inference.', - str(error.exception)) + # Convert quantized weights model. + quantized_converter = lite.TFLiteConverter.from_session( + sess, [inp], [output]) + quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE] + quantized_converter.representative_dataset = lite.RepresentativeDataset( + calibration_gen) + quantized_tflite = quantized_converter.convert() + self.assertTrue(quantized_tflite) + + # Ensure that the quantized weights tflite model is smaller. + self.assertTrue(len(quantized_tflite) < len(float_tflite)) def testFloatTocoConverter(self): """Tests deprecated test TocoConverter.""" @@ -550,6 +597,35 @@ class FromSessionTest(test_util.TensorFlowTestCase): interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() + def testMultipleOutputNodeNames(self): + """Tests converting a graph with an op that have multiple outputs.""" + input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) + out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0) + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [input_tensor], + [out0, out1, out2, out3]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + interpreter.set_tensor(input_details[0]['index'], + np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + interpreter.invoke() + + output_details = interpreter.get_output_details() + self.assertEqual(4, len(output_details)) + self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index'])) + self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index'])) + self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index'])) + self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index'])) + @test_util.run_v1_only('b/120545219') class FromFrozenGraphFile(test_util.TensorFlowTestCase): @@ -694,7 +770,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): # TODO(nupurgarg): Test model loading in open source. def _initObjectDetectionArgs(self): # Initializes the arguments required for the object detection model. - # Looks for the model file which is saved in a different location interally + # Looks for the model file which is saved in a different location internally # and externally. filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') if not os.path.exists(filename): diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b0d8235a192abc0a459a7f85e40b1c893ec99c --- /dev/null +++ b/tensorflow/lite/python/lite_v2_test.py @@ -0,0 +1,215 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lite.py functionality related to TensorFlow 2.0.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.lite.python import lite +from tensorflow.lite.python.interpreter import Interpreter +from tensorflow.python import keras +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model.load import load +from tensorflow.python.saved_model.save import save +from tensorflow.python.training.tracking import tracking + + +class FromConcreteFunctionTest(test_util.TensorFlowTestCase): + + def _evaluateTFLiteModel(self, tflite_model, input_data): + """Evaluates the model on the `input_data`.""" + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + for input_tensor, tensor_data in zip(input_details, input_data): + interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) + interpreter.invoke() + return interpreter.get_tensor(output_details[0]['index']) + + @test_util.run_v2_only + def testTypeInvalid(self): + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3.) + root.v2 = variables.Variable(2.) + root.f = def_function.function(lambda x: root.v1 * root.v2 * x) + + with self.assertRaises(ValueError) as error: + _ = lite.TFLiteConverterV2.from_concrete_function(root.f) + self.assertIn('call from_concrete_function', str(error.exception)) + + @test_util.run_v2_only + def testFloat(self): + input_data = constant_op.constant(1., shape=[1]) + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3.) + root.v2 = variables.Variable(2.) + root.f = def_function.function(lambda x: root.v1 * root.v2 * x) + concrete_func = root.f.get_concrete_function(input_data) + + # Convert model. + converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = root.f(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value.numpy(), actual_value) + + @test_util.run_v2_only + def testSizeNone(self): + # Test with a shape of None + input_data = constant_op.constant(1., shape=None) + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3.) + root.f = def_function.function(lambda x: root.v1 * x) + concrete_func = root.f.get_concrete_function(input_data) + + # Convert model. + converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = root.f(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value.numpy(), actual_value) + + @test_util.run_v2_only + def testConstSavedModel(self): + """Test a basic model with functions to make sure functions are inlined.""" + self.skipTest('b/124205572') + input_data = constant_op.constant(1., shape=[1]) + root = tracking.AutoTrackable() + root.f = def_function.function(lambda x: 2. * x) + to_save = root.f.get_concrete_function(input_data) + + save_dir = os.path.join(self.get_temp_dir(), 'saved_model') + save(root, save_dir, to_save) + saved_model = load(save_dir) + concrete_func = saved_model.signatures['serving_default'] + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = root.f(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value.numpy(), actual_value) + + @test_util.run_v2_only + def testVariableSavedModel(self): + """Test a basic model with Variables with saving/loading the SavedModel.""" + self.skipTest('b/124205572') + input_data = constant_op.constant(1., shape=[1]) + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3.) + root.v2 = variables.Variable(2.) + root.f = def_function.function(lambda x: root.v1 * root.v2 * x) + to_save = root.f.get_concrete_function(input_data) + + save_dir = os.path.join(self.get_temp_dir(), 'saved_model') + save(root, save_dir, to_save) + saved_model = load(save_dir) + concrete_func = saved_model.signatures['serving_default'] + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = root.f(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value.numpy(), actual_value) + + @test_util.run_v2_only + def testMultiFunctionModel(self): + """Test a basic model with Variables.""" + + class BasicModel(tracking.AutoTrackable): + + def __init__(self): + self.y = None + self.z = None + + @def_function.function + def add(self, x): + if self.y is None: + self.y = variables.Variable(2.) + return x + self.y + + @def_function.function + def sub(self, x): + if self.z is None: + self.z = variables.Variable(3.) + return x - self.z + + input_data = constant_op.constant(1., shape=[1]) + root = BasicModel() + concrete_func = root.add.get_concrete_function(input_data) + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = root.add(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value.numpy(), actual_value) + + @test_util.run_v2_only + def testKerasModel(self): + input_data = constant_op.constant(1., shape=[1, 1]) + + # Create a simple Keras model. + x = [-1, 0, 1, 2, 3, 4] + y = [-3, -1, 1, 3, 5, 7] + + model = keras.models.Sequential( + [keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x, y, epochs=1) + + # Get the concrete function from the Keras model. + @def_function.function + def to_save(x): + return model(x) + + concrete_func = to_save.get_concrete_function( + tensor_spec.TensorSpec([None, 1], dtypes.float32)) + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = to_save(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) + self.assertEqual(expected_value.numpy(), actual_value) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py index 6ec050171fc39308c36ec8f43af639f59f4f387c..aec4b28a33e1e6a3242929a40de9f809c64d23c4 100644 --- a/tensorflow/lite/python/op_hint.py +++ b/tensorflow/lite/python/op_hint.py @@ -36,9 +36,7 @@ Example: session = tf.Session() graphdef_to_convert = tf.lite.convert_op_hints_to_stubs(session) - tflite_graph = tf.lite.toco_convert(graphdef_to_convert, - [image], [output]) - [image], [output]) + tflite_graph = tf.lite.toco_convert(graphdef_to_convert, [image], [output]) with open("/tmp/graph.fb", "wb") as fp: fp.write(tflite_graph) @@ -73,6 +71,7 @@ from __future__ import print_function import collections as _collections import copy as _copy +import json as _json import uuid as _uuid import six as _six @@ -134,6 +133,14 @@ class OpHint(object): # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this # attribute to [2, 0, 1, -1]. TFLITE_INPUT_INDICES = "_tflite_input_indices" + # OpHint level. + FUNCTION_LEVEL_ATTR = "_tflite_ophint_level" + # Ophint internal mapping, this is for high level Ophint only. + # This basically contains three kinds of mapping: + # 1) How parental ophinted inputs map to the first child ophinted inputs; + # 2) How internal children nodes are connected; + # 3) How parental ophinted outputs map to the last child ophinted outputs. + CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping" # Types of aggregations # stack: stacks all ophints with matching tags. i.e. for a static rnn. @@ -151,10 +158,16 @@ class OpHint(object): """Conceptually tracks indices of arguments of "OpHint functions". The inputs and arguments of these functions both use an instance - of the class so they can have independent numbering.""" + of the class so they can have independent numbering. + """ - def __init__(self, function_name, unique_function_id, node_name_prefix, - attr_name): + def __init__(self, + function_name, + unique_function_id, + node_name_prefix, + attr_name, + level=1, + children_inputs_mappings=None): """Initialize ophint argument. Args: @@ -163,6 +176,8 @@ class OpHint(object): node_name_prefix: How identities that are created are named. attr_name: Name of attribute to use to store the index for this hint. i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX + level: Hierarchical level of the Ophint node, a number. + children_inputs_mappings: Inputs/Outputs mapping for children hints. """ # The global index is the argument index of the op. This is in contrast @@ -178,6 +193,8 @@ class OpHint(object): self._tag_to_next_sort_index = {} # The current index for each tag self._node_name_prefix = node_name_prefix self._attr_name = attr_name + self._level = level + self._children_inputs_mappings = children_inputs_mappings def _get_new_global_index(self, index_override): """Return the next unused argument index in order or use an override. @@ -217,7 +234,7 @@ class OpHint(object): and OpHint.AGGREGATE_STACK. Note, aggregate is only valid if tag is specified. index_override: Specify what input/output index should this be in the - final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the + final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than the default call order based ordering. @@ -253,6 +270,7 @@ class OpHint(object): uuid = self._unique_function_id name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, uuid, global_index, sort_index, name) + identity_op = _array_ops.identity(arg, name=name) # pylint: disable=protected-access @@ -266,6 +284,15 @@ class OpHint(object): s=_compat.as_bytes(self._unique_function_id))) identity_op.op._set_attr( self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) + identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR, + _attr_value_pb2.AttrValue(i=self._level)) + if self._children_inputs_mappings: + identity_op.op._set_attr( + OpHint.CHILDREN_INPUTS_MAPPINGS, + _attr_value_pb2.AttrValue( + s=_compat.as_bytes(_json.dumps( + self._children_inputs_mappings)))) + if sort_index is not None: identity_op.op._set_attr( OpHint.FUNCTION_SORT_INDEX_ATTR, @@ -277,23 +304,74 @@ class OpHint(object): # pylint: enable=protected-access return identity_op - def __init__(self, function_name, **kwargs): + def __init__(self, + function_name, + level=1, + children_inputs_mappings=None, + **kwargs): """Create a OpHint. Args: function_name: Name of the function (the custom op name in tflite) + level: OpHint level. + children_inputs_mappings: Children OpHint inputs/outputs mapping. + children_inputs_mappings should like below: + "parent_first_child_input": + [{"parent_input_index": num, "child_input_index": num}, ...] + "parent_last_child_output": + [{"parent_output_index": num, "child_output_index": num}, ...] + "internal_children_input_output": + [{"child_input_index": num, "child_output_index": num}, ...] **kwargs: Keyword arguments of any constant attributes for the function. """ self._function_name = function_name + self._level = level + if self._level == 1: + assert children_inputs_mappings is None + else: + assert isinstance(children_inputs_mappings, dict) + self._children_inputs_mappings = children_inputs_mappings + if self._children_inputs_mappings is not None: + self._validate_children_inputs_mappings(self._children_inputs_mappings) self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? self._attrs_to_store_later = kwargs self._stored_attrs = False self._inputs = OpHint.OpHintArgumentTracker( self._function_name, self._unique_function_id, "InputHint", - OpHint.FUNCTION_INPUT_INDEX_ATTR) + OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings) self._outputs = OpHint.OpHintArgumentTracker( self._function_name, self._unique_function_id, "OutputHint", - OpHint.FUNCTION_OUTPUT_INDEX_ATTR) + OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level, + self._children_inputs_mappings) + + def _validate_children_inputs_mappings(self, children_inputs_mappings): + """Validate children inputs mappings is in the right format. + + Args: + children_inputs_mappings: the Children ophint inputs/outputs mapping. + """ + assert isinstance(children_inputs_mappings, dict) + assert "parent_first_child_input" in children_inputs_mappings + assert "parent_last_child_output" in children_inputs_mappings + assert "internal_children_input_output" in children_inputs_mappings + + # validate parent_first_child_input. + + def assert_dictlist_has_keys(dictlist, keys): + for dikt in dictlist: + assert isinstance(dikt, dict) + for key in keys: + assert key in dikt + + assert_dictlist_has_keys( + children_inputs_mappings["parent_first_child_input"], + ["parent_ophint_input_index", "first_child_ophint_input_index"]) + assert_dictlist_has_keys( + children_inputs_mappings["parent_last_child_output"], + ["parent_output_index", "child_output_index"]) + assert_dictlist_has_keys( + children_inputs_mappings["internal_children_input_output"], + ["child_input_index", "child_output_index"]) def _setattr(self, dest_op, name, value): tensor_value = _ops.convert_to_tensor(value) @@ -384,7 +462,7 @@ class OpHint(object): class _LiteOperand(object): - """Abstract operand for a tflite hint function. + """Abstract operand for a tflite hint function._dynamic_rnn_loop. This is a base class that handles representing arguments to an OpHint. It also is able to serialize operands to the stubbed graph_def. @@ -582,15 +660,18 @@ class _LiteFuncCall(object): This is uses to accumulate found hints in the graphdef into a single conceptual unit. - Properties: - self.inputs: inputs to the op (hash from index # to argument) - self.outputs: outputs to the op (hash from index # to argument) - self.function_name: the tflite custom op name to use - self.uuid: a unique call id for this particular call (i.e. + Attributes: + inputs: inputs to the op (hash from index # to argument) + outputs: outputs to the op (hash from index # to argument) + function_name: the tflite custom op name to use + uuid: a unique call id for this particular call (i.e. multiple function calls would have the same function_name but different uuids. - self.params: A param name to key value for op constant data. I.e. for + params: A param name to key value for op constant data. I.e. for axis on a reduction, strides on a convolution, etc. + level: Level of the OpHint. + children_inputs_mappings: If the Ophint has children, children inputs + mappings indicate how their inputs & outputs are mapped. """ def __init__(self): @@ -599,6 +680,8 @@ class _LiteFuncCall(object): self.function_name = None self.uuid = None self.params = {} + self.level = -1 + self.children_inputs_mappings = {} def flattened_inputs_and_outputs(self): """Return a list of inputs and outputs in a flattened format. @@ -624,22 +707,25 @@ class _LiteFuncCall(object): inputs_str = "\tInputs\n" + format_args(self.inputs) outputs_str = "\tOutputs\n" + format_args(self.outputs) - return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" - % (self.function_name, self.uuid, inputs_str, outputs_str)) + return ( + "tflite function %s call %s level %d " + "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" % + (self.function_name, self.uuid, self.level, inputs_str, outputs_str)) -def _find_all_hints_in_graph_def(graphdef): - """Look at the current default graph and return a list of LiteFuncCall objs. +def _find_all_hints_in_nodes(nodes): + """Look at the all the input nodes and return a list of LiteFuncCall objs. Args: - graphdef: A TensorFlow graph_def to look for LiteFuncCalls. + nodes: A TensorFlow graph_def to look for LiteFuncCalls. + Returns: a list of `LifeFuncCall` objects in the form """ func_calls = _collections.defaultdict(_LiteFuncCall) - for node in graphdef.node: + for node in nodes: attr = node.attr # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip uuid = attr[OpHint.FUNCTION_UUID_ATTR].s @@ -651,6 +737,7 @@ def _find_all_hints_in_graph_def(graphdef): call_def = func_calls[uuid] call_def.uuid = uuid call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s + call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i # Get sorting and aggregation information sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i @@ -660,6 +747,10 @@ def _find_all_hints_in_graph_def(graphdef): if OpHint.FUNCTION_AGGREGATE_ATTR in attr: aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) + if OpHint.CHILDREN_INPUTS_MAPPINGS in attr: + call_def.children_inputs_mappings = _json.loads( + _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s)) + # Add the input or output def put_operand(stuff, index, sort, operand, aggregation): """Add a given index into the function structure.""" @@ -685,6 +776,98 @@ def _find_all_hints_in_graph_def(graphdef): return func_calls +def _extract_topology_sequence_mapping(nodes): + return dict( + (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes)) + + +def _find_children_hints_in_while_loop(function_def, nodes_mapping): + """Find children hints and all nodes inside the while loop. + + Args: + function_def: Function def of the while loop. + nodes_mapping: While loop input_arg : real node name. + + Returns: + Ordered children hints and all re-mapped nodes inside the while loop. + """ + new_nodes = [] + + # Make nodes inside function def inputs point to the real nodes. + for node in function_def.node_def: + for i in range(len(node.input)): + if node.input[i] in nodes_mapping: + node.input[i] = nodes_mapping[node.input[i]] + new_nodes.append(_copy.deepcopy(node)) + name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def) + children_hints = _find_all_hints_in_nodes(new_nodes) + children_hints_q = [] + # Ordered by the outputs. + for hint in _six.itervalues(children_hints): + _, output_names = hint.flattened_inputs_and_outputs() + seq = name_to_seq_num[output_names[0]] + for output_name in output_names: + seq = min(seq, name_to_seq_num[output_name]) + children_hints_q.append((seq, hint)) + children_hints_q.sort(key=lambda tup: tup[0]) + ordered_children_hints = [x[1] for x in children_hints_q] + return ordered_children_hints, new_nodes + + +def _find_children_hints(call, graph_def): + """Find all children hints. + + For a given OpHint, we find all children hints inside it, we also copy all the + nodes inside function defs (if applicable) to the original graph_def, they are + returned in a list as well. + + Args: + call: Parent OpHint that contains children ophints. + graph_def: Original graph def. + + Returns: + Ordered children hints inside the parent ophint; new graph def that contains + nodes inside function defs (if applicable); nodes inside function defs. + """ + name_to_input_name, _, _ = _extract_graph_summary(graph_def) + input_names, output_names = call.flattened_inputs_and_outputs() + + reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) + reachable_by_output = _bfs_for_reachable_nodes(output_names, + name_to_input_name) + output_nodes_set = set(output_names) + children_hints = [] + out = _graph_pb2.GraphDef() + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + function_def_nodes = set() + for node in graph_def.node: + out.node.extend([_copy.deepcopy(node)]) + n = _tensor_name_base(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # special handle for while loop function def. + if node.op == "While": + body_name = node.attr["body"].func.name + inputs_outside_loop = node.input + for function_def in graph_def.library.function: + if function_def.signature.name == body_name: + function_inputs = function_def.signature.input_arg + assert len(inputs_outside_loop) == len(function_inputs) + nodes_mapping = {} + for i in range(len(function_inputs)): + nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i] + # TODO(b/123050804): Consider use grappler. + (children_hints_in_loop, + new_nodes) = _find_children_hints_in_while_loop( + function_def, nodes_mapping) + function_def_nodes.update([x.name for x in new_nodes]) + children_hints.extend(children_hints_in_loop) + out.node.extend(new_nodes) + + return children_hints, out, function_def_nodes + + def _tensor_name_base(full_tensor_name): """Removes the device assignment code from a tensor. @@ -737,12 +920,20 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, # TODO(aselle): This should be converted to grappler in the future. -def _convert_single_op_hint_to_stub(call, graph_def): +def _convert_single_op_hint_to_stub(call, + graph_def, + function_def_nodes=None, + is_last_run=True): """Given a graph_def, converts `call` into a stub and returns a new graph_def. Args: call: A single function call to be converted. - graph_def: A graph_def to use as input (that hass call obviously). + graph_def: A graph_def to use as input (that has call obviously). + function_def_nodes: Nodes inside the function def those are not connected to + the graph. + is_last_run: Whether it is the last run for a given pass (for OpHint has + children). + Returns: A new transformed graph-def that has call as a stub (single op). @@ -750,6 +941,8 @@ def _convert_single_op_hint_to_stub(call, graph_def): the tensorflow runtime, so all future manipulations are done in graph_def level. """ + if function_def_nodes is None: + function_def_nodes = set() name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( graph_def) input_names, output_names = call.flattened_inputs_and_outputs() @@ -757,7 +950,6 @@ def _convert_single_op_hint_to_stub(call, graph_def): reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) reachable_by_output = _bfs_for_reachable_nodes(output_names, name_to_input_name) - input_nodes_set = set(input_names) output_nodes_set = set(output_names) nodes_after_fuse = [] nodes_deleted_by_fuse = set() @@ -768,19 +960,16 @@ def _convert_single_op_hint_to_stub(call, graph_def): n = _tensor_name_base(node.name) if n in reachable_by_output: if n not in reachable_by_input and n not in output_nodes_set: - # n is an internal node. Check to make sure it is really internal. - # TODO(aselle): this could be done more efficiently by flooding - # the graph first. - _check_subgraph_closed(n, reachable_by_input, input_nodes_set, - name_to_input_name) nodes_deleted_by_fuse.add(n) - elif n not in reachable_by_input: + elif n not in reachable_by_input and n not in function_def_nodes: # n is a node that after all the fusings, so keep it. nodes_after_fuse.append(n) else: - # n is a node that is randomly in the graph but not connected to - # the chain of dependencies. - pass + # In the last run, n is a node that is randomly in the graph but not + # connected to the chain of dependencies, we will delete n, otherwise + # we keep them. + if not is_last_run: + nodes_after_fuse.append(n) # Make a new graphdef with all the pre-input and input nodes out = _graph_pb2.GraphDef() @@ -802,10 +991,11 @@ def _convert_single_op_hint_to_stub(call, graph_def): # non-fused things. for input_index in sorted_input_indices: inputs = call.inputs[input_index] - new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) + input_name = inputs.aggregate_and_return_name_for_input(out) + new_node.input.append(input_name) new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) - # Ceate the function + # Create the function new_node.op = call.function_name new_node.name = call.uuid out.node.extend([new_node]) @@ -938,7 +1128,18 @@ def _remove_redundant_stack_unstack(graph_def): return curr -@_tf_export("lite.convert_op_hints_to_stubs") +def _get_correct_mapping(original_index, nodes): + # Special handle for the index is -1 case. + # If it is -1, return the last index. + if original_index == -1: + node_indices = nodes.keys() + node_indices = sorted(node_indices) + return node_indices[-1] + else: + return original_index + return original_index + + def _convert_op_hints_to_stubs_helper( graph_def, write_callback=lambda sess, graph_def: None): """Converts a graph_def to a new graph_def where all op hints are stubbed. @@ -950,14 +1151,67 @@ def _convert_op_hints_to_stubs_helper( Returns: A new stubbed graph_def. """ + hints = _find_all_hints_in_nodes(graph_def.node) + + hints_q = [] + for hint in _six.itervalues(hints): + hints_q.append((hint.level, hint.uuid)) + + hints_q.sort(key=lambda tup: tup[0]) + for i in range(len(hints_q) - 1, -1, -1): + level, hint_uuid = hints_q[i] - hints = _find_all_hints_in_graph_def(graph_def) curr_graph_def = graph_def del graph_def # prevent using graph_def again (common source of error) - for hint in _six.itervalues(hints): - curr_graph_def = _convert_single_op_hint_to_stub( - hint, curr_graph_def) - write_callback(curr_graph_def, "initial") + for i in range(len(hints_q) - 1, -1, -1): + level, hint_uuid = hints_q[i] + if level >= 2: + children_hints, curr_graph_def, function_def_nodes = _find_children_hints( + hints[hint_uuid], curr_graph_def) + # pylint: disable=superfluous-parens + assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test + # pylint: enable=superfluous-parens + + # Re-wire the children hints inputs/outputs, so latter child's inputs + # connect to previous child node's outputs. + children_inputs_mappings = hints[hint_uuid].children_inputs_mappings + for j in range(len(children_hints)): + child_hint = children_hints[j] + if j == 0: + for mapping in children_inputs_mappings["parent_first_child_input"]: + parent_input_index = _get_correct_mapping( + mapping["parent_ophint_input_index"], hints[hint_uuid].inputs) + child_input_index = _get_correct_mapping( + mapping["first_child_ophint_input_index"], child_hint.inputs) + child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[ + parent_input_index] + else: + for mapping in children_inputs_mappings[ + "internal_children_input_output"]: + input_index = _get_correct_mapping(mapping["child_input_index"], + child_hint.inputs) + output_index = _get_correct_mapping(mapping["child_output_index"], + children_hints[j - 1].outputs) + child_hint.inputs[input_index] = children_hints[ + j - 1].outputs[output_index] + if j == len(children_hints) - 1: + for mapping in children_inputs_mappings["parent_last_child_output"]: + parent_output_index = _get_correct_mapping( + mapping["parent_output_index"], hints[hint_uuid].outputs) + child_output_index = _get_correct_mapping( + mapping["child_output_index"], child_hint.outputs) + child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[ + parent_output_index] + + for j in range(len(children_hints)): + child_hint = children_hints[j] + curr_graph_def = _convert_single_op_hint_to_stub( + child_hint, curr_graph_def, function_def_nodes, + j == len(children_hints) - 1) + else: + curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid], + curr_graph_def) + write_callback(curr_graph_def, "initial") # The stubbing process can create stacks/unstacks in the case of LSTMs # remove them. curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) @@ -984,15 +1238,16 @@ def find_all_hinted_output_nodes(session=None, graph_def=None): raise ValueError("Provide only one of session and graph_def.") hinted_outputs_nodes = [] if session is not None: - hints = _find_all_hints_in_graph_def(session.graph_def) + hints = _find_all_hints_in_nodes(session.graph_def.node) elif graph_def is not None: - hints = _find_all_hints_in_graph_def(graph_def) + hints = _find_all_hints_in_nodes(graph_def.node) for hint in _six.itervalues(hints): _, ouput_nodes = hint.flattened_inputs_and_outputs() hinted_outputs_nodes.extend(ouput_nodes) return hinted_outputs_nodes +@_tf_export("lite.experimental.convert_op_hints_to_stubs") def convert_op_hints_to_stubs(session=None, graph_def=None, write_callback=lambda graph_def, comments: None): diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..069612ba9f4cf00b445e40d459adb5244d3fc218 --- /dev/null +++ b/tensorflow/lite/python/optimize/BUILD @@ -0,0 +1,70 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +cc_library( + name = "calibration_wrapper_lib", + srcs = ["calibration_wrapper.cc"], + hdrs = ["calibration_wrapper.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/python/interpreter_wrapper:numpy", + "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", + "//tensorflow/lite/python/interpreter_wrapper:python_utils", + "//tensorflow/lite/tools/optimize:quantize_model", + "//tensorflow/lite/tools/optimize/calibration:calibration_reader", + "//tensorflow/lite/tools/optimize/calibration:calibrator_lib", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/memory", + ], +) + +tf_py_wrap_cc( + name = "tensorflow_lite_wrap_calibration_wrapper", + srcs = [ + "calibration_wrapper.i", + ], + deps = [ + ":calibration_wrapper_lib", + "//third_party/python_runtime:headers", + ], +) + +py_library( + name = "calibrator", + srcs = [ + "calibrator.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/lite/python/optimize:tensorflow_lite_wrap_calibration_wrapper", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "calibrator_test", + srcs = ["calibrator_test.py"], + data = [ + ":test_data", + "//tensorflow/lite:testdata/multi_add.bin", + ], + srcs_version = "PY2AND3", + tags = ["no_oss"], + deps = [ + ":calibrator", + "//tensorflow/lite/python/optimize:tensorflow_lite_wrap_calibration_wrapper", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..12bcd6a6283ccc71f4df7758b46aec298a87bb7d --- /dev/null +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -0,0 +1,212 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/python/optimize/calibration_wrapper.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" +#include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h" +#include "tensorflow/lite/tools/optimize/calibration/calibrator.h" +#include "tensorflow/lite/tools/optimize/quantize_model.h" + +#define TFLITE_PY_CHECK(x) \ + if ((x) != kTfLiteOk) { \ + return error_reporter_->exception(); \ + } + +#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \ + if (!interpreter_) { \ + PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \ + return nullptr; \ + } + +namespace tflite { +namespace calibration_wrapper { + +namespace { + +using python_utils::PyDecrefDeleter; + +std::unique_ptr CreateMutableModel(const tflite::Model& model) { + std::unique_ptr copied_model = + absl::make_unique(); + model.UnPackTo(copied_model.get(), nullptr); + return copied_model; +} + +} // namespace + +CalibrationWrapper::CalibrationWrapper( + std::unique_ptr interpreter, + std::unique_ptr resolver, + std::unique_ptr + error_reporter, + std::unique_ptr model, + std::unique_ptr reader) + : interpreter_(std::move(interpreter)), + error_reporter_(std::move(error_reporter)), + resolver_(std::move(resolver)), + model_(std::move(model)), + reader_(std::move(reader)) {} + +CalibrationWrapper::~CalibrationWrapper() {} + +PyObject* CalibrationWrapper::Prepare() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->AllocateTensors()); + TFLITE_PY_CHECK(interpreter_->ResetVariableTensors()); + Py_RETURN_NONE; +} + +PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + if (!PyList_Check(input_value)) { + PyErr_Format(PyExc_ValueError, + "Invalid input type: expected input to be a list."); + return nullptr; + } + + const size_t inputs_size = PyList_Size(input_value); + + if (inputs_size != interpreter_->inputs().size()) { + PyErr_Format(PyExc_ValueError, + "Invalid input size: expected %ld items got %ld items.", + interpreter_->inputs().size(), inputs_size); + return nullptr; + } + + for (size_t i = 0; i < inputs_size; i++) { + PyObject* input = PyList_GetItem(input_value, i); + if (!input) { + return nullptr; + } + int input_tensor_idx = interpreter_->inputs()[i]; + if (!SetTensor(input_tensor_idx, input)) { + return nullptr; + } + } + + TFLITE_PY_CHECK(interpreter_->Invoke()); + Py_RETURN_NONE; +} + +PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + + std::unique_ptr array_safe( + PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); + if (!array_safe) { + PyErr_SetString(PyExc_ValueError, + "Failed to convert value into readable tensor."); + return nullptr; + } + + PyArrayObject* array = reinterpret_cast(array_safe.get()); + const TfLiteTensor* tensor = interpreter_->tensor(index); + + if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) { + PyErr_Format(PyExc_ValueError, + "Cannot set tensor:" + " Got tensor of type %d" + " but expected type %d for input %d, name: %s ", + python_utils::TfLiteTypeFromPyArray(array), tensor->type, + index, tensor->name); + return nullptr; + } + + if (PyArray_NDIM(array) != tensor->dims->size) { + PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + return nullptr; + } + + for (int j = 0; j < PyArray_NDIM(array); j++) { + if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { + PyErr_SetString(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch"); + return nullptr; + } + } + + size_t size = PyArray_NBYTES(array); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } + memcpy(tensor->data.raw, PyArray_DATA(array), size); + Py_RETURN_NONE; +} + +PyObject* CalibrationWrapper::QuantizeModel() { + auto tflite_model = CreateMutableModel(*model_->GetModel()); + reader_->AddCalibrationToModel(tflite_model.get()); + flatbuffers::FlatBufferBuilder builder; + auto status = tflite::optimize::QuantizeModel(&builder, tflite_model.get(), + error_reporter_.get()); + if (status != kTfLiteOk) { + error_reporter_->exception(); + return nullptr; + } + + return python_utils::ConvertToPyString( + reinterpret_cast(builder.GetCurrentBufferPointer()), + builder.GetSize()); +} + +/*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer( + PyObject* data) { + using tflite::interpreter_wrapper::PythonErrorReporter; + char* buf = nullptr; + Py_ssize_t length; + std::unique_ptr error_reporter(new PythonErrorReporter); + ::tflite::python::ImportNumpy(); + + if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) { + return nullptr; + } + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromBuffer(buf, length, + error_reporter.get()); + if (!model) { + PyErr_Format(PyExc_ValueError, "Invalid model"); + return nullptr; + } + auto resolver = absl::make_unique(); + std::unique_ptr interpreter; + std::unique_ptr reader; + auto status = tflite::optimize::calibration::BuildLoggingInterpreter( + *model, *resolver, &interpreter, &reader); + if (status != kTfLiteOk) { + error_reporter->exception(); + return nullptr; + } + + auto wrapper = new CalibrationWrapper( + std::move(interpreter), std::move(resolver), std::move(error_reporter), + std::move(model), std::move(reader)); + return wrapper; +} + +} // namespace calibration_wrapper +} // namespace tflite diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..213bc4a182d348e5a19b5c2624cca375d367aba7 --- /dev/null +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_PYTHON_OPTIMIZE_CALIBRATION_WRAPPER_H_ +#define TENSORFLOW_LITE_PYTHON_OPTIMIZE_CALIBRATION_WRAPPER_H_ + +#include +#include +#include + +// Place `` before to avoid build failures in macOS. +#include + +// The empty line above is on purpose as otherwise clang-format will +// automatically move before . +#include + +// We forward declare TFLite classes here to avoid exposing them to SWIG. +namespace tflite { +namespace ops { +namespace builtin { +class BuiltinOpResolver; +} // namespace builtin +} // namespace ops + +class FlatBufferModel; +class Interpreter; + +namespace interpreter_wrapper { +class PythonErrorReporter; +} // namespace interpreter_wrapper + +namespace optimize { +namespace calibration { +class CalibrationReader; +} // namespace calibration +} // namespace optimize + +namespace calibration_wrapper { + +class CalibrationWrapper { + public: + // SWIG caller takes ownership of pointer. + static CalibrationWrapper* CreateWrapperCPPFromBuffer(PyObject* data); + ~CalibrationWrapper(); + + PyObject* Prepare(); + + PyObject* FeedTensor(PyObject* input_value); + + PyObject* QuantizeModel(); + + private: + // CalibrationWrapper is not copyable or assignable. We avoid the use of + // CalibrationWrapper() = delete here for SWIG compatibility. + CalibrationWrapper( + std::unique_ptr interpreter, + std::unique_ptr resolver, + std::unique_ptr + error_reporter, + std::unique_ptr model, + std::unique_ptr reader); + + CalibrationWrapper(const CalibrationWrapper& rhs); + + PyObject* SetTensor(int index, PyObject* value); + + std::unique_ptr interpreter_; + std::unique_ptr + error_reporter_; + std::unique_ptr resolver_; + std::unique_ptr model_; + std::unique_ptr reader_; +}; + +} // namespace calibration_wrapper +} // namespace tflite + +#endif // TENSORFLOW_LITE_PYTHON_OPTIMIZE_CALIBRATION_WRAPPER_H_ diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.i b/tensorflow/lite/python/optimize/calibration_wrapper.i new file mode 100644 index 0000000000000000000000000000000000000000..094ac20733abc3797d5d325b838215c2909045ba --- /dev/null +++ b/tensorflow/lite/python/optimize/calibration_wrapper.i @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "std_string.i" + + +%{ +#define SWIG_FILE_WITH_INIT +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/python/optimize/calibration_wrapper.h" +%} + + +%include "tensorflow/lite/python/optimize/calibration_wrapper.h" \ No newline at end of file diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py new file mode 100644 index 0000000000000000000000000000000000000000..940987932e5469065e01e86e18a35be81990ff5a --- /dev/null +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -0,0 +1,68 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrapper for post training quantization with calibration.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.lazy_loader import LazyLoader + +# Lazy load since some of the performance benchmark skylark rules +# break dependencies. Must use double quotes to match code internal rewrite +# rule. +_calibration_wrapper = LazyLoader( + "_calibration_wrapper", globals(), + "tensorflow.lite.python.optimize." + "tensorflow_lite_wrap_calibration_wrapper") + + +class Calibrator(object): + """Calibrates a floating point model and then quantizes it. + + This is an internal class, not a public interface. + """ + + def __init__(self, model_content): + """Constructor. + + Args: + model_content: Content of a TF-Lite Flatbuffer file. + + Raises: + ValueError: If the calibrator was unable to open the model. + """ + if not model_content: + raise ValueError("`model_content` must be specified.") + try: + self._calibrator = (_calibration_wrapper.CalibrationWrapper + .CreateWrapperCPPFromBuffer(model_content)) + except Exception as e: + raise ValueError("Failed to parse the model: %s." % e) + if not self._calibrator: + raise ValueError("Failed to parse the model.") + + def calibrate_and_quantize(self, dataset_gen): + """Calibrates the model with specified generator and then quantizes it. + + Returns: + A quantized model. + + Args: + dataset_gen: A generator that generates calibration samples. + """ + self._calibrator.Prepare() + for calibration_sample in dataset_gen(): + self._calibrator.FeedTensor(calibration_sample) + return self._calibrator.QuantizeModel() diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..25e0c1efbff3023b0386c53b8eb612bb89c2f19b --- /dev/null +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -0,0 +1,109 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Calibrator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +from tensorflow.lite.python.optimize import calibrator as _calibrator +from tensorflow.python.framework import test_util +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class CalibratorTest(test_util.TensorFlowTestCase): + + def test_calibration_with_quantization(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + # Input generator for the model. + def input_gen(): + for _ in range(10): + yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] + + quantized_model = quantizer.calibrate_and_quantize(input_gen) + self.assertIsNotNone(quantized_model) + + def test_calibration_with_quantization_multiple_inputs(self): + # Load multi add model from test data. + # This model has 4 inputs of size (1, 8, 8, 3). + model_path = resource_loader.get_path_to_datafile( + '../../testdata/multi_add.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + # Input generator for the model. + def input_gen(): + for _ in range(10): + yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)] + + quantized_model = quantizer.calibrate_and_quantize(input_gen) + self.assertIsNotNone(quantized_model) + + def test_invalid_model_buffer(self): + float_model = b'\0' * 100 + with self.assertRaisesWithRegexpMatch(ValueError, + 'Failed to parse the model'): + _calibrator.Calibrator(float_model) + + def test_empty_calibrator_gen(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + def empty_input_gen(): + for i in (): + yield i + + with self.assertRaises(RuntimeError): + quantizer.calibrate_and_quantize(empty_input_gen) + + def test_invalid_shape_calibrator_gen(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + # Input generator with incorrect shape. + def input_gen(): + for _ in range(10): + yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)] + + with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'): + quantizer.calibrate_and_quantize(input_gen) + + def test_invalid_type_calibrator_gen(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + # Input generator with incorrect shape. + def input_gen(): + for _ in range(10): + yield np.ones(shape=(1, 5, 5, 3), dtype=np.int32) + + with self.assertRaises(ValueError): + quantizer.calibrate_and_quantize(input_gen) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/lite/python/optimize/test_data/mobilenet_like_model.bin b/tensorflow/lite/python/optimize/test_data/mobilenet_like_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..a2909249ffd2675fad9c0cd60a6ff75f940b3fb0 Binary files /dev/null and b/tensorflow/lite/python/optimize/test_data/mobilenet_like_model.bin differ diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2fa08e5326990ecda1857fec8eb9caadac1f4102 --- /dev/null +++ b/tensorflow/lite/python/testdata/BUILD @@ -0,0 +1,53 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/lite:build_def.bzl", "tf_to_tflite") + +exports_files(glob(["*.pb"])) + +tf_to_tflite( + name = "permute_float", + src = "permute.pbtxt", + out = "permute_float.tflite", + options = [ + "--input_arrays=input", + "--output_arrays=output", + ], +) + +tf_to_tflite( + name = "permute_uint8", + src = "permute.pbtxt", + out = "permute_uint8.tflite", + options = [ + "--input_arrays=input", + "--output_arrays=output", + "--inference_type=QUANTIZED_UINT8", + "--std_values=1", + "--mean_values=0", + "--default_ranges_min=0", + "--default_ranges_max=255", + ], +) + +tf_to_tflite( + name = "gather_string", + src = "gather.pbtxt", + out = "gather_string.tflite", + options = [ + "--input_arrays=input,indices", + "--output_arrays=output", + ], +) + +filegroup( + name = "interpreter_test_data", + srcs = [ + ":gather_string", + ":permute_float", + ":permute_uint8", + ], + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/lite/python/testdata/gather.pbtxt b/tensorflow/lite/python/testdata/gather.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0b1193c475d3b4b663accf036753bfbe9d8adb7d --- /dev/null +++ b/tensorflow/lite/python/testdata/gather.pbtxt @@ -0,0 +1,93 @@ +node { + name: "input" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } +} +node { + name: "indices" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + } + } + } +} +node { + name: "axis" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "output" + op: "GatherV2" + input: "input" + input: "indices" + input: "axis" + device: "/device:CPU:0" + attr { + key: "Taxis" + value { + type: DT_INT32 + } + } + attr { + key: "Tindices" + value { + type: DT_INT64 + } + } + attr { + key: "Tparams" + value { + type: DT_STRING + } + } +} +versions { + producer: 27 +} diff --git a/tensorflow/lite/python/testdata/permute.pbtxt b/tensorflow/lite/python/testdata/permute.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..482b5c64828f4f5ef2057b4552a81425485d0841 --- /dev/null +++ b/tensorflow/lite/python/testdata/permute.pbtxt @@ -0,0 +1,98 @@ +node { + name: "input" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 4 + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 4 + } + dim { + size: 4 + } + } + float_val: 0.0 + float_val: 0.0 + float_val: 0.0 + float_val: 1.0 + + float_val: 0.0 + float_val: 0.0 + float_val: 1.0 + float_val: 0.0 + + float_val: 0.0 + float_val: 1.0 + float_val: 0.0 + float_val: 0.0 + + float_val: 1.0 + float_val: 0.0 + float_val: 0.0 + float_val: 0.0 + } + } + } +} +node { + name: "output" + op: "MatMul" + input: "input" + input: "Const" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +versions { + producer: 27 +} diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index 401a592273c9c76f1f371bb8972f7f9a3d494278..ad4b04c4f3e3aefee69085856f62868ba667e5e7 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -25,6 +25,7 @@ import sys from tensorflow.lite.python import lite from tensorflow.lite.python import lite_constants from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.python import tf2 from tensorflow.python.platform import app @@ -258,6 +259,11 @@ def _check_flags(flags, unparsed): def run_main(_): """Main in toco_convert.py.""" + if tf2.enabled(): + raise ValueError("tflite_convert is currently unsupported in 2.0. " + "Please use the Python API " + "tf.lite.TFLiteConverter.from_concrete_function().") + parser = argparse.ArgumentParser( description=("Command line tool to run TensorFlow Lite Optimizing " "Converter (TOCO).")) diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index ea516764c929080bc42e48a7cfcdd171f2d6cc57..e55419186e16f62f27f9df0201e814cb8936fc27 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -9,6 +9,12 @@ load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") py_binary( name = "upgrade_schema", + srcs = ["upgrade_schema.py"], + deps = [":upgrade_schema_main_lib"], +) + +py_library( + name = "upgrade_schema_main_lib", srcs = [ "upgrade_schema.py", ], @@ -39,7 +45,7 @@ py_test( "notap", ], deps = [ - ":upgrade_schema", + ":upgrade_schema_main_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 13f47d2cfcc8bae23ffc34183adde5e8770cc8e7..c6c61a602a8191e2d9c611338c1c5cf4b7cd814d 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -217,6 +217,15 @@ enum BuiltinOperator : byte { ABS = 101, SPLIT_V = 102, UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, } // Options for the builtin operators. @@ -301,6 +310,13 @@ union BuiltinOptions { AbsOptions, SplitVOptions, UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, } enum Padding : byte { SAME, VALID } @@ -544,6 +560,9 @@ table TransposeOptions { table ExpOptions { } +table CosOptions { +} + table ReducerOptions { keep_dims: bool; } @@ -638,6 +657,9 @@ table ShapeOptions { out_type : TensorType; } +table RankOptions { +} + table PowOptions { } @@ -718,6 +740,22 @@ table UniqueOptions { idx_out_type:TensorType = INT32; } +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index a1174e79f722dc2e6dca59a5b2071b8d461378f1..2a55698a616a52af50e19f9ae0ef14e70e14930e 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -139,6 +139,9 @@ struct TransposeOptionsT; struct ExpOptions; struct ExpOptionsT; +struct CosOptions; +struct CosOptionsT; + struct ReducerOptions; struct ReducerOptionsT; @@ -214,6 +217,9 @@ struct NotEqualOptionsT; struct ShapeOptions; struct ShapeOptionsT; +struct RankOptions; +struct RankOptionsT; + struct PowOptions; struct PowOptionsT; @@ -271,6 +277,21 @@ struct MirrorPadOptionsT; struct UniqueOptions; struct UniqueOptionsT; +struct ReverseV2Options; +struct ReverseV2OptionsT; + +struct AddNOptions; +struct AddNOptionsT; + +struct GatherNdOptions; +struct GatherNdOptionsT; + +struct WhereOptions; +struct WhereOptionsT; + +struct ReverseSequenceOptions; +struct ReverseSequenceOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -524,11 +545,20 @@ enum BuiltinOperator { BuiltinOperator_ABS = 101, BuiltinOperator_SPLIT_V = 102, BuiltinOperator_UNIQUE = 103, + BuiltinOperator_CEIL = 104, + BuiltinOperator_REVERSE_V2 = 105, + BuiltinOperator_ADD_N = 106, + BuiltinOperator_GATHER_ND = 107, + BuiltinOperator_COS = 108, + BuiltinOperator_WHERE = 109, + BuiltinOperator_RANK = 110, + BuiltinOperator_ELU = 111, + BuiltinOperator_REVERSE_SEQUENCE = 112, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_UNIQUE + BuiltinOperator_MAX = BuiltinOperator_REVERSE_SEQUENCE }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[103] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[112] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -632,7 +662,16 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[103] { BuiltinOperator_MIRROR_PAD, BuiltinOperator_ABS, BuiltinOperator_SPLIT_V, - BuiltinOperator_UNIQUE + BuiltinOperator_UNIQUE, + BuiltinOperator_CEIL, + BuiltinOperator_REVERSE_V2, + BuiltinOperator_ADD_N, + BuiltinOperator_GATHER_ND, + BuiltinOperator_COS, + BuiltinOperator_WHERE, + BuiltinOperator_RANK, + BuiltinOperator_ELU, + BuiltinOperator_REVERSE_SEQUENCE }; return values; } @@ -743,6 +782,15 @@ inline const char * const *EnumNamesBuiltinOperator() { "ABS", "SPLIT_V", "UNIQUE", + "CEIL", + "REVERSE_V2", + "ADD_N", + "GATHER_ND", + "COS", + "WHERE", + "RANK", + "ELU", + "REVERSE_SEQUENCE", nullptr }; return names; @@ -835,11 +883,18 @@ enum BuiltinOptions { BuiltinOptions_AbsOptions = 78, BuiltinOptions_SplitVOptions = 79, BuiltinOptions_UniqueOptions = 80, + BuiltinOptions_ReverseV2Options = 81, + BuiltinOptions_AddNOptions = 82, + BuiltinOptions_GatherNdOptions = 83, + BuiltinOptions_CosOptions = 84, + BuiltinOptions_WhereOptions = 85, + BuiltinOptions_RankOptions = 86, + BuiltinOptions_ReverseSequenceOptions = 87, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_UniqueOptions + BuiltinOptions_MAX = BuiltinOptions_ReverseSequenceOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[88] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -921,7 +976,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] { BuiltinOptions_MirrorPadOptions, BuiltinOptions_AbsOptions, BuiltinOptions_SplitVOptions, - BuiltinOptions_UniqueOptions + BuiltinOptions_UniqueOptions, + BuiltinOptions_ReverseV2Options, + BuiltinOptions_AddNOptions, + BuiltinOptions_GatherNdOptions, + BuiltinOptions_CosOptions, + BuiltinOptions_WhereOptions, + BuiltinOptions_RankOptions, + BuiltinOptions_ReverseSequenceOptions }; return values; } @@ -1009,6 +1071,13 @@ inline const char * const *EnumNamesBuiltinOptions() { "AbsOptions", "SplitVOptions", "UniqueOptions", + "ReverseV2Options", + "AddNOptions", + "GatherNdOptions", + "CosOptions", + "WhereOptions", + "RankOptions", + "ReverseSequenceOptions", nullptr }; return names; @@ -1343,6 +1412,34 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CosOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RankOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2014,6 +2111,62 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_UniqueOptions ? reinterpret_cast(value) : nullptr; } + ReverseV2OptionsT *AsReverseV2Options() { + return type == BuiltinOptions_ReverseV2Options ? + reinterpret_cast(value) : nullptr; + } + const ReverseV2OptionsT *AsReverseV2Options() const { + return type == BuiltinOptions_ReverseV2Options ? + reinterpret_cast(value) : nullptr; + } + AddNOptionsT *AsAddNOptions() { + return type == BuiltinOptions_AddNOptions ? + reinterpret_cast(value) : nullptr; + } + const AddNOptionsT *AsAddNOptions() const { + return type == BuiltinOptions_AddNOptions ? + reinterpret_cast(value) : nullptr; + } + GatherNdOptionsT *AsGatherNdOptions() { + return type == BuiltinOptions_GatherNdOptions ? + reinterpret_cast(value) : nullptr; + } + const GatherNdOptionsT *AsGatherNdOptions() const { + return type == BuiltinOptions_GatherNdOptions ? + reinterpret_cast(value) : nullptr; + } + CosOptionsT *AsCosOptions() { + return type == BuiltinOptions_CosOptions ? + reinterpret_cast(value) : nullptr; + } + const CosOptionsT *AsCosOptions() const { + return type == BuiltinOptions_CosOptions ? + reinterpret_cast(value) : nullptr; + } + WhereOptionsT *AsWhereOptions() { + return type == BuiltinOptions_WhereOptions ? + reinterpret_cast(value) : nullptr; + } + const WhereOptionsT *AsWhereOptions() const { + return type == BuiltinOptions_WhereOptions ? + reinterpret_cast(value) : nullptr; + } + RankOptionsT *AsRankOptions() { + return type == BuiltinOptions_RankOptions ? + reinterpret_cast(value) : nullptr; + } + const RankOptionsT *AsRankOptions() const { + return type == BuiltinOptions_RankOptions ? + reinterpret_cast(value) : nullptr; + } + ReverseSequenceOptionsT *AsReverseSequenceOptions() { + return type == BuiltinOptions_ReverseSequenceOptions ? + reinterpret_cast(value) : nullptr; + } + const ReverseSequenceOptionsT *AsReverseSequenceOptions() const { + return type == BuiltinOptions_ReverseSequenceOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4946,6 +5099,46 @@ inline flatbuffers::Offset CreateExpOptions( flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct CosOptionsT : public flatbuffers::NativeTable { + typedef CosOptions TableType; + CosOptionsT() { + } +}; + +struct CosOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CosOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + CosOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CosOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CosOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit CosOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CosOptionsBuilder &operator=(const CosOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCosOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + CosOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateCosOptions(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ReducerOptionsT : public flatbuffers::NativeTable { typedef ReducerOptions TableType; bool keep_dims; @@ -6192,6 +6385,46 @@ inline flatbuffers::Offset CreateShapeOptions( flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct RankOptionsT : public flatbuffers::NativeTable { + typedef RankOptions TableType; + RankOptionsT() { + } +}; + +struct RankOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef RankOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + RankOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RankOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit RankOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RankOptionsBuilder &operator=(const RankOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateRankOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + RankOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct PowOptionsT : public flatbuffers::NativeTable { typedef PowOptions TableType; PowOptionsT() { @@ -7110,6 +7343,232 @@ inline flatbuffers::Offset CreateUniqueOptions( flatbuffers::Offset CreateUniqueOptions(flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ReverseV2OptionsT : public flatbuffers::NativeTable { + typedef ReverseV2Options TableType; + ReverseV2OptionsT() { + } +}; + +struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReverseV2OptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ReverseV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReverseV2OptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ReverseV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReverseV2OptionsBuilder &operator=(const ReverseV2OptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateReverseV2Options( + flatbuffers::FlatBufferBuilder &_fbb) { + ReverseV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AddNOptionsT : public flatbuffers::NativeTable { + typedef AddNOptions TableType; + AddNOptionsT() { + } +}; + +struct AddNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AddNOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + AddNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AddNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AddNOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit AddNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AddNOptionsBuilder &operator=(const AddNOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAddNOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + AddNOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateAddNOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GatherNdOptionsT : public flatbuffers::NativeTable { + typedef GatherNdOptions TableType; + GatherNdOptionsT() { + } +}; + +struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GatherNdOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GatherNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GatherNdOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit GatherNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GatherNdOptionsBuilder &operator=(const GatherNdOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateGatherNdOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + GatherNdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct WhereOptionsT : public flatbuffers::NativeTable { + typedef WhereOptions TableType; + WhereOptionsT() { + } +}; + +struct WhereOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef WhereOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + WhereOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(WhereOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct WhereOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit WhereOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + WhereOptionsBuilder &operator=(const WhereOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateWhereOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + WhereOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateWhereOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReverseSequenceOptionsT : public flatbuffers::NativeTable { + typedef ReverseSequenceOptions TableType; + int32_t seq_dim; + int32_t batch_dim; + ReverseSequenceOptionsT() + : seq_dim(0), + batch_dim(0) { + } +}; + +struct ReverseSequenceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReverseSequenceOptionsT NativeTableType; + enum { + VT_SEQ_DIM = 4, + VT_BATCH_DIM = 6 + }; + int32_t seq_dim() const { + return GetField(VT_SEQ_DIM, 0); + } + int32_t batch_dim() const { + return GetField(VT_BATCH_DIM, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SEQ_DIM) && + VerifyField(verifier, VT_BATCH_DIM) && + verifier.EndTable(); + } + ReverseSequenceOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReverseSequenceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReverseSequenceOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_seq_dim(int32_t seq_dim) { + fbb_.AddElement(ReverseSequenceOptions::VT_SEQ_DIM, seq_dim, 0); + } + void add_batch_dim(int32_t batch_dim) { + fbb_.AddElement(ReverseSequenceOptions::VT_BATCH_DIM, batch_dim, 0); + } + explicit ReverseSequenceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReverseSequenceOptionsBuilder &operator=(const ReverseSequenceOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateReverseSequenceOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t seq_dim = 0, + int32_t batch_dim = 0) { + ReverseSequenceOptionsBuilder builder_(_fbb); + builder_.add_batch_dim(batch_dim); + builder_.add_seq_dim(seq_dim); + return builder_.Finish(); +} + +flatbuffers::Offset CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -7483,6 +7942,27 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const UniqueOptions *builtin_options_as_UniqueOptions() const { return builtin_options_type() == BuiltinOptions_UniqueOptions ? static_cast(builtin_options()) : nullptr; } + const ReverseV2Options *builtin_options_as_ReverseV2Options() const { + return builtin_options_type() == BuiltinOptions_ReverseV2Options ? static_cast(builtin_options()) : nullptr; + } + const AddNOptions *builtin_options_as_AddNOptions() const { + return builtin_options_type() == BuiltinOptions_AddNOptions ? static_cast(builtin_options()) : nullptr; + } + const GatherNdOptions *builtin_options_as_GatherNdOptions() const { + return builtin_options_type() == BuiltinOptions_GatherNdOptions ? static_cast(builtin_options()) : nullptr; + } + const CosOptions *builtin_options_as_CosOptions() const { + return builtin_options_type() == BuiltinOptions_CosOptions ? static_cast(builtin_options()) : nullptr; + } + const WhereOptions *builtin_options_as_WhereOptions() const { + return builtin_options_type() == BuiltinOptions_WhereOptions ? static_cast(builtin_options()) : nullptr; + } + const RankOptions *builtin_options_as_RankOptions() const { + return builtin_options_type() == BuiltinOptions_RankOptions ? static_cast(builtin_options()) : nullptr; + } + const ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const { + return builtin_options_type() == BuiltinOptions_ReverseSequenceOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -7834,6 +8314,34 @@ template<> inline const UniqueOptions *Operator::builtin_options_as inline const ReverseV2Options *Operator::builtin_options_as() const { + return builtin_options_as_ReverseV2Options(); +} + +template<> inline const AddNOptions *Operator::builtin_options_as() const { + return builtin_options_as_AddNOptions(); +} + +template<> inline const GatherNdOptions *Operator::builtin_options_as() const { + return builtin_options_as_GatherNdOptions(); +} + +template<> inline const CosOptions *Operator::builtin_options_as() const { + return builtin_options_as_CosOptions(); +} + +template<> inline const WhereOptions *Operator::builtin_options_as() const { + return builtin_options_as_WhereOptions(); +} + +template<> inline const RankOptions *Operator::builtin_options_as() const { + return builtin_options_as_RankOptions(); +} + +template<> inline const ReverseSequenceOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReverseSequenceOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -9379,6 +9887,29 @@ inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferB _fbb); } +inline CosOptionsT *CosOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CosOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CosOptions::UnPackTo(CosOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset CosOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCosOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateCosOptions(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CosOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateCosOptions( + _fbb); +} + inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ReducerOptionsT(); UnPackTo(_o, _resolver); @@ -10008,6 +10539,29 @@ inline flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBuf _out_type); } +inline RankOptionsT *RankOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new RankOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void RankOptions::UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset RankOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateRankOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RankOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateRankOptions( + _fbb); +} + inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new PowOptionsT(); UnPackTo(_o, _resolver); @@ -10481,6 +11035,127 @@ inline flatbuffers::Offset CreateUniqueOptions(flatbuffers::FlatB _idx_out_type); } +inline ReverseV2OptionsT *ReverseV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReverseV2OptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ReverseV2Options::UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ReverseV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReverseV2Options(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReverseV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateReverseV2Options( + _fbb); +} + +inline AddNOptionsT *AddNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new AddNOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void AddNOptions::UnPackTo(AddNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset AddNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateAddNOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateAddNOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AddNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateAddNOptions( + _fbb); +} + +inline GatherNdOptionsT *GatherNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GatherNdOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GatherNdOptions::UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset GatherNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGatherNdOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GatherNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGatherNdOptions( + _fbb); +} + +inline WhereOptionsT *WhereOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new WhereOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void WhereOptions::UnPackTo(WhereOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset WhereOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateWhereOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateWhereOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const WhereOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateWhereOptions( + _fbb); +} + +inline ReverseSequenceOptionsT *ReverseSequenceOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReverseSequenceOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ReverseSequenceOptions::UnPackTo(ReverseSequenceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = seq_dim(); _o->seq_dim = _e; }; + { auto _e = batch_dim(); _o->batch_dim = _e; }; +} + +inline flatbuffers::Offset ReverseSequenceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReverseSequenceOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateReverseSequenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReverseSequenceOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _seq_dim = _o->seq_dim; + auto _batch_dim = _o->batch_dim; + return tflite::CreateReverseSequenceOptions( + _fbb, + _seq_dim, + _batch_dim); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -11059,6 +11734,34 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -11397,6 +12100,34 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -11723,6 +12454,34 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateUniqueOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(value); + return CreateReverseV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(value); + return CreateAddNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(value); + return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(value); + return CreateCosOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(value); + return CreateWhereOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(value); + return CreateRankOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(value); + return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -12049,6 +12808,34 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new UniqueOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_ReverseV2Options: { + value = new ReverseV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AddNOptions: { + value = new AddNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_GatherNdOptions: { + value = new GatherNdOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CosOptions: { + value = new CosOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_WhereOptions: { + value = new WhereOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RankOptions: { + value = new RankOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReverseSequenceOptions: { + value = new ReverseSequenceOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -12456,6 +13243,41 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/stderr_reporter.cc b/tensorflow/lite/stderr_reporter.cc index 09eb1d254a608ba2d19c824a323f0b5173afe15f..366a1816ef2b2ef62e093bbe99690eae52fdc8c4 100644 --- a/tensorflow/lite/stderr_reporter.cc +++ b/tensorflow/lite/stderr_reporter.cc @@ -13,28 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/stderr_reporter.h" -#include -#include -#ifdef __ANDROID__ -#include -#endif +#include "tensorflow/lite/minimal_logging.h" namespace tflite { int StderrReporter::Report(const char* format, va_list args) { -#ifdef __ANDROID__ - // On Android stderr is not captured for applications, only for code run from - // the shell. Rather than assume all users will set up a custom error - // reporter, let's output to logcat here - va_list args_for_log; - va_copy(args_for_log, args); - __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log); - va_end(args_for_log); -#endif - const int result = vfprintf(stderr, format, args); - fputc('\n', stderr); - return result; + logging_internal::MinimalLogger::VLog(TFLITE_LOG_ERROR, format, args); + return 0; } ErrorReporter* DefaultErrorReporter() { diff --git a/tensorflow/lite/string_util.h b/tensorflow/lite/string_util.h index adb67c6d4e748ab98e71e6519bddb237e985ebfb..4b56bcaca0157b715704a57560fa2394c6fa3be1 100644 --- a/tensorflow/lite/string_util.h +++ b/tensorflow/lite/string_util.h @@ -35,7 +35,7 @@ limitations under the License. // buf.AddString("AB", 2); // # Write content of DynamicBuffer to tensor in format of string tensor // # described above. -// buf.WriteToTensor(tensor) +// buf.WriteToTensor(tensor, nullptr) #ifndef TENSORFLOW_LITE_STRING_UTIL_H_ #define TENSORFLOW_LITE_STRING_UTIL_H_ @@ -83,10 +83,6 @@ class DynamicBuffer { // Fill content into a string tensor. Set shape to {num_strings}. void WriteToTensorAsVector(TfLiteTensor* tensor); - // Deprecated. Use WriteToTensorAsVector() or pass in the new shpe. - // TODO(b/120230709): remove when people migrate away. - void WriteToTensor(TfLiteTensor* tensor) { WriteToTensorAsVector(tensor); } - private: // Data buffer to store contents of strings, not including headers. std::vector data_; diff --git a/tensorflow/lite/string_util_test.cc b/tensorflow/lite/string_util_test.cc index cbf1d7b226af20251d5f70a354a21f1eb40ae1c6..6fc7de90ea534f9c8c4f61b4607ff7d2d8647d00 100644 --- a/tensorflow/lite/string_util_test.cc +++ b/tensorflow/lite/string_util_test.cc @@ -35,8 +35,11 @@ TEST(StringUtil, TestStringUtil) { char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'X', 'Y', 'Z'}; - interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, {}, data, - 15); + TfLiteQuantization quant; + quant.type = kTfLiteNoQuantization; + quant.params = nullptr; + interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, quant, + data, 15); TfLiteTensor* t2 = interpreter.tensor(2); interpreter.AllocateTensors(); diff --git a/tensorflow/lite/testdata/test_input.csv b/tensorflow/lite/testdata/test_input.csv new file mode 100644 index 0000000000000000000000000000000000000000..33894d3063f35a885fb34c3c5b85bb6a4d8e711e --- /dev/null +++ b/tensorflow/lite/testdata/test_input.csv @@ -0,0 +1 @@ +1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 \ No newline at end of file diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index ce85a393851898fc6eb27ffbf37b3e7cbb44e2d7..68512b952a9a34c67452b676db97534b1fb3c733 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -10,6 +10,7 @@ load( "generated_test_models_all", ) load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", @@ -78,6 +79,7 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":generate_examples_report", + ":string_util_wrapper", "//tensorflow:tensorflow_py", "//tensorflow/python:graph_util", "//third_party/py/numpy", @@ -159,6 +161,7 @@ cc_library( srcs = ["tflite_driver.cc"], hdrs = ["tflite_driver.h"], deps = [ + ":join", ":split", ":test_runner", "//tensorflow/lite:builtin_op_data", @@ -237,7 +240,7 @@ tf_cc_binary( deps = [ ":parse_testdata_lib", ":tflite_driver", - "//tensorflow/lite/nnapi:nnapi_lib", + "//tensorflow/lite/nnapi:nnapi_implementation", ], ) @@ -391,4 +394,29 @@ tf_cc_binary( ], ) +cc_library( + name = "string_util_lib", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = [ + "//tensorflow/lite:string_util", + "//tensorflow/lite/python/interpreter_wrapper:numpy", + "//tensorflow/lite/python/interpreter_wrapper:python_utils", + "//third_party/py/numpy:headers", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/strings", + ], +) + +tf_py_wrap_cc( + name = "string_util_wrapper", + srcs = [ + "string_util.i", + ], + deps = [ + ":string_util_lib", + "//third_party/python_runtime:headers", + ], +) + tflite_portable_test_suite() diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py index 12b5a8b210137ce19e1321042293b8ac6375be37..213d214c132948910eec923dc18b60703a8ca571 100644 --- a/tensorflow/lite/testing/generate_examples.py +++ b/tensorflow/lite/testing/generate_examples.py @@ -36,6 +36,7 @@ import operator import os import random import re +import string import sys import tempfile import traceback @@ -52,6 +53,8 @@ import tensorflow as tf from google.protobuf import text_format # TODO(aselle): switch to TensorFlow's resource_loader from tensorflow.lite.testing import generate_examples_report as report_lib +from tensorflow.lite.testing import string_util_wrapper +from tensorflow.python.framework import test_util from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.ops import rnn @@ -163,6 +166,16 @@ def toco_options(data_types, return s +def format_result(t): + """Convert a tensor to a format that can be used in test specs.""" + if t.dtype.kind not in [np.dtype(np.string_).kind, np.dtype(np.object_).kind]: + # Output 9 digits after the point to ensure the precision is good enough. + values = ["{:.9f}".format(value) for value in list(t.flatten())] + return ",".join(values) + else: + return string_util_wrapper.SerializeAsHexString(t.flatten()) + + def write_examples(fp, examples): """Given a list `examples`, write a text format representation. @@ -179,9 +192,7 @@ def write_examples(fp, examples): """Write tensor in file format supported by TFLITE example.""" fp.write("dtype,%s\n" % x.dtype) fp.write("shape," + ",".join(map(str, x.shape)) + "\n") - # Output 9 digits after the point to ensure the precision is good enough. - values = ["{:.9f}".format(value) for value in list(x.flatten())] - fp.write("values," + ",".join(values) + "\n") + fp.write("values," + format_result(x) + "\n") fp.write("test_cases,%d\n" % len(examples)) for example in examples: @@ -214,11 +225,9 @@ def write_test_cases(fp, model_name, examples): fp.write("invoke {\n") for t in example["inputs"]: - values = ["{:.9f}".format(value) for value in list(t.flatten())] - fp.write(" input: \"" + ",".join(values) + "\"\n") + fp.write(" input: \"" + format_result(t) + "\"\n") for t in example["outputs"]: - values = ["{:.9f}".format(value) for value in list(t.flatten())] - fp.write(" output: \"" + ",".join(values) + "\"\n") + fp.write(" output: \"" + format_result(t) + "\"\n") fp.write("}\n") @@ -230,6 +239,7 @@ _TF_TYPE_INFO = { tf.int16: (np.int16, "QUANTIZED_INT16"), tf.int64: (np.int64, "INT64"), tf.bool: (np.bool, "BOOL"), + tf.string: (np.string_, "STRING"), } @@ -245,6 +255,10 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): value = np.random.randint(min_value, max_value+1, shape) elif dtype == tf.bool: value = np.random.choice([True, False], size=shape) + elif dtype == np.string_: + # Not the best strings, but they will do for some basic testing. + letters = list(string.ascii_uppercase) + return np.random.choice(letters, size=shape).astype(dtype) return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype( dtype) @@ -300,8 +314,13 @@ def make_control_dep_tests(zip_path): extra_toco_options = ExtraTocoOptions() extra_toco_options.drop_control_dependency = True - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, - extra_toco_options) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + extra_toco_options, + expected_tf_failures=3) def toco_convert(graph_def_str, input_tensors, output_tensors, @@ -369,7 +388,7 @@ def make_zip_of_tests(zip_path, make_test_inputs, extra_toco_options=ExtraTocoOptions(), use_frozen_graph=False, - expected_tf_success=None): + expected_tf_failures=0): """Helper to make a zip file of a bunch of TensorFlow models. This does a cartestian product of the dictionary of test_parameters and @@ -389,8 +408,9 @@ def make_zip_of_tests(zip_path, `output_tensors` and returns tuple `(input_values, output_values)`. extra_toco_options: Additional toco options. use_frozen_graph: Whether or not freeze graph before toco converter. - expected_tf_success: Number of times tensorflow is supposed to succeed in - executing the input graphs. `None` means "unknown". + expected_tf_failures: Number of times tensorflow is expected to fail in + executing the input graphs. In some cases it is OK for TensorFlow to + fail because the one or more combination of parameters is invalid. Raises: RuntimeError: if there are toco errors that can't be ignored. @@ -472,9 +492,8 @@ def make_zip_of_tests(zip_path, report["toco"] = report_lib.FAILED report["tf"] = report_lib.SUCCESS # Convert graph to toco - input_tensors = [(input_tensor.name.split(":")[0], - input_tensor.get_shape(), input_tensor.dtype) - for input_tensor in inputs] + input_tensors = [(input_tensor.name.split(":")[0], input_tensor.shape, + input_tensor.dtype) for input_tensor in inputs] output_tensors = [normalize_output_name(out.name) for out in outputs] graph_def = freeze_graph( sess, @@ -485,6 +504,10 @@ def make_zip_of_tests(zip_path, extra_toco_options.split_tflite_lstm_inputs = param_dict_real[ "split_tflite_lstm_inputs"] + # Convert ophint ops if presented. + graph_def = tf.lite.experimental.convert_op_hints_to_stubs( + graph_def=graph_def) + graph_def = tf.graph_util.remove_training_nodes(graph_def) tflite_model_binary, toco_log = toco_convert( graph_def.SerializeToString(), input_tensors, output_tensors, extra_toco_options) @@ -551,10 +574,17 @@ def make_zip_of_tests(zip_path, " and %d TOCO converted graphs (%.1f%%"), zip_path, total_conversions, tf_success, toco_success, percent) - if expected_tf_success is not None and tf_success != expected_tf_success: - raise RuntimeError( - "Expected TF to succeed %d times, but that happened %d times" % - (expected_tf_success, tf_success)) + tf_failures = parameter_count - tf_success + + if tf_failures / parameter_count > 0.8: + raise RuntimeError(("Test for '%s' is not very useful. " + "TensorFlow fails in %d percent of the cases.") % + (zip_path, int(100 * tf_failures / parameter_count))) + + if tf_failures != expected_tf_failures: + raise RuntimeError(("Expected TF to fail %d times while generating '%s', " + "but that happened %d times") % (expected_tf_failures, + zip_path, tf_failures)) if not FLAGS.ignore_toco_errors and toco_errors > 0: raise RuntimeError( @@ -573,11 +603,12 @@ def make_pool_tests(pool_op_in): pool_op = pool_op_in - def f(zip_path): + def f(zip_path, expected_tf_failures=0): """Actual function that generates examples. Args: zip_path: path to write zip to. + expected_tf_failures: number of expected tensorflow failures. """ # Chose a set of parameters @@ -606,20 +637,26 @@ def make_pool_tests(pool_op_in): return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=expected_tf_failures) + return f def make_l2_pool_tests(zip_path): - make_pool_tests(make_l2_pool)(zip_path) + make_pool_tests(make_l2_pool)(zip_path, expected_tf_failures=80) def make_avg_pool_tests(zip_path): - make_pool_tests(tf.nn.avg_pool)(zip_path) + make_pool_tests(tf.nn.avg_pool)(zip_path, expected_tf_failures=80) def make_max_pool_tests(zip_path): - make_pool_tests(tf.nn.max_pool)(zip_path) + make_pool_tests(tf.nn.max_pool)(zip_path, expected_tf_failures=80) def make_abs_tests(zip_path): @@ -645,6 +682,32 @@ def make_abs_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_elu_tests(zip_path): + """Make a set of tests to do (float) tf.nn.elu.""" + + test_parameters = [ + { + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }, + ] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.elu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) def make_relu_tests(zip_path): """Make a set of tests to do relu.""" @@ -814,6 +877,9 @@ def make_constant_tests(zip_path): "dtype": [tf.float32, tf.int32], "input_shape": [[], [1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], "constant_is_also_output": [True, False], + # This is a regression test for a bug where Toco rejects models with + # unread inputs. + "has_unread_input": [True, False], }] def build_graph(parameters): @@ -823,22 +889,28 @@ def make_constant_tests(zip_path): shape=parameters["input_shape"]) constant = tf.constant( create_tensor_data(parameters["dtype"], parameters["input_shape"])) - out = [tf.maximum(dummy_input, constant)] + outputs = [tf.maximum(dummy_input, constant)] if parameters["constant_is_also_output"]: - out.append(constant) + outputs.append(constant) + inputs = [dummy_input] + if parameters["has_unread_input"]: + unread_input = tf.placeholder( + dtype=parameters["dtype"], + name="unread_input", + shape=parameters["input_shape"]) + inputs.append(unread_input) - return [dummy_input], out + return inputs, outputs def build_inputs(parameters, sess, inputs, outputs): dummy_input = np.zeros( parameters["input_shape"], dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) return [dummy_input], sess.run(outputs, feed_dict={inputs[0]: dummy_input}) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, - expected_tf_success=20) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_binary_op_tests(zip_path, binary_operator): +def make_binary_op_tests(zip_path, binary_operator, expected_tf_failures=0): """Make a set of tests to do binary ops with and without broadcast.""" test_parameters = [ @@ -908,7 +980,12 @@ def make_binary_op_tests(zip_path, binary_operator): inputs[1]: input2 }) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=expected_tf_failures) def make_reduce_tests(reduce_op, @@ -1074,6 +1151,34 @@ def make_exp_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_cos_tests(zip_path): + """Make a set of tests to do cos.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + }] + + def build_graph(parameters): + """Build the cos op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + out = tf.cos(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], parameters["input_shape"], + min_value=-np.pi, max_value=np.pi) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_log_softmax_tests(zip_path): """Make a set of tests to do log_softmax.""" @@ -1137,7 +1242,12 @@ def make_maximum_tests(zip_path): ] return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=8) def make_minimum_tests(zip_path): @@ -1172,7 +1282,12 @@ def make_minimum_tests(zip_path): ] return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=8) def make_binary_op_tests_func(binary_operator): @@ -1184,6 +1299,51 @@ def make_add_tests(zip_path): make_binary_op_tests(zip_path, tf.add) +def make_add_n_tests(zip_path): + """Make a set of tests for AddN op.""" + + test_parameters = [ + { + "dtype": [tf.float32, tf.int32], + "input_shape": [[2, 5, 3, 1]], + "num_inputs": [2, 3, 4, 5], + }, + { + "dtype": [tf.float32, tf.int32], + "input_shape": [[5]], + "num_inputs": [2, 3, 4, 5], + }, + { + "dtype": [tf.float32, tf.int32], + "input_shape": [[]], + "num_inputs": [2, 3, 4, 5], + }, + ] + + def build_graph(parameters): + """Builds the graph given the current parameters.""" + input_tensors = [] + for i in range(parameters["num_inputs"]): + input_tensors.append( + tf.placeholder( + dtype=parameters["dtype"], + name="input_{}".format(i), + shape=parameters["input_shape"])) + out = tf.add_n(input_tensors) + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Builds operand inputs for op.""" + input_data = [] + for i in range(parameters["num_inputs"]): + input_data.append( + create_tensor_data(parameters["dtype"], parameters["input_shape"])) + return input_data, sess.run( + outputs, feed_dict={i: d for i, d in zip(inputs, input_data)}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_div_tests(zip_path): make_binary_op_tests(zip_path, tf.div) @@ -1197,7 +1357,7 @@ def make_mul_tests(zip_path): def make_pow_tests(zip_path): - make_binary_op_tests(zip_path, tf.pow) + make_binary_op_tests(zip_path, tf.pow, expected_tf_failures=7) def make_floor_div_tests(zip_path): @@ -1215,16 +1375,23 @@ def make_squared_difference_tests(zip_path): def make_gather_tests(zip_path): """Make a set of tests to do gather.""" - test_parameters = [{ - # TODO(mgubin): add string tests when they are supported by Toco. - # TODO(mgubin): add tests for Nd indices when they are supported by - # TfLite. - "params_dtype": [tf.float32, tf.int32, tf.int64], - "params_shape": [[10], [1, 2, 20]], - "indices_dtype": [tf.int32, tf.int64], - "indices_shape": [[3], [5]], - "axis": [-1, 0, 1], - }] + test_parameters = [ + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[10], [1, 2, 20]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[3], [5]], + "axis": [-1, 0, 1], + }, + { + # TODO(b/123895910): add Nd support for strings. + "params_dtype": [tf.string], + "params_shape": [[8]], + "indices_dtype": [tf.int32], + "indices_shape": [[3]], + "axis": [0], + } + ] def build_graph(parameters): """Build the gather op testing graph.""" @@ -1255,7 +1422,56 @@ def make_gather_tests(zip_path): test_parameters, build_graph, build_inputs, - expected_tf_success=60) + expected_tf_failures=12) + + +def make_gather_nd_tests(zip_path): + """Make a set of tests to do gather_nd.""" + + test_parameters = [ + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[5, 1]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[1, 1]], + }, + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[5, 5]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[2, 1], [2, 2]], + }, + { + "params_dtype": [tf.float32, tf.int32, tf.int64], + "params_shape": [[5, 5, 10]], + "indices_dtype": [tf.int32, tf.int64], + "indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]], + }, + ] + + def build_graph(parameters): + """Build the gather_nd op testing graph.""" + params = tf.placeholder( + dtype=parameters["params_dtype"], + name="params", + shape=parameters["params_shape"]) + indices = tf.placeholder( + dtype=parameters["indices_dtype"], + name="indices", + shape=parameters["indices_shape"]) + out = tf.gather_nd(params, indices) + return [params, indices], [out] + + def build_inputs(parameters, sess, inputs, outputs): + params = create_tensor_data(parameters["params_dtype"], + parameters["params_shape"]) + indices = create_tensor_data(parameters["indices_dtype"], + parameters["indices_shape"], 0, + parameters["params_shape"][0] - 1) + return [params, indices], sess.run( + outputs, feed_dict=dict(zip(inputs, [params, indices]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) def make_gather_with_constant_tests(zip_path): @@ -1284,8 +1500,7 @@ def make_gather_with_constant_tests(zip_path): return [reference_values], sess.run( outputs, feed_dict={inputs[0]: reference_values}) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, - expected_tf_success=2) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) def make_global_batch_norm_tests(zip_path): @@ -1421,7 +1636,12 @@ def make_conv_tests(zip_path): values.append(create_tensor_data(np.float32, filter_shape)) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=40) # Note: This is a regression test for a bug (b/122651451) that Toco incorrectly @@ -1663,7 +1883,12 @@ def make_depthwiseconv_tests(zip_path): values.append(create_tensor_data(np.float32, filter_shape)) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=4) def make_split_tests(zip_path): @@ -1686,7 +1911,12 @@ def make_split_tests(zip_path): values = [create_tensor_data(np.float32, parameters["input_shape"])] return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=112) def make_splitv_tests(zip_path): @@ -1709,7 +1939,12 @@ def make_splitv_tests(zip_path): values = [create_tensor_data(np.float32, parameters["input_shape"])] return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=158) def make_concat_tests(zip_path): @@ -1751,7 +1986,12 @@ def make_concat_tests(zip_path): return all_values, sess.run( outputs, feed_dict=dict(zip(inputs, all_values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=60) def make_fully_connected_tests(zip_path): @@ -1812,7 +2052,12 @@ def make_fully_connected_tests(zip_path): values.append(create_tensor_data(np.float32, parameters["shape2"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=10) def make_l2norm_tests(zip_path): @@ -1842,7 +2087,12 @@ def make_l2norm_tests(zip_path): return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=9) def make_local_response_norm_tests(zip_path): @@ -2068,6 +2318,29 @@ def make_shape_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_rank_tests(zip_path): + """Make a set of tests to do rank.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]], + }] + + def build_graph(parameters): + """Build the rank op testing graph.""" + input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") + out = tf.rank(input_value) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_one_hot_tests(zip_path): """Make a set of tests to do one_hot.""" @@ -2327,7 +2600,12 @@ def make_space_to_batch_nd_tests(zip_path): values.append(np.array(parameters["paddings"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=56) def make_batch_to_space_nd_tests(zip_path): @@ -2440,7 +2718,12 @@ def make_transpose_tests(zip_path): values.append(np.array(parameters["perm"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=9) def make_squeeze_tests(zip_path): @@ -2478,10 +2761,48 @@ def make_squeeze_tests(zip_path): return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=12) + +def make_squeeze_transpose_tests(zip_path): + """Make a set of tests to do squeeze followed by transpose.""" -def _make_strided_slice_tests(zip_path, test_parameters): + test_parameters = [{ + "dtype": [tf.int32, tf.float32, tf.int64], + "input_shape": [[1, 4, 10, 1]], + "axis": [[-1], [3]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + out = tf.squeeze(input_tensor, axis=parameters["axis"]) + out = tf.transpose(out, perm=[1, 2]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=0) + + +def _make_strided_slice_tests(zip_path, test_parameters, + expected_tf_failures=0): """Utility function to make strided_slice_tests based on parameters.""" def build_graph(parameters): @@ -2541,7 +2862,12 @@ def _make_strided_slice_tests(zip_path, test_parameters): return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=expected_tf_failures) def make_strided_slice_tests(zip_path): @@ -2615,7 +2941,7 @@ def make_strided_slice_tests(zip_path): "constant_indices": [False], }, ] - _make_strided_slice_tests(zip_path, test_parameters) + _make_strided_slice_tests(zip_path, test_parameters, expected_tf_failures=2) def make_strided_slice_1d_exhaustive_tests(zip_path): @@ -2638,7 +2964,10 @@ def make_strided_slice_1d_exhaustive_tests(zip_path): _make_strided_slice_tests(zip_path, test_parameters) -def make_strided_slice_buggy_tests(zip_path): +# For verifying https://github.com/tensorflow/tensorflow/issues/23599 +# TODO(chaomei): refactor the test to cover more cases, like negative stride, +# negative array index etc. +def make_resolve_constant_strided_slice_tests(zip_path): """Make a set of tests to show strided_slice yields incorrect results.""" test_parameters = [{ @@ -2693,7 +3022,7 @@ def make_lstm_tests(zip_path): shape=[num_batchs, input_vec_size]) inputs_after_split.append(one_timestamp_input) # Currently lstm identifier has a few limitations: only supports - # forget_bias == 0, inner state activiation == tanh. + # forget_bias == 0, inner state activation == tanh. # TODO(zhixianyan): Add another test with forget_bias == 1. # TODO(zhixianyan): Add another test with relu as activation. lstm_cell = tf.contrib.rnn.BasicLSTMCell( @@ -2817,7 +3146,12 @@ def make_arg_min_max_tests(zip_path): return [input_value], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=4) def make_equal_tests(zip_path): @@ -2852,7 +3186,12 @@ def make_equal_tests(zip_path): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=3) def make_not_equal_tests(zip_path): @@ -2886,7 +3225,12 @@ def make_not_equal_tests(zip_path): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=3) def make_greater_tests(zip_path): @@ -2920,7 +3264,12 @@ def make_greater_tests(zip_path): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=3) def make_greater_equal_tests(zip_path): @@ -2954,7 +3303,12 @@ def make_greater_equal_tests(zip_path): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=3) def make_less_tests(zip_path): @@ -2988,7 +3342,12 @@ def make_less_tests(zip_path): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=3) def make_less_equal_tests(zip_path): @@ -3022,7 +3381,12 @@ def make_less_equal_tests(zip_path): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=3) def make_floor_tests(zip_path): @@ -3030,7 +3394,7 @@ def make_floor_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], }] def build_graph(parameters): @@ -3042,6 +3406,31 @@ def make_floor_tests(zip_path): out = tf.floor(input_value) return [input_value], [out] + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run(outputs, feed_dict={inputs[0]: input_value}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_ceil_tests(zip_path): + """Make a set of tests to do ceil.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] + + def build_graph(parameters): + """Build the ceil op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = tf.ceil(input_value) + return [input_value], [out] + def build_inputs(parameters, sess, inputs, outputs): input_value = create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) @@ -3249,7 +3638,12 @@ def make_slice_tests(zip_path): return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=18) def make_conv2d_transpose_tests(zip_path): @@ -3286,12 +3680,7 @@ def make_conv2d_transpose_tests(zip_path): ] return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests( - zip_path, - test_parameters, - build_graph, - build_inputs, - expected_tf_success=4) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) # Since compute output_shape is fairly complicated for @@ -3541,7 +3930,12 @@ def make_pack_tests(zip_path): return all_values, sess.run( outputs, feed_dict=dict(zip(inputs, all_values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=72) def make_unpack_tests(zip_path): @@ -3632,13 +4026,18 @@ def make_fill_tests(zip_path): return [input1, input2], sess.run( outputs, feed_dict=dict(zip(inputs, [input1, input2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=12) def _make_logical_tests(op): """Make a set of tests to do logical operations.""" - def logical(zip_path): + def logical(zip_path, expected_tf_failures=0): """Generate examples.""" test_parameters = [{ "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), @@ -3663,19 +4062,24 @@ def _make_logical_tests(op): return [input_value1, input_value2], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=expected_tf_failures) return logical def make_logical_or_tests(zip_path): """Make a set of tests to do logical_or.""" - return _make_logical_tests(tf.logical_or)(zip_path) + return _make_logical_tests(tf.logical_or)(zip_path, expected_tf_failures=1) def make_logical_and_tests(zip_path): """Make a set of tests to do logical_and.""" - return _make_logical_tests(tf.logical_and)(zip_path) + return _make_logical_tests(tf.logical_and)(zip_path, expected_tf_failures=1) def make_logical_xor_tests(zip_path): @@ -3683,7 +4087,7 @@ def make_logical_xor_tests(zip_path): Test logical_not as well. """ - return _make_logical_tests(tf.logical_xor)(zip_path) + return _make_logical_tests(tf.logical_xor)(zip_path, expected_tf_failures=1) def make_mirror_pad_tests(zip_path): @@ -3732,6 +4136,12 @@ def make_mirror_pad_tests(zip_path): "mode": ["REFLECT"], "type": ["const"] }, + { + "input_shape": [[3, 2, 4, 5]], + "padding_matrix": [[[1, 1], [2, 2], [1, 1], [1, 1]]], + "mode": ["SYMMETRIC"], + "type": ["placeholder"] + }, ] def build_graph(parameters): @@ -3760,18 +4170,23 @@ def make_mirror_pad_tests(zip_path): return input_values, sess.run( outputs, feed_dict=dict(zip(inputs, input_values))) - make_zip_of_tests( - zip_path, - test_parameters, - build_graph, - build_inputs, - expected_tf_success=7) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) def make_unroll_batch_matmul_tests(zip_path): """Make a set of tests to test unroll_batch_matmul.""" - test_parameters = [{"dtype": [tf.float32], "shape": [[(2, 2, 3), (2, 3, 2)]]}] + test_parameters = [{ + "dtype": [tf.float32], + "shape": [[(2, 2, 3), (2, 3, 2), False, False], + [(2, 2, 3), (2, 3, 2), True, True], + [(2, 2, 3), (2, 2, 3), False, True], + [(2, 2, 3), (2, 2, 3), True, False], + [(4, 2, 2, 3), (4, 2, 3, 2), False, False], + [(4, 2, 2, 3), (4, 2, 3, 2), True, True], + [(4, 2, 2, 3), (4, 2, 2, 3), False, True], + [(4, 2, 2, 3), (4, 2, 2, 3), True, False]] + }] def build_graph(parameters): """Build the batch_matmul op testing graph.""" @@ -3780,7 +4195,11 @@ def make_unroll_batch_matmul_tests(zip_path): input_tensor2 = tf.placeholder( dtype=parameters["dtype"], shape=parameters["shape"][1]) # Should be unrolled and replaced with fully_connected ops in the end. - out = tf.matmul(input_tensor1, input_tensor2) + out = tf.matmul( + input_tensor1, + input_tensor2, + transpose_a=parameters["shape"][2], + transpose_b=parameters["shape"][3]) return [input_tensor1, input_tensor2], [out] def build_inputs(parameters, sess, inputs, outputs): @@ -3817,8 +4236,7 @@ def make_placeholder_with_default_tests(zip_path): return [input_value], sess.run( outputs, feed_dict=dict(zip(inputs, [input_value]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, - expected_tf_success=3) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) def make_unique_tests(zip_path): @@ -3863,12 +4281,235 @@ def make_unique_tests(zip_path): return input_values, sess.run( outputs, feed_dict=dict(zip(inputs, input_values))) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_reverse_v2_tests(zip_path): + """Make a set of tests to do reverse_v2.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]], + "axis": [0, 1, 2, 3], + }] + + def get_valid_axis(parameters): + """Return a tweaked version of 'axis'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + while axis > len(shape) - 1: + axis -= 1 + return axis + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input"), shape=parameters["base_shape"]) + outs = tf.reverse(input_tensor, axis=[get_valid_axis(parameters)]) + return [input_tensor], [outs] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(np.float32, shape=parameters["base_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_reverse_sequence_tests(zip_path): + """Make a set of tests to do reverse_sequence.""" + + test_parameters = [ + { + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape": [[8, 4, 5, 5, 6], [4, 4, 3, 5]], + "seq_lengths": [[2, 2, 2, 2], [2, 1, 1, 0]], + "seq_axis": [0, 3], + "batch_axis": [1] + }, + { + "input_dtype": [tf.float32], + "input_shape": [[2, 4, 5, 5, 6]], + "seq_lengths": [[2, 1]], + "seq_axis": [2], + "batch_axis": [0] + }, + { + "input_dtype": [tf.float32], + "input_shape": [[4, 2]], + "seq_lengths": [[3, 1]], + "seq_axis": [0], + "batch_axis": [1] + }] + + def build_graph(parameters): + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + outs = tf.reverse_sequence( + input_value, + seq_lengths=parameters["seq_lengths"], + batch_axis=parameters["batch_axis"], + seq_axis=parameters["seq_axis"]) + return [input_value], [outs] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +@test_util.enable_control_flow_v2 +def make_unidirectional_sequence_lstm_tests(zip_path): + """Make a set of tests to do unidirectional_sequence_lstm.""" + + test_parameters = [{ + "batch_size": [2, 4, 6], + "seq_length": [1, 3], + "units": [4, 5], + "use_peepholes": [False, True], + "is_dynamic_rnn": [False, True] + }] + + def build_graph(parameters): + input_values = [] + if parameters["is_dynamic_rnn"]: + shape = [ + parameters["seq_length"], parameters["batch_size"], + parameters["units"] + ] + input_value = tf.placeholder(dtype=tf.float32, name="input", shape=shape) + input_values.append(input_value) + lstm_cell = tf.lite.experimental.nn.TFLiteLSTMCell( + parameters["units"], + use_peepholes=parameters["use_peepholes"]) + outs, _ = tf.lite.experimental.nn.dynamic_rnn( + lstm_cell, input_value, dtype=tf.float32, time_major=True) + outs = tf.unstack(outs, axis=1) + else: + shape = [parameters["batch_size"], parameters["units"]] + for i in range(parameters["seq_length"]): + input_value = tf.placeholder( + dtype=tf.float32, name=("input_%d" % i), shape=shape) + input_values.append(input_value) + lstm_cell = tf.lite.experimental.nn.TFLiteLSTMCell( + parameters["units"], use_peepholes=parameters["use_peepholes"]) + outs, _ = tf.nn.static_rnn(lstm_cell, input_values, dtype=tf.float32) + + real_output = tf.zeros([1], dtype=tf.float32) + outs[-1] + return input_values, [real_output] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = [] + if parameters["is_dynamic_rnn"]: + shape = [ + parameters["seq_length"], parameters["batch_size"], + parameters["units"] + ] + input_value = create_tensor_data(tf.float32, shape) + input_values.append(input_value) + else: + shape = [parameters["batch_size"], parameters["units"]] + for i in range(parameters["seq_length"]): + input_value = create_tensor_data(tf.float32, shape) + input_values.append(input_value) + init = tf.global_variables_initializer() + sess.run(init) + # Tflite fused kernel takes input as [time, batch, input]. + # For static unidirectional sequence lstm, the input is an array sized of + # time, and pack the array together, however, for time = 1, the input is + # not packed. + tflite_input_values = input_values + if not parameters["is_dynamic_rnn"] and parameters["seq_length"] == 1: + tflite_input_values = [ + input_values[0].reshape((1, parameters["batch_size"], + parameters["units"])) + ] + return tflite_input_values, sess.run( + outputs, feed_dict=dict(zip(inputs, input_values))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + +@test_util.enable_control_flow_v2 +def make_unidirectional_sequence_rnn_tests(zip_path): + """Make a set of tests to do unidirectional_sequence_rnn.""" + + test_parameters = [{ + "batch_size": [2, 4, 6], + "seq_length": [1, 3], + "units": [4, 5], + "is_dynamic_rnn": [False, True] + }] + + def build_graph(parameters): + input_values = [] + if parameters["is_dynamic_rnn"]: + shape = [ + parameters["seq_length"], parameters["batch_size"], + parameters["units"] + ] + input_value = tf.placeholder(dtype=tf.float32, name="input", shape=shape) + input_values.append(input_value) + rnn_cell = tf.lite.experimental.nn.TfLiteRNNCell(parameters["units"]) + outs, _ = tf.lite.experimental.nn.dynamic_rnn( + rnn_cell, input_value, dtype=tf.float32, time_major=True) + outs = tf.unstack(outs, axis=1) + else: + shape = [parameters["batch_size"], parameters["units"]] + for i in range(parameters["seq_length"]): + input_value = tf.placeholder( + dtype=tf.float32, name=("input_%d" % i), shape=shape) + input_values.append(input_value) + rnn_cell = tf.lite.experimental.nn.TfLiteRNNCell(parameters["units"]) + outs, _ = tf.nn.static_rnn(rnn_cell, input_values, dtype=tf.float32) + + real_output = tf.zeros([1], dtype=tf.float32) + outs[-1] + return input_values, [real_output] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = [] + if parameters["is_dynamic_rnn"]: + shape = [ + parameters["seq_length"], parameters["batch_size"], + parameters["units"] + ] + input_value = create_tensor_data(tf.float32, shape) + input_values.append(input_value) + else: + shape = [parameters["batch_size"], parameters["units"]] + for i in range(parameters["seq_length"]): + input_value = create_tensor_data(tf.float32, shape) + input_values.append(input_value) + init = tf.global_variables_initializer() + sess.run(init) + # Tflite fused kernel takes input as [time, batch, input]. + # For static unidirectional sequence rnn, the input is an array sized of + # time, and pack the array together, however, for time = 1, the input is + # not packed. + tflite_input_values = input_values + if not parameters["is_dynamic_rnn"] and parameters["seq_length"] == 1: + tflite_input_values = [ + input_values[0].reshape((1, parameters["batch_size"], + parameters["units"])) + ] + return tflite_input_values, sess.run( + outputs, feed_dict=dict(zip(inputs, input_values))) + make_zip_of_tests( zip_path, test_parameters, build_graph, build_inputs, - expected_tf_success=9) + use_frozen_graph=True) + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc index 45bd59a67d10baf61ad981f2fef29e948c2e77d2..fb98cc9b1725f8295bb060ae60ceb151569616e6 100644 --- a/tensorflow/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/lite/testing/generated_examples_zip_test.cc @@ -102,9 +102,6 @@ std::map kBrokenTests = { {R"(^\/add.*dtype=tf\.int64)", "119126484"}, {R"(^\/floor_div.*dtype=tf\.int64)", "119126484"}, {R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"}, - - // Strided Slice chooses the wrong dimension. - {R"(^\/strided_slice_buggy)", "119786029"}, }; // Additional list of tests that are expected to fail when diff --git a/tensorflow/lite/testing/join.h b/tensorflow/lite/testing/join.h index d1c314608687f045b346cc5526ea46c8149c2755..d10d2909b5ec4a269fd1a67d7a22f4c1e76f707e 100644 --- a/tensorflow/lite/testing/join.h +++ b/tensorflow/lite/testing/join.h @@ -24,7 +24,21 @@ limitations under the License. namespace tflite { namespace testing { -// Join a list of data separated by delimiter. +// Join a list of data with default precision separated by delimiter. +template +string JoinDefault(T* data, size_t len, const string& delimiter) { + if (len == 0 || data == nullptr) { + return ""; + } + std::stringstream result; + result << data[0]; + for (int i = 1; i < len; i++) { + result << delimiter << data[i]; + } + return result.str(); +} + +// Join a list of data with fixed precision separated by delimiter. template string Join(T* data, size_t len, const string& delimiter) { if (len == 0 || data == nullptr) { diff --git a/tensorflow/lite/testing/join_test.cc b/tensorflow/lite/testing/join_test.cc index 0b3c07f37e14e3815ac1eb4acd0aefac3515064c..476a7f20591691ccddff6829c894c640608f6471 100644 --- a/tensorflow/lite/testing/join_test.cc +++ b/tensorflow/lite/testing/join_test.cc @@ -26,6 +26,11 @@ TEST(JoinTest, JoinInt) { EXPECT_EQ(Join(data.data(), data.size(), ","), "1,2,3"); } +TEST(JoinDefaultTest, JoinFloat) { + float data[] = {1.0, -3, 2.3, 1e-5}; + EXPECT_EQ(JoinDefault(data, 4, " "), "1 -3 2.3 1e-05"); +} + TEST(JoinTest, JoinFloat) { float data[] = {1.0, -3, 2.3, 1e-5}; EXPECT_EQ(Join(data, 4, " "), "1 -3 2.29999995 9.99999975e-06"); diff --git a/tensorflow/lite/testing/kernel_test/BUILD b/tensorflow/lite/testing/kernel_test/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c46e80cc360043158928544a54c0221a7b405ad0 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/BUILD @@ -0,0 +1,124 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_binary", + "tf_cc_test", +) + +cc_library( + name = "util", + hdrs = ["util.h"], + deps = [ + ":input_generator", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/testing:split", + "//tensorflow/lite/testing:tflite_driver", + ] + select({ + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + }), +) + +tf_cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/test_input.csv", + ], + tags = [ + "no_oss", + ], + deps = [ + ":util", + "//tensorflow/lite/testing:tflite_driver", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_binary( + name = "tflite_kernel_runner", + srcs = ["tflite_kernel_runner.cc"], + deps = [ + ":util", + ], +) + +tf_cc_binary( + name = "generate_diff_report", + srcs = ["generate_diff_report.cc"], + deps = [ + ":diff_analyzer", + "//tensorflow/core:framework_internal", + ], +) + +cc_library( + name = "input_generator", + srcs = ["input_generator.cc"], + hdrs = ["input_generator.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite:string", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/testing:join", + "//tensorflow/lite/testing:split", + ], +) + +tf_cc_test( + name = "input_generator_test", + size = "small", + srcs = ["input_generator_test.cc"], + data = [ + "//tensorflow/lite:testdata/multi_add.bin", + "//tensorflow/lite:testdata/test_input.csv", + ], + tags = [ + "no_oss", + ], + deps = [ + ":input_generator", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "diff_analyzer", + srcs = ["diff_analyzer.cc"], + hdrs = ["diff_analyzer.h"], + deps = [ + "//tensorflow/lite:string", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/testing:split", + ], +) + +tf_cc_test( + name = "diff_analyzer_test", + size = "small", + srcs = ["diff_analyzer_test.cc"], + data = [ + "//tensorflow/lite:testdata/test_input.csv", + ], + tags = [ + "no_oss", + ], + deps = [ + ":diff_analyzer", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/testing/kernel_test/diff_analyzer.cc b/tensorflow/lite/testing/kernel_test/diff_analyzer.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d6fcc80be17b4020f53dddb8215a083031fd501 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/diff_analyzer.cc @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/testing/kernel_test/diff_analyzer.h" + +#include +#include +#include "tensorflow/lite/testing/split.h" + +namespace tflite { +namespace testing { + +namespace { +float CalculateNormalizedMaxDiff(const std::vector& base, + const std::vector& test) { + float diff = 0; + // For numerical stability in case the tensor is all 0. + float base_max = 1e-6; + + for (int i = 0; i < base.size(); i++) { + diff = std::max(diff, std::abs(base[i] - test[i])); + base_max = std::max(base_max, base[i]); + } + + return diff / base_max; +} + +float CalculateNormalizedL2Norm(const std::vector& base, + const std::vector& test) { + float l2_error = 0; + // For numerical stability in case the tensor is all 0. + float base_max = 1e-6; + + for (int i = 0; i < base.size(); i++) { + float diff = base[i] - test[i]; + l2_error += diff * diff; + base_max = std::max(base_max, base[i]); + } + + l2_error /= base.size(); + + return std::sqrt(l2_error) / base_max; +} + +TfLiteStatus Populate(const string& filename, + std::vector>* tensors) { + if (filename.empty()) { + fprintf(stderr, "Empty input file name."); + return kTfLiteError; + } + + std::ifstream file(filename); + string content; + while (std::getline(file, content, '\n')) { + tensors->push_back(Split(content, ",")); + } + + file.close(); + return kTfLiteOk; +} +} // namespace + +TfLiteStatus DiffAnalyzer::ReadFiles(const string& base, const string& test) { + TF_LITE_ENSURE_STATUS(Populate(base, &base_tensors_)); + TF_LITE_ENSURE_STATUS(Populate(test, &test_tensors_)); + + if (base_tensors_.size() != test_tensors_.size()) { + fprintf(stderr, "Golden and test tensor dimensions don't match."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus DiffAnalyzer::WriteReport(const string& filename) { + if (filename.empty()) { + fprintf(stderr, "Empty output file name."); + return kTfLiteError; + } + + std::ofstream output_file; + output_file.open(filename, std::fstream::out | std::fstream::trunc); + if (!output_file) { + fprintf(stderr, "Failed to open output file %s.", filename.c_str()); + return kTfLiteError; + } + + output_file << "Normalized L2 Error" + << "," + << "Normalized Max Diff" + << "\n"; + for (int i = 0; i < base_tensors_.size(); i++) { + float l2_error = + CalculateNormalizedL2Norm(base_tensors_[i], test_tensors_[i]); + float max_diff = + CalculateNormalizedMaxDiff(base_tensors_[i], test_tensors_[i]); + output_file << l2_error << "," << max_diff << "\n"; + } + + output_file.close(); + return kTfLiteOk; +} +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/kernel_test/diff_analyzer.h b/tensorflow/lite/testing/kernel_test/diff_analyzer.h new file mode 100644 index 0000000000000000000000000000000000000000..aecbaea449bda3edd1e5176b9a91b4542afc64f3 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/diff_analyzer.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TESTING_KERNEL_TEST_DIFF_ANALYZER_H_ +#define TENSORFLOW_LITE_TESTING_KERNEL_TEST_DIFF_ANALYZER_H_ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/string.h" + +namespace tflite { +namespace testing { + +// Reads the baseline and test files with output tensor values, and calculates +// the diff metrics. +class DiffAnalyzer { + public: + DiffAnalyzer() = default; + TfLiteStatus ReadFiles(const string& base, const string& test); + TfLiteStatus WriteReport(const string& filename); + + private: + std::vector> base_tensors_; + std::vector> test_tensors_; +}; + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_TESTING_KERNEL_TEST_DIFF_ANALYZER_H_ diff --git a/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc b/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..da54b38acf82a9647c545e8577b09188ea54934f --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/testing/kernel_test/diff_analyzer.h" + +#include + +#include +#include +#include "tensorflow/core/lib/io/path.h" + +namespace tflite { +namespace testing { + +namespace { + +TEST(DiffAnalyzerTest, ZeroDiff) { + DiffAnalyzer diff_analyzer; + string filename = "tensorflow/lite/testdata/test_input.csv"; + ASSERT_EQ(diff_analyzer.ReadFiles(filename, filename), kTfLiteOk); + + string output_file = + tensorflow::io::JoinPath(FLAGS_test_tmpdir + "diff_report.csv"); + ASSERT_EQ(diff_analyzer.WriteReport(output_file), kTfLiteOk); + + std::string content; + std::ifstream file(output_file); + std::getline(file, content); + std::getline(file, content); + ASSERT_EQ(content, "0,0"); +} + +} // namespace + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/kernel_test/generate_diff_report.cc b/tensorflow/lite/testing/kernel_test/generate_diff_report.cc new file mode 100644 index 0000000000000000000000000000000000000000..afa6a9a94ec2ffd824d66a363c53b69455706d06 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/generate_diff_report.cc @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/lite/testing/kernel_test/diff_analyzer.h" + +int main(int argc, char** argv) { + string base, test, output; + std::vector flag_list = { + tensorflow::Flag("base", &base, "Path to the base serialized tensor."), + tensorflow::Flag("test", &test, "Path to the test serialized tensor."), + tensorflow::Flag("output", &output, "Path to the output file."), + }; + tensorflow::Flags::Parse(&argc, argv, flag_list); + + tflite::testing::DiffAnalyzer diff_analyzer; + diff_analyzer.ReadFiles(base, test); + diff_analyzer.WriteReport(output); + return 0; +} diff --git a/tensorflow/lite/testing/kernel_test/input_generator.cc b/tensorflow/lite/testing/kernel_test/input_generator.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c69bdff86bbacbbb343e48b8c45d20811463620 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/input_generator.cc @@ -0,0 +1,208 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/testing/kernel_test/input_generator.h" + +#include +#include +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/testing/join.h" +#include "tensorflow/lite/testing/split.h" + +namespace tflite { +namespace testing { + +namespace { + +template +std::vector GenerateRandomTensor(TfLiteIntArray* dims, + const std::function& random_func) { + int64_t num_elements = 1; + for (int i = 0; i < dims->size; i++) { + num_elements *= dims->data[i]; + } + + std::vector result(num_elements); + for (int i = 0; i < num_elements; i++) { + result[i] = random_func(i); + } + return result; +} + +template +std::vector GenerateUniform(TfLiteIntArray* dims, float min, float max) { + auto random_float = [](float min, float max) { + // TODO(yunluli): Change seed for each invocation if needed. + // Used rand() instead of rand_r() here to make it runnable on android. + return min + (max - min) * static_cast(rand()) / RAND_MAX; + }; + + std::function random_t = [&](int) { + return static_cast(random_float(min, max)); + }; + std::vector data = GenerateRandomTensor(dims, random_t); + return data; +} + +template +std::vector GenerateGaussian(TfLiteIntArray* dims, float min, float max) { + auto random_float = [](float min, float max) { + static std::default_random_engine generator; + // We generate a float number within [0, 1) following a mormal distribution + // with mean = 0.5 and stddev = 1/3, and use it to scale the final random + // number into the desired range. + static std::normal_distribution distribution(0.5, 1.0 / 3); + auto rand_n = distribution(generator); + while (rand_n < 0 || rand_n >= 1) { + rand_n = distribution(generator); + } + + return min + (max - min) * static_cast(rand_n); + }; + + std::function random_t = [&](int) { + return static_cast(random_float(min, max)); + }; + std::vector data = GenerateRandomTensor(dims, random_t); + return data; +} + +} // namespace + +TfLiteStatus InputGenerator::LoadModel(const string& model_dir) { + model_ = FlatBufferModel::BuildFromFile(model_dir.c_str()); + if (!model_) { + fprintf(stderr, "Cannot load model %s", model_dir.c_str()); + return kTfLiteError; + } + + ::tflite::ops::builtin::BuiltinOpResolver builtin_ops; + InterpreterBuilder(*model_, builtin_ops)(&interpreter_); + if (!interpreter_) { + fprintf(stderr, "Failed to build interpreter."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus InputGenerator::ReadInputsFromFile(const string& filename) { + if (filename.empty()) { + fprintf(stderr, "Empty input file name."); + return kTfLiteError; + } + + std::ifstream input_file(filename); + string input; + while (std::getline(input_file, input, '\n')) { + inputs_.push_back(input); + } + input_file.close(); + return kTfLiteOk; +} + +TfLiteStatus InputGenerator::WriteInputsToFile(const string& filename) { + if (filename.empty()) { + fprintf(stderr, "Empty input file name."); + return kTfLiteError; + } + + std::ofstream output_file; + output_file.open(filename, std::fstream::out | std::fstream::trunc); + if (!output_file) { + fprintf(stderr, "Failed to open output file %s.", filename.c_str()); + return kTfLiteError; + } + + for (const auto& input : inputs_) { + output_file << input << "\n"; + } + output_file.close(); + + return kTfLiteOk; +} + +// TODO(yunluli): Support more tensor types when needed. +TfLiteStatus InputGenerator::GenerateInput(const string& distribution) { + auto input_tensor_ids = interpreter_->inputs(); + for (auto id : input_tensor_ids) { + auto* tensor = interpreter_->tensor(id); + if (distribution == "UNIFORM") { + switch (tensor->type) { + case kTfLiteInt8: { + auto data = GenerateUniform( + tensor->dims, std::numeric_limits::min(), + std::numeric_limits::max()); + inputs_.push_back(Join(data.data(), data.size(), ",")); + break; + } + case kTfLiteUInt8: { + auto data = GenerateUniform( + tensor->dims, std::numeric_limits::min(), + std::numeric_limits::max()); + inputs_.push_back(Join(data.data(), data.size(), ",")); + break; + } + case kTfLiteFloat32: { + auto data = GenerateUniform(tensor->dims, -1, 1); + inputs_.push_back(JoinDefault(data.data(), data.size(), ",")); + break; + } + default: + fprintf(stderr, "Unsupported input tensor type %s.", + TfLiteTypeGetName(tensor->type)); + break; + } + } else if (distribution == "GAUSSIAN") { + switch (tensor->type) { + case kTfLiteInt8: { + auto data = GenerateGaussian( + tensor->dims, std::numeric_limits::min(), + std::numeric_limits::max()); + inputs_.push_back(Join(data.data(), data.size(), ",")); + break; + } + case kTfLiteUInt8: { + auto data = GenerateGaussian( + tensor->dims, std::numeric_limits::min(), + std::numeric_limits::max()); + inputs_.push_back(Join(data.data(), data.size(), ",")); + break; + } + case kTfLiteFloat32: { + auto data = GenerateGaussian(tensor->dims, -1, 1); + inputs_.push_back(JoinDefault(data.data(), data.size(), ",")); + break; + } + default: + fprintf(stderr, "Unsupported input tensor type %s.", + TfLiteTypeGetName(tensor->type)); + break; + } + } else { + fprintf(stderr, "Unsupported distribution %s.", distribution.c_str()); + return kTfLiteError; + } + } + + return kTfLiteOk; +} + +std::vector InputGenerator::GetInputs() { return inputs_; } + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/kernel_test/input_generator.h b/tensorflow/lite/testing/kernel_test/input_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..859c7068e5448c837580fe79e89918fbd34c2a66 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/input_generator.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TESTING_KERNEL_TEST_INPUT_GENERATOR_H_ +#define TENSORFLOW_LITE_TESTING_KERNEL_TEST_INPUT_GENERATOR_H_ + +#include +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string.h" + +namespace tflite { +namespace testing { + +// Generate random input, or read input from a file for kernel diff test. +// Needs to load the tflite graph to get information like tensor shape and +// data type. +class InputGenerator { + public: + InputGenerator() = default; + TfLiteStatus LoadModel(const string& model_dir); + TfLiteStatus ReadInputsFromFile(const string& filename); + TfLiteStatus GenerateInput(const string& distribution); + std::vector GetInputs(); + TfLiteStatus WriteInputsToFile(const string& filename); + + private: + std::unique_ptr model_; + std::unique_ptr interpreter_; + std::vector inputs_; +}; + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_TESTING_KERNEL_TEST_INPUT_GENERATOR_H_ diff --git a/tensorflow/lite/testing/kernel_test/input_generator_test.cc b/tensorflow/lite/testing/kernel_test/input_generator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2223a9196ff7d6b790e2e7a0170d42c5d5468a5f --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/input_generator_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/testing/kernel_test/input_generator.h" + +#include +#include + +#include +#include "testing/base/public/googletest.h" +#include + +namespace tflite { +namespace testing { + +namespace { + +TEST(InputGeneratorTest, LoadModel) { + InputGenerator input_generator; + ASSERT_EQ(input_generator.LoadModel( + "tensorflow/lite/testdata/multi_add.bin"), + kTfLiteOk); +} + +TEST(InputGeneratorTest, ReadWriteSimpleFile) { + InputGenerator input_generator; + ASSERT_EQ(input_generator.ReadInputsFromFile( + "tensorflow/lite/testdata/test_input.csv"), + kTfLiteOk); + + std::vector inputs; + std::string content = "1"; + for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) { + content.append(",1"); + } + inputs.push_back(content); + ASSERT_EQ(input_generator.GetInputs(), inputs); + + auto output_filename = FLAGS_test_tmpdir + "/out.csv"; + ASSERT_EQ(input_generator.WriteInputsToFile(output_filename), kTfLiteOk); + + std::ifstream in(output_filename); + std::string out; + std::getline(in, out, '\n'); + ASSERT_EQ(out, content); +} + +TEST(InputGeneratorTest, GenerateUniformInput) { + InputGenerator input_generator; + ASSERT_EQ(input_generator.LoadModel( + "tensorflow/lite/testdata/multi_add.bin"), + kTfLiteOk); + input_generator.GenerateInput("UNIFORM"); + auto inputs = input_generator.GetInputs(); + ASSERT_EQ(inputs.size(), 4); +} + +TEST(InputGeneratorTest, GenerateGaussianInput) { + InputGenerator input_generator; + ASSERT_EQ(input_generator.LoadModel( + "tensorflow/lite/testdata/multi_add.bin"), + kTfLiteOk); + input_generator.GenerateInput("GAUSSIAN"); + auto inputs = input_generator.GetInputs(); + ASSERT_EQ(inputs.size(), 4); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc b/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..34c1728ed1da6ec962989479dccfdc64bc8ca6cd --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/tflite_kernel_runner.cc @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/testing/kernel_test/util.h" + +int main(int argc, char** argv) { + tflite::testing::kernel_test::TestOptions options = + tflite::testing::kernel_test::ParseTfliteKernelTestFlags(&argc, argv); + const bool run_reference_kernel = options.kernel_type == "REFERENCE"; + const bool use_nnapi = options.kernel_type == "NNAPI"; + + auto runner = absl::make_unique( + use_nnapi, "", run_reference_kernel); + if (tflite::testing::kernel_test::RunKernelTest(options, runner.get()) == + kTfLiteOk) { + return 0; + } + + return -1; +} diff --git a/tensorflow/lite/testing/kernel_test/util.h b/tensorflow/lite/testing/kernel_test/util.h new file mode 100644 index 0000000000000000000000000000000000000000..d940e5ad12f497ec827ce0dc6be9e6311078b1a9 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/util.h @@ -0,0 +1,122 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TESTING_KERNEL_TEST_UTIL_H_ +#define TENSORFLOW_LITE_TESTING_KERNEL_TEST_UTIL_H_ + +#include + +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/testing/kernel_test/input_generator.h" +#include "tensorflow/lite/testing/split.h" +#include "tensorflow/lite/testing/tflite_driver.h" + +namespace tflite { +namespace testing { +namespace kernel_test { + +struct TestOptions { + // Path of tensorflow lite model. + string tflite_model; + // Path of the input file. If empty, generate at runtime. + string read_input_from_file; + // Path to dump the input file. + string dump_input_to_file; + // Path to dump the output. + string dump_output_to_file; + // Input distribution. + string input_distribution; + // Kernel type. + string kernel_type; +}; + +TestOptions ParseTfliteKernelTestFlags(int* argc, char** argv) { + TestOptions options; + std::vector flags = { + tensorflow::Flag("tflite_model", &options.tflite_model, + "Path of tensorflow lite model."), + tensorflow::Flag("read_input_from_file", &options.read_input_from_file, + "File to read input data from. If empty, generates " + "input at runtime."), + tensorflow::Flag("dump_input_to_file", &options.dump_input_to_file, + "File to dump randomly generated input."), + tensorflow::Flag("dump_output_to_file", &options.dump_output_to_file, + "File to dump output."), + tensorflow::Flag("input_distribution", &options.input_distribution, + "Input distribution. Default: Gaussian."), + tensorflow::Flag("kernel_type", &options.kernel_type, "Kernel type."), + }; + + tensorflow::Flags::Parse(argc, argv, flags); + + return options; +} + +TfLiteStatus RunKernelTest(const kernel_test::TestOptions& options, + TestRunner* runner) { + InputGenerator input_generator; + + if (options.read_input_from_file.empty()) { + TF_LITE_ENSURE_STATUS(input_generator.LoadModel(options.tflite_model)); + TF_LITE_ENSURE_STATUS( + input_generator.GenerateInput(options.input_distribution)); + } else { + TF_LITE_ENSURE_STATUS( + input_generator.ReadInputsFromFile(options.read_input_from_file)); + } + + runner->LoadModel(options.tflite_model); + runner->AllocateTensors(); + if (!runner->IsValid()) return kTfLiteError; + auto input_tensor_ids = runner->GetInputs(); + auto inputs = input_generator.GetInputs(); + if (inputs.size() != input_tensor_ids.size()) { + fprintf(stderr, + "Number of input tensors generated doesn't match what the model " + "asks for."); + } + for (int i = 0; i < inputs.size(); i++) { + runner->SetInput(input_tensor_ids[i], inputs[i]); + } + + runner->Invoke(); + + if (!options.dump_input_to_file.empty()) { + TF_LITE_ENSURE_STATUS( + input_generator.WriteInputsToFile(options.dump_input_to_file)); + } + + if (!options.dump_output_to_file.empty()) { + std::ofstream output_file; + output_file.open(options.dump_output_to_file, + std::fstream::out | std::fstream::trunc); + if (!output_file) { + return kTfLiteError; + } + + for (auto id : runner->GetOutputs()) { + output_file << runner->ReadOutput(id) << "\n"; + } + output_file.close(); + } + + return kTfLiteOk; +} + +} // namespace kernel_test +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_TESTING_KERNEL_TEST_UTIL_H_ diff --git a/tensorflow/lite/testing/kernel_test/util_test.cc b/tensorflow/lite/testing/kernel_test/util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cbec66092837726e90fb6d37135ba990c7cbdb86 --- /dev/null +++ b/tensorflow/lite/testing/kernel_test/util_test.cc @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/testing/kernel_test/util.h" + +#include +#include + +#include +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/lite/testing/tflite_driver.h" + +namespace tflite { +namespace testing { +namespace kernel_test { +namespace { + +TEST(UtilTest, SimpleE2ETest) { + TestOptions options; + options.tflite_model = "tensorflow/lite/testdata/add.bin"; + options.read_input_from_file = + "tensorflow/lite/testdata/test_input.csv"; + options.dump_output_to_file = FLAGS_test_tmpdir + "/test_out.csv"; + options.kernel_type = "REFERENCE"; + std::unique_ptr runner(new TfLiteDriver(false, "", true)); + RunKernelTest(options, runner.get()); + std::string expected = "3"; + for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) { + expected.append(",3"); + } + std::string content; + std::ifstream file(options.dump_output_to_file); + std::getline(file, content); + EXPECT_EQ(content, expected); +} + +} // namespace +} // namespace kernel_test +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index 804e328d9da248859e806bd070de26a8f5aa37b4..95bda42cd71abe4cecb41f34859cf6990e851ed7 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.core.framework import graph_pb2 as _graph_pb2 @@ -25,13 +26,50 @@ from tensorflow.lite.python import convert_saved_model as _convert_saved_model from tensorflow.lite.python import lite as _lite from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session +from tensorflow.python.framework import constant_op from tensorflow.python.framework.importer import import_graph_def as _import_graph_def +from tensorflow.python.keras.preprocessing import image from tensorflow.python.lib.io import file_io as _file_io +from tensorflow.python.platform import resource_loader as _resource_loader +from tensorflow.python.saved_model import load as _load from tensorflow.python.saved_model import loader as _loader from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants +def get_filepath(filename, base_dir=None): + """Returns the full path of the filename. + + Args: + filename: Subdirectory and name of the model file. + base_dir: Base directory containing model file. + + Returns: + str. + """ + if base_dir is None: + base_dir = "learning/brain/mobile/tflite_compat_models" + return os.path.join(_resource_loader.get_root_dir_with_all_resources(), + base_dir, filename) + + +def get_image(size): + """Returns an image loaded into an np.ndarray with dims [1, size, size, 3]. + + Args: + size: Size of image. + + Returns: + np.ndarray. + """ + img_filename = _resource_loader.get_path_to_datafile( + "testdata/grace_hopper.jpg") + img = image.load_img(img_filename, target_size=(size, size)) + img_array = image.img_to_array(img) + img_array = np.expand_dims(img_array, axis=0) + return img_array + + def _convert(converter, **kwargs): """Converts the model. @@ -186,6 +224,32 @@ def compare_models(tflite_model, tf_eval_func, input_data=None, tolerance=5): np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) +def compare_models_v2(tflite_model, concrete_func, input_data=None, + tolerance=5): + """Compares TensorFlow and TFLite models for TensorFlow 2.0. + + Unless the input data is provided, the models are compared with random data. + Currently only 1 input and 1 output are supported by this function. + + Args: + tflite_model: Serialized TensorFlow Lite model. + concrete_func: TensorFlow ConcreteFunction. + input_data: np.ndarray to pass into models during inference. (default None) + tolerance: Decimal place to check accuracy to. (default 5) + """ + if input_data is None: + input_data = _generate_random_input_data(tflite_model) + input_data_func = constant_op.constant(input_data[0]) + + # Gets the TensorFlow results as a map from the output names to outputs. + # Converts the map into a list that is equivalent to the TFLite list. + tf_results_map = concrete_func(input_data_func) + tf_results = [tf_results_map[tf_results_map.keys()[0]]] + tflite_results = _evaluate_tflite_model(tflite_model, input_data) + for tf_result, tflite_result in zip(tf_results, tflite_results): + np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) + + def test_frozen_graph_quant(filename, input_arrays, output_arrays, @@ -315,6 +379,39 @@ def test_saved_model(directory, compare_models(tflite_model, tf_eval_func, input_data=input_data) +# TODO(nupurgarg): Remove input_shape parameter after bug with shapes is fixed. +def test_saved_model_v2(directory, + input_shape=None, + tag_set=None, + signature_key=None, + input_data=None, + **kwargs): + """Validates the TensorFlow SavedModel converts to a TFLite model. + + Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the + model on random data. + + Args: + directory: SavedModel directory to convert. + input_shape: Input shape for the single input array as a list of integers. + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. + signature_key: Key identifying SignatureDef containing inputs and outputs. + input_data: np.ndarray to pass into models during inference. (default None) + **kwargs: Additional arguments to be passed into the converter. + """ + model = _load.load(directory, tags=tag_set) + if not signature_key: + signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + concrete_func = model.signatures[signature_key] + concrete_func.inputs[0].set_shape(input_shape) + + converter = _lite.TFLiteConverterV2.from_concrete_function(concrete_func) + tflite_model = _convert(converter, **kwargs) + + compare_models_v2(tflite_model, concrete_func, input_data=input_data) + + def test_keras_model(filename, input_arrays=None, input_shapes=None, diff --git a/tensorflow/lite/testing/nnapi_example.cc b/tensorflow/lite/testing/nnapi_example.cc index 0ff1fbcd411d86738faaddfb740ddf1fc500590a..309cb19628cd54a39ea926a6f3506cf570ff3679 100644 --- a/tensorflow/lite/testing/nnapi_example.cc +++ b/tensorflow/lite/testing/nnapi_example.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include #include -#include "tensorflow/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" #include "tensorflow/lite/testing/parse_testdata.h" #include "tensorflow/lite/testing/tflite_driver.h" diff --git a/tensorflow/lite/testing/string_util.cc b/tensorflow/lite/testing/string_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf9d5087644cc52415a83dd80b457249b85765b5 --- /dev/null +++ b/tensorflow/lite/testing/string_util.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/testing/string_util.h" + +#include "absl/strings/escaping.h" +#include "tensorflow/lite/python/interpreter_wrapper/numpy.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace testing { +namespace python { + +PyObject* SerializeAsHexString(PyObject* value) { + DynamicBuffer dynamic_buffer; + if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) { + return nullptr; + } + + char* char_buffer = nullptr; + size_t size = dynamic_buffer.WriteToBuffer(&char_buffer); + string s = absl::BytesToHexString({char_buffer, size}); + free(char_buffer); + + return python_utils::ConvertToPyString(s.data(), s.size()); +} + +} // namespace python +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/string_util.h b/tensorflow/lite/testing/string_util.h new file mode 100644 index 0000000000000000000000000000000000000000..56c024d918df37641c12851a2a02187d12e03b7d --- /dev/null +++ b/tensorflow/lite/testing/string_util.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TESTING_STRING_UTIL_H_ +#define TENSORFLOW_LITE_TESTING_STRING_UTIL_H_ + +#include +#include + +namespace tflite { +namespace testing { +namespace python { + +// Take a python string array, convert it to TF Lite dynamic buffer format and +// serialize it as a HexString. +PyObject* SerializeAsHexString(PyObject* value); + +} // namespace python +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_TESTING_STRING_UTIL_H_ diff --git a/tensorflow/lite/testing/string_util.i b/tensorflow/lite/testing/string_util.i new file mode 100644 index 0000000000000000000000000000000000000000..574abb79653ff858721e28d0d33225e3e24cbbfd --- /dev/null +++ b/tensorflow/lite/testing/string_util.i @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%{ + +#define SWIG_FILE_WITH_INIT +#include "tensorflow/lite/testing/string_util.h" + +%} + +namespace tflite { +namespace testing { +namespace python { + +PyObject* SerializeAsHexString(PyObject* string_tensor); + +} // namespace python +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index a637dc86c020d4e16fb4fc02e9f62e8dec6a3a25..55670858338bda0bfe04828c33da6c64982a6656 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register_ref.h" #include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/testing/join.h" #include "tensorflow/lite/testing/split.h" namespace tflite { @@ -383,5 +384,34 @@ void TfLiteDriver::ResetLSTMStateTensors() { interpreter_->ResetVariableTensors(); } +string TfLiteDriver::ReadOutput(int id) { + auto* tensor = interpreter_->tensor(id); + int num_elements = 1; + + for (int i = 0; i < tensor->dims->size; ++i) { + num_elements *= tensor->dims->data[i]; + } + + switch (tensor->type) { + case kTfLiteFloat32: + return JoinDefault(tensor->data.f, num_elements, ","); + case kTfLiteInt32: + return JoinDefault(tensor->data.i32, num_elements, ","); + case kTfLiteInt64: + return JoinDefault(tensor->data.i64, num_elements, ","); + case kTfLiteUInt8: + return Join(tensor->data.uint8, num_elements, ","); + case kTfLiteInt8: + return JoinDefault(tensor->data.int8, num_elements, ","); + case kTfLiteBool: + return JoinDefault(tensor->data.b, num_elements, ","); + default: + Invalidate(absl::StrCat("Unsupported tensor type ", + TfLiteTypeGetName(tensor->type), + " in TfLiteDriver::ReadOutput")); + return ""; + } +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h index 537f20dfbfd6c6fe0fbefd854358146129d33b7a..3cce6c4222ec36f5eac2f144062b5b850c326345 100644 --- a/tensorflow/lite/testing/tflite_driver.h +++ b/tensorflow/lite/testing/tflite_driver.h @@ -49,7 +49,7 @@ class TfLiteDriver : public TestRunner { void SetExpectation(int id, const string& csv_values) override; void Invoke() override; bool CheckResults() override; - string ReadOutput(int id) override { return "no-op"; } + string ReadOutput(int id) override; private: void DeallocateStringTensor(TfLiteTensor* t) { diff --git a/tensorflow/lite/testing/tflite_driver_test.cc b/tensorflow/lite/testing/tflite_driver_test.cc index 81bf6700cb898796a72bea38ea0711556a7215a5..e80816bdf5ecd21d4f147e824188dd3a206d68dd 100644 --- a/tensorflow/lite/testing/tflite_driver_test.cc +++ b/tensorflow/lite/testing/tflite_driver_test.cc @@ -54,6 +54,8 @@ TEST(TfliteDriverTest, SimpleTest) { ASSERT_TRUE(runner->IsValid()); ASSERT_TRUE(runner->CheckResults()); + EXPECT_EQ(runner->ReadOutput(5), "0.101,0.202,0.303,0.404"); + EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044"); } TEST(TfliteDriverTest, SingleAddOpTest) { @@ -88,6 +90,8 @@ TEST(TfliteDriverTest, SingleAddOpTest) { ASSERT_TRUE(runner->IsValid()); ASSERT_TRUE(runner->CheckResults()); + EXPECT_EQ(runner->ReadOutput(5), "0.101,0.202,0.303,0.404"); + EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044"); } } // namespace diff --git a/tensorflow/lite/tflite_exported_symbols.lds b/tensorflow/lite/tflite_exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..b145204aa1e2b039aa7075047b1fd9ca73157320 --- /dev/null +++ b/tensorflow/lite/tflite_exported_symbols.lds @@ -0,0 +1,3 @@ +*TfLite* +*tflite* +*TFL_* diff --git a/tensorflow/lite/tflite_static.bp b/tensorflow/lite/tflite_static.bp index e9a8378b810b411cd0d0daaba437a4337131393b..3aca8f9c68c8c3ccde454b77fa974c72bf41638a 100644 --- a/tensorflow/lite/tflite_static.bp +++ b/tensorflow/lite/tflite_static.bp @@ -28,6 +28,8 @@ cc_library_static { "delegates/nnapi/nnapi_delegate.cc", "graph_info.cc", "interpreter.cc", + "minimal_logging.cc", + "minimal_logging_android.cc", "mmap_allocation.cc", "model.cc", "mutable_op_resolver.cc", @@ -42,12 +44,14 @@ cc_library_static { "kernels/topk_v2.cc", "kernels/activations.cc", "kernels/add.cc", + "kernels/add_n.cc", "kernels/arg_min_max.cc", "kernels/basic_rnn.cc", "kernels/batch_to_space_nd.cc", "kernels/bidirectional_sequence_lstm.cc", "kernels/bidirectional_sequence_rnn.cc", "kernels/cast.cc", + "kernels/ceil.cc", "kernels/comparisons.cc", "kernels/concatenation.cc", "kernels/conv.cc", @@ -67,11 +71,12 @@ cc_library_static { "kernels/floor_mod.cc", "kernels/fully_connected.cc", "kernels/gather.cc", + "kernels/gather_nd.cc", "kernels/gemm_support.cc", "kernels/hashtable_lookup.cc", + "kernels/if.cc", "kernels/kernel_util.cc", "kernels/l2norm.cc", - "kernels/layer_norm_lstm.cc", "kernels/local_response_norm.cc", "kernels/logical.cc", "kernels/lsh_projection.cc", @@ -87,12 +92,14 @@ cc_library_static { "kernels/pooling.cc", "kernels/pow.cc", "kernels/range.cc", + "kernels/rank.cc", "kernels/reduce.cc", - "kernels/relu1.cc", "kernels/register.cc", "kernels/reshape.cc", "kernels/resize_bilinear.cc", "kernels/resize_nearest_neighbor.cc", + "kernels/reverse.cc", + "kernels/reverse_sequence.cc", "kernels/select.cc", "kernels/shape.cc", "kernels/skip_gram.cc", @@ -113,12 +120,15 @@ cc_library_static { "kernels/unidirectional_sequence_rnn.cc", "kernels/unique.cc", "kernels/unpack.cc", + "kernels/where.cc", + "kernels/while.cc", "kernels/zeros_like.cc", "kernels/internal/kernel_utils.cc", "kernels/internal/tensor_utils.cc", "kernels/internal/quantization_util.cc", "kernels/internal/reference/portable_tensor_utils.cc", "kernels/internal/optimized/neon_tensor_utils.cc", + "nnapi/nnapi_implementation.cc", ], include_dirs: [ "external/eigen", @@ -143,9 +153,11 @@ cc_library_static { "-Wno-missing-field-initializers", "-Wno-sign-compare", "-Wno-typedef-redefinition", + "-Wno-unused-function", "-Wno-unused-lambda-capture", "-Wno-unused-local-typedef", "-Wno-unused-parameter", + "-Wno-unused-private-field", "-Wno-unused-variable", "-Wno-invalid-partial-specialization", "-Wno-mismatched-tags", diff --git a/tensorflow/lite/tflite_version_script.lds b/tensorflow/lite/tflite_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..1df70705ebf4a85d2f4c9f2301c53d48e623dff7 --- /dev/null +++ b/tensorflow/lite/tflite_version_script.lds @@ -0,0 +1,8 @@ +VERS_1.0 { + global: + *TfLite*; + *tflite*; + *TFL_*; + local: + *; +}; diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 40bceedd6a1e8398d25a4c58a3ee69228ae8d868..8481b0b754c370934860212b4d9d2e45899bb6cd 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -133,6 +133,7 @@ cc_library( cc_library( name = "model_cmdline_flags", srcs = [ + "args.cc", "model_cmdline_flags.cc", ], hdrs = [ @@ -192,6 +193,7 @@ cc_library( "graph_transformations/fuse_binary_into_preceding_affine.cc", "graph_transformations/fuse_broadcast_into_following_binary.cc", "graph_transformations/graph_transformations.cc", + "graph_transformations/group_bidirectional_sequence_ops.cc", "graph_transformations/hardcode_min_max.cc", "graph_transformations/identify_dilated_conv.cc", "graph_transformations/identify_l2_normalization.cc", @@ -308,7 +310,7 @@ cc_library( "toco_tooling.h", ], copts = tf_copts() + select({ - "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], + "//tensorflow:macos": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], "//conditions:default": [], }), visibility = ["//visibility:public"], @@ -333,7 +335,7 @@ cc_library( "//tensorflow/lite/toco/tflite:export", "//tensorflow/lite/toco/tflite:import", ] + select({ - # Placeholder for internal darwin rule. + # Placeholder for internal macOS rule. "//conditions:default": [], }), ) @@ -376,6 +378,7 @@ cc_library( ":types_proto_cc", "//tensorflow/core:lib", "//tensorflow/lite/kernels/internal:types", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", "@protobuf_archive//:protobuf_headers", @@ -477,3 +480,16 @@ tf_cc_test( "@com_google_googletest//:gtest", ], ) + +tf_cc_test( + name = "model_cmdline_flags_test", + srcs = [ + "model_cmdline_flags_test.cc", + ], + deps = [ + ":model_cmdline_flags", + ":model_flags_proto_cc", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/toco/args.cc b/tensorflow/lite/toco/args.cc new file mode 100644 index 0000000000000000000000000000000000000000..da8debc49a697fb77832c93940b60c0bebe1a7f9 --- /dev/null +++ b/tensorflow/lite/toco/args.cc @@ -0,0 +1,169 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/toco/args.h" +#include "absl/strings/str_split.h" + +namespace toco { +namespace { + +// Helper class for SplitStructuredLine parsing. +class ClosingSymbolLookup { + public: + explicit ClosingSymbolLookup(const char* symbol_pairs) + : closing_(), valid_closing_() { + // Initialize the opening/closing arrays. + for (const char* symbol = symbol_pairs; *symbol != 0; ++symbol) { + unsigned char opening = *symbol; + ++symbol; + // If the string ends before the closing character has been found, + // use the opening character as the closing character. + unsigned char closing = *symbol != 0 ? *symbol : opening; + closing_[opening] = closing; + valid_closing_[closing] = true; + if (*symbol == 0) break; + } + } + + ClosingSymbolLookup(const ClosingSymbolLookup&) = delete; + ClosingSymbolLookup& operator=(const ClosingSymbolLookup&) = delete; + + // Returns the closing character corresponding to an opening one, + // or 0 if the argument is not an opening character. + char GetClosingChar(char opening) const { + return closing_[static_cast(opening)]; + } + + // Returns true if the argument is a closing character. + bool IsClosing(char c) const { + return valid_closing_[static_cast(c)]; + } + + private: + // Maps an opening character to its closing. If the entry contains 0, + // the character is not in the opening set. + char closing_[256]; + // Valid closing characters. + bool valid_closing_[256]; +}; + +bool SplitStructuredLine(absl::string_view line, char delimiter, + const char* symbol_pairs, + std::vector* cols) { + ClosingSymbolLookup lookup(symbol_pairs); + + // Stack of symbols expected to close the current opened expressions. + std::vector expected_to_close; + + ABSL_RAW_CHECK(cols != nullptr, ""); + cols->push_back(line); + for (size_t i = 0; i < line.size(); ++i) { + char c = line[i]; + if (expected_to_close.empty() && c == delimiter) { + // We don't have any open expression, this is a valid separator. + cols->back().remove_suffix(line.size() - i); + cols->push_back(line.substr(i + 1)); + } else if (!expected_to_close.empty() && c == expected_to_close.back()) { + // Can we close the currently open expression? + expected_to_close.pop_back(); + } else if (lookup.GetClosingChar(c)) { + // If this is an opening symbol, we open a new expression and push + // the expected closing symbol on the stack. + expected_to_close.push_back(lookup.GetClosingChar(c)); + } else if (lookup.IsClosing(c)) { + // Error: mismatched closing symbol. + return false; + } + } + if (!expected_to_close.empty()) { + return false; // Missing closing symbol(s) + } + return true; // Success +} + +inline bool TryStripPrefixString(absl::string_view str, + absl::string_view prefix, string* result) { + bool res = absl::ConsumePrefix(&str, prefix); + result->assign(str.begin(), str.end()); + return res; +} + +inline bool TryStripSuffixString(absl::string_view str, + absl::string_view suffix, string* result) { + bool res = absl::ConsumeSuffix(&str, suffix); + result->assign(str.begin(), str.end()); + return res; +} + +} // namespace + +bool Arg::Parse(string text) { + parsed_value_.elements.clear(); + specified_ = true; + // strings::Split("") produces {""}, but we need {} on empty input. + // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could + // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements) + if (!text.empty()) { + int32 element; + for (absl::string_view part : absl::StrSplit(text, ',')) { + if (!SimpleAtoi(part, &element)) return false; + parsed_value_.elements.push_back(element); + } + } + return true; +} + +bool Arg::Parse(string text) { + parsed_value_.elements.clear(); + specified_ = true; + + if (text.empty()) { + return true; + } + + std::vector outer_vector; + absl::string_view text_disposable_copy = text; + // TODO(aselle): Change argument parsing when absl supports structuredline. + SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); + for (const absl::string_view& outer_member_stringpiece : outer_vector) { + string outer_member(outer_member_stringpiece); + if (outer_member.empty()) { + continue; + } + string outer_member_copy = outer_member; + absl::StripAsciiWhitespace(&outer_member); + if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; + if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; + const std::vector inner_fields_vector = + absl::StrSplit(outer_member, ','); + + std::unordered_map element; + for (const string& member_field : inner_fields_vector) { + std::vector outer_member_key_value = + absl::StrSplit(member_field, ':'); + if (outer_member_key_value.size() != 2) return false; + string& key = outer_member_key_value[0]; + string& value = outer_member_key_value[1]; + absl::StripAsciiWhitespace(&key); + absl::StripAsciiWhitespace(&value); + if (element.count(key) != 0) return false; + element[key] = value; + } + parsed_value_.elements.push_back(element); + } + return true; +} + +} // namespace toco diff --git a/tensorflow/lite/toco/args.h b/tensorflow/lite/toco/args.h index 188f2f7e7af61c6c9e94da42d528d3fcff4b5e39..c6eeb2859a91643c3e87bdeb25c32a8ef5611c87 100644 --- a/tensorflow/lite/toco/args.h +++ b/tensorflow/lite/toco/args.h @@ -22,10 +22,6 @@ limitations under the License. #include #include #include "tensorflow/lite/toco/toco_port.h" -#if defined(PLATFORM_GOOGLE) -#include "strings/split.h" -#include "strings/strip.h" -#endif #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "tensorflow/lite/toco/toco_types.h" @@ -64,7 +60,7 @@ class Arg final { const T& value() const { return value_; } // Parsing callback for the tensorflow::Flags code - bool parse(T value_in) { + bool Parse(T value_in) { value_ = value_in; specified_ = true; return true; @@ -72,7 +68,7 @@ class Arg final { // Bind the parse member function so tensorflow::Flags can call it. std::function bind() { - return std::bind(&Arg::parse, this, std::placeholders::_1); + return std::bind(&Arg::Parse, this, std::placeholders::_1); } private: @@ -90,24 +86,10 @@ class Arg final { // Return true if the command line argument was specified on the command line. bool specified() const { return specified_; } // Bind the parse member function so tensorflow::Flags can call it. - bool parse(string text) { - parsed_value_.elements.clear(); - specified_ = true; - // strings::Split("") produces {""}, but we need {} on empty input. - // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could - // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements) - if (!text.empty()) { - int32 element; - for (absl::string_view part : absl::StrSplit(text, ',')) { - if (!SimpleAtoi(part, &element)) return false; - parsed_value_.elements.push_back(element); - } - } - return true; - } + bool Parse(string text); std::function bind() { - return std::bind(&Arg::parse, this, std::placeholders::_1); + return std::bind(&Arg::Parse, this, std::placeholders::_1); } const toco::IntList& value() const { return parsed_value_; } @@ -126,57 +108,10 @@ class Arg final { bool specified() const { return specified_; } // Bind the parse member function so tensorflow::Flags can call it. - bool parse(string text) { - parsed_value_.elements.clear(); - specified_ = true; - - if (text.empty()) { - return true; - } - -#if defined(PLATFORM_GOOGLE) - std::vector outer_vector; - absl::string_view text_disposable_copy = text; - SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); - for (const absl::string_view& outer_member_stringpiece : outer_vector) { - string outer_member(outer_member_stringpiece); - if (outer_member.empty()) { - continue; - } - string outer_member_copy = outer_member; - absl::StripAsciiWhitespace(&outer_member); - if (!strings::TryStripPrefixString(outer_member, "{", &outer_member)) - return false; - if (!strings::TryStripSuffixString(outer_member, "}", &outer_member)) - return false; - const std::vector inner_fields_vector = - absl::StrSplit(outer_member, ','); - - std::unordered_map element; - for (const string& member_field : inner_fields_vector) { - std::vector outer_member_key_value = - absl::StrSplit(member_field, ':'); - if (outer_member_key_value.size() != 2) return false; - string& key = outer_member_key_value[0]; - string& value = outer_member_key_value[1]; - absl::StripAsciiWhitespace(&key); - absl::StripAsciiWhitespace(&value); - if (element.count(key) != 0) return false; - element[key] = value; - } - parsed_value_.elements.push_back(element); - } - return true; -#else - // TODO(aselle): Fix argument parsing when absl supports structuredline - fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__, - __LINE__); - abort(); -#endif - } + bool Parse(string text); std::function bind() { - return std::bind(&Arg::parse, this, std::placeholders::_1); + return std::bind(&Arg::Parse, this, std::placeholders::_1); } const toco::StringMapList& value() const { return parsed_value_; } diff --git a/tensorflow/lite/toco/dump_graphviz.cc b/tensorflow/lite/toco/dump_graphviz.cc index 8896893f3579abcefa87e3411f9b186ca7a45a1b..ad69e4f7b7a4285f36750c60291d7a6a97e7e9f7 100644 --- a/tensorflow/lite/toco/dump_graphviz.cc +++ b/tensorflow/lite/toco/dump_graphviz.cc @@ -15,17 +15,21 @@ limitations under the License. #include "tensorflow/lite/toco/dump_graphviz.h" #include +#include #include #include +#include "absl/memory/memory.h" #include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "absl/strings/strip.h" +#include "re2/re2.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" using toco::port::AppendF; using toco::port::StringF; @@ -33,72 +37,158 @@ using toco::port::StringF; namespace toco { namespace { +// 'nslimit' is a graphviz (dot) paramater that limits the iterations during +// the layout phase. Omitting it allows infinite iterations, causing some +// complex graphs to never finish. A value of 125 produces good graphs +// while allowing complex graphs to finish. +constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/" + nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s +)CODE"; +// Note: tooltip's are only supported on SVGs in Chrome. +constexpr char kSubgraphFmt[] = + R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s +)CODE"; +constexpr char kArrayNodeFmt[] = + R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"]; +)CODE"; +constexpr char kOpNodeFmt[] = + R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"]; +)CODE"; +constexpr char kInputEdgeFmt[] = + R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f]; +)CODE"; +constexpr char kOutputEdgeFmt[] = + R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f]; +)CODE"; +constexpr char kRNNBackEdgeFmt[] = + R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false]; +)CODE"; +constexpr char kUnicodeMult[] = "\u00D7"; +constexpr char kUnicodeEllipsis[] = " \u2026 "; + class Color { public: Color() {} Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {} + explicit Color(uint32 word) + : r_((word & 0x00FF0000) >> 16), + g_((word & 0x0000FF00) >> 8), + b_((word & 0x000000FF) >> 0) {} + // Returns the string serialization of this color in graphviz format, // for use as 'fillcolor' in boxes. - string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); } + string AsHexString() const { return StringF("#%.2X%.2X%.2X", r_, g_, b_); } + // The color to use for this node; will be used as 'fillcolor' + // for its box. See Color::AsHexString. A suitable, different + // color will be chosen for the 'fontcolor' for the inside text + // label, see Color::TextColorString. // Returns the serialization in graphviz format of a suitable color to use // 'fontcolor' in the same boxes. It should black or white, whichever offers - // the better contrast from FillColorString(). + // the better contrast from AsHexString(). string TextColorString() const { // https://en.wikipedia.org/wiki/Relative_luminance const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_; const uint8 l = luminance > 128.f ? 0 : 255; - return StringF("%.2X%.2X%.2X", l, l, l); + return StringF("#%.2X%.2X%.2X", l, l, l); } private: uint8 r_ = 0, g_ = 0, b_ = 0; }; -struct NodeProperties { - // The text to display inside the box for this node. - string label; - // The color to use for this node; will be used as 'fillcolor' - // for its box. See Color::FillColorString. A suitable, different - // color will be chosen for the 'fontcolor' for the inside text - // label, see Color::TextColorString. - Color color; - float log2_buffer_size; -}; - -// All colors in this file are from: -// https://material.io/guidelines/style/color.html +Color HashStringToColor(string s) { + // Return a unique color for a name. + // + // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes + // the string to a uint_32, then twiddles some bits to get a light and subtle + // color. This seems to be a good heuristic for keeping enough of the name to + // hash to a unique color while still revealing structure through naming + // similarities. + // + // The regular expression "_\d+" matches any underscore followed by numbers, + // which we strip out. Examples: + // + // "Conv" -> "Conv" + // "Conv_2" -> "Conv" + // "Conv_72" -> "Conv" + // "Pad_1_bias -> "Pad_bias" + // "Conv_abc" -> "Conv_abc" + + RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", ""); + uint32 color_word = std::hash{}(s); + color_word |= 0x00E0E0E0; + return Color(color_word); +} -Color GetColorForArray(const Model& model, const string& array_name) { +void GetArrayColorAndShape(const Model& model, const string& array_name, + Color* color, string* shape) { + // All colors in this file are from: + // https://material.io/guidelines/style/color.html // Arrays involved in RNN back-edges have a different color for (const auto& rnn_state : model.flags.rnn_states()) { // RNN state, fed by a back-edge. Bold color. if (array_name == rnn_state.state_array()) { - return Color(0x0F, 0x9D, 0x58); + *color = Color(0x0F, 0x9D, 0x58); + *shape = "invhouse"; + return; } // RNN back-edge source, feeding a RNN state. // Light tone of the same color as RNN states. if (array_name == rnn_state.back_edge_source_array()) { - return Color(0xB7, 0xE1, 0xCD); + *color = Color(0xB7, 0xE1, 0xCD); + *shape = "house"; + return; } } // Constant parameter arrays have their own bold color if (model.GetArray(array_name).buffer) { - return Color(0x42, 0x85, 0xF4); + *color = Color(0x42, 0x85, 0xF4); + *shape = "cylinder"; + return; } // Remaining arrays are activations. // We use gray colors for them because they are the majority // of arrays so we want to highlight other arrays instead of them. // First, we use a bolder gray for input/output arrays: if (IsInputArray(model, array_name)) { - return Color(0x9E, 0x9E, 0x9E); + *color = Color(0x9E, 0x9E, 0x9E); + *shape = "invhouse"; + return; } if (IsOutputArray(model, array_name)) { - return Color(0x9E, 0x9E, 0x9E); + *color = Color(0x9E, 0x9E, 0x9E); + *shape = "house"; + return; } // Remaining arrays are intermediate activation arrays. // Lighter tone of the same grey as for input/output arrays: // We want these to be very discrete. - return Color(0xF5, 0xF5, 0xF5); + *color = Color(0xF5, 0xF5, 0xF5); + *shape = "box"; +} + +string GetArrayCompassPt(const Model& model, const string& array_name) { + // The "compass point" is the point on the node where edge connections are + // made. For most arrays we don't care, but input's and outputs look better + // connected at the tip of the "house" and "invhouse" shapes used. So we + // append ":n" and ":s" respectively for those. + for (const auto& rnn_state : model.flags.rnn_states()) { + // RNN state is essentially an input + if (array_name == rnn_state.state_array()) { + return ":s"; + } + // RNN back-edge source is essentially an output + if (array_name == rnn_state.back_edge_source_array()) { + return ":n"; + } + } + if (IsInputArray(model, array_name)) { + return ":s"; + } + if (IsOutputArray(model, array_name)) { + return ":n"; + } + return ""; } void AppendArrayVal(string* string, Array const& array, int index) { @@ -141,239 +231,550 @@ void AppendArrayVal(string* string, Array const& array, int index) { } } -NodeProperties GetPropertiesForArray(const Model& model, - const string& array_name) { - NodeProperties node_properties; - node_properties.color = GetColorForArray(model, array_name); - node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); - node_properties.log2_buffer_size = 0.0f; +typedef std::map Attributes; + +string AttributesToHtml(Attributes attributes) { + string html; + for (const auto& attr : attributes) { + html += R"CODE()CODE"; + html += attr.first; + html += R"CODE(:)CODE"; + html += attr.second; + html += ""; + } + return html; +} + +string GetArrayLabel(const Model& model, const string& array_id) { + string html; - // Append array shape to the label. - auto& array = model.GetArray(array_name); - AppendF(&node_properties.label, "\\nType: %s", - ArrayDataTypeName(array.data_type)); + // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) + html += "<"; + // Begin Table + html += R"CODE()CODE"; + html += R"CODE()CODE"; + + auto& array = model.GetArray(array_id); + if (array.buffer) { + // "cylinder" shapes require some extra head room. + html += R"CODE()CODE"; + } + + // "Primary" name of array (last non-slash delimited group of characters). + html += R"CODE("; + + // Array data type and dimensions + html += R"CODE("; + + // Large buffer samples get their own line + if (buffer_size > 4) { + html += R"CODE("; } + // Other array properties + Attributes attrs; if (array.minmax) { - AppendF(&node_properties.label, "\\nMinMax: [%.7g, %.7g]", - array.minmax->min, array.minmax->max); + attrs["minmax"] = + StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max); } - if (array.quantization_params) { - AppendF(&node_properties.label, "\\nQuantization: %7g * (x - %d)", - array.quantization_params->scale, - array.quantization_params->zero_point); + attrs["quant"] = StringF("%7g\u00B7(x-%d)", // Unicode "cdot" + array.quantization_params->scale, + array.quantization_params->zero_point); } - if (array.alloc) { - AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)", - array.alloc->start, array.alloc->end); + attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end); } - - return node_properties; + html += AttributesToHtml(attrs); + + // output array_id in ultra-small font so it can be searched and copied. + html += R"CODE("; + + // End Table and HTML-like label + html += R"CODE(
)CODE"; + html += R"CODE()CODE"; + AppendF(&html, R"CODE(%s)CODE", + std::vector(absl::StrSplit(array_id, '/')).back()); + html += R"CODE()CODE"; + html += "
)CODE"; + html += R"CODE()CODE"; + // Type + html += ArrayDataTypeName(array.data_type); + // Shape if (array.has_shape()) { auto& array_shape = array.shape(); - node_properties.label += "\\n["; - for (int id = 0; id < array_shape.dimensions_count(); id++) { - if (id == 0) { - AppendF(&node_properties.label, "%d", array_shape.dims(id)); - } else { - // 0x00D7 is the unicode multiplication symbol - AppendF(&node_properties.label, "\u00D7%d", array_shape.dims(id)); + html += "["; + for (int dim = 0; dim < array_shape.dimensions_count(); dim++) { + AppendF(&html, "%d", array_shape.dims(dim)); + if (dim + 1 < array_shape.dimensions_count()) { + html += kUnicodeMult; } } - node_properties.label += "]"; + html += "]"; + } - int buffer_size = 0; - if (IsNonEmpty(array.shape())) { - buffer_size = RequiredBufferSizeForShape(array.shape()); - node_properties.log2_buffer_size = - std::log2(static_cast(buffer_size)); + // Small buffer sample + int buffer_size = 0; + if (array.buffer) { + buffer_size = RequiredBufferSizeForShape(array.shape()); + } + if ((buffer_size > 0) && (buffer_size <= 4)) { + html += " = "; + if (array.shape().dimensions_count() > 0) { + html += "{"; } - - if (array.buffer) { - const auto& array = model.GetArray(array_name); - if (buffer_size <= 4) { - AppendF(&node_properties.label, " = "); - if (array.shape().dimensions_count() > 0) { - AppendF(&node_properties.label, "{"); - } - for (int i = 0; i < buffer_size; i++) { - AppendArrayVal(&node_properties.label, array, i); - if (i + 1 < buffer_size) { - AppendF(&node_properties.label, ", "); - } - } - } else { - AppendF(&node_properties.label, "\\n = "); - if (array.shape().dimensions_count() > 0) { - AppendF(&node_properties.label, "{"); - } - AppendArrayVal(&node_properties.label, array, 0); - AppendF(&node_properties.label, ", "); - AppendArrayVal(&node_properties.label, array, 1); - // 0x2026 is the unicode ellipsis symbol - AppendF(&node_properties.label, " \u2026 "); - AppendArrayVal(&node_properties.label, array, buffer_size - 2); - AppendF(&node_properties.label, ", "); - AppendArrayVal(&node_properties.label, array, buffer_size - 1); - } - if (array.shape().dimensions_count() > 0) { - AppendF(&node_properties.label, "}"); + for (int i = 0; i < buffer_size; i++) { + AppendArrayVal(&html, array, i); + if (i + 1 < buffer_size) { + html += ", "; } } + if (array.shape().dimensions_count() > 0) { + html += "}"; + } + } + html += R"CODE()CODE"; + html += "
= {)CODE"; + AppendArrayVal(&html, array, 0); + html += ", "; + AppendArrayVal(&html, array, 1); + html += kUnicodeEllipsis; + AppendArrayVal(&html, array, buffer_size - 2); + html += ", "; + AppendArrayVal(&html, array, buffer_size - 1); + html += "}
)CODE"; + html += R"CODE()CODE"; + AppendF(&html, R"CODE("%s")CODE", array_id); + html += R"CODE()CODE"; + html += "
)CODE"; + html += ">"; + return html; } -NodeProperties GetPropertiesForOperator(const Operator& op) { - NodeProperties node_properties; - if (op.type == OperatorType::kUnsupported) { - node_properties.label = - static_cast(op).tensorflow_op; - } else { - node_properties.label = - string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow")); - } +Attributes GetOpAttributes(const Model& model, const Operator& op) { + Attributes attrs; switch (op.fused_activation_function) { case FusedActivationFunctionType::kRelu: - AppendF(&node_properties.label, "\\nReLU"); + attrs["func"] = "ReLU"; break; case FusedActivationFunctionType::kRelu6: - AppendF(&node_properties.label, "\\nReLU6"); + attrs["func"] = "ReLU6"; break; case FusedActivationFunctionType::kRelu1: - AppendF(&node_properties.label, "\\nReLU1"); + attrs["func"] = "ReLU1"; break; default: break; } - // Additional information for some of the operators. + // Output state of member vars on derived operators. switch (op.type) { case OperatorType::kConv: { const auto& conv_op = static_cast(op); - node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color - AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, - conv_op.stride_height, - conv_op.padding.type == PaddingType::kSame ? "S" : "V"); + string stride; + AppendF(&stride, "%d", conv_op.stride_width); + stride += kUnicodeMult; + AppendF(&stride, "%d", conv_op.stride_height); + attrs["stride"] = stride; + attrs["padding"] = + (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid"; break; } case OperatorType::kDepthwiseConv: { - const auto& conv_op = static_cast(op); - node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color - AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, - conv_op.stride_height, - conv_op.padding.type == PaddingType::kSame ? "S" : "V"); - break; - } - case OperatorType::kFullyConnected: { - node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + const auto& depthconv_op = static_cast(op); + string stride; + AppendF(&stride, "%d", depthconv_op.stride_width); + stride += kUnicodeMult; + AppendF(&stride, "%d", depthconv_op.stride_height); + attrs["stride"] = stride; + attrs["padding"] = + (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid"; break; } case OperatorType::kFakeQuant: { const auto& fakequant_op = static_cast(op); - node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + attrs["bits"] = StringF("%d", fakequant_op.num_bits); if (fakequant_op.minmax) { - AppendF(&node_properties.label, "\\n%dbit [%g,%g]", - fakequant_op.num_bits, fakequant_op.minmax->min, - fakequant_op.minmax->max); + attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min, + fakequant_op.minmax->max); } else { - AppendF(&node_properties.label, "\\n%dbit [?,?]", - fakequant_op.num_bits); + attrs["range"] = "[?,?]"; } break; } default: - node_properties.color = Color(0xDB, 0x44, 0x37); break; } + int64 math_ops_count; + if (EstimateArithmeticOpsCount(model, op, &math_ops_count) && + (math_ops_count != 0)) { + attrs["math"] = FormattedNumber(math_ops_count) + "ops"; + } - return node_properties; + return attrs; } -} // namespace +Color GetOpColor(const Operator& op) { + if ((op.type == OperatorType::kDepthwiseConv) || + (op.type == OperatorType::kConv) || + (op.type == OperatorType::kFullyConnected) || + (op.type == OperatorType::kFakeQuant)) { + // Give some ops a bolder red + return Color(0xC5, 0x39, 0x29); + } else { + return Color(0xDB, 0x44, 0x37); + } +} + +string GetOpLabel(const Model& model, const Operator& op) { + // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) + string html; + html += "<"; + + // Begin Table + html += R"CODE()CODE"; + html += + R"CODE()CODE"; + + // Input Ports + if (!op.inputs.empty()) { + html += R"CODE()CODE"; + } -void DumpGraphviz(const Model& model, string* output_file_contents) { - AppendF(output_file_contents, "digraph Computegraph {\n"); - // 'nslimit' is a graphviz (dot) paramater that limits the iterations during - // the layout phase. Omitting it allows infinite iterations, causing some - // complex graphs to never finish. A value of 125 produces good graphs - // while allowing complex graphs to finish. - AppendF(output_file_contents, "\t nslimit=125;\n"); - - constexpr char kNodeFormat[] = - "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", " - "fontcolor = \"#%sDD\"];\n"; - - constexpr char kEdgeFormat[] = - "\t \"%s\" -> \"%s\" [penwidth=%f, weight=%f];\n"; - - constexpr char kRNNBackEdgeFormat[] = - "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; - - for (const auto& array_kv : model.GetArrayMap()) { - // Add node for array. - const string& array_name = array_kv.first; - const auto& array_properties = GetPropertiesForArray(model, array_name); - AppendF(output_file_contents, kNodeFormat, array_name, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); + // Name + html += R"CODE("; + + // Attributes + Attributes attrs = GetOpAttributes(model, op); + html += AttributesToHtml(attrs); + + // Output Ports + if (!op.outputs.empty()) { + html += R"CODE()CODE"; } + + // End Table and HTML-like label + html += R"CODE(
)CODE"; + // Distribute evenly using a sub-table + html += R"CODE()CODE"; + html += R"CODE()CODE"; + for (int i = 0; i < op.inputs.size(); i++) { + html += R"CODE("; + } + html += ""; + html += R"CODE(
)CODE"; + if (op.inputs.size() > 1) { + // Only number inputs when op has two or more inputs + AppendF(&html, "%d", i); + } + html += "
)CODE"; + html += R"CODE()CODE"; + if (op.type == OperatorType::kUnsupported) { + html += static_cast(op).tensorflow_op; + } else { + html += string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow")); + } + html += R"CODE()CODE"; + html += "
)CODE"; + // Distribute evenly using a sub-table + html += R"CODE()CODE"; + html += R"CODE()CODE"; + for (int i = 0; i < op.outputs.size(); i++) { + html += R"CODE("; + } + html += ""; + html += R"CODE(
)CODE"; + if (op.outputs.size() > 1) { + // Only number outputs when op has two or more outputs + AppendF(&html, "%d", i); + } + html += "
)CODE"; + html += ">"; + + return html; +} + +float GetLog2BufferSize(const Model& model, const string& array_id) { + auto& array = model.GetArray(array_id); + if (array.has_shape()) { + int buffer_size = 0; + if (IsNonEmpty(array.shape())) { + buffer_size = RequiredBufferSizeForShape(array.shape()); + return std::log2(static_cast(buffer_size)); + } + } + return 0.0f; +} + +string GetOpId(int op_index) { return StringF("op%05d", op_index); } + +void DumpOperator(const Model& model, string* output_file, int op_index) { + // Dump node for operator. + const Operator& op = *model.operators[op_index]; + Color color = GetOpColor(op); + string label = GetOpLabel(model, op); + string op_id = GetOpId(op_index); + AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(), + color.TextColorString()); +} + +void DumpOperatorEdges(const Model& model, string* output_file, int op_index) { + // Inputs + const Operator& op = *model.operators[op_index]; + string op_id = GetOpId(op_index); + for (int i = 0; i < op.inputs.size(); i++) { + const auto& input = op.inputs[i]; + if (!model.HasArray(input)) { + // Connected arrays should _always_ exist. Except, perhaps, during + // development. + continue; + } + float log2_buffer_size = GetLog2BufferSize(model, input); + // Draw lines that transport more data thicker (Otherwise, where would the + // data fit? right?). + float line_width = std::max(0.5f, log2_buffer_size / 3.0f); + // Keep edges that transport more data shorter than those with less. + float weight = std::max(1.0f, log2_buffer_size); + if (!IsInputArray(model, input) && + GetOpWithOutput(model, input) == nullptr) { + // Give the main line of data flow a straighter path by penalizing edges + // to standalone buffers. Weights are generally very large buffers that + // would otherwise skew the layout. + weight = 1.0f; + } + string compass_pt = GetArrayCompassPt(model, input); + AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width, + weight); + } + // Outputs + for (int i = 0; i < op.outputs.size(); i++) { + const auto& output = op.outputs[i]; + if (!model.HasArray(output)) { + continue; + } + float log2_buffer_size = GetLog2BufferSize(model, output); + // See comments above regarding weight and line_width calculations. + float line_width = std::max(0.5f, log2_buffer_size / 3.0f); + float weight = std::max(1.0f, log2_buffer_size); + if (!IsArrayConsumed(model, output)) { + weight = 1.0f; + } + string compass_pt = GetArrayCompassPt(model, output); + AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt, + line_width, weight); + } +} + +struct Node { + Node() : math_ops(0) {} + // Name used as a key in the model's array map + string array_id; + + // Estimated number of math ops incurred by this node (the sum of the op + // with this array as 1st output, plus all children nodes). + int64 math_ops; + + // A map of child nodes keyed by name. + std::map> children; +}; + +string GetSubgraphLabel(Node const& node, const string& subgraph) { + // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) + string html; + html += "<"; + + // Begin Table + html += R"CODE()CODE"; + html += + R"CODE()CODE"; + + // Name + html += R"CODE("; + + // Attributes + Attributes attrs; + if (node.math_ops > 0) { + attrs["math"] = FormattedNumber(node.math_ops) + "ops"; + } + html += AttributesToHtml(attrs); + + // End Table and HTML-like label + html += R"CODE(
)CODE"; + html += R"CODE()CODE"; + html += subgraph; + html += R"CODE()CODE"; + html += "
)CODE"; + html += ">"; + + return html; +} + +void DumpSubgraphHeader(string* output_file, Node const& node, + const string& node_name) { + Color color = HashStringToColor(node_name); + string label = GetSubgraphLabel(node, node_name); + AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label); +} + +void DumpArray(const Model& model, string* output_file, + const string& array_id) { + Color color; + string shape; + GetArrayColorAndShape(model, array_id, &color, &shape); + string label = GetArrayLabel(model, array_id); + AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape, + color.AsHexString(), color.TextColorString()); + + // Ops are placed in the same subgraph as their first output. for (int op_index = 0; op_index < model.operators.size(); op_index++) { const Operator& op = *model.operators[op_index]; - // Add node for operator. - auto op_properties = GetPropertiesForOperator(op); - string operator_id = StringF("op%05d", op_index); - AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, - "box", op_properties.color.FillColorString().c_str(), - op_properties.color.TextColorString().c_str()); - // Add edges for all inputs of the operator. - for (const auto& input : op.inputs) { - if (!model.HasArray(input)) { - // Arrays should _always_ exist. Except, perhaps, during development. - continue; - } - auto array_properties = GetPropertiesForArray(model, input); - // Draw lines that transport more data thicker (Otherwise, where would the - // data fit? right?). - float line_width = - std::max(0.5f, array_properties.log2_buffer_size / 3.0f); - // Keep edges that transport more data shorter than those with less. - float weight = std::max(1.0f, array_properties.log2_buffer_size); - if (!IsInputArray(model, input) && - GetOpWithOutput(model, input) == nullptr) { - // Give the main line of data flow a straighter path by penalizing edges - // to standalone buffers. Weights are generally very large buffers that - // otherwise skew the layout without this. - weight = 1.0f; - } - AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width, - weight); + if (!op.outputs.empty() && (op.outputs[0] == array_id)) { + DumpOperator(model, output_file, op_index); } - // Add edges for all outputs of the operator. - for (const auto& output : op.outputs) { - if (!model.HasArray(output)) { - // Arrays should _always_ exist. Except, perhaps, during development. - continue; - } - auto array_properties = GetPropertiesForArray(model, output); - // See comments above regarding weight and line_width calculations. - float line_width = - std::max(0.5f, array_properties.log2_buffer_size / 3.0f); - float weight = std::max(1.0f, array_properties.log2_buffer_size); - if (!IsArrayConsumed(model, output)) { - weight = 1.0f; + } +} + +void DumpNode(const Model& model, string* output_file, const string& node_name, + Node const& node) { + bool not_root = !node_name.empty(); + if (not_root) { + DumpSubgraphHeader(output_file, node, node_name); + } + + for (const auto& child : node.children) { + if (!child.second->array_id.empty()) { + // Dump array if this node posesses one. + DumpArray(model, output_file, child.second->array_id); + } + // Note that it is always possible to have children. Unlike a filesystem, + // the existence of array "foo/bar" does _not_ prevent other arrays, such as + // and "foo/bar/baz", from being nested beneath it. + DumpNode(model, output_file, child.first, *child.second); + } + + if (not_root) { + // End subgraph + AppendF(output_file, " }\n"); + } +} + +int64 GetArithmeticOpsCount(const Model& model, const string& array_id) { + for (const auto& op : model.operators) { + if (!op->outputs.empty() && op->outputs[0] == array_id) { + int64 count; + if (EstimateArithmeticOpsCount(model, *op, &count)) { + return count; + } else { + return 0; } - AppendF(output_file_contents, kEdgeFormat, operator_id, output, - line_width, weight); } } + return 0; +} - for (const auto& rnn_state : model.flags.rnn_states()) { - AppendF(output_file_contents, kRNNBackEdgeFormat, - rnn_state.back_edge_source_array(), rnn_state.state_array()); +void InsertNode(const Model& model, const string& array_id, Node* node, + std::vector prefixes, int64* math_ops) { + if (prefixes.empty()) { + // Base case: store array in this node. + node->array_id = array_id; + *math_ops = GetArithmeticOpsCount(model, array_id); + } else { + // Insert into the sub-tree for that prefix. + string prefix = prefixes.back(); + prefixes.pop_back(); + if (node->children.count(prefix) == 0) { + // Create a new node if this prefix is unseen. + node->children[prefix] = absl::make_unique(); + } + InsertNode(model, array_id, node->children[prefix].get(), prefixes, + math_ops); } + // Sum estimated math ops into all nodes. + node->math_ops += *math_ops; +} - AppendF(output_file_contents, "}\n"); +void BuildArrayTree(const Model& model, Node* tree) { + // Delimit array names by path "/", then place into a tree based on this path. + for (const auto& array_id : model.GetArrayMap()) { + std::vector prefixes = absl::StrSplit(array_id.first, '/'); + std::reverse(prefixes.begin(), prefixes.end()); + int64 math_ops; // Temporary storage for math ops used during recursion. + InsertNode(model, array_id.first, tree, prefixes, &math_ops); + } +} + +string GetGraphLabel(const Model& model, const string& graph_name) { + // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) + string html; + html += "<"; + + // Begin Table + html += R"CODE()CODE"; + html += + R"CODE()CODE"; + + // Name + html += R"CODE("; + + // Attributes + Attributes attrs; + attrs["arrays"] = StringF("%d", model.GetArrayMap().size()); + if (!model.optional_arrays.empty()) { + attrs["optional arrays"] = StringF("%d", model.optional_arrays.size()); + } + attrs["operators"] = StringF("%d", model.operators.size()); + int64 ops_count; + if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) { + attrs["math"] = FormattedNumber(ops_count) + "ops"; + } + if (model.transient_data_size > 0) { + attrs["transient data size"] = + StringF("%d KiB", model.transient_data_size / 1024); + } + if (model.transient_data_alignment > 0) { + attrs["transient data alignment"] = + StringF("%d bytes", model.transient_data_alignment); + } + html += AttributesToHtml(attrs); + + // End Table and HTML-like label + html += R"CODE(
)CODE"; + html += R"CODE()CODE"; + html += graph_name; + html += R"CODE()CODE"; + html += "
)CODE"; + html += ">"; + + return html; +} +} // namespace + +void DumpGraphviz(const Model& model, string* output_file, + const string& graph_name) { + // Start graphviz format + AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name)); + + // Organize arrays into a tree for subgraphing + Node tree; + BuildArrayTree(model, &tree); + DumpNode(model, output_file, "", tree); + + // Dump edges outside all subgraphs (otherwise the referred-to nodes are + // implicitly included in that subgraph). + for (int op_index = 0; op_index < model.operators.size(); op_index++) { + DumpOperatorEdges(model, output_file, op_index); + } + + // Dump RNN Backedges + for (const auto& rnn_state : model.flags.rnn_states()) { + AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(), + rnn_state.state_array()); + } + // End graphviz format + AppendF(output_file, "}\n"); } } // namespace toco diff --git a/tensorflow/lite/toco/dump_graphviz.h b/tensorflow/lite/toco/dump_graphviz.h index 9697bd6f0dc434aaf98762698c64fb60cb97f2ee..9bb74dac3f8fb34fb2a440e499c4ed0066ffea4d 100644 --- a/tensorflow/lite/toco/dump_graphviz.h +++ b/tensorflow/lite/toco/dump_graphviz.h @@ -21,7 +21,8 @@ limitations under the License. namespace toco { -void DumpGraphviz(const Model& model, string* output_file_contents); +void DumpGraphviz(const Model& model, string* output_file_contents, + const string& graph_name); } // namespace toco diff --git a/tensorflow/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc index 9fff0015527ebadf501f571bdd5ed0a7643d66e0..d426a690678fabf2ca344d2fc80cef88b08f196a 100644 --- a/tensorflow/lite/toco/export_tensorflow.cc +++ b/tensorflow/lite/toco/export_tensorflow.cc @@ -22,11 +22,6 @@ limitations under the License. #include "google/protobuf/text_format.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/model_flags.pb.h" -#include "tensorflow/lite/toco/runtime/types.h" -#include "tensorflow/lite/toco/tensorflow_util.h" -#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -34,6 +29,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tensorflow_util.h" +#include "tensorflow/lite/toco/tooling_util.h" using tensorflow::DT_BOOL; using tensorflow::DT_COMPLEX64; @@ -1205,6 +1205,16 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertCeilOperator(const Model& model, const CeilOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node(); + ceil_op->set_op("Ceil"); + ceil_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *ceil_op->add_input() = src_op.inputs[0]; + (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); @@ -1295,7 +1305,8 @@ void ConvertTensorFlowShapeOperator(const Model& model, GetTensorFlowDataType(model, src_op.outputs[0])); } -void ConvertRankOperator(const Model& model, const RankOperator& src_op, +void ConvertRankOperator(const Model& model, + const TensorFlowRankOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* rank_op = tensorflow_graph->add_node(); rank_op->set_op("Rank"); @@ -2052,6 +2063,33 @@ void ConvertZerosLikeOperator(const Model& model, (*zeros_like_op->mutable_attr())["T"].set_type(data_type); } +void ConvertReverseV2Operator(const Model& model, + const ReverseV2Operator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node(); + reverse_v2_op->set_op(op_name); + reverse_v2_op->set_name(src_op.outputs[0]); + DCHECK_EQ(src_op.inputs.size(), 2); + *reverse_v2_op->add_input() = src_op.inputs[0]; + *reverse_v2_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*reverse_v2_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertReverseSequenceOperator(const Model& model, + const ReverseSequenceOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node(); + reverse_seq_op->set_op("ReverseSequence"); + reverse_seq_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *reverse_seq_op->add_input() = src_op.inputs[0]; + *reverse_seq_op->add_input() = src_op.inputs[1]; + (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim); + (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2169,6 +2207,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kFloor) { ConvertFloorOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kCeil) { + ConvertCeilOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kGather) { ConvertGatherOperator(model, static_cast(src_op), tensorflow_graph); @@ -2247,7 +2288,8 @@ void ConvertOperator(const Model& model, const Operator& src_op, model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRank) { - ConvertRankOperator(model, static_cast(src_op), + ConvertRankOperator(model, + static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRange) { ConvertRangeOperator(model, static_cast(src_op), @@ -2328,6 +2370,14 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertZerosLikeOperator( model, static_cast(src_op), "ZerosLike", tensorflow_graph); + } else if (src_op.type == OperatorType::kReverseV2) { + ConvertReverseV2Operator(model, + static_cast(src_op), + "Reverse_V2", tensorflow_graph); + } else if (src_op.type == OperatorType::kReverseSequence) { + ConvertReverseSequenceOperator( + model, static_cast(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc index a0260e24013bfda8718e0dc04052abb49b65debf..e4eb7698597f588947bc19f5ab449c9d3ff14adc 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc @@ -128,7 +128,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { } bool GraphTransformationsPass(int increment, Model* model, - const GraphTransformationsSet& transformations) { + const GraphTransformationsSet& transformations, + tensorflow::Status* status) { CHECK(increment == 1 || increment == -1); bool changed = false; if (model->operators.empty()) { @@ -142,7 +143,10 @@ bool GraphTransformationsPass(int increment, Model* model, for (const auto& transformation : transformations) { CHECK(!changed_now); CHECK(transformation->Messages().empty()); - CHECK(transformation->Run(model, op_index, &changed_now).ok()); + *status = transformation->Run(model, op_index, &changed_now); + if (!status->ok()) { + return false; + } const char* made_a_change_msg = changed_now ? "made a change" : "did NOT make a change"; const int log_level = @@ -186,18 +190,21 @@ bool GraphTransformationsPass(int increment, Model* model, } // namespace -void RunGraphTransformations(Model* model, const string& msg, - const GraphTransformationsSet& transformations) { +tensorflow::Status RunGraphTransformationsWithStatus( + Model* model, const string& msg, + const GraphTransformationsSet& transformations) { PrintModelStats(toco::port::StringF("Before %s", msg), *model); int pass_index = 0; + tensorflow::Status status; while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, - transformations)) { + transformations, &status)) { pass_index++; const auto& label = toco::port::StringF("After %s pass %d", msg, pass_index); PrintModelStats(label, *model); CheckInvariants(*model); } + return status; } } // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index 187b584b6989cc55894160fc5508c13474a1d2d3..d92733ba3b5490b0b77e88e3beb1bbe9d4508a3a 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -102,8 +102,16 @@ class GraphTransformationsSet { // construct GraphTransformation objects by using 'new', pass us // the resulting raw pointers, and this RunGraphTransformations // takes care of delete'ing these pointers. -void RunGraphTransformations(Model* model, const string& message, - const GraphTransformationsSet& transformations); +tensorflow::Status RunGraphTransformationsWithStatus( + Model* model, const string& msg, + const GraphTransformationsSet& transformations); + +inline void RunGraphTransformations( + Model* model, const string& msg, + const GraphTransformationsSet& transformations) { + auto s = RunGraphTransformationsWithStatus(model, msg, transformations); + CHECK(s.ok()) << s.error_message(); +} #define DECLARE_GRAPH_TRANSFORMATION(GTName) \ class GTName : public GraphTransformation { \ @@ -127,6 +135,10 @@ DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary) +DECLARE_GRAPH_TRANSFORMATION(GroupBidirectionalSequenceLstm) +DECLARE_GRAPH_TRANSFORMATION(GroupBidirectionalSequenceRnn) +DECLARE_GRAPH_TRANSFORMATION(GroupDynamicBidirectionalSequenceLstm) +DECLARE_GRAPH_TRANSFORMATION(GroupDynamicBidirectionalSequenceRnn) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) diff --git a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..03dbf3cb1db6949c865dced5d0b4b8a34f86c406 --- /dev/null +++ b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc @@ -0,0 +1,654 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" + +namespace toco { +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator& op) { + return std::find_if( + model->operators.begin(), model->operators.end(), + [&op](const std::unique_ptr& ptr) { return ptr.get() == &op; }); +} + +bool MatchTwoUnpackOps(const Operator& op, const Model& model, + Operator** fw_output, Operator** bw_output) { + if (op.inputs.size() != 2) { + return false; + } + + *fw_output = GetOpWithOutput(model, op.inputs[0]); + *bw_output = GetOpWithOutput(model, op.inputs[1]); + if (*fw_output == nullptr || *bw_output == nullptr) { + return false; + } + + if ((*fw_output)->type != OperatorType::kUnpack || + (*bw_output)->type != OperatorType::kUnpack) { + return false; + } + + // TODO(renjieliu): Check the shapes are matching. + + return true; +} + +bool MatchDynamicBidirectionalSequenceOutputs(Operator* op, const Model& model, + Operator** fw_output, + Operator** bw_output) { + if (op->inputs.size() != 2) { + return false; + } + + // The concat op is already the fw_rnn_output. + *fw_output = op; + auto* reverse_output = GetOpWithOutput(model, op->inputs[1]); + if (*fw_output == nullptr || reverse_output == nullptr) { + return false; + } + + if (reverse_output->type != OperatorType::kReverseV2 && + reverse_output->type != OperatorType::kReverseSequence) { + return false; + } + + *bw_output = reverse_output; + + return true; +} + +bool FindUnidirectionalSequenceOp(const Model& model, const Operator& output_op, + OperatorType operator_type, + std::stack* sequence_ops, + Operator** input_op) { + Operator* op_it = nullptr; + op_it = GetOpWithOutput(model, output_op.inputs[0]); + if (op_it == nullptr) { + return false; + } + + while (op_it->type == operator_type) { + sequence_ops->push(op_it); + // Check the first input of the unidirectional sequence op. + op_it = GetOpWithOutput(model, op_it->inputs[0]); + if (op_it == nullptr) { + return false; + } + } + + *input_op = op_it; + return true; +} + +bool CheckTwoUnidirectionalSequenceOpsAreValid( + const Model& model, std::stack fw_unidirectional_sequence_ops, + std::stack bw_unidirectional_sequence_ops, + const Operator* first_fw_sequence_op_input, + const Operator* first_bw_sequence_op_input, bool is_dynamic_rnn) { + if (fw_unidirectional_sequence_ops.size() != + bw_unidirectional_sequence_ops.size() || + fw_unidirectional_sequence_ops.empty()) { + return false; + } + + // Fw & bw sequence ops are allowed to have different input shapes, but they + // need to have the same data type. + while (!fw_unidirectional_sequence_ops.empty()) { + Operator* fw_sequence_op = fw_unidirectional_sequence_ops.top(); + Operator* bw_sequence_op = bw_unidirectional_sequence_ops.top(); + + if (fw_sequence_op->inputs.size() != bw_sequence_op->inputs.size() || + fw_sequence_op->outputs.size() != bw_sequence_op->outputs.size()) + return false; + + // Make sure the inputs datatype matches. + for (int i = 0; i < fw_sequence_op->inputs.size(); ++i) { + const auto& fw_input_array_name = fw_sequence_op->inputs[i]; + const auto& bw_input_array_name = bw_sequence_op->inputs[i]; + if (model.HasArray(fw_input_array_name) && + model.HasArray(bw_input_array_name)) { + if (model.GetArray(fw_input_array_name).data_type != + model.GetArray(bw_input_array_name).data_type) + return false; + } + } + + // Make sure the outputs datatype matches. + for (int i = 0; i < fw_sequence_op->outputs.size(); ++i) { + const auto& fw_output_array_name = fw_sequence_op->outputs[i]; + const auto& bw_output_array_name = bw_sequence_op->outputs[i]; + if (model.HasArray(fw_output_array_name) && + model.HasArray(bw_output_array_name)) { + if (model.GetArray(fw_output_array_name).data_type != + model.GetArray(bw_output_array_name).data_type) + return false; + } + } + + fw_unidirectional_sequence_ops.pop(); + bw_unidirectional_sequence_ops.pop(); + } + + if (is_dynamic_rnn) { + // For dynamic bidirectional sequence ops, bw_sequence will have a reverse + // op. + if (first_bw_sequence_op_input->type != OperatorType::kReverseV2 && + first_bw_sequence_op_input->type != OperatorType::kReverseSequence) { + return false; + } + + const auto* bw_real_input_op = + GetOpWithOutput(model, first_bw_sequence_op_input->inputs[0]); + if (first_fw_sequence_op_input != bw_real_input_op) { + return false; + } + + } else { + // For static bidirectional sequence ops, we should have two pack ops. + if (first_fw_sequence_op_input->type != OperatorType::kPack || + first_bw_sequence_op_input->type != OperatorType::kPack) { + return false; + } + + // fw_lstm & bw_lstm should point to the same input, but reversed sequence. + for (size_t i = 0; i < first_fw_sequence_op_input->inputs.size(); ++i) { + if (first_fw_sequence_op_input->inputs[i] != + first_bw_sequence_op_input + ->inputs[first_fw_sequence_op_input->inputs.size() - i - 1]) { + return false; + } + } + } + + return true; +} + +void ConstructBidirectionalSequenceOp( + const Operator& fw_lstm_op, const Operator& bw_lstm_op, Model* model, + BidirectionalSequenceLstmOperator** bi_op) { + // TODO(renjieliu): Check the shapes & configurations are equal. + constexpr int kBidirectionalSequenceLstmInputsCount = 47; + constexpr int kFwLstmInputsStartIndex = 1; + constexpr int kBwLstmInputsStartIndex = 18; + constexpr int kFwInputActivationStartIndex = 35; + constexpr int kBwInputActivationStartIndex = 37; + constexpr int kAuxInputStartIndex = 39; + (*bi_op)->inputs.reserve(kBidirectionalSequenceLstmInputsCount); + const string& input_array_name = + AvailableArrayName(*model, "bidirectional_sequence_lstm_input_0"); + model->GetOrCreateArray(input_array_name); + // The input will be changed later. + (*bi_op)->inputs.push_back(input_array_name); + int i = 1; + // Fill in the fw_lstm weights. + for (; i < kBwLstmInputsStartIndex; ++i) { + (*bi_op)->inputs.push_back(fw_lstm_op.inputs[i]); + } + + // Fill in the bw_lstm weights. bidirectional lstm backward weights start + // from 18. + for (; i < kFwInputActivationStartIndex; ++i) { + (*bi_op)->inputs.push_back( + bw_lstm_op + .inputs[i - (kBwLstmInputsStartIndex - kFwLstmInputsStartIndex)]); + } + + // Fill in fw_lstm previous states. + for (; i < kBwInputActivationStartIndex; ++i) { + (*bi_op)->inputs.push_back( + fw_lstm_op.inputs[i - (kFwInputActivationStartIndex - + kBwLstmInputsStartIndex)]); + } + + // Fill in bw_lstm previous states. + for (; i < kAuxInputStartIndex; ++i) { + (*bi_op)->inputs.push_back( + bw_lstm_op.inputs[i - (kBwInputActivationStartIndex - + kBwLstmInputsStartIndex)]); + } + + // TODO(renjieliu): Deal with Auxiliary input and weights for 39 - 47. + for (; i <= kBidirectionalSequenceLstmInputsCount; ++i) { + const string& temp_array_name = AvailableArrayName( + *model, "bidirectional_sequence_lstm_temp_" + std::to_string(i)); + model->CreateOptionalArray(temp_array_name); + (*bi_op)->inputs.push_back(temp_array_name); + } + + // Deal with outputs. + (*bi_op)->outputs.reserve(2); + const string& fw_output_array_name = + AvailableArrayName(*model, "bidirectional_sequence_lstm_fw_output_0"); + const string& bw_output_array_name = + AvailableArrayName(*model, "bidirectional_sequence_lstm_bw_output_0"); + model->GetOrCreateArray(fw_output_array_name); + model->GetOrCreateArray(bw_output_array_name); + (*bi_op)->outputs.push_back(fw_output_array_name); + (*bi_op)->outputs.push_back(bw_output_array_name); + (*bi_op)->merge_outputs = false; +} + +void ConstructBidirectionalSequenceOp( + const Operator& fw_rnn_op, const Operator& bw_rnn_op, Model* model, + BidirectionalSequenceRnnOperator** bi_op) { + // TODO(renjieliu): Check the shapes & configurations are equal. + constexpr int kBidirectionalSequenceRnnInputsCount = 12; + constexpr int kFwInputsStartIndex = 1; + constexpr int kBwInputsStartIndex = 5; + constexpr int kAuxInputsStartIndex = 9; + (*bi_op)->inputs.reserve(kBidirectionalSequenceRnnInputsCount); + const string& input_array_name = + AvailableArrayName(*model, "bidirectional_sequence_rnn_input_0"); + model->GetOrCreateArray(input_array_name); + // The input will be changed later. + (*bi_op)->inputs.push_back(input_array_name); + int i = 1; + + // Fill in the fw_rnn weights. + for (; i < kBwInputsStartIndex; ++i) { + (*bi_op)->inputs.push_back(fw_rnn_op.inputs[i]); + } + + // Fill in the bw_rnn weights. + for (; i < kAuxInputsStartIndex; ++i) { + (*bi_op)->inputs.push_back( + bw_rnn_op.inputs[i - (kBwInputsStartIndex - kFwInputsStartIndex)]); + } + + // TODO(renjieliu): Deal with optional weights. + for (; i < kBidirectionalSequenceRnnInputsCount; ++i) { + const string& temp_array_name = AvailableArrayName( + *model, "bidirectional_sequence_rnn_temp_" + std::to_string(i)); + model->CreateOptionalArray(temp_array_name); + (*bi_op)->inputs.push_back(temp_array_name); + } + + // Deal with outputs. + (*bi_op)->outputs.reserve(2); + const string& fw_output_array_name = + AvailableArrayName(*model, "bidirectional_sequence_rnn_fw_output_0"); + const string& bw_output_array_name = + AvailableArrayName(*model, "bidirectional_sequence_rnn_bw_output_0"); + model->GetOrCreateArray(fw_output_array_name); + model->GetOrCreateArray(bw_output_array_name); + (*bi_op)->outputs.push_back(fw_output_array_name); + (*bi_op)->outputs.push_back(bw_output_array_name); + (*bi_op)->merge_outputs = false; +} + +template +void GroupFwBwSequenceOps(Model* model, std::stack fw_sequence_ops, + std::stack bw_sequence_ops, + std::vector* bidirectional_sequence_ops) { + while (!fw_sequence_ops.empty()) { + Operator* fw_sequence_op = fw_sequence_ops.top(); + Operator* bw_sequence_op = bw_sequence_ops.top(); + T* bidirectional_sequence_op = new T; + ConstructBidirectionalSequenceOp(*fw_sequence_op, *bw_sequence_op, model, + &bidirectional_sequence_op); + + bidirectional_sequence_ops->push_back(bidirectional_sequence_op); + fw_sequence_ops.pop(); + bw_sequence_ops.pop(); + } +} + +template +void RewireBidirectionalSequenceSequenceOpsConnections( + OperatorType operator_type, const string& input_array_name, + const std::vector& bidirectional_sequence_ops, + std::vector>::iterator* op_it, Model* model) { + int aux_input_index = -1; + switch (operator_type) { + case OperatorType::kBidirectionalSequenceLstm: + aux_input_index = 39; + break; + case OperatorType::kBidirectionalSequenceRnn: + aux_input_index = 9; + break; + default: + // Should not reach here. + DCHECK(false); + } + string cur_fw_input = input_array_name; + string cur_bw_input = input_array_name; + for (size_t i = 0; i < bidirectional_sequence_ops.size(); ++i) { + DeleteArrayIfUsedOnce(bidirectional_sequence_ops[i]->inputs[0], model); + bidirectional_sequence_ops[i]->inputs[0] = cur_fw_input; + if (i != 0) { + DeleteArrayIfUsedOnce( + bidirectional_sequence_ops[i]->inputs[aux_input_index], model); + bidirectional_sequence_ops[i]->inputs[aux_input_index] = cur_bw_input; + } + cur_fw_input = bidirectional_sequence_ops[i]->outputs[0]; + cur_bw_input = bidirectional_sequence_ops[i]->outputs[1]; + if (i != (bidirectional_sequence_ops.size() - 1)) { + bidirectional_sequence_ops[i]->merge_outputs = false; + } else { + // TODO(renjieliu): We need to check whether the outputs of the last bidi + // lstms needs merged outputs or not. + bidirectional_sequence_ops[i]->merge_outputs = true; + DeleteArrayIfUnused(bidirectional_sequence_ops[i]->outputs[1], model); + bidirectional_sequence_ops[i]->outputs.pop_back(); + } + model->operators.emplace(*op_it, bidirectional_sequence_ops[i]); + *op_it += 1; + } +} + +template +void RewireFinalUnpackOutputs(const UnpackOperator& original_unpack_operator, + UnpackOperator** final_unpack_operator, + T** final_bidi_sequence_operator, Model* model) { + (*final_unpack_operator) + ->inputs.push_back((*final_bidi_sequence_operator)->outputs[0]); + (*final_unpack_operator)->axis = original_unpack_operator.axis; + (*final_unpack_operator)->num = original_unpack_operator.num; + + for (size_t i = 0; i < original_unpack_operator.outputs.size(); ++i) { + const string& output_array_name = original_unpack_operator.outputs[i]; + const string& final_unpack_output_array_name = AvailableArrayName( + *model, "bidirectional_sequence_unpack_" + std::to_string(i)); + model->GetOrCreateArray(final_unpack_output_array_name); + (*final_unpack_operator)->outputs.push_back(final_unpack_output_array_name); + Operator* unpack_following_op = GetOpWithInput(*model, output_array_name); + if (unpack_following_op != nullptr) { + // If there's a following op after the unpack, it must be a concat op. + DCHECK(unpack_following_op->type == OperatorType::kConcatenation); + // For every output of the concat, rewire the outputs. + for (const string& concat_output : unpack_following_op->outputs) { + (*final_unpack_operator)->outputs[i] = concat_output; + } + // Remove the concat op. + model->operators.erase(FindOperator(model, *unpack_following_op)); + } + } +} + +void RemoveUnpackOperator(const Operator& unpack_op, Model* model) { + for (const string& output_array_name : unpack_op.outputs) { + DeleteArrayIfUnused(output_array_name, model); + } + model->operators.erase(FindOperator(model, unpack_op)); +} + +void RemoveUnidirectionalSequenceOps(std::stack uni_sequence_ops, + Model* model) { + while (!uni_sequence_ops.empty()) { + Operator* uni_sequence_op = uni_sequence_ops.top(); + DeleteArrayIfUnused(uni_sequence_op->outputs[0], model); + model->operators.erase(FindOperator(model, *uni_sequence_op)); + uni_sequence_ops.pop(); + } +} + +template +::tensorflow::Status GroupDynamicSequenceOps(Model* model, std::size_t op_index, + OperatorType operator_type, + bool* modified) { + *modified = false; + + // We assume there's a concatenation right after the bidirectional sequence + // ops, it may not be the case. + auto op_it = model->operators.begin() + op_index; + Operator* final_concat_op = op_it->get(); + if (final_concat_op->type != OperatorType::kConcatenation && + final_concat_op->type != OperatorType::kConcat && + final_concat_op->type != OperatorType::kConcatV2) { + return ::tensorflow::Status::OK(); + } + + // for bw, there will be a reverse op at the end. + Operator *fw_sequence_output, *bw_sequence_output; + if (!MatchDynamicBidirectionalSequenceOutputs( + final_concat_op, *model, &fw_sequence_output, &bw_sequence_output)) { + return ::tensorflow::Status::OK(); + } + + // Find all upstream unidirectional sequence ops. + std::stack fw_unidirectional_sequence_ops, + bw_unidirectional_sequence_ops; + OperatorType unidirectional_op_type; + if (operator_type == OperatorType::kBidirectionalSequenceLstm) { + unidirectional_op_type = OperatorType::kUnidirectionalSequenceLstm; + } else { + unidirectional_op_type = OperatorType::kUnidirectionalSequenceRnn; + } + Operator *first_fw_sequence_input, *first_bw_sequence_input; + if (!FindUnidirectionalSequenceOp( + *model, *fw_sequence_output, unidirectional_op_type, + &fw_unidirectional_sequence_ops, &first_fw_sequence_input) || + !FindUnidirectionalSequenceOp( + *model, *bw_sequence_output, unidirectional_op_type, + &bw_unidirectional_sequence_ops, &first_bw_sequence_input)) { + return ::tensorflow::Status::OK(); + } + + if (!CheckTwoUnidirectionalSequenceOpsAreValid( + *model, fw_unidirectional_sequence_ops, + bw_unidirectional_sequence_ops, first_fw_sequence_input, + first_bw_sequence_input, /*is_dynamic_rnn=*/true)) { + return ::tensorflow::Status::OK(); + } + + std::vector bidirectional_sequence_ops; + GroupFwBwSequenceOps(model, fw_unidirectional_sequence_ops, + bw_unidirectional_sequence_ops, + &bidirectional_sequence_ops); + + // Rewire the inputs & outputs. + string current_input = first_fw_sequence_input->outputs[0]; + RewireBidirectionalSequenceSequenceOpsConnections( + operator_type, current_input, bidirectional_sequence_ops, &op_it, model); + + // Change last bidirectional sequence rnn output to the concat output. + bidirectional_sequence_ops[bidirectional_sequence_ops.size() - 1] + ->outputs[0] = final_concat_op->outputs[0]; + + // Delete unused ops. + RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_ops, model); + RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_ops, model); + + DeleteArrayIfUnused(final_concat_op->inputs[0], model); + DeleteArrayIfUnused(final_concat_op->inputs[1], model); + model->operators.erase(FindOperator(model, *final_concat_op)); + + // Only keep the fw lstm's input. + DeleteArrayIfUnused(first_bw_sequence_input->outputs[0], model); + model->operators.erase(FindOperator(model, *first_bw_sequence_input)); + *modified = true; + return ::tensorflow::Status::OK(); +} + +} // namespace + +::tensorflow::Status GroupBidirectionalSequenceLstm::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; + // Bidirectional sequence lstm will generate two separate unidirectional + // sequence lstm ops, for static bidirectional sequence lstm, there will be + // a concatenation op at very end; for dynamic bidirectional sequence lstm, + // it is not guaranteed, but currently we do not support that. + auto op_it = model->operators.begin() + op_index; + Operator* final_concat_op = op_it->get(); + if (final_concat_op->type != OperatorType::kConcatenation && + final_concat_op->type != OperatorType::kConcat && + final_concat_op->type != OperatorType::kConcatV2) { + return ::tensorflow::Status::OK(); + } + + // Match fw unidirectional lstm outputs and bw unidirectional lstm outputs: + // should be two unstack ops. + Operator *fw_lstm_output, *bw_lstm_output; + if (!MatchTwoUnpackOps(*final_concat_op, *model, &fw_lstm_output, + &bw_lstm_output)) { + return ::tensorflow::Status::OK(); + } + + // Find all upstream unidirectional lstm ops. + std::stack fw_unidirectional_sequence_lstm_ops, + bw_unidirectional_sequence_lstm_ops; + Operator *first_fw_lstm_input, *first_bw_lstm_input; + if (!FindUnidirectionalSequenceOp( + *model, *fw_lstm_output, OperatorType::kUnidirectionalSequenceLstm, + &fw_unidirectional_sequence_lstm_ops, &first_fw_lstm_input) || + !FindUnidirectionalSequenceOp( + *model, *bw_lstm_output, OperatorType::kUnidirectionalSequenceLstm, + &bw_unidirectional_sequence_lstm_ops, &first_bw_lstm_input)) { + return ::tensorflow::Status::OK(); + } + + if (!CheckTwoUnidirectionalSequenceOpsAreValid( + *model, fw_unidirectional_sequence_lstm_ops, + bw_unidirectional_sequence_lstm_ops, first_fw_lstm_input, + first_bw_lstm_input, /*is_dynamic_rnn=*/false)) { + return ::tensorflow::Status::OK(); + } + + std::vector + bidirectional_sequence_lstm_ops; + GroupFwBwSequenceOps(model, fw_unidirectional_sequence_lstm_ops, + bw_unidirectional_sequence_lstm_ops, + &bidirectional_sequence_lstm_ops); + + // Rewire the inputs & outputs. + string current_input = first_fw_lstm_input->outputs[0]; + RewireBidirectionalSequenceSequenceOpsConnections( + OperatorType::kBidirectionalSequenceLstm, current_input, + bidirectional_sequence_lstm_ops, &op_it, model); + + // Insert a unpack op for the output. + UnpackOperator* unpack_operator = new UnpackOperator; + + RewireFinalUnpackOutputs( + static_cast(*fw_lstm_output), &unpack_operator, + &bidirectional_sequence_lstm_ops[bidirectional_sequence_lstm_ops.size() - + 1], + model); + model->operators.emplace(op_it, unpack_operator); + + // Delete unused ops. + RemoveUnpackOperator(*fw_lstm_output, model); + RemoveUnpackOperator(*bw_lstm_output, model); + RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_lstm_ops, model); + RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_lstm_ops, model); + // Only keep the fw lstm's pack input. + DeleteArrayIfUnused(first_bw_lstm_input->outputs[0], model); + model->operators.erase(FindOperator(model, *first_bw_lstm_input)); + *modified = true; + return ::tensorflow::Status::OK(); +} + +::tensorflow::Status GroupBidirectionalSequenceRnn::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; + // Bidirectional sequence rnn will generate two separate unidirectional + // sequence rnn ops, for static bidirectional sequence rnn, there will be + // a concatenation op at very end; for dynamic bidirectional sequence rnn, + // it is not guaranteed, but currently we do not support that. + auto op_it = model->operators.begin() + op_index; + Operator* final_concat_op = op_it->get(); + if (final_concat_op->type != OperatorType::kConcatenation && + final_concat_op->type != OperatorType::kConcat && + final_concat_op->type != OperatorType::kConcatV2) { + return ::tensorflow::Status::OK(); + } + + // Match fw unidirectional rnn outputs and bw unidirectional rnn outputs: + // should be two unstack ops. + Operator *fw_rnn_output, *bw_rnn_output; + if (!MatchTwoUnpackOps(*final_concat_op, *model, &fw_rnn_output, + &bw_rnn_output)) { + return ::tensorflow::Status::OK(); + } + + // Find all upstream unidirectional rnn ops. + std::stack fw_unidirectional_sequence_rnn_ops, + bw_unidirectional_sequence_rnn_ops; + Operator *first_fw_rnn_input, *first_bw_rnn_input; + if (!FindUnidirectionalSequenceOp( + *model, *fw_rnn_output, OperatorType::kUnidirectionalSequenceRnn, + &fw_unidirectional_sequence_rnn_ops, &first_fw_rnn_input) || + !FindUnidirectionalSequenceOp( + *model, *bw_rnn_output, OperatorType::kUnidirectionalSequenceRnn, + &bw_unidirectional_sequence_rnn_ops, &first_bw_rnn_input)) { + return ::tensorflow::Status::OK(); + } + + if (!CheckTwoUnidirectionalSequenceOpsAreValid( + *model, fw_unidirectional_sequence_rnn_ops, + bw_unidirectional_sequence_rnn_ops, first_fw_rnn_input, + first_bw_rnn_input, /*is_dynamic_rnn=*/false)) { + return ::tensorflow::Status::OK(); + } + + std::vector bidirectional_sequence_rnn_ops; + GroupFwBwSequenceOps(model, fw_unidirectional_sequence_rnn_ops, + bw_unidirectional_sequence_rnn_ops, + &bidirectional_sequence_rnn_ops); + + // Rewire the inputs & outputs. + string current_input = first_fw_rnn_input->outputs[0]; + RewireBidirectionalSequenceSequenceOpsConnections( + OperatorType::kBidirectionalSequenceRnn, current_input, + bidirectional_sequence_rnn_ops, &op_it, model); + + // Insert a unpack op for the output. + UnpackOperator* unpack_operator = new UnpackOperator; + RewireFinalUnpackOutputs( + static_cast(*fw_rnn_output), &unpack_operator, + &bidirectional_sequence_rnn_ops[bidirectional_sequence_rnn_ops.size() - + 1], + model); + model->operators.emplace(op_it, unpack_operator); + + // Delete unused ops. + RemoveUnpackOperator(*fw_rnn_output, model); + RemoveUnpackOperator(*bw_rnn_output, model); + RemoveUnidirectionalSequenceOps(fw_unidirectional_sequence_rnn_ops, model); + RemoveUnidirectionalSequenceOps(bw_unidirectional_sequence_rnn_ops, model); + // Only keep the fw rnn's pack input. + DeleteArrayIfUnused(first_bw_rnn_input->outputs[0], model); + model->operators.erase(FindOperator(model, *first_bw_rnn_input)); + *modified = true; + return ::tensorflow::Status::OK(); +} + +::tensorflow::Status GroupDynamicBidirectionalSequenceRnn::Run( + Model* model, std::size_t op_index, bool* modified) { + return GroupDynamicSequenceOps( + model, op_index, OperatorType::kBidirectionalSequenceRnn, modified); +} + +::tensorflow::Status GroupDynamicBidirectionalSequenceLstm::Run( + Model* model, std::size_t op_index, bool* modified) { + return GroupDynamicSequenceOps( + model, op_index, OperatorType::kBidirectionalSequenceLstm, modified); +} + +} // namespace toco diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc index 089ecee959a3ab80474782a88fa176b7a9f42001..65dbb8a1766a6aae4347435b392ff4af49e3d44e 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc @@ -147,12 +147,26 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, if (final_output_mul->type != OperatorType::kMul) { return ::tensorflow::Status::OK(); } + // final_output_mul->outputs[0] would be one of the two outputs of our + // LstmCell. Exit if it does not already have a data type. + // We won't be able to propagate data types through a fused LstmCell. + if (model->GetArray(final_output_mul->outputs[0]).data_type == + ArrayDataType::kNone) { + return ::tensorflow::Status::OK(); + } Operator *state_output_tanh, *fc_output_sig; if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, &state_output_tanh, OperatorType::kLogistic, &fc_output_sig)) { return ::tensorflow::Status::OK(); } + // state_output_tanh->inputs[0] would be one of the two outputs of our + // LstmCell. Exit if it does not already have a data type. + // We won't be able to propagate data types through a fused LstmCell. + if (model->GetArray(state_output_tanh->inputs[0]).data_type == + ArrayDataType::kNone) { + return ::tensorflow::Status::OK(); + } // State output TanH // (We don't count an operator as ID'd until we verify it has the correct @@ -262,11 +276,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT])); const string& concat_temp_array_name = AvailableArrayName(*model, base_name + "concat_temp"); - model->GetOrCreateArray(concat_temp_array_name); + auto& concat_temp_array = model->GetOrCreateArray(concat_temp_array_name); + concat_temp_array.data_type = + model->GetArray(concat_inputs->outputs[0]).data_type; lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; const string& activ_temp_array_name = AvailableArrayName(*model, base_name + "activ_temp"); - model->GetOrCreateArray(activ_temp_array_name); + auto& activ_temp_array = model->GetOrCreateArray(activ_temp_array_name); + activ_temp_array.data_type = + model->GetArray(fully_connected->outputs[0]).data_type; lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name; AddMessageF("Created temp outputs %s and %s on operator %s", concat_temp_array_name, activ_temp_array_name, diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index 7aec6728da6bc51779e0ec15377fd0c12696e94c..cb66a2372fdd3edf484902c336821b35befae48d 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -266,6 +266,26 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, model->GetArray(op->outputs[1]).data_type = unique_op->idx_out_type; break; } + case OperatorType::kBidirectionalSequenceLstm: { + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK(); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kBidirectionalSequenceRnn: { + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK(); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kLstmCell: { + // It's tricky to propagate data types through a LstmCell, as that has + // multiple inputs and outputs, and there are quantized cases with + // mixed (8bit vs 16bit) cases. Fortunately, that should never be needed, + // as the data formats, such as TFLITE, that have LstmCell nodes, also + // have data type fields for all their arrays. + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 04a5a1c1687b4caae2f31548ec549cb95e153df5..38becd6f63386f568ab2d9ff6244ecabb84f8e7f 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -110,6 +110,13 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { case OperatorType::kSelect: case OperatorType::kTile: // Reshapes and transposes don't change values. + case OperatorType::kRelu: + case OperatorType::kRelu1: + case OperatorType::kRelu6: + // Relus only clamp the output. If min/max of parent is unknown, just + // prop the range backward. This only happens for cases where activations + // are not fused to avoid a default being set on the RELU input and + // propagating forward to the RELU output. return false; default: return true; diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 1b1780a73b057bc2da91af65e342aff77546a11a..ca72d0037a9e67d549ae3e337a2891c75648c021 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/strings/str_join.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { @@ -1081,6 +1081,18 @@ void ProcessUnidirectionalSequenceLstmOperator( // TODO(renjieliu): check the inputs, as well as all kinds of weights. const auto& input_array = model->GetArray(op->inputs[0]); + + constexpr int kInputActivationStateTensor = 18; + constexpr int kInputCellStateTensor = 19; + + // TFlite intepreter does not support array which is variable and contains a + // buffer (see b/115961645 for more discussion). + // The follow block remove buffer from the array to work around the + // restriction, as a consequence, downstream applications should not + // read lstm state as input to other operations. + model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset(); + // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -1096,12 +1108,6 @@ void ProcessUnidirectionalSequenceLstmOperator( return; } - constexpr int kInputActivationStateTensor = 18; - constexpr int kInputCellStateTensor = 19; - // b(115961645): This is a hack to work around. - model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset(); - model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset(); - const auto& output_weights_shape = recurrent_to_output_weights_array.shape(); const int output_size = output_weights_shape.dims(1); @@ -1122,6 +1128,14 @@ void ProcessUnidirectionalSequenceRnnOperator( return; } + constexpr int kHiddenStateTensor = 4; + // TFlite intepreter does not support array which is variable and contains a + // buffer (see b/115961645 for more discussion). + // The follow block remove buffer from the array to work around the + // restriction, as a consequence, downstream applications should not + // read lstm state as input to other operations. + model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset(); + // TODO(renjieliu): check the inputs, as well as all kinds of weights. const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1138,10 +1152,6 @@ void ProcessUnidirectionalSequenceRnnOperator( return; } - constexpr int kHiddenStateTensor = 4; - // b(115961645): This is a hack to work around. - model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset(); - const auto& bias_shape = bias_array.shape(); const int output_size = bias_shape.dims(0); @@ -1149,6 +1159,114 @@ void ProcessUnidirectionalSequenceRnnOperator( output_shape->ReplaceDims({timestamp, batch_size, output_size}); } +void ProcessBidirectionalSequenceLstmOperator( + Model* model, BidirectionalSequenceLstmOperator* op) { + // We assume time major. + auto& fw_output_array = model->GetArray(op->outputs[0]); + auto& bw_output_array = model->GetArray(op->outputs[1]); + if (fw_output_array.has_shape()) { + // Shape already propagated + return; + } + + if (fw_output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + + // TODO(renjieliu): check the inputs, as well as all kinds of weights. + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const int batch_size = input_shape.dims(1); + const int timestamp = input_shape.dims(0); + + constexpr int kBwRecurrentToOutputWeightsTensor = 25; + const auto& recurrent_to_output_weights_array = + model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]); + // Yield until input dims have been resolved. + if (!recurrent_to_output_weights_array.has_shape()) { + return; + } + + constexpr int kFwInputActivationStateTensor = 35; + constexpr int kFwInputCellStateTensor = 36; + constexpr int kBwInputActivationStateTensor = 37; + constexpr int kBwInputCellStateTensor = 38; + // b(115961645): This is a hack to work around. + model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset(); + + const auto& output_weights_shape = recurrent_to_output_weights_array.shape(); + const int output_size = output_weights_shape.dims(1); + + Shape* fw_output_shape = fw_output_array.mutable_shape(); + if (op->merge_outputs) { + fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size}); + } else { + fw_output_shape->ReplaceDims({timestamp, batch_size, output_size}); + Shape* bw_output_shape = bw_output_array.mutable_shape(); + bw_output_shape->ReplaceDims({timestamp, batch_size, output_size}); + } +} + +void ProcessBidirectionalSequenceRnnOperator( + Model* model, BidirectionalSequenceRnnOperator* op) { + // We assume time major. + auto& fw_output_array = model->GetArray(op->outputs[0]); + auto& bw_output_array = model->GetArray(op->outputs[1]); + if (fw_output_array.has_shape()) { + // Shape already propagated + return; + } + + if (fw_output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + + // TODO(renjieliu): check the inputs, as well as all kinds of weights. + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const int batch_size = input_shape.dims(1); + const int timestamp = input_shape.dims(0); + + constexpr int kFwWeightsTensor = 1; + const auto& forward_weights_array = + model->GetArray(op->inputs[kFwWeightsTensor]); + // Yield until input dims have been resolved. + if (!forward_weights_array.has_shape()) { + return; + } + + constexpr int kFwHiddenStateTensor = 4; + constexpr int kBwHiddenStateTensor = 8; + // b(115961645): This is a hack to work around. + model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset(); + + const auto& output_weights_shape = forward_weights_array.shape(); + const int output_size = output_weights_shape.dims(0); + + Shape* fw_output_shape = fw_output_array.mutable_shape(); + if (op->merge_outputs) { + fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size}); + } else { + fw_output_shape->ReplaceDims({timestamp, batch_size, output_size}); + Shape* bw_output_shape = bw_output_array.mutable_shape(); + bw_output_shape->ReplaceDims({timestamp, batch_size, output_size}); + } +} + void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1292,6 +1410,38 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { } } +void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) { + const auto& input_array = model->GetArray(op->inputs[0]); + const auto& indices_array = model->GetArray(op->inputs[1]); + auto& output_array = model->GetArray(op->outputs[0]); + + // Bail if we already know the output shape. + if (output_array.has_shape()) { + return; + } + + // Yield until input dims have been resolved. + if (!input_array.has_shape() || !indices_array.has_shape()) { + return; + } + + const auto& input_shape = input_array.shape(); + const auto& indices_shape = indices_array.shape(); + QCHECK_GE(input_shape.dimensions_count(), 1); + QCHECK_GE(indices_shape.dimensions_count(), 1); + const int indices_nd = + indices_shape.dims(indices_shape.dimensions_count() - 1); + QCHECK_LE(indices_nd, input_shape.dimensions_count()); + + auto output_dims = output_array.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) { + output_dims->push_back(indices_shape.dims(dim)); + } + for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) { + output_dims->push_back(input_shape.dims(dim)); + } +} + void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { const auto& input_values = model->GetArray(op->inputs[0]); const auto& input_k = model->GetArray(op->inputs[1]); @@ -1377,7 +1527,7 @@ void ProcessPadV2Operator(Model* model, PadV2Operator* op) { output_array.copy_shape(output_shape); } -void ProcessRankOperator(Model* model, RankOperator* op) { +void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); @@ -1567,11 +1717,16 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { const std::vector& input_dims = input_array.shape().dims(); std::vector output_dims; - for (int i = 0; i < input_dims.size(); ++i) { + std::vector squeeze_dims; + const int input_num_dims = input_dims.size(); + for (int i : op->squeeze_dims) { + squeeze_dims.push_back(i < 0 ? i + input_num_dims : i); + } + for (int i = 0; i < input_num_dims; ++i) { if (input_dims[i] != 1 || - (!op->squeeze_dims.empty() && - std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) == - op->squeeze_dims.end())) { + (!squeeze_dims.empty() && + std::find(squeeze_dims.begin(), squeeze_dims.end(), i) == + squeeze_dims.end())) { output_dims.push_back(input_dims[i]); } } @@ -1656,14 +1811,37 @@ void ProcessArgMinMaxOperator(Model* model, Op* op) { return; } + const Array& axis_array = model->GetArray(op->inputs[1]); + // Yield until input axis array shape has been resolved. + if (!axis_array.has_shape()) { + return; + } + const std::vector& input_dims = input_array.shape().dims(); + + CHECK(axis_array.data_type == ArrayDataType::kInt32 || + axis_array.data_type == ArrayDataType::kInt64) + << "axis_array must be int32, int64"; + + CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1) + << "Axis array must be scalar."; + + int64 axis; + if (axis_array.data_type == ArrayDataType::kInt32) { + axis = axis_array.GetBuffer().data[0]; + } else { + axis = axis_array.GetBuffer().data[0]; + } + std::vector output_dims; - output_dims.reserve(input_dims.size()); - for (int i = 0; i < input_dims.size() - 1; ++i) { - output_dims.push_back(input_dims[i]); + output_dims.reserve(input_dims.size() - 1); + for (int i = 0; i < input_dims.size(); ++i) { + if (i != axis) { + output_dims.push_back(input_dims[i]); + } } - output_dims.push_back(1); + const string& output_name = op->outputs[0]; auto& output_array = model->GetArray(output_name); if (output_array.has_shape()) { @@ -1902,6 +2080,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { case OperatorType::kBatchNormalization: case OperatorType::kL2Normalization: case OperatorType::kDequantize: + case OperatorType::kElu: case OperatorType::kRelu: case OperatorType::kRelu1: case OperatorType::kRelu6: @@ -1923,17 +2102,24 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { case OperatorType::kAssert: case OperatorType::kCast: case OperatorType::kFloor: + case OperatorType::kCeil: case OperatorType::kExp: case OperatorType::kSin: + case OperatorType::kCos: case OperatorType::kLogicalAnd: case OperatorType::kLogicalNot: case OperatorType::kLogicalOr: case OperatorType::kZerosLike: + case OperatorType::kReverseV2: + case OperatorType::kReverseSequence: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: ProcessGatherOperator(model, static_cast(op)); break; + case OperatorType::kGatherNd: + ProcessGatherNdOperator(model, static_cast(op)); + break; case OperatorType::kTopK_V2: ProcessTopkV2Operator(model, static_cast(op)); break; @@ -2050,7 +2236,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { ProcessRangeOperator(model, static_cast(op)); break; case OperatorType::kRank: - ProcessRankOperator(model, static_cast(op)); + ProcessRankOperator(model, static_cast(op)); break; case OperatorType::kShape: ProcessShapeOperator(model, static_cast(op)); @@ -2081,6 +2267,14 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { ProcessUnidirectionalSequenceRnnOperator( model, static_cast(op)); break; + case OperatorType::kBidirectionalSequenceLstm: + ProcessBidirectionalSequenceLstmOperator( + model, static_cast(op)); + break; + case OperatorType::kBidirectionalSequenceRnn: + ProcessBidirectionalSequenceRnnOperator( + model, static_cast(op)); + break; case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast(op)); break; @@ -2164,6 +2358,11 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) { case OperatorType::kUnique: ProcessUniqueOperator(model, static_cast(op)); break; + case OperatorType::kWhere: + // The size of the output can only be known after evaluating the cond + // tensor. Ignore shape propagation here and defer that to the + // interpreter. + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index ee65f92e00cd9f9347e62db314ca3a3f5e8bb396..cef1774d4131fed2ab52850a2ebe53634f34a15a 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -106,7 +106,7 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { // We always want [min, max] to contain 0. float min = 0.f; float max = 0.f; - for (auto val : data) { + for (const auto& val : data) { min = std::min(min, val); max = std::max(max, val); } @@ -121,7 +121,7 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { // weights arrays for which fake-quantization would make sense, rather // they tend to be hardcoded arrays of zeros or ones used in some graphs. bool is_quantization_trivially_exact = true; - for (auto val : data) { + for (const auto& val : data) { is_quantization_trivially_exact &= (val == min || val == max); } if (!is_quantization_trivially_exact) { @@ -489,20 +489,20 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation, } } if (!SupportsQuantization(op)) { - LOG(FATAL) << "Unimplemented: this graph contains an operator of type " - << HelpfulOperatorTypeName(op) - << " for which the quantized form is not yet implemented. " - "Sorry, and patches welcome (that's a relatively fun patch " - "to write, mostly providing the actual quantized arithmetic " - "code for this op)."; + return tensorflow::errors::InvalidArgument( + "Unimplemented: this graph contains an operator of type ", + HelpfulOperatorTypeName(op), + " for which the quantized form is not yet implemented. Sorry, and " + "patches welcome (that's a relatively fun patch to write, mostly " + "providing the actual quantized arithmetic code for this op)."); } for (const auto& input : op.inputs) { const auto& array = model->GetArray(input); if (array.data_type == ArrayDataType::kFloat) { if (!array.minmax && !array.buffer) { - LOG(ERROR) << "Can't quantize input array " << input - << " because it lacks min/max info"; + LOG(WARNING) << "Can't quantize input array " << input + << " because it lacks min/max info"; return ::tensorflow::Status::OK(); } const auto* other_op = GetOpWithOutput(*model, input); diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc index 8879a7cd2664ed3f32e32435f9d45c0744dfbea2..b9405e1fa057944bd2498ce196ae8ee5e357d872 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -31,7 +31,7 @@ namespace { template bool AreAllBufferElementsEqualTo(const std::vector& buffer_data, Scalar value) { - for (auto x : buffer_data) { + for (const auto& x : buffer_data) { if (x != value) { return false; } diff --git a/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 6a4b9198548956217d24693bceff2bd6b3b8f0a6..98105d384e176573b248ffc3fd75710768002750 100644 --- a/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -30,6 +30,7 @@ namespace { bool IsElementwiseOperator(OperatorType optype) { switch (optype) { case OperatorType::kCast: + case OperatorType::kCeil: case OperatorType::kExp: case OperatorType::kFloor: case OperatorType::kNeg: diff --git a/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc index fdd411c84c2678bc483b00849d5142665e706fac..77803d580e98aea94f0a7191666212cb15f58a7a 100644 --- a/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -218,6 +218,7 @@ std::vector ComputeNewPerm(std::vector input_dims, CHECK_EQ(input_dims.size(), new_perm.size()); auto& transpose_array = model->GetOrCreateArray(transpose_op->inputs[1]); + transpose_array.data_type = ArrayDataType::kInt32; transpose_array.GetMutableBuffer().data = new_perm; *(transpose_array.mutable_shape()->mutable_dims()) = { static_cast(new_perm.size())}; diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 0c9effee1fd364fa83f61339251e48070f503d1e..a46bb803eba49f8488c83993f79c3d69bd91ef8e 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -71,34 +71,29 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, // Each "digit" is incremented individually (by the stride). When it overflows // (becomes greater than the stop), that digit is reset and a carry flag is // used to increment the next digit. - int dst_offset = 0; - do { + for (size_t dst_offset = 0; dst_offset < output_data.size(); ++dst_offset) { // Copy element. output_data[dst_offset] = input_buffer.data[Offset(input_shape, src_coord)]; - // Compute next source input coordinates. - bool carry = true; - for (int axis = 0; axis < num_input_axes; axis++) { + // Note we consider elements in the highest dimension are stored + // contiguously. So, we increment the stride starting from the highest + // dimension. + for (int axis = num_input_axes - 1; axis >= 0; --axis) { int stride = op.strides[axis]; - // Increment this axis if we carried from the previous one - if (carry) { - src_coord[axis] += stride; - } + src_coord[axis] += stride; - // Check if we've overflowed. + // Check if we've overflowed. If not, we just break from the loop to + // continue w/ the element copy. Otherwise, reset the starting coordinate + // for this axis and move to the next lower axis. int stop = stop_for_axis[axis]; - if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { - // Reset axis and set carry - src_coord[axis] = tflite::strided_slice::StartForAxis( - strided_slice_params, ToRuntimeShape(input_shape), axis); - carry = true; - } else { - carry = false; + if (!tflite::strided_slice::LoopCondition(src_coord[axis], stop, + stride)) { + break; } + src_coord[axis] = tflite::strided_slice::StartForAxis( + strided_slice_params, ToRuntimeShape(input_shape), axis); } - // increment destination buffer offset - dst_offset++; - } while (dst_offset < output_data.size()); + } } } // anonymous namespace diff --git a/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc index 43070b063c4a426907e80f444e00da44417c0e18..ce0854b4721ce5878b0f91f114aa535784e81cff 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -30,7 +30,7 @@ namespace toco { namespace { // Using the function reducer, reduce input along all axes in axes. -// Put the reduced data in output, which should aleady be appropriately sized. +// Put the reduced data in output, which should already be appropriately sized. // check_output_shape is set to what this code computes the final shape // to be, so it can be cross checked with the shape computation logic. void ReduceGeneric(bool keep_dims, const std::vector& axes, @@ -54,7 +54,7 @@ void ReduceGeneric(bool keep_dims, const std::vector& axes, // Reduction mask will be elementwise multiplied against the input // indices to figure out the output index for the element. std::vector reduction_mask(input_shape.dimensions_count(), 1); - for (int axis : axes) { + for (const auto& axis : axes) { CHECK_GE(axis, 0); CHECK_LT(axis, input_shape.dimensions_count()); reduction_mask[axis] = 0; diff --git a/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc index c0becaf7d39cdbc01217bbb9b5a6b50017cc2eaa..2c860c30974766a093ef1bf2d9a93fb29bb65949 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc @@ -61,11 +61,11 @@ namespace toco { minmax.max = max_array.GetBuffer().data[0]; // We always want [min, max] to contain 0. if (minmax.min > 0 || minmax.max < 0) { - LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range " - << "[" << minmax.min << ", " << minmax.max - << "] does not contain 0. " - << "Proceeding by tweaking it to contain 0, which will result " - "in poor accuracy."; + LOG(WARNING) << "For " << LogName(*fakequant_op) << " the MinMax range " + << "[" << minmax.min << ", " << minmax.max + << "] does not contain 0. " + << "Proceeding by tweaking it to contain 0, which will result " + "in poor accuracy."; } minmax.min = std::min(minmax.min, 0.); minmax.max = std::max(minmax.max, 0.); diff --git a/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc index 51c724dd1ab058e08a3a29a5408e5d584831a3d9..e65a0dc48581dc6c5768b94cb9b45bcc99cb9cd1 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -49,8 +49,8 @@ void FillArrayWithZeros(Array* array) { } // namespace // Removes a multiplication by array of constant zeros by making the output -// array an array of constant zeros and removing the input arrays if they are no -// longer needed. +// array to an array of constant zeros and removing the input arrays if they +// are no longer needed. ::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, std::size_t op_index, bool* modified) { diff --git a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc index 41a735394d714b65a4c9fc309927e34a7f610431..7492f3e116c60ca2c574bf8d2fd4b08f5914f3d0 100644 --- a/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -13,17 +13,192 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { +namespace { + +void UnrollBatchMatMul3D( + const string& input_lhs, const string& input_rhs, + const BatchMatMulOperator* batch_op, const std::vector batch, + Model* model, std::vector>::iterator* tail_it, + std::vector* pack_inputs) { + const std::string batch_name = + absl::StrCat(batch_op->outputs[0], "_b", absl::StrJoin(batch, "-")); + const auto& input_array_a = model->GetArray(input_lhs); + const auto& input_array_b = model->GetArray(input_rhs); + const int dims_count = input_array_a.shape().dimensions_count(); + + // tf.slice(a, ...). + std::vector begin_indices_a = batch; + begin_indices_a.resize(dims_count); + std::vector slice_size_a = input_array_a.shape().dims(); + for (int i = 0; i < batch.size(); ++i) { + slice_size_a[i] = 1; + } + auto* slice_a_op = new SliceOperator; + slice_a_op->inputs = { + input_lhs, + CreateInt32Array(model, batch_name + "/slice_a/slice/begin", + begin_indices_a), + CreateInt32Array(model, batch_name + "/slice_a/slice/size", slice_size_a), + }; + slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")}; + auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]); + slice_a_op_output.data_type = input_array_a.data_type; + *tail_it = model->operators.emplace(*tail_it, slice_a_op) + 1; + + // Reshape to remove the first dimension ([1,M,N] -> [M,N]). + auto* slice_a_reshape_op = new TensorFlowReshapeOperator; + slice_a_reshape_op->inputs = { + slice_a_op->outputs[0], + CreateInt32Array(model, batch_name + "/slice_a/reshape/shape", + {-1, input_array_a.shape().dims(dims_count - 1)})}; + slice_a_reshape_op->outputs = { + AvailableArrayName(*model, batch_name + "/slice_a/reshape")}; + auto& slice_a_reshape_op_output = + model->GetOrCreateArray(slice_a_reshape_op->outputs[0]); + slice_a_reshape_op_output.data_type = input_array_a.data_type; + *tail_it = model->operators.emplace(*tail_it, slice_a_reshape_op) + 1; + + // tf.slice(b, ...). + std::vector begin_indices_b = batch; + begin_indices_b.resize(dims_count); + std::vector slice_size_b = input_array_b.shape().dims(); + for (int i = 0; i < batch.size(); ++i) { + slice_size_b[i] = 1; + } + auto* slice_b_op = new SliceOperator; + slice_b_op->inputs = { + input_rhs, + CreateInt32Array(model, batch_name + "/slice_b/slice/begin", + begin_indices_b), + CreateInt32Array(model, batch_name + "/slice_b/slice/size", slice_size_b), + }; + slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")}; + auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]); + slice_b_op_output.data_type = input_array_b.data_type; + *tail_it = model->operators.emplace(*tail_it, slice_b_op) + 1; + + // Reshape to remove the first dimension ([1,M,N] -> [M,N]). + auto* slice_b_reshape_op = new TensorFlowReshapeOperator; + slice_b_reshape_op->inputs = { + slice_b_op->outputs[0], + CreateInt32Array(model, batch_name + "/slice_b/reshape/shape", + {-1, input_array_b.shape().dims(dims_count - 1)})}; + slice_b_reshape_op->outputs = { + AvailableArrayName(*model, batch_name + "/slice_b/reshape")}; + auto& slice_b_reshape_op_output = + model->GetOrCreateArray(slice_b_reshape_op->outputs[0]); + slice_b_reshape_op_output.data_type = input_array_b.data_type; + *tail_it = model->operators.emplace(*tail_it, slice_b_reshape_op) + 1; + + // tf.matmul(slice_a, slice_b). + auto* matmul_op = new TensorFlowMatMulOperator; + matmul_op->inputs = {slice_a_reshape_op->outputs[0], + slice_b_reshape_op->outputs[0]}; + matmul_op->outputs = {AvailableArrayName(*model, batch_name)}; + auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]); + matmul_op_output.data_type = input_array_a.data_type; + *tail_it = model->operators.emplace(*tail_it, matmul_op) + 1; + + // Add to stack. + pack_inputs->push_back(matmul_op->outputs[0]); +} + +std::vector UnrollBatchMatMulRecursion( + const string& input_lhs, const string& input_rhs, + const BatchMatMulOperator* batch_op, Model* model, + std::vector>::iterator* tail_it, + const std::vector& batch_prefix) { + const auto& input_array_a = model->GetArray(input_lhs); + const auto& dims_vec = input_array_a.shape().dims(); + const int current_dim_size = dims_vec[batch_prefix.size()]; + std::vector batch_pack_inputs; + + if (batch_prefix.size() + 3 == dims_vec.size()) { + // Base case + for (int batch = 0; batch < current_dim_size; ++batch) { + std::vector new_batch_prefix = batch_prefix; + new_batch_prefix.emplace_back(batch); + UnrollBatchMatMul3D(input_lhs, input_rhs, batch_op, new_batch_prefix, + model, tail_it, &batch_pack_inputs); + } + } else { + // Recursion + for (int batch = 0; batch < current_dim_size; ++batch) { + std::vector new_batch_prefix = batch_prefix; + new_batch_prefix.emplace_back(batch); + std::vector pack_inputs = UnrollBatchMatMulRecursion( + input_lhs, input_rhs, batch_op, model, tail_it, new_batch_prefix); + + // The pack that will join all the individual matmul results together. + auto* pack_op = new PackOperator; + std::string batch_name = absl::StrCat( + batch_op->outputs[0], "_b", absl::StrJoin(new_batch_prefix, "-")); + pack_op->inputs = pack_inputs; + pack_op->outputs = {AvailableArrayName(*model, batch_name + "/pack")}; + auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]); + pack_op_output.data_type = input_array_a.data_type; + pack_op->axis = 0; + pack_op->values_count = pack_inputs.size(); + *tail_it = model->operators.emplace(*tail_it, pack_op) + 1; + + batch_pack_inputs.push_back(pack_op->outputs[0]); + } + } + return batch_pack_inputs; +} + +std::vector GetTransposePerm(const Array& input_array) { + const int32 dims = input_array.shape().dimensions_count(); + std::vector perm_array_val(dims); + for (int i = 0; i < dims; ++i) { + perm_array_val[i] = i; + } + perm_array_val[dims - 2] = dims - 1; + perm_array_val[dims - 1] = dims - 2; + return perm_array_val; +} + +std::vector GetTransposeShape(const Shape& input_shape, + const std::vector& perm_array_val) { + const int32 dims = input_shape.dimensions_count(); + std::vector output_shape(dims); + for (int i = 0; i < dims; ++i) { + output_shape[i] = input_shape.dims(perm_array_val[i]); + } + return output_shape; +} + +TransposeOperator* TransposeInput(const string& input, Model* model) { + const auto& input_array = model->GetArray(input); + const auto perm_array = GetTransposePerm(input_array); + const string perm_array_name = CreateInt32Array( + model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array); + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = {input, perm_array_name}; + transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")}; + auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]); + *transpose_array.mutable_shape()->mutable_dims() = + GetTransposeShape(input_array.shape(), perm_array); + model->GetOrCreateArray(transpose_op->outputs[0]); + return transpose_op; +} + +} // namespace + // Unrolls a BatchMatMul on the batch dimension. // We need to slice each batch out of the inputs, matmul them individually, then // stack them all back together at the end. @@ -46,115 +221,67 @@ namespace toco { const auto* batch_op = static_cast(batch_op_it->get()); - // We must have the shape of at least one input to know our batch size. - const auto& input_array_a = model->GetArray(batch_op->inputs[0]); - const auto& input_array_b = model->GetArray(batch_op->inputs[1]); - if (!input_array_a.has_shape() || !input_array_b.has_shape()) + auto& tail_it = batch_op_it; + + string input_lhs = batch_op->inputs[0]; + string input_rhs = batch_op->inputs[1]; + const auto& input_lhs_array = model->GetArray(input_lhs); + const auto& input_rhs_array = model->GetArray(input_rhs); + if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape()) return ::tensorflow::Status::OK(); - // We only support the rank 3 case. If you are batching on rank > 3 you'll - // have to figure that out. - CHECK_EQ(input_array_a.shape().dimensions_count(), - input_array_b.shape().dimensions_count()) - << "Input dimensions must have the same rank"; - if (input_array_a.shape().dimensions_count() == 2) { + // Transpose LHS input if necessary. + if (batch_op->adj_x) { + TransposeOperator* transpose_op = TransposeInput(input_lhs, model); + tail_it = model->operators.emplace(tail_it, transpose_op) + 1; + input_lhs = transpose_op->outputs[0]; + } + const auto& input_array_a = model->GetArray(input_lhs); + + // Transpose RHS input if necessary. + if (batch_op->adj_y) { + TransposeOperator* transpose_op = TransposeInput(input_rhs, model); + tail_it = model->operators.emplace(tail_it, transpose_op) + 1; + input_rhs = transpose_op->outputs[0]; + } + const auto& input_array_b = model->GetArray(input_rhs); + + const int dims = input_array_a.shape().dimensions_count(); + for (int i = 0; i < dims - 2; ++i) { + CHECK_EQ(input_array_a.shape().dims(i), input_array_b.shape().dims(i)) + << "input array not consistent at index " << i; + } + CHECK_EQ(input_array_a.shape().dims(dims - 1), + input_array_b.shape().dims(dims - 2)) + << "Input dimensions must be compatible for multipication. shape a = [" + << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = [" + << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]"; + + if (dims == 2) { // This is really just a MatMul. This likely means that someone hand-crafted // a graphdef with a BatchMatMul when they really wanted a MatMul. AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator", LogName(*batch_op)); auto* matmul_op = new TensorFlowMatMulOperator; - matmul_op->inputs = batch_op->inputs; + matmul_op->inputs = {input_lhs, input_rhs}; matmul_op->outputs = batch_op->outputs; - const auto matmul_op_it = model->operators.emplace(batch_op_it, matmul_op); - batch_op_it = matmul_op_it + 1; - CHECK_EQ(batch_op_it->get(), batch_op); - model->operators.erase(batch_op_it); + tail_it = model->operators.emplace(tail_it, matmul_op) + 1; + CHECK_EQ(tail_it->get(), batch_op); + model->operators.erase(tail_it); *modified = true; return ::tensorflow::Status::OK(); } - CHECK_EQ(input_array_a.shape().dimensions_count(), 3) - << "Input arrays must have rank 3"; - // Perform the matmul for each slice of the batch. - int batch_count = input_array_a.shape().dims(0); - AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), - batch_count); - auto tail_it = batch_op_it; - std::vector pack_inputs; - for (int batch = 0; batch < batch_count; ++batch) { - std::string batch_name = - std::string(batch_op->outputs[0]) + "_b" + std::to_string(batch); - - // tf.slice(a, ...). - auto* slice_a_op = new SliceOperator; - slice_a_op->inputs = { - batch_op->inputs[0], - CreateInt32Array(model, batch_name + "/slice_a/slice/begin", - {batch, 0, 0}), - CreateInt32Array( - model, batch_name + "/slice_a/slice/size", - {1, input_array_a.shape().dims(1), input_array_a.shape().dims(2)}), - }; - slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")}; - auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]); - slice_a_op_output.data_type = input_array_a.data_type; - tail_it = model->operators.emplace(tail_it, slice_a_op) + 1; - - // Reshape to remove the first dimension ([1,M,N] -> [M,N]). - auto* slice_a_reshape_op = new TensorFlowReshapeOperator; - slice_a_reshape_op->inputs = { - slice_a_op->outputs[0], - CreateInt32Array(model, batch_name + "/slice_a/reshape/shape", - {-1, input_array_a.shape().dims(2)})}; - slice_a_reshape_op->outputs = { - AvailableArrayName(*model, batch_name + "/slice_a/reshape")}; - auto& slice_a_reshape_op_output = - model->GetOrCreateArray(slice_a_reshape_op->outputs[0]); - slice_a_reshape_op_output.data_type = input_array_a.data_type; - tail_it = model->operators.emplace(tail_it, slice_a_reshape_op) + 1; - - // tf.slice(b, ...). - auto* slice_b_op = new SliceOperator; - slice_b_op->inputs = { - batch_op->inputs[1], - CreateInt32Array(model, batch_name + "/slice_b/slice/begin", - {batch, 0, 0}), - CreateInt32Array( - model, batch_name + "/slice_b/slice/size", - {1, input_array_b.shape().dims(1), input_array_b.shape().dims(2)}), - }; - slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")}; - auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]); - slice_b_op_output.data_type = input_array_b.data_type; - tail_it = model->operators.emplace(tail_it, slice_b_op) + 1; - - // Reshape to remove the first dimension ([1,M,N] -> [M,N]). - auto* slice_b_reshape_op = new TensorFlowReshapeOperator; - slice_b_reshape_op->inputs = { - slice_b_op->outputs[0], - CreateInt32Array(model, batch_name + "/slice_b/reshape/shape", - {-1, input_array_b.shape().dims(2)})}; - slice_b_reshape_op->outputs = { - AvailableArrayName(*model, batch_name + "/slice_b/reshape")}; - auto& slice_b_reshape_op_output = - model->GetOrCreateArray(slice_b_reshape_op->outputs[0]); - slice_b_reshape_op_output.data_type = input_array_b.data_type; - tail_it = model->operators.emplace(tail_it, slice_b_reshape_op) + 1; - - // tf.matmul(slice_a, slice_b). - auto* matmul_op = new TensorFlowMatMulOperator; - matmul_op->inputs = {slice_a_reshape_op->outputs[0], - slice_b_reshape_op->outputs[0]}; - matmul_op->outputs = {AvailableArrayName(*model, batch_name)}; - auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]); - matmul_op_output.data_type = input_array_a.data_type; - tail_it = model->operators.emplace(tail_it, matmul_op) + 1; + CHECK_GE(input_array_a.shape().dimensions_count(), 3) + << "Input arrays must have rank >= 3"; - // Add to stack. - pack_inputs.push_back(matmul_op->outputs[0]); - } + const auto& dims_vec = input_array_a.shape().dims(); + AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), + std::accumulate(dims_vec.begin(), dims_vec.end() - 2, 1, + std::multiplies())); - // The pack that will join all the individual matmul results together. + std::vector pack_inputs = UnrollBatchMatMulRecursion( + input_lhs, input_rhs, batch_op, model, &tail_it, {}); auto* pack_op = new PackOperator; pack_op->inputs = pack_inputs; pack_op->outputs = {batch_op->outputs[0]}; diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index dac106b398870f497087fc14fa4c318eb8156408..fdf72bde05754c4af07f368ffac15366be303557 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -1092,11 +1092,14 @@ tensorflow::Status ConvertBatchMatMulOperator( Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions - CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); - CHECK(!HasAttr(node, "adj_b") || (GetBoolAttr(node, "adj_b") == false)); - auto* batch_matmul = new BatchMatMulOperator; + // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions + if (HasAttr(node, "adj_x")) { + batch_matmul->adj_x = GetBoolAttr(node, "adj_x"); + } + if (HasAttr(node, "adj_y")) { + batch_matmul->adj_y = GetBoolAttr(node, "adj_y"); + } batch_matmul->inputs = {node.input(0), node.input(1)}; batch_matmul->outputs = {node.name()}; @@ -1346,7 +1349,7 @@ tensorflow::Status ConvertUnsupportedOperator( } // Parse outputs. Name them after the node's name, plus an ordinal suffix. - // Note that some outputs are to be multipled by a named attribute. + // Note that some outputs are to be multiplied by a named attribute. const tensorflow::OpDef* op_def = nullptr; if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { GetOutputNamesFromNodeDef(node, *op_def, op); @@ -1480,7 +1483,7 @@ tensorflow::Status ConvertPlaceholderOperator( if (node.attr().count("shape")) { const auto& shape = GetShapeAttr(node, "shape"); auto num_dims = shape.dim_size(); - // TODO(b/62716978): This logic needs to be revisted. During dims + // TODO(b/62716978): This logic needs to be revisited. During dims // refactoring it is an interim fix. if (num_dims > 0 && !HasWildcardDimension(shape)) { auto& dst_array_dims = *array.mutable_shape()->mutable_dims(); @@ -1529,6 +1532,20 @@ tensorflow::Status ConvertFloorOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertCeilOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Ceil"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto data_type = GetDataTypeAttr(node, "T"); + CHECK(data_type == DT_FLOAT); + auto* op = new CeilOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertGatherOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1558,6 +1575,21 @@ tensorflow::Status ConvertGatherOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertGatherNdOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "GatherNd"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); + CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); + auto* op = new GatherNdOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + template tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1993,6 +2025,27 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertReverseSequenceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ReverseSequence"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + auto op = absl::make_unique(); + if (HasAttr(node, "seq_dim")) { + op->seq_dim = GetIntAttr(node, "seq_dim"); + } + // In tf.reverse_sequence, batch_dim defaults to 0. + op->batch_dim = + HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0; + const int num_inputs = GetInputsCount(node, tf_import_flags); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -2361,7 +2414,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map({ {"Abs", ConvertSimpleOperator}, {"Add", ConvertSimpleOperator}, - {"AddN", ConvertSimpleOperatorFlexOk}, + {"AddN", ConvertSimpleOperator}, {"All", ConvertSimpleOperator}, {"Any", ConvertReduceOperator}, {"ArgMax", ConvertArgMaxOperator}, @@ -2375,18 +2428,21 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"BatchToSpaceND", ConvertBatchToSpaceNDOperator}, {"BiasAdd", ConvertBiasAddOperator}, {"Cast", ConvertCastOperator}, + {"Ceil", ConvertCeilOperator}, {"CheckNumerics", ConvertIdentityOperator}, {"Concat", ConvertConcatOperator}, {"ConcatV2", ConvertConcatOperator}, {"Const", ConvertConstOperator}, {"Conv2D", ConvertConvOperator}, {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"Cos", ConvertSimpleOperator}, {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator}, {"DepthToSpace", ConvertDepthToSpaceOperator}, {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, {"Div", ConvertSimpleOperator}, {"DynamicPartition", ConvertDynamicPartitionOperator}, {"DynamicStitch", ConvertDynamicStitchOperator}, + {"Elu", ConvertSimpleOperator}, {"Equal", ConvertSimpleOperator}, {"Exp", ConvertSimpleOperator}, {"ExpandDims", ConvertSimpleOperator}, @@ -2399,6 +2455,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"FusedBatchNorm", ConvertFusedBatchNormOperator}, {"Gather", ConvertGatherOperator}, {"GatherV2", ConvertGatherOperator}, + {"GatherNd", ConvertGatherNdOperator}, {"Greater", ConvertSimpleOperator}, {"GreaterEqual", ConvertSimpleOperator}, @@ -2418,7 +2475,8 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator}, {"Mean", ConvertReduceOperator}, - {"Merge", ConvertSimpleOperator}, + {"Merge", + ConvertSimpleOperator}, {"Min", ConvertReduceOperator}, {"Minimum", ConvertSimpleOperator}, {"Mul", ConvertSimpleOperator}, @@ -2437,13 +2495,15 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Prod", ConvertReduceOperator}, {"RandomUniform", ConvertRandomUniform}, {"Range", ConvertRangeOperator}, - {"Rank", ConvertSimpleOperator}, + {"Rank", ConvertSimpleOperator}, {"RealDiv", ConvertSimpleOperator}, {"Relu", ConvertSimpleOperator}, {"Relu6", ConvertSimpleOperator}, {"Reshape", ConvertSimpleOperator}, {"ResizeBilinear", ConvertResizeBilinearOperator}, {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator}, + {"ReverseSequence", ConvertReverseSequenceOperator}, + {"ReverseV2", ConvertSimpleOperator}, {"Rsqrt", ConvertSimpleOperator}, {"Select", ConvertSimpleOperator}, {"Shape", ConvertShapeOperator}, @@ -2478,6 +2538,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn}, {"MirrorPad", ConvertMirrorPadOperator}, {"Unique", ConvertSimpleOperator}, + {"Where", ConvertSimpleOperator}, }); } diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc index 8ff3f7733afb4355a8e7863594633a6555287c10..b620ade756e457c0b25829d282ea4bc027fc2681 100644 --- a/tensorflow/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/lite/toco/import_tensorflow_test.cc @@ -115,7 +115,6 @@ void BuildConstNode(std::initializer_list shape, s->add_dim()->set_size(d); } - // TODO(ahentz): also need to test via tensor_content() switch (dtype) { case DT_FLOAT: for (int64_t i = 0; i < num_elements; ++i) { @@ -385,6 +384,127 @@ std::vector> UnaryTestTypes() { {DT_INT64, ArrayDataType::kInt64}}; } +class TensorContentTest : public ::testing::Test { + public: + template + std::vector> ImportAndGetData(const NodeDef& node) { + Model model; + auto status = ImportNode(node, &model); + CHECK(status.ok()) << status.error_message(); + const auto& nodearray = model.GetArray("Node1"); + return nodearray.GetBuffer().data; + } + template + void NodeWithTensorContent(std::initializer_list shape, + tensorflow::DataType dtype, int64_t num_elements, + NodeDef* node) { + node->set_op("Const"); + node->set_name("Node1"); + + // An attribute describing the type of this const node. + AttrValue dtype_attr; + SetAttrValue(dtype, &dtype_attr); + (*node->mutable_attr())["dtype"] = dtype_attr; + + auto allocated_content = absl::make_unique(num_elements); + + // An attribute describing the content of this const node. + tensorflow::TensorProto t; + t.set_dtype(dtype); + auto* s = t.mutable_tensor_shape(); + for (const auto& d : shape) { + s->add_dim()->set_size(d); + } + + switch (dtype) { + case DT_FLOAT: + for (int64_t i = 0; i < num_elements; ++i) { + allocated_content[i] = i / 10000.0 + 1; + } + break; + case DT_INT32: + for (int64_t i = 0; i < num_elements; ++i) { + allocated_content[i] = i % std::numeric_limits::max() + 1; + } + break; + case DT_QUINT8: + for (int64_t i = 0; i < num_elements; ++i) { + allocated_content[i] = i % std::numeric_limits::max() + 1; + } + break; + case DT_INT64: + for (int64_t i = 0; i < num_elements; ++i) { + allocated_content[i] = i + 1; + } + break; + case DT_STRING: + break; + case DT_BOOL: + for (int64_t i = 0; i < num_elements; ++i) { + allocated_content[i] = ((i % 2) == 0); + } + break; + default: + break; + } + t.set_tensor_content( + string(reinterpret_cast(allocated_content.get()), + num_elements * sizeof(T))); + + AttrValue value_attr; + SetAttrValue(t, &value_attr); + (*node->mutable_attr())["value"] = value_attr; + + allocated_content.reset(); + } +}; + +TEST_F(TensorContentTest, Int64) { + constexpr ArrayDataType kType = ArrayDataType::kInt64; + + NodeDef node; + NodeWithTensorContent({1, 2, 3}, DT_INT64, 6, &node); + + EXPECT_THAT(ImportAndGetData(node), ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST_F(TensorContentTest, Int32) { + constexpr ArrayDataType kType = ArrayDataType::kInt32; + + NodeDef node; + NodeWithTensorContent({1, 2, 3}, DT_INT32, 6, &node); + + EXPECT_THAT(ImportAndGetData(node), ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST_F(TensorContentTest, Float) { + constexpr ArrayDataType kType = ArrayDataType::kFloat; + + NodeDef node; + NodeWithTensorContent({1, 2, 3}, DT_FLOAT, 6, &node); + + EXPECT_THAT(ImportAndGetData(node), + ElementsAre(1.0000, 1.0001, 1.0002, 1.0003, 1.0004, 1.0005)); +} + +TEST_F(TensorContentTest, Quint8) { + constexpr ArrayDataType kType = ArrayDataType::kUint8; + + NodeDef node; + NodeWithTensorContent({1, 2, 3}, DT_QUINT8, 6, &node); + + EXPECT_THAT(ImportAndGetData(node), ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST_F(TensorContentTest, Bool) { + constexpr ArrayDataType kType = ArrayDataType::kBool; + + NodeDef node; + NodeWithTensorContent({1, 2, 3}, DT_BOOL, 6, &node); + + EXPECT_THAT(ImportAndGetData(node), ElementsAre(1, 0, 1, 0, 1, 0)); +} + class TypeImportTest : public ::testing::TestWithParam< std::pair> { protected: diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 296ed9fc747aa2e2d79e3d706d27ff3bdaacdd30..46f70c9e379de6fcfa5405fd9ef26fb3819432df 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -24,11 +24,11 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_types.h" -#include "tensorflow/core/platform/logging.h" namespace toco { @@ -42,8 +42,10 @@ enum class OperatorType : uint8 { kAveragePool, kBatchMatMul, kBatchNormalization, + kCeil, kConv, kConcatenation, + kCos, kDepthwiseConv, kDepthToSpace, kSpaceToDepth, @@ -159,7 +161,14 @@ enum class OperatorType : uint8 { kAbs, kMirrorPad, kUnique, - kUnidirectionalSequenceRnn + kUnidirectionalSequenceRnn, + kBidirectionalSequenceLstm, + kReverseV2, + kBidirectionalSequenceRnn, + kGatherNd, + kWhere, + kElu, + kReverseSequence }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -649,6 +658,18 @@ struct UnidirectionalSequenceLstmOperator : Operator { : Operator(OperatorType::kUnidirectionalSequenceLstm) {} }; +struct BidirectionalSequenceLstmOperator : Operator { + BidirectionalSequenceLstmOperator() + : Operator(OperatorType::kBidirectionalSequenceLstm) {} + bool merge_outputs; +}; + +struct BidirectionalSequenceRnnOperator : Operator { + BidirectionalSequenceRnnOperator() + : Operator(OperatorType::kBidirectionalSequenceRnn) {} + bool merge_outputs; +}; + // Element-wise multiplication operator. // // Inputs: @@ -671,6 +692,17 @@ struct AbsOperator : Operator { AbsOperator() : Operator(OperatorType::kAbs) {} }; +// Elu +// f(x) -> exp(x) - 1 for x < 0, x for x >= 0. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Elu +struct EluOperator : Operator { + EluOperator() : Operator(OperatorType::kElu) {} +}; + // Element-wise Relu operator: // x -> max(0, x) // @@ -948,6 +980,8 @@ struct TensorFlowIdentityOperator : Operator { // TensorFlow equivalent: MatMul struct BatchMatMulOperator : Operator { BatchMatMulOperator() : Operator(OperatorType::kBatchMatMul) {} + bool adj_x = false; + bool adj_y = false; }; // General matrix multiplication operator. We don't want to support general @@ -1150,6 +1184,17 @@ struct ExpOperator : Operator { ExpOperator() : Operator(OperatorType::kExp) {} }; +// Given a tensor input, this operation calculates element-wise exponential +// (y = cos(x)). +// +// Inputs: +// inputs[0]: required: input tensor +// +// TensorFlow equivalent: Cos +struct CosOperator : Operator { + CosOperator() : Operator(OperatorType::kCos) {} +}; + // Given a tensor input, this operation inserts a dimension of 1 at the // dimension index axis of input's shape. The dimension index axis starts at // zero; if you specify a negative number for axis it is counted backward from @@ -1228,13 +1273,12 @@ struct RangeOperator : Operator { // Inputs: // inputs[0]: required: the input array // -// This operation outputs a 0-D integer tensor representing the rank of -// the input. +// This operation outputs a 0-D int32 Tensor representing the rank of input. // -// TensorFlow equivalent: Rank. We currently assume that the output is int32 -// and not int64. The output type could be stored herein. -struct RankOperator : Operator { - RankOperator() : Operator(OperatorType::kRank) {} +// TensorFlow equivalent: Rank. +struct TensorFlowRankOperator : Operator { + TensorFlowRankOperator() : Operator(OperatorType::kRank) {} + ArrayDataType output_data_type = ArrayDataType::kInt32; }; // Element-wise negation (-x) operator. @@ -1660,6 +1704,16 @@ struct FloorOperator : Operator { FloorOperator() : Operator(OperatorType::kFloor) {} }; +// Ceil operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Ceil +struct CeilOperator : Operator { + CeilOperator() : Operator(OperatorType::kCeil) {} +}; + // Gather operator. It gathers slices from params according to indices. // Only 1-D indices are supported at the moment. // @@ -1681,10 +1735,22 @@ struct GatherOperator : Operator { int input_rank = 0; }; +// GatherNd operator. It gathers slices from params according to indices. +// +// Inputs: +// inputs[0]: required: the params array +// inputs[1]: required: the indices to gather +// +// TensorFlow equivalent: GatherNd +struct GatherNdOperator : Operator { + GatherNdOperator() : Operator(OperatorType::kGatherNd) {} +}; + // ArgMax operator. It returns the index of the maximum value along axis. // // Inputs: // inputs[0]: required: the input tensor +// inputs[1]: optional: 0-D (scalar) axis // // TensorFlow equivalent: ArgMax struct ArgMaxOperator : Operator { @@ -1696,6 +1762,7 @@ struct ArgMaxOperator : Operator { // // Inputs: // inputs[0]: required: the input tensor +// inputs[1]: optional: 0-D (scalar) axis // // TensorFlow equivalent: ArgMin struct ArgMinOperator : Operator { @@ -1938,6 +2005,16 @@ struct TensorFlowZerosLikeOperator : Operator { TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {} }; +// ReverseV2 operator: +// +// Inputs: +// Inputs[0]: required: the input array. +// +// TensorFlow equivalent: ReverseV2. +struct ReverseV2Operator : Operator { + ReverseV2Operator() : Operator(OperatorType::kReverseV2) {} +}; + enum class MirrorPadMode { kNone, kSymmetric, kReflect }; // MirrorPad Operator: @@ -1955,6 +2032,19 @@ struct MirrorPadOperator : Operator { MirrorPadMode mode; }; +// ReverseSequence operator: +// +// Inputs: +// Inputs[0]: required: the input array. +// Inputs[1]: required: the lengths of the elements to be reversed. +// +// TensorFlow equivalent: tf.reverse_sequence. +struct ReverseSequenceOperator : Operator { + ReverseSequenceOperator() : Operator(OperatorType::kReverseSequence) {} + int seq_dim; + int batch_dim = 0; +}; + // Unique Operator: // // Inputs: @@ -1973,6 +2063,18 @@ struct UnidirectionalSequenceRnnOperator : Operator { FusedActivationFunctionType fused_activation_function; }; +// Where Operator: +// Return the coordinates of the true values in condition tensor in row-major +// order. +// +// Inputs: +// inputs[0]: required: boolean condition tensor +// +// TensorFlow equivalent: Where +struct WhereOperator : Operator { + WhereOperator() : Operator(OperatorType::kWhere) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/lite/toco/model_cmdline_flags.cc b/tensorflow/lite/toco/model_cmdline_flags.cc index 717a28bc615e0a142c41efb3afaa49f64d2a1e14..7e48bd9542b0cc0de4c0218465edfc75e97d0049 100644 --- a/tensorflow/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/lite/toco/model_cmdline_flags.cc @@ -261,7 +261,7 @@ void ReadModelFlagsFromCommandLineFlags( std::vector mean_values = absl::StrSplit(parsed_model_flags.mean_values.value(), ','); QCHECK(mean_values.size() == model_flags->input_arrays_size()); - for (int i = 0; i < mean_values.size(); ++i) { + for (size_t i = 0; i < mean_values.size(); ++i) { char* last = nullptr; model_flags->mutable_input_arrays(i)->set_mean_value( strtod(mean_values[i].data(), &last)); @@ -278,7 +278,7 @@ void ReadModelFlagsFromCommandLineFlags( std::vector std_values = absl::StrSplit(parsed_model_flags.std_values.value(), ','); QCHECK(std_values.size() == model_flags->input_arrays_size()); - for (int i = 0; i < std_values.size(); ++i) { + for (size_t i = 0; i < std_values.size(); ++i) { char* last = nullptr; model_flags->mutable_input_arrays(i)->set_std_value( strtod(std_values[i].data(), &last)); @@ -296,7 +296,7 @@ void ReadModelFlagsFromCommandLineFlags( std::vector input_data_types = absl::StrSplit(parsed_model_flags.input_data_types.value(), ','); QCHECK(input_data_types.size() == model_flags->input_arrays_size()); - for (int i = 0; i < input_data_types.size(); ++i) { + for (size_t i = 0; i < input_data_types.size(); ++i) { IODataType type; QCHECK(IODataType_Parse(input_data_types[i], &type)); model_flags->mutable_input_arrays(i)->set_data_type(type); @@ -319,7 +319,7 @@ void ReadModelFlagsFromCommandLineFlags( std::vector input_shapes = absl::StrSplit(parsed_model_flags.input_shapes.value(), ':'); QCHECK(input_shapes.size() == model_flags->input_arrays_size()); - for (int i = 0; i < input_shapes.size(); ++i) { + for (size_t i = 0; i < input_shapes.size(); ++i) { auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); shape->clear_dims(); // Treat an empty input shape as a scalar. diff --git a/tensorflow/lite/toco/model_cmdline_flags_test.cc b/tensorflow/lite/toco/model_cmdline_flags_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2f8dc59852d06af001c7e084f1eeedcb040b7a8 --- /dev/null +++ b/tensorflow/lite/toco/model_cmdline_flags_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/toco/args.h" +#include "tensorflow/lite/toco/model_cmdline_flags.h" + +namespace toco { +namespace { + +TEST(ModelCmdlineFlagsTest, ParseArgsStringMapList) { + int args_count = 3; + const char* args[] = { + "toco", + "--input_arrays=input_1", + "--rnn_states={state_array:rnn/BasicLSTMCellZeroState/zeros," + "back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4}," + "{state_array:rnn/BasicLSTMCellZeroState/zeros_1," + "back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}", + }; + + string expected_input_arrays = "input_1"; + std::vector> expected_rnn_states; + expected_rnn_states.push_back( + {{"state_array", "rnn/BasicLSTMCellZeroState/zeros"}, + {"back_edge_source_array", "rnn/basic_lstm_cell/Add_1"}, + {"size", "4"}}); + expected_rnn_states.push_back( + {{"state_array", "rnn/BasicLSTMCellZeroState/zeros_1"}, + {"back_edge_source_array", "rnn/basic_lstm_cell/Mul_2"}, + {"size", "4"}}); + + string message; + ParsedModelFlags result_flags; + + EXPECT_TRUE(ParseModelFlagsFromCommandLineFlags( + &args_count, const_cast(args), &message, &result_flags)); + EXPECT_EQ(result_flags.input_arrays.value(), expected_input_arrays); + EXPECT_EQ(result_flags.rnn_states.value().elements, expected_rnn_states); +} + +} // namespace +} // namespace toco + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + ::toco::port::InitGoogleWasDoneElsewhere(); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/toco/model_flags.proto b/tensorflow/lite/toco/model_flags.proto index bcdac295d261c0e7cc04c5a8c3e2e5d88736cd88..dfc425073f51333c8ab4bf9d1dc62b0e88b92ce5 100644 --- a/tensorflow/lite/toco/model_flags.proto +++ b/tensorflow/lite/toco/model_flags.proto @@ -85,6 +85,7 @@ message RnnState { // Will be expanded with 1's to fit the model. // TODO(benoitjacob): should allow a generic, explicit shape. optional int32 size = 3; + optional int32 num_dims = 4; } // An ArraysExtraInfo message stores a collection of additional Information diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 8a6e82ec46445b5ec5440de129177eae836f8db8..2f5654c56e0acca57a2d644a7c50e87c185f721b 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -1,5 +1,4 @@ package(default_visibility = [ - "//tensorflow/contrib/lite:__subpackages__", "//tensorflow/lite:__subpackages__", "//tensorflow/tools/pip_package:__subpackages__", ]) @@ -26,6 +25,7 @@ cc_library( deps = [ "//third_party/python_runtime:headers", "//tensorflow/core:lib", + "//tensorflow/lite/python/interpreter_wrapper:python_utils", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_graphviz_dump_options", @@ -47,7 +47,6 @@ tf_py_wrap_cc( visibility = [ "//learning/expander/pod/deep_pod/utils:__subpackages__", "//research/handwriting/converters/tflite:__subpackages__", - "//tensorflow/contrib/lite:__subpackages__", "//tensorflow/lite:__subpackages__", ], deps = [ diff --git a/tensorflow/lite/toco/python/toco_from_protos_test.py b/tensorflow/lite/toco/python/toco_from_protos_test.py index 34cfd2c59fdc3aa3c83728f622fbf5b8d02d7e00..cc0d6f748acb94f8f5c8248bd4cdcc78129e0e17 100644 --- a/tensorflow/lite/toco/python/toco_from_protos_test.py +++ b/tensorflow/lite/toco/python/toco_from_protos_test.py @@ -54,7 +54,7 @@ class TocoFromProtosTest(googletest.TestCase): model_flags = model_flags_pb2.ModelFlags() input_array = model_flags.input_arrays.add() input_array.name = TensorName(in_tensor) - input_array.shape.dims.extend(map(int, in_tensor.get_shape())) + input_array.shape.dims.extend(map(int, in_tensor.shape)) model_flags.output_arrays.append(TensorName(out_tensor)) # Shell out to run toco (in case it crashes) with tempfile.NamedTemporaryFile() as fp_toco, \ diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index ce8e3c9df88ba511fcca9d9a256896624194463b..6fad092f35aa386757885f9320f47e9f372e9f47 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/python/toco_python_api.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -26,14 +27,6 @@ limitations under the License. namespace toco { -#if PY_MAJOR_VERSION >= 3 -#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize -#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize -#else -#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize -#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize -#endif - // NOTE(aselle): We are using raw PyObject's here because we want to make // sure we input and output bytes rather than unicode strings for Python3. PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, @@ -44,7 +37,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, auto ConvertArg = [&](PyObject* obj, bool* error) { char* buf; Py_ssize_t len; - if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) { + if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) { *error = true; return std::string(); } else { @@ -96,15 +89,15 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* dict = PyDict_New(); PyDict_SetItemString( dict, "flatbuffer", - TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), - output_file_contents_txt.size())); + ::tflite::python_utils::ConvertToPyString( + output_file_contents_txt.data(), output_file_contents_txt.size())); PyDict_SetItemString(dict, "arithmetic_ops", PyLong_FromLong(model->ArithmeticOpsCount())); return dict; } // Convert arguments back to byte (py3) or str (py2) - return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), - output_file_contents_txt.size()); + return ::tflite::python_utils::ConvertToPyString( + output_file_contents_txt.data(), output_file_contents_txt.size()); } } // namespace toco diff --git a/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc index fcd9ee45d984f05eabf5d51c223b45433e801308..2f9f9a8c9b0b179c62c9dafb23edc1cdc5f108a3 100644 --- a/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc +++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc @@ -186,7 +186,7 @@ void SvdfCluster::MaybeMergeConstNodes( allocated_tensor->mutable_tensor_shape(); auto tensor_shape_dim0 = allocated_tensor_shape->add_dim(); int allocated_content_flat_size = 0; - for (int i = 0; i < const_node_parts.size(); i++) { + for (size_t i = 0; i < const_node_parts.size(); i++) { const auto& value_attr = const_node_parts[i]->attr().at("value"); const tensorflow::TensorProto& tensor = value_attr.tensor(); if (i == 0) { @@ -214,7 +214,7 @@ void SvdfCluster::MaybeMergeConstNodes( std::unique_ptr allocated_content( new char[allocated_content_flat_size]); char* content_ptr = allocated_content.get(); - for (int i = 0; i < const_node_parts.size(); i++) { + for (size_t i = 0; i < const_node_parts.size(); i++) { const auto& value_attr = const_node_parts[i]->attr().at("value"); const tensorflow::TensorProto& tensor = value_attr.tensor(); port::CopyToBuffer(tensor.tensor_content(), content_ptr); diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 53f5ecef872774e83fbcb1abe394a2dbaa189f4c..8b0d38da0688a998ca99d3b50e217a89d1d8fe84 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -63,12 +63,12 @@ bool IsControlFlowOp(const string& tensorflow_op) { return false; } -// Check if a TensorFlow Op is unsupportred by the Flex runtime. +// Check if a TensorFlow Op is unsupported by the Flex runtime. bool IsUnsupportedFlexOp(const string& tensorflow_op) { if (IsControlFlowOp(tensorflow_op)) { return true; } - // `HashTableV2` isn't supported for now since it requires an additinonal + // `HashTableV2` isn't supported for now since it requires an additional // initialization step. // TODO(b/117651199): Support `HashTableV2` with Flex runtime. if (tensorflow_op == "HashTableV2") { @@ -157,7 +157,7 @@ OperatorKey::OperatorKey( string(::tflite::kFlexCustomCodePrefix) + flex_tensorflow_op_; } else { // If Flex is disabled or the original TensorFlow NodeDef isn't available, - // we produce a custom op. This gives developers a chance to implemenr + // we produce a custom op. This gives developers a chance to implement // custom ops. custom_code_ = name; } @@ -222,7 +222,8 @@ Offset>> ExportTensors( std::vector shape; if (array.has_shape()) { - for (int d : array.shape().dims()) { + shape.reserve(array.shape().dims().size()); + for (const auto& d : array.shape().dims()) { shape.push_back(d); } } @@ -384,7 +385,7 @@ Offset>> ExportOperators( mutating_input_variables = tflite_op->GetMutatingInputVariables(*op); if (!mutating_input_variables.empty()) { - for (int i = 0; i < op->inputs.size(); ++i) { + for (size_t i = 0; i < op->inputs.size(); ++i) { if (!mutating_input_variables[i]) { continue; } diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index 58cfb4987ff67d87d330688cd7bf75ca2eb98ebd..fb640f776abdef3e5a59d075d3bc15e8d0f9565f 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -51,9 +51,27 @@ class ExportTest : public ::testing::Test { output_array.data_type = ArrayDataType::kFloat; input_model_.operators.emplace_back(op); } else if (name == "Add") { - input_model_.operators.emplace_back(new AddOperator); + auto* op = new AddOperator; + op->inputs = {"input1", "input2"}; + op->outputs = {"output"}; + Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]); + Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]); + Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]); + input1_array.data_type = ArrayDataType::kFloat; + input2_array.data_type = ArrayDataType::kFloat; + output_array.data_type = ArrayDataType::kFloat; + input_model_.operators.emplace_back(op); } else if (name == "Sub") { - input_model_.operators.emplace_back(new SubOperator); + auto* op = new SubOperator; + op->inputs = {"input1", "input2"}; + op->outputs = {"output"}; + Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]); + Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]); + Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]); + input1_array.data_type = ArrayDataType::kFloat; + input2_array.data_type = ArrayDataType::kFloat; + output_array.data_type = ArrayDataType::kFloat; + input_model_.operators.emplace_back(op); } else if (name == "Assert") { auto* op = new TensorFlowAssertOperator; @@ -114,7 +132,18 @@ class ExportTest : public ::testing::Test { output_array.data_type = ArrayDataType::kFloat; input_model_.operators.emplace_back(op); } - input_model_.operators.emplace_back(new AddOperator); + { + auto* op = new AddOperator; + op->inputs = {"input1", "input2"}; + op->outputs = {"output"}; + Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]); + Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]); + Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]); + input1_array.data_type = ArrayDataType::kFloat; + input2_array.data_type = ArrayDataType::kFloat; + output_array.data_type = ArrayDataType::kFloat; + input_model_.operators.emplace_back(op); + } } std::vector ExportAndSummarizeOperators(const ExportParams& params) { @@ -445,7 +474,7 @@ TEST_F(VersionedOpExportTest, Export) { auto* model = ::tflite::GetModel(result.data()); auto operator_codes = model->operator_codes(); - // Verify that 2 operator codes are populdated. Both are CONV_2D but with + // Verify that 2 operator codes are populated. Both are CONV_2D but with // different versions. EXPECT_EQ(2, operator_codes->size()); EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, diff --git a/tensorflow/lite/toco/tflite/import_test.cc b/tensorflow/lite/toco/tflite/import_test.cc index 93ab5141abe81c4ed4c1ff0ac7ca5e89577c71fb..b00c4124d83ae558b4aa6f5ecc2ba9eb06e5dac0 100644 --- a/tensorflow/lite/toco/tflite/import_test.cc +++ b/tensorflow/lite/toco/tflite/import_test.cc @@ -60,7 +60,7 @@ class ImportTest : public ::testing::Test { builder_.CreateString("tensor_one"), q); auto t2 = ::tflite::CreateTensor(builder_, builder_.CreateVector({2, 1}), - ::tflite::TensorType_FLOAT32, 2, + ::tflite::TensorType_FLOAT32, 0, builder_.CreateString("tensor_two"), q); return builder_.CreateVector( std::vector>({t1, t2})); diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 088673fd95460519d655ff86861bbee24ecf2d75..2919f81571a2c23dce09476268769aa20e07a537 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/util/ptr_util.h" + // TODO(ycling): Consider refactoring to extract the LSTM definition out of // graph_transformation module. #include "tensorflow/lite/schema/schema_generated.h" @@ -108,6 +109,12 @@ class Convolution const Array& input_array = op_signature.model->GetArray(input_name); const Array& filter_array = op_signature.model->GetArray(filter_name); const Array& output_array = op_signature.model->GetArray(output_name); + // If the op has signed int8 inputs and outputs, its version 3. + if (input_array.data_type == ArrayDataType::kInt8 && + filter_array.data_type == ArrayDataType::kInt8 && + output_array.data_type == ArrayDataType::kInt8) { + return 3; + } // If the op is a signed int8 hybrid operation, we need to return // version 2. if (input_array.data_type == ArrayDataType::kFloat && @@ -153,6 +160,18 @@ class DepthwiseConvolution int GetVersion(const OperatorSignature& op_signature) const override { const auto& conv_op = static_cast(*op_signature.op); + const string& input_name = op_signature.op->inputs[0]; + const string& filter_name = op_signature.op->inputs[1]; + const string& output_name = op_signature.op->outputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + const Array& filter_array = op_signature.model->GetArray(filter_name); + const Array& output_array = op_signature.model->GetArray(output_name); + // If the op has signed int8 inputs and outputs, its version 3. + if (input_array.data_type == ArrayDataType::kInt8 && + filter_array.data_type == ArrayDataType::kInt8 && + output_array.data_type == ArrayDataType::kInt8) { + return 3; + } if (conv_op.dilation_width_factor != 1 || conv_op.dilation_height_factor != 1) { return 2; @@ -180,6 +199,31 @@ class Add : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class AddN : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateAddNOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + int GetVersion(const OperatorSignature& op_signature) const override { return 1; } @@ -202,6 +246,12 @@ class SpaceToBatchND TocoOperator* op) const override {} int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -226,6 +276,12 @@ class Sub : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -271,6 +327,12 @@ class BatchToSpaceND TocoOperator* op) const override {} int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -316,6 +378,12 @@ class Concatenation } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -409,18 +477,29 @@ class FullyConnected } } + // +-----------------+--------------------+--------------------------+ + // | | Weight::Default | Weight::Shuffled4x16Int8 | + // +-----------------+--------------------+--------------------------+ + // | Float | 1 | 2 | + // | Quantized Uint8 | 1 | 2 | + // | Hybrid | 3 | 3 | + // | Quantized Int8 | 4 | 4 | + // +-----------------+--------------------+--------------------------+ int GetVersion(const OperatorSignature& op_signature) const override { const auto& fc_op = static_cast(*op_signature.op); - if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { - return 1; - } const string& input_name = op_signature.op->inputs[0]; const string& weights_name = op_signature.op->inputs[1]; const string& output_name = op_signature.op->outputs[0]; const Array& input_array = op_signature.model->GetArray(input_name); const Array& weights_array = op_signature.model->GetArray(weights_name); const Array& output_array = op_signature.model->GetArray(output_name); + // Int8 fully fixed point kernel is at version 4. + if (input_array.data_type == ArrayDataType::kInt8 && + weights_array.data_type == ArrayDataType::kInt8 && + output_array.data_type == ArrayDataType::kInt8) { + return 4; + } // If the op is a signed int8 hybrid operation, we need to return // version 3. if (input_array.data_type == ArrayDataType::kFloat && @@ -428,7 +507,15 @@ class FullyConnected output_array.data_type == ArrayDataType::kFloat) { return 3; } - return 2; + // For float and uint8 fixed point kernels, if the weight is + // Shuffled4x16Int8, is is version 2. + if (fc_op.weights_format == + FullyConnectedWeightsFormat::kShuffled4x16Int8) { + return 2; + } + + // Otherwise (weight is default), the version is 1. + return 1; } }; @@ -448,6 +535,32 @@ class Gather : public BuiltinOperatoraxis = {options.axis()}; } + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class GatherNd + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateGatherNdOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + int GetVersion(const OperatorSignature& op_signature) const override { return 1; } @@ -511,6 +624,12 @@ class L2Normalization } int GetVersion(const OperatorSignature& op_signature) const override { + const string& output_name = op_signature.op->outputs[0]; + const Array& output_array = op_signature.model->GetArray(output_name); + // Version 2 supports signed int8 input types. + if (output_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -600,6 +719,39 @@ class MaxPool : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class Maximum : public SimpleOperator { + public: + explicit Maximum() : SimpleOperator("MAXIMUM", OperatorType::kMaximum) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class Minimum : public SimpleOperator { + public: + explicit Minimum() : SimpleOperator("MINIMUM", OperatorType::kMinimum) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -624,6 +776,12 @@ class Mul : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -643,6 +801,12 @@ class Pad : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -680,6 +844,12 @@ class PadV2 : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -753,6 +923,12 @@ class SpaceToDepth } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -772,6 +948,12 @@ class Transpose TocoOperator* op) const override {} int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -783,7 +965,7 @@ class Lstm : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - ::tflite::LSTMKernelType kernel_type; + ::tflite::LSTMKernelType kernel_type = ::tflite::LSTMKernelType_FULL; switch (op.kernel_type) { case LstmCellOperator::KERNEL_BASIC: kernel_type = ::tflite::LSTMKernelType_BASIC; @@ -791,6 +973,8 @@ class Lstm : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateBidirectionalSequenceLSTMOptions( + *builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0, + /*merge_outputs=*/op.merge_outputs, + /*time_major=*/true); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + DCHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + op->merge_outputs = options.merge_outputs(); + } + + int GetVersion(const OperatorSignature& op_signature) const override { + return 1; + } + + std::vector GetMutatingInputVariables( + const Operator& op) const override { + std::vector mutating_input_variables(op.inputs.size(), false); + // Forward input activation state. + mutating_input_variables[35] = true; + // Forward input cell state. + mutating_input_variables[36] = true; + // Backward input activation state. + mutating_input_variables[37] = true; + // Backward input cell state. + mutating_input_variables[38] = true; + return mutating_input_variables; + } +}; + +class BidirectionalSequenceRnn + : public BuiltinOperator< + BidirectionalSequenceRnnOperator, + ::tflite::BidirectionalSequenceRNNOptions, + ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateBidirectionalSequenceRNNOptions( + *builder, /*time_major=*/true, + /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*merge_outputs=*/op.merge_outputs); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + DCHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + op->merge_outputs = options.merge_outputs(); + } + + int GetVersion(const OperatorSignature& op_signature) const override { + return 1; + } + + std::vector GetMutatingInputVariables( + const Operator& op) const override { + std::vector mutating_input_variables(op.inputs.size(), false); + // Forward hidden state. + mutating_input_variables[4] = true; + // Backward hidden state. + mutating_input_variables[8] = true; + return mutating_input_variables; + } +}; + class Mean : public BuiltinOperator { public: @@ -973,6 +1245,12 @@ class ReduceMax } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -994,6 +1272,12 @@ class ReduceMin } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -1040,6 +1324,20 @@ class ReduceAny } }; +class Relu6 : public SimpleOperator { + public: + explicit Relu6() : SimpleOperator("RELU6", OperatorType::kRelu6) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + class ResizeBilinear : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op takes int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -1081,6 +1385,12 @@ class ResizeNearestNeighbor } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -1128,6 +1438,14 @@ class Split } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2, for int32 it's version 3. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } else if (input_array.data_type == ArrayDataType::kInt32) { + return 3; + } return 1; } }; @@ -1178,6 +1496,12 @@ class StridedSlice } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -1196,6 +1520,11 @@ class TopK_V2 : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -1217,6 +1546,12 @@ class ArgMax : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; } }; @@ -1238,6 +1573,12 @@ class ArgMin : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; } }; @@ -1330,6 +1671,12 @@ class Pack : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; @@ -1356,6 +1703,34 @@ class Shape } }; +class Slice : public SimpleOperator { + public: + explicit Slice() : SimpleOperator("SLICE", OperatorType::kSlice) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class Tanh : public SimpleOperator { + public: + explicit Tanh() : SimpleOperator("TANH", OperatorType::kTanh) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + class OneHot : public BuiltinOperator { public: @@ -1438,6 +1813,35 @@ class LeakyRelu } }; +class Logistic : public SimpleOperator { + public: + explicit Logistic() : SimpleOperator("LOGISTIC", OperatorType::kLogistic) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class LogSoftmax : public SimpleOperator { + public: + explicit LogSoftmax() + : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + class SquaredDifference : public BuiltinOperator< SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions, @@ -1525,7 +1929,7 @@ class UnidirectionalSequenceRnn } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - // Only support tanh actication, so check that tflite type is tanh. + // Only support tanh activation, so check that tflite type is tanh. DCHECK(options.fused_activation_function() == ::tflite::ActivationFunctionType_TANH); } @@ -1542,6 +1946,25 @@ class UnidirectionalSequenceRnn } }; +class Where : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateWhereOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const OperatorSignature& op_signature) const override { + return 1; + } +}; + std::unique_ptr WriteFlexOpOptions( const string& tensorflow_node_def) { auto fbb = absl::make_unique(); @@ -1640,7 +2063,14 @@ class TensorFlowUnsupported : public BaseOperator { has_valid_attr = true; break; case tensorflow::AttrValue::kList: - if (attr.list().i_size() > 0) { + if (attr.list().s_size() > 0) { + auto start = fbb->StartVector(key); + for (const string& v : attr.list().s()) { + fbb->Add(v); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + has_valid_attr = true; + } else if (attr.list().i_size() > 0) { auto start = fbb->StartVector(key); for (const int64_t v : attr.list().i()) { fbb->Add(v); @@ -1722,6 +2152,14 @@ class TensorFlowUnsupported : public BaseOperator { } break; } + case 15: { // flexbuffers::FBT_VECTOR_STRING: { + auto* list = (*attr)[key].mutable_list(); + const auto& vector = value.AsTypedVector(); + for (size_t i = 0; i < vector.size(); i++) { + list->add_s(vector[i].AsString().str()); + } + break; + } default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -1767,6 +2205,131 @@ class Dequantize } }; +class ReverseSequence + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim, + op.batch_dim); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->seq_dim = options.seq_dim(); + op->batch_dim = options.batch_dim(); + } + + int GetVersion(const OperatorSignature& op_signature) const override { + return 1; + } +}; + +class Equal : public SimpleOperator { + public: + explicit Equal() : SimpleOperator("EQUAL", OperatorType::kEqual) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class NotEqual : public SimpleOperator { + public: + explicit NotEqual() : SimpleOperator("NOT_EQUAL", OperatorType::kNotEqual) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class Greater : public SimpleOperator { + public: + explicit Greater() : SimpleOperator("GREATER", OperatorType::kGreater) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class GreaterEqual : public SimpleOperator { + public: + explicit GreaterEqual() + : SimpleOperator("GREATER_EQUAL", OperatorType::kGreaterEqual) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class Less : public SimpleOperator { + public: + explicit Less() : SimpleOperator("LESS", OperatorType::kLess) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class LessEqual : public SimpleOperator { + public: + explicit LessEqual() + : SimpleOperator("LESS_EQUAL", OperatorType::kLessEqual) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + +class Select : public SimpleOperator { + public: + explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {} + int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // Version 2 supports signed int8 input types. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } + return 1; + } +}; + namespace { // Build a vector containing all the known operators. std::vector> BuildOperatorList( @@ -1776,6 +2339,8 @@ std::vector> BuildOperatorList( // Builtin Operators. ops.push_back( MakeUnique(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN)); ops.push_back( MakeUnique
(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); ops.push_back( @@ -1802,6 +2367,8 @@ std::vector> BuildOperatorList( OperatorType::kFullyConnected)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_GATHER, OperatorType::kGather)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_GATHER_ND, + OperatorType::kGatherNd)); ops.push_back( MakeUnique(::tflite::BuiltinOperator_L2_NORMALIZATION, OperatorType::kL2Normalization)); @@ -1882,6 +2449,12 @@ std::vector> BuildOperatorList( ops.emplace_back(MakeUnique( ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, OperatorType::kUnidirectionalSequenceLstm)); + ops.emplace_back(MakeUnique( + ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + OperatorType::kBidirectionalSequenceLstm)); + ops.emplace_back(MakeUnique( + ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + OperatorType::kBidirectionalSequenceRnn)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_UNPACK, @@ -1898,6 +2471,11 @@ std::vector> BuildOperatorList( ops.push_back(MakeUnique( ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, OperatorType::kUnidirectionalSequenceRnn)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_REVERSE_SEQUENCE, + OperatorType::kReverseSequence)); // Custom Operators. ops.push_back( @@ -1915,44 +2493,36 @@ std::vector> BuildOperatorList( // builtins. ops.push_back( MakeUnique>("FLOOR", OperatorType::kFloor)); + ops.push_back( + MakeUnique>("CEIL", OperatorType::kCeil)); + ops.push_back( + MakeUnique>("ELU", OperatorType::kElu)); ops.push_back( MakeUnique>("RELU", OperatorType::kRelu)); ops.push_back(MakeUnique>( "RELU_N1_TO_1", OperatorType::kRelu1)); - ops.push_back( - MakeUnique>("RELU6", OperatorType::kRelu6)); + ops.push_back(MakeUnique()); ops.push_back( MakeUnique>("PRELU", OperatorType::kPRelu)); - ops.push_back(MakeUnique>( - "LOGISTIC", OperatorType::kLogistic)); - ops.push_back( - MakeUnique>("TANH", OperatorType::kTanh)); + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); ops.push_back( MakeUnique>("EXP", OperatorType::kExp)); - ops.push_back(MakeUnique>( - "LOG_SOFTMAX", OperatorType::kLogSoftmax)); - ops.push_back(MakeUnique>( - "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum - ops.push_back(MakeUnique>( - "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum - ops.push_back(MakeUnique>( - "GREATER", OperatorType::kGreater)); - ops.push_back(MakeUnique>( - "GREATER_EQUAL", OperatorType::kGreaterEqual)); - ops.push_back(MakeUnique>( - "LESS", OperatorType::kLess)); - ops.push_back(MakeUnique>( - "LESS_EQUAL", OperatorType::kLessEqual)); - ops.push_back(MakeUnique>( - "EQUAL", OperatorType::kEqual)); - ops.push_back(MakeUnique>( - "NOT_EQUAL", OperatorType::kNotEqual)); ops.push_back( - MakeUnique>("NEG", OperatorType::kNeg)); - ops.push_back(MakeUnique>( - "SELECT", OperatorType::kSelect)); + MakeUnique>("COS", OperatorType::kCos)); + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); // Element-wise Maximum + ops.push_back(MakeUnique()); // Element-wise Minimum + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); + ops.push_back(MakeUnique()); ops.push_back( - MakeUnique>("SLICE", OperatorType::kSlice)); + MakeUnique>("NEG", OperatorType::kNeg)); + ops.push_back(MakeUnique